From d941fd6d6893abd1893d2da25afa7847ec32f132 Mon Sep 17 00:00:00 2001 From: zhengchenhui Date: Tue, 9 Sep 2025 15:27:43 +0800 Subject: [PATCH 1/2] fix pattern conflicts. --- python/annc/optimize/graph.py | 6 ++++++ python/annc/optimize/rewriter.py | 17 ++++++++++------- python/setup.py | 1 - 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/python/annc/optimize/graph.py b/python/annc/optimize/graph.py index dae7d6d..1a97d1b 100644 --- a/python/annc/optimize/graph.py +++ b/python/annc/optimize/graph.py @@ -255,6 +255,12 @@ class Graph: self.versions['producer'] = graph_def.versions.producer + def check_node_exist(self, node: Node) -> bool: + for i, n in enumerate(self.nodes): + if n.name == node.name: + return True + return False + def get_node(self, name: str) -> Node: for node in self.nodes: if node.name == name: diff --git a/python/annc/optimize/rewriter.py b/python/annc/optimize/rewriter.py index 2e66189..c0f2894 100644 --- a/python/annc/optimize/rewriter.py +++ b/python/annc/optimize/rewriter.py @@ -46,15 +46,18 @@ class BaseRewriter(ABC): :param users: expected list of user types and names :param expected_num_users: expected number of users (optional) """ - if user_num is not None and len(node.users) != user_num: - raise CheckFailed - if len(users) > len(node.users): - raise CheckFailed - for i, (type, name) in enumerate(users): - if type is not None and node.users[i].type != type: + if self.graph.check_node_exist(node): + if user_num is not None and len(node.users) != user_num: raise CheckFailed - if name is not None and node.users[i].name != name: + if len(users) > len(node.users): raise CheckFailed + for i, (type, name) in enumerate(users): + if type is not None and node.users[i].type != type: + raise CheckFailed + if name is not None and node.users[i].name != name: + raise CheckFailed + else: + raise CheckFailed def check_operands(self, node: Node, diff --git a/python/setup.py b/python/setup.py index 1197745..d32026e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -34,6 +34,5 @@ setup(name=project, entry_points={ 'console_scripts': [ f'{project}-opt = {project}.main:opt', - f'{project}-apply-tf = scripts.install:tf_install', ], }) -- 2.33.0