summaryrefslogtreecommitdiff
path: root/0001-fix-pattern-conflicts.patch
blob: 461a7e667ae34b8ffa86625dfddbb978dc41b968 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
From d941fd6d6893abd1893d2da25afa7847ec32f132 Mon Sep 17 00:00:00 2001
From: zhengchenhui <zhengchenhui1@huawei.com>
Date: Tue, 9 Sep 2025 15:27:43 +0800
Subject: [PATCH 1/2] fix pattern conflicts.

---
 python/annc/optimize/graph.py    |  6 ++++++
 python/annc/optimize/rewriter.py | 17 ++++++++++-------
 python/setup.py                  |  1 -
 3 files changed, 16 insertions(+), 8 deletions(-)

diff --git a/python/annc/optimize/graph.py b/python/annc/optimize/graph.py
index dae7d6d..1a97d1b 100644
--- a/python/annc/optimize/graph.py
+++ b/python/annc/optimize/graph.py
@@ -255,6 +255,12 @@ class Graph:
 
         self.versions['producer'] = graph_def.versions.producer
 
+    def check_node_exist(self, node: Node) -> bool:
+        for i, n in enumerate(self.nodes):
+            if n.name == node.name:
+                return True
+        return False
+
     def get_node(self, name: str) -> Node:
         for node in self.nodes:
             if node.name == name:
diff --git a/python/annc/optimize/rewriter.py b/python/annc/optimize/rewriter.py
index 2e66189..c0f2894 100644
--- a/python/annc/optimize/rewriter.py
+++ b/python/annc/optimize/rewriter.py
@@ -46,15 +46,18 @@ class BaseRewriter(ABC):
         :param users: expected list of user types and names
         :param expected_num_users: expected number of users (optional)
         """
-        if user_num is not None and len(node.users) != user_num:
-            raise CheckFailed
-        if len(users) > len(node.users):
-            raise CheckFailed
-        for i, (type, name) in enumerate(users):
-            if type is not None and node.users[i].type != type:
+        if self.graph.check_node_exist(node):
+            if user_num is not None and len(node.users) != user_num:
                 raise CheckFailed
-            if name is not None and node.users[i].name != name:
+            if len(users) > len(node.users):
                 raise CheckFailed
+            for i, (type, name) in enumerate(users):
+                if type is not None and node.users[i].type != type:
+                    raise CheckFailed
+                if name is not None and node.users[i].name != name:
+                    raise CheckFailed
+        else:
+            raise CheckFailed
 
     def check_operands(self,
                        node: Node,
diff --git a/python/setup.py b/python/setup.py
index 1197745..d32026e 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -34,6 +34,5 @@ setup(name=project,
       entry_points={
           'console_scripts': [
               f'{project}-opt = {project}.main:opt',
-              f'{project}-apply-tf = scripts.install:tf_install',
           ],
       })
-- 
2.33.0