From ac9d57aa05e912dd7586ff4c9ee938129fe902a3 Mon Sep 17 00:00:00 2001 From: CoprDistGit Date: Thu, 11 Sep 2025 14:13:08 +0000 Subject: automatic import of ANNC --- 0001-fix-pattern-conflicts.patch | 72 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 0001-fix-pattern-conflicts.patch (limited to '0001-fix-pattern-conflicts.patch') diff --git a/0001-fix-pattern-conflicts.patch b/0001-fix-pattern-conflicts.patch new file mode 100644 index 0000000..461a7e6 --- /dev/null +++ b/0001-fix-pattern-conflicts.patch @@ -0,0 +1,72 @@ +From d941fd6d6893abd1893d2da25afa7847ec32f132 Mon Sep 17 00:00:00 2001 +From: zhengchenhui +Date: Tue, 9 Sep 2025 15:27:43 +0800 +Subject: [PATCH 1/2] fix pattern conflicts. + +--- + python/annc/optimize/graph.py | 6 ++++++ + python/annc/optimize/rewriter.py | 17 ++++++++++------- + python/setup.py | 1 - + 3 files changed, 16 insertions(+), 8 deletions(-) + +diff --git a/python/annc/optimize/graph.py b/python/annc/optimize/graph.py +index dae7d6d..1a97d1b 100644 +--- a/python/annc/optimize/graph.py ++++ b/python/annc/optimize/graph.py +@@ -255,6 +255,12 @@ class Graph: + + self.versions['producer'] = graph_def.versions.producer + ++ def check_node_exist(self, node: Node) -> bool: ++ for i, n in enumerate(self.nodes): ++ if n.name == node.name: ++ return True ++ return False ++ + def get_node(self, name: str) -> Node: + for node in self.nodes: + if node.name == name: +diff --git a/python/annc/optimize/rewriter.py b/python/annc/optimize/rewriter.py +index 2e66189..c0f2894 100644 +--- a/python/annc/optimize/rewriter.py ++++ b/python/annc/optimize/rewriter.py +@@ -46,15 +46,18 @@ class BaseRewriter(ABC): + :param users: expected list of user types and names + :param expected_num_users: expected number of users (optional) + """ +- if user_num is not None and len(node.users) != user_num: +- raise CheckFailed +- if len(users) > len(node.users): +- raise CheckFailed +- for i, (type, name) in enumerate(users): +- if type is not None and node.users[i].type != type: ++ if self.graph.check_node_exist(node): ++ if user_num is not None and len(node.users) != user_num: + raise CheckFailed +- if name is not None and node.users[i].name != name: ++ if len(users) > len(node.users): + raise CheckFailed ++ for i, (type, name) in enumerate(users): ++ if type is not None and node.users[i].type != type: ++ raise CheckFailed ++ if name is not None and node.users[i].name != name: ++ raise CheckFailed ++ else: ++ raise CheckFailed + + def check_operands(self, + node: Node, +diff --git a/python/setup.py b/python/setup.py +index 1197745..d32026e 100644 +--- a/python/setup.py ++++ b/python/setup.py +@@ -34,6 +34,5 @@ setup(name=project, + entry_points={ + 'console_scripts': [ + f'{project}-opt = {project}.main:opt', +- f'{project}-apply-tf = scripts.install:tf_install', + ], + }) +-- +2.33.0 + -- cgit v1.2.3