summaryrefslogtreecommitdiff
path: root/0001-fix-pattern-conflicts.patch
diff options
context:
space:
mode:
Diffstat (limited to '0001-fix-pattern-conflicts.patch')
-rw-r--r--0001-fix-pattern-conflicts.patch72
1 files changed, 72 insertions, 0 deletions
diff --git a/0001-fix-pattern-conflicts.patch b/0001-fix-pattern-conflicts.patch
new file mode 100644
index 0000000..461a7e6
--- /dev/null
+++ b/0001-fix-pattern-conflicts.patch
@@ -0,0 +1,72 @@
+From d941fd6d6893abd1893d2da25afa7847ec32f132 Mon Sep 17 00:00:00 2001
+From: zhengchenhui <zhengchenhui1@huawei.com>
+Date: Tue, 9 Sep 2025 15:27:43 +0800
+Subject: [PATCH 1/2] fix pattern conflicts.
+
+---
+ python/annc/optimize/graph.py | 6 ++++++
+ python/annc/optimize/rewriter.py | 17 ++++++++++-------
+ python/setup.py | 1 -
+ 3 files changed, 16 insertions(+), 8 deletions(-)
+
+diff --git a/python/annc/optimize/graph.py b/python/annc/optimize/graph.py
+index dae7d6d..1a97d1b 100644
+--- a/python/annc/optimize/graph.py
++++ b/python/annc/optimize/graph.py
+@@ -255,6 +255,12 @@ class Graph:
+
+ self.versions['producer'] = graph_def.versions.producer
+
++ def check_node_exist(self, node: Node) -> bool:
++ for i, n in enumerate(self.nodes):
++ if n.name == node.name:
++ return True
++ return False
++
+ def get_node(self, name: str) -> Node:
+ for node in self.nodes:
+ if node.name == name:
+diff --git a/python/annc/optimize/rewriter.py b/python/annc/optimize/rewriter.py
+index 2e66189..c0f2894 100644
+--- a/python/annc/optimize/rewriter.py
++++ b/python/annc/optimize/rewriter.py
+@@ -46,15 +46,18 @@ class BaseRewriter(ABC):
+ :param users: expected list of user types and names
+ :param expected_num_users: expected number of users (optional)
+ """
+- if user_num is not None and len(node.users) != user_num:
+- raise CheckFailed
+- if len(users) > len(node.users):
+- raise CheckFailed
+- for i, (type, name) in enumerate(users):
+- if type is not None and node.users[i].type != type:
++ if self.graph.check_node_exist(node):
++ if user_num is not None and len(node.users) != user_num:
+ raise CheckFailed
+- if name is not None and node.users[i].name != name:
++ if len(users) > len(node.users):
+ raise CheckFailed
++ for i, (type, name) in enumerate(users):
++ if type is not None and node.users[i].type != type:
++ raise CheckFailed
++ if name is not None and node.users[i].name != name:
++ raise CheckFailed
++ else:
++ raise CheckFailed
+
+ def check_operands(self,
+ node: Node,
+diff --git a/python/setup.py b/python/setup.py
+index 1197745..d32026e 100644
+--- a/python/setup.py
++++ b/python/setup.py
+@@ -34,6 +34,5 @@ setup(name=project,
+ entry_points={
+ 'console_scripts': [
+ f'{project}-opt = {project}.main:opt',
+- f'{project}-apply-tf = scripts.install:tf_install',
+ ],
+ })
+--
+2.33.0
+