1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
|
From d941fd6d6893abd1893d2da25afa7847ec32f132 Mon Sep 17 00:00:00 2001
From: zhengchenhui <zhengchenhui1@huawei.com>
Date: Tue, 9 Sep 2025 15:27:43 +0800
Subject: [PATCH 1/2] fix pattern conflicts.
---
python/annc/optimize/graph.py | 6 ++++++
python/annc/optimize/rewriter.py | 17 ++++++++++-------
python/setup.py | 1 -
3 files changed, 16 insertions(+), 8 deletions(-)
diff --git a/python/annc/optimize/graph.py b/python/annc/optimize/graph.py
index dae7d6d..1a97d1b 100644
--- a/python/annc/optimize/graph.py
+++ b/python/annc/optimize/graph.py
@@ -255,6 +255,12 @@ class Graph:
self.versions['producer'] = graph_def.versions.producer
+ def check_node_exist(self, node: Node) -> bool:
+ for i, n in enumerate(self.nodes):
+ if n.name == node.name:
+ return True
+ return False
+
def get_node(self, name: str) -> Node:
for node in self.nodes:
if node.name == name:
diff --git a/python/annc/optimize/rewriter.py b/python/annc/optimize/rewriter.py
index 2e66189..c0f2894 100644
--- a/python/annc/optimize/rewriter.py
+++ b/python/annc/optimize/rewriter.py
@@ -46,15 +46,18 @@ class BaseRewriter(ABC):
:param users: expected list of user types and names
:param expected_num_users: expected number of users (optional)
"""
- if user_num is not None and len(node.users) != user_num:
- raise CheckFailed
- if len(users) > len(node.users):
- raise CheckFailed
- for i, (type, name) in enumerate(users):
- if type is not None and node.users[i].type != type:
+ if self.graph.check_node_exist(node):
+ if user_num is not None and len(node.users) != user_num:
raise CheckFailed
- if name is not None and node.users[i].name != name:
+ if len(users) > len(node.users):
raise CheckFailed
+ for i, (type, name) in enumerate(users):
+ if type is not None and node.users[i].type != type:
+ raise CheckFailed
+ if name is not None and node.users[i].name != name:
+ raise CheckFailed
+ else:
+ raise CheckFailed
def check_operands(self,
node: Node,
diff --git a/python/setup.py b/python/setup.py
index 1197745..d32026e 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -34,6 +34,5 @@ setup(name=project,
entry_points={
'console_scripts': [
f'{project}-opt = {project}.main:opt',
- f'{project}-apply-tf = scripts.install:tf_install',
],
})
--
2.33.0
|