summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore5
-rw-r--r--0001-fix-pattern-conflicts.patch72
-rw-r--r--0002-Add-graph-optimizer-and-embedding_fused-kernels.patch4297
-rw-r--r--ANNC.spec83
-rw-r--r--sources9
-rw-r--r--x86_64_external_files.patch132
6 files changed, 4511 insertions, 87 deletions
diff --git a/.gitignore b/.gitignore
index cb59680..04f958b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,3 +2,8 @@
/annc_external.tar.gz.aa
/annc_external.tar.gz.ab
/ANNC-v0.0.2.tar.gz
+/XNNPACK.tar.gz
+/external.tar.gz.aa
+/external.tar.gz.ab
+/external.tar.gz.ac
+/v3.2.tar.gz
diff --git a/0001-fix-pattern-conflicts.patch b/0001-fix-pattern-conflicts.patch
new file mode 100644
index 0000000..461a7e6
--- /dev/null
+++ b/0001-fix-pattern-conflicts.patch
@@ -0,0 +1,72 @@
+From d941fd6d6893abd1893d2da25afa7847ec32f132 Mon Sep 17 00:00:00 2001
+From: zhengchenhui <zhengchenhui1@huawei.com>
+Date: Tue, 9 Sep 2025 15:27:43 +0800
+Subject: [PATCH 1/2] fix pattern conflicts.
+
+---
+ python/annc/optimize/graph.py | 6 ++++++
+ python/annc/optimize/rewriter.py | 17 ++++++++++-------
+ python/setup.py | 1 -
+ 3 files changed, 16 insertions(+), 8 deletions(-)
+
+diff --git a/python/annc/optimize/graph.py b/python/annc/optimize/graph.py
+index dae7d6d..1a97d1b 100644
+--- a/python/annc/optimize/graph.py
++++ b/python/annc/optimize/graph.py
+@@ -255,6 +255,12 @@ class Graph:
+
+ self.versions['producer'] = graph_def.versions.producer
+
++ def check_node_exist(self, node: Node) -> bool:
++ for i, n in enumerate(self.nodes):
++ if n.name == node.name:
++ return True
++ return False
++
+ def get_node(self, name: str) -> Node:
+ for node in self.nodes:
+ if node.name == name:
+diff --git a/python/annc/optimize/rewriter.py b/python/annc/optimize/rewriter.py
+index 2e66189..c0f2894 100644
+--- a/python/annc/optimize/rewriter.py
++++ b/python/annc/optimize/rewriter.py
+@@ -46,15 +46,18 @@ class BaseRewriter(ABC):
+ :param users: expected list of user types and names
+ :param expected_num_users: expected number of users (optional)
+ """
+- if user_num is not None and len(node.users) != user_num:
+- raise CheckFailed
+- if len(users) > len(node.users):
+- raise CheckFailed
+- for i, (type, name) in enumerate(users):
+- if type is not None and node.users[i].type != type:
++ if self.graph.check_node_exist(node):
++ if user_num is not None and len(node.users) != user_num:
+ raise CheckFailed
+- if name is not None and node.users[i].name != name:
++ if len(users) > len(node.users):
+ raise CheckFailed
++ for i, (type, name) in enumerate(users):
++ if type is not None and node.users[i].type != type:
++ raise CheckFailed
++ if name is not None and node.users[i].name != name:
++ raise CheckFailed
++ else:
++ raise CheckFailed
+
+ def check_operands(self,
+ node: Node,
+diff --git a/python/setup.py b/python/setup.py
+index 1197745..d32026e 100644
+--- a/python/setup.py
++++ b/python/setup.py
+@@ -34,6 +34,5 @@ setup(name=project,
+ entry_points={
+ 'console_scripts': [
+ f'{project}-opt = {project}.main:opt',
+- f'{project}-apply-tf = scripts.install:tf_install',
+ ],
+ })
+--
+2.33.0
+
diff --git a/0002-Add-graph-optimizer-and-embedding_fused-kernels.patch b/0002-Add-graph-optimizer-and-embedding_fused-kernels.patch
new file mode 100644
index 0000000..4fd5abd
--- /dev/null
+++ b/0002-Add-graph-optimizer-and-embedding_fused-kernels.patch
@@ -0,0 +1,4297 @@
+From 90ac46231be919e1c07b7d41bc0a8c4b1f1ba41a Mon Sep 17 00:00:00 2001
+From: zhengchenhui <zhengchenhui1@huawei.com>
+Date: Mon, 10 Nov 2025 16:49:10 +0800
+Subject: [PATCH] Add graph optimizer and embedding_fused kernels.
+
+---
+ annc/tensorflow/graph_optimizer/BUILD | 32 +
+ annc/tensorflow/graph_optimizer/graph_opt.cc | 738 ++++++++++++++++++
+ annc/tensorflow/graph_optimizer/graph_opt.h | 165 ++++
+ .../embedding_fused_action_id_gather.cc | 144 ++++
+ .../embedding_fused_action_id_gather_test.cc | 289 +++++++
+ .../kernels/embedding_fused_gather.cc | 90 +++
+ .../kernels/embedding_fused_gather_test.cc | 186 +++++
+ .../kernels/embedding_fused_padding.cc | 126 +++
+ .../kernels/embedding_fused_padding_test.cc | 307 ++++++++
+ .../embedding_fused_sparse_dynamic_stitch.cc | 87 +++
+ ...edding_fused_sparse_dynamic_stitch_test.cc | 108 +++
+ .../kernels/embedding_fused_sparse_reshape.cc | 195 +++++
+ .../embedding_fused_sparse_reshape_test.cc | 281 +++++++
+ .../embedding_fused_sparse_segment_reduce.cc | 165 ++++
+ ...ing_fused_sparse_segment_reduce_nonzero.cc | 159 ++++
+ ...used_sparse_segment_reduce_nonzero_test.cc | 183 +++++
+ ...edding_fused_sparse_segment_reduce_test.cc | 205 +++++
+ .../kernels/embedding_fused_sparse_select.cc | 111 +++
+ .../embedding_fused_sparse_select_test.cc | 182 +++++
+ annc/tensorflow/ops/embedding_fused_ops.cc | 133 ++++
+ annc/tensorflow/tf_annc_optimizer.patch | 221 ++++++
+ 21 files changed, 4107 insertions(+)
+ create mode 100644 annc/tensorflow/graph_optimizer/BUILD
+ create mode 100644 annc/tensorflow/graph_optimizer/graph_opt.cc
+ create mode 100644 annc/tensorflow/graph_optimizer/graph_opt.h
+ create mode 100644 annc/tensorflow/kernels/embedding_fused_action_id_gather.cc
+ create mode 100644 annc/tensorflow/kernels/embedding_fused_action_id_gather_test.cc
+ create mode 100644 annc/tensorflow/kernels/embedding_fused_gather.cc
+ create mode 100644 annc/tensorflow/kernels/embedding_fused_gather_test.cc
+ create mode 100644 annc/tensorflow/kernels/embedding_fused_padding.cc
+ create mode 100644 annc/tensorflow/kernels/embedding_fused_padding_test.cc
+ create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_dynamic_stitch.cc
+ create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_dynamic_stitch_test.cc
+ create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_reshape.cc
+ create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_reshape_test.cc
+ create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce.cc
+ create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_nonzero.cc
+ create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_nonzero_test.cc
+ create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_test.cc
+ create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_select.cc
+ create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_select_test.cc
+ create mode 100644 annc/tensorflow/ops/embedding_fused_ops.cc
+ create mode 100644 annc/tensorflow/tf_annc_optimizer.patch
+
+diff --git a/annc/tensorflow/graph_optimizer/BUILD b/annc/tensorflow/graph_optimizer/BUILD
+new file mode 100644
+index 0000000..39a37f8
+--- /dev/null
++++ b/annc/tensorflow/graph_optimizer/BUILD
+@@ -0,0 +1,32 @@
++package(
++ default_visibility = [
++ "//visibility:public",
++ ],
++ licenses = ["notice"],
++)
++
++cc_library(
++ name = "annc_graph_opt",
++ srcs = glob(["*.cc"]),
++ hdrs = glob(["*.h"]),
++ linkstatic = True,
++ alwayslink = True,
++ visibility = ["//visibility:public"],
++ deps = [
++ "//tensorflow/core/grappler:graph_view",
++ "//tensorflow/core/grappler:grappler_item",
++ "//tensorflow/core/grappler:op_types",
++ ],
++)
++
++cc_binary(
++ name = "libannc_graph_opt.so",
++ srcs = glob(["*.cc", "*.h"]),
++ linkshared = True,
++ visibility = ["//visibility:public"],
++ deps = [
++ "//tensorflow/core/grappler:graph_view",
++ "//tensorflow/core/grappler:grappler_item",
++ "//tensorflow/core/grappler:op_types",
++ ],
++)
+\ No newline at end of file
+diff --git a/annc/tensorflow/graph_optimizer/graph_opt.cc b/annc/tensorflow/graph_optimizer/graph_opt.cc
+new file mode 100644
+index 0000000..9c74489
+--- /dev/null
++++ b/annc/tensorflow/graph_optimizer/graph_opt.cc
+@@ -0,0 +1,738 @@
++#include "graph_opt.h"
++
++using namespace tensorflow;
++using namespace tensorflow::grappler;
++
++namespace annc {
++void update_node_indexes(const GraphDef* graph,
++ std::unordered_map<std::string, int>& node_indexes) {
++ for (int i = 0; i < graph->node_size(); ++i) {
++ node_indexes[graph->node(i).name()] = i;
++ }
++}
++
++void GraphOptimizer::register_rewriter(
++ std::unique_ptr<PatternRewriter> rewriter) {
++ rewriters_.push_back(std::move(rewriter));
++}
++
++void GraphOptimizer::optimize() {
++ update_node_indexes(graph_, node_indexes_);
++ int node_index = 0;
++ const int node_size = graph_->node_size();
++ while (node_index < node_size) {
++ const NodeDef& node = graph_->node(node_index);
++ const std::string& node_name = node.name();
++ for (auto& rewriter : rewriters_) {
++ if (rewriter->match_and_rewrite(&node, graph_, node_indexes_)) {
++ update_node_indexes(graph_, node_indexes_);
++ const std::string new_node_name = node_name + fusion_appendix;
++ node_index = node_indexes_.at(new_node_name);
++ break;
++ }
++ }
++ node_index++;
++ }
++}
++
++std::string get_node_name(const std::string& name) {
++ size_t colon_pos = name.find_last_of(':');
++ std::string node_name = name;
++ if (colon_pos != std::string::npos) {
++ node_name = name.substr(0, colon_pos);
++ }
++ return node_name;
++}
++
++void set_fusedop_attributes(NodeDef* fused,
++ const absl::Span<const absl::string_view> fused_ops,
++ int num_args = 1, float epsilon = 0.0) {
++ auto* attr = fused->mutable_attr();
++ SetAttrValue(fused_ops, &(*attr)["fused_ops"]);
++ SetAttrValue(num_args, &(*attr)["num_args"]);
++ SetAttrValue(epsilon, &(*attr)["epsilon"]); // required only for BatchNorm
++}
++
++const NodeDef* PatternRewriter::get_node(const std::string& name) {
++ const std::string node_name = get_node_name(name);
++ const int node_index = indexes_->at(node_name);
++ return &graph_->node(node_index);
++}
++
++NodeDef* PatternRewriter::get_mutable_node(const std::string& name) {
++ const std::string node_name = get_node_name(name);
++ if (indexes_->find(node_name) == indexes_->end()) return nullptr;
++ const int node_index = indexes_->at(node_name);
++ return graph_->mutable_node(node_index);
++}
++
++NodeDef* PatternRewriter::get_operand(const NodeDef* node,
++ std::string op_type) {
++ for (int i = 0; i < node->input_size(); ++i) {
++ NodeDef* operand = get_mutable_node(node->input(i));
++ if (operand != nullptr && operand->op() == op_type) return operand;
++ }
++ return nullptr;
++}
++
++const NodeDef* PatternRewriter::get_user(const NodeDef* node, int index,
++ const std::string& op_type) {
++ std::string node_name = node->name();
++ if (index) std::string node_name = node_name + ":" + std::to_string(index);
++ for (int i = 0; i < graph_->node_size(); ++i) {
++ const NodeDef* node = graph_->mutable_node(i);
++ for (int j = 0; j < node->input_size(); ++j) {
++ if (node->input(j) == node_name && node->op() == op_type) {
++ return node;
++ }
++ }
++ }
++ return nullptr;
++}
++
++void PatternRewriter::replace_all_users_with(const NodeDef* old_node,
++ int old_index,
++ const NodeDef* new_node,
++ int new_index, GraphDef* graph) {
++ std::string old_name = old_node->name();
++ if (old_index) old_name = old_name + ":" + std::to_string(old_index);
++ std::string new_name = new_node->name();
++ if (new_index) new_name = new_name + ":" + std::to_string(new_index);
++ for (int i = 0; i < graph->node_size(); ++i) {
++ NodeDef* node = graph->mutable_node(i);
++ for (int j = 0; j < node->input_size(); ++j) {
++ if (node->input(j) == old_name) {
++ node->set_input(j, new_name);
++ }
++ }
++ }
++}
++
++class KPFusedSparseDynamicStitchRewriter : public PatternRewriter {
++ public:
++ std::string name() const override { return "KPFusedSparseDynamicStitch"; }
++
++ bool match_and_rewrite(
++ const NodeDef* node, GraphDef* graph,
++ std::unordered_map<std::string, int>& node_indexes) override {
++ graph_ = graph;
++ indexes_ = &node_indexes;
++ CHECK_NODE_OK(node->op() == "ParallelDynamicStitch" &&
++ node->input_size() % 2 == 0)
++ int num_inputs = node->input_size();
++ int num_partitions = num_inputs / 2;
++ // left branch
++ const NodeDef* partition = get_node(node->input(0));
++ CHECK_NODE_OK(partition->op() == "DynamicPartition" &&
++ partition->input_size() == 2)
++ const NodeDef* range = get_node(partition->input(0));
++ CHECK_NODE_OK(range->op() == "Range" && range->input_size() == 3)
++ // Range start=0, delta=1
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(range->input(0)), {0}))
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(range->input(2)), {1}))
++ const NodeDef* size = get_node(range->input(1));
++ CHECK_NODE_OK(IsSize(*size) && size->input_size() == 1)
++ const NodeDef* cast = get_node(partition->input(1));
++ CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
++ const NodeDef* floor_mod = get_node(cast->input(0));
++ CHECK_NODE_OK(floor_mod->op() == "FloorMod" && floor_mod->input_size() == 2)
++ CHECK_NODE_OK(check_const_value<int64_t>(
++ get_mutable_node(floor_mod->input(1)), {num_partitions}))
++
++ CHECK_NODE_OK(check_int_attr(node, "N", {num_partitions}))
++ CHECK_NODE_OK(check_int_attr(partition, "num_partitions", {num_partitions}))
++
++ auto nodes = graph->mutable_node();
++ NodeDef* fused_node = nodes->Add();
++ fused_node->set_name(node->name() + fusion_appendix);
++ fused_node->set_op(name());
++ fused_node->set_device(node->device());
++ fused_node->add_input(size->input(0));
++ // right branch
++ for (int i = num_partitions; i < num_inputs; ++i) {
++ const NodeDef* gather = get_node(node->input(i));
++ CHECK_NODE_OK(gather->op() == "GatherV2" && gather->input_size() == 3)
++ // Gather axis=0
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(gather->input(2)), {0}))
++ const NodeDef* partition_1 = get_node(gather->input(1));
++ CHECK_NODE_OK(partition_1->op() == "DynamicPartition" &&
++ partition_1->input_size() == 2)
++ CHECK_NODE_OK(
++ check_int_attr(partition_1, "num_partitions", {num_partitions}))
++ const NodeDef* floor_div = get_node(partition_1->input(0));
++ CHECK_NODE_OK(floor_div->op() == "FloorDiv" &&
++ floor_div->input_size() == 2)
++ CHECK_NODE_OK(check_const_value<int64_t>(
++ get_mutable_node(floor_div->input(1)), {num_partitions}))
++ fused_node->add_input(gather->input(0));
++ }
++ (*fused_node->mutable_attr())["N"].set_i(num_partitions);
++ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
++
++ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
++ << fused_node->name();
++ replace_all_users_with(node, 0, fused_node, 0, graph);
++ return true;
++ }
++};
++
++class KPFusedSparseSegmentReduceRewriter : public PatternRewriter {
++ public:
++ std::string name() const override { return "KPFusedSparseSegmentReduce"; }
++
++ bool match_and_rewrite(
++ const NodeDef* node, GraphDef* graph,
++ std::unordered_map<std::string, int>& node_indexes) override {
++ graph_ = graph;
++ indexes_ = &node_indexes;
++ CHECK_NODE_OK(IsStridedSlice(*node) && node->input_size() == 4)
++ const NodeDef* shape = get_node(node->input(0));
++ CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1)
++ const NodeDef* ss_reduce = get_node(shape->input(0));
++ CHECK_NODE_OK(ss_reduce->input_size() == 3)
++ AttrValue combiner;
++ if (ss_reduce->op() == "SparseSegmentMean")
++ combiner.set_i(1);
++ else if (ss_reduce->op() == "SparseSegmentSum")
++ combiner.set_i(0);
++ else
++ return false;
++ const NodeDef* cast = get_node(ss_reduce->input(2));
++ CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
++ const NodeDef* strided_slice = get_node(cast->input(0));
++ CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
++ strided_slice->input_size() == 4)
++
++ // check fusion conditions
++ CHECK_NODE_OK(
++ check_const_shape(get_mutable_node(strided_slice->input(1)), {2}))
++ CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
++ CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
++ CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
++ CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(1)), {1}))
++ CHECK_NODE_OK(check_int_attr(node, "shrink_axis_mask", 1))
++
++ auto nodes = graph->mutable_node();
++ NodeDef* fused_node = nodes->Add();
++ fused_node->set_name(node->name() + fusion_appendix);
++ fused_node->set_op(name());
++ fused_node->set_device(node->device());
++ fused_node->add_input(ss_reduce->input(0));
++ fused_node->add_input(ss_reduce->input(1));
++ fused_node->add_input(strided_slice->input(0));
++ fused_node->add_input(strided_slice->input(1));
++ fused_node->add_input(node->input(1));
++ AddNodeAttr("combiner", combiner, fused_node);
++
++ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
++
++ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
++ << fused_node->name();
++ replace_all_users_with(ss_reduce, 0, fused_node, 0, graph);
++ replace_all_users_with(node, 0, fused_node, 1, graph);
++ return true;
++ }
++};
++
++class KPFusedSparseSegmentReduceNonzeroRewriter : public PatternRewriter {
++ public:
++ std::string name() const override {
++ return "KPFusedSparseSegmentReduceNonzero";
++ }
++
++ bool match_and_rewrite(
++ const NodeDef* node, GraphDef* graph,
++ std::unordered_map<std::string, int>& node_indexes) override {
++ graph_ = graph;
++ indexes_ = &node_indexes;
++ CHECK_NODE_OK(node->op() == "GatherND" &&
++ node->input_size() == 2) // output:2
++ const NodeDef* ss_reduce = get_node(node->input(0));
++ CHECK_NODE_OK(ss_reduce->input_size() == 3)
++ AttrValue combiner;
++ if (ss_reduce->op() == "SparseSegmentMean") {
++ combiner.set_i(1);
++ } else if (ss_reduce->op() == "SparseSegmentSum") {
++ combiner.set_i(0);
++ } else {
++ return false;
++ }
++ const NodeDef* where = get_node(node->input(1));
++ CHECK_NODE_OK(where->op() == "Where" && where->input_size() == 1)
++ const NodeDef* cast = get_user(where, 0, "Cast");
++ CHECK_NODE_OK(cast != nullptr) // output: 1
++ const NodeDef* notequal = get_node(where->input(0));
++ CHECK_NODE_OK(IsNotEqual(*notequal) && notequal->input_size() == 2);
++ const NodeDef* zerolike = get_node(notequal->input(1));
++ CHECK_NODE_OK(IsZerosLike(*zerolike) && zerolike->input_size() == 1)
++ const NodeDef* cast_1 = get_node(ss_reduce->input(2));
++ CHECK_NODE_OK(IsCast(*cast_1) && cast_1->input_size() == 1)
++ const NodeDef* strided_slice = get_node(cast->input(0));
++ CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
++ strided_slice->input_size() == 4)
++ const NodeDef* shape = get_user(ss_reduce, 0, "Shape");
++ CHECK_NODE_OK(shape != nullptr)
++ const NodeDef* cast_2 = get_user(shape, 0, "Cast"); // output: 0
++ CHECK_NODE_OK(cast_2 != nullptr)
++
++ CHECK_NODE_OK(
++ check_const_shape(get_mutable_node(strided_slice->input(1)), {2}))
++ CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
++ CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
++ CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
++
++ auto nodes = graph->mutable_node();
++ NodeDef* fused_node = nodes->Add();
++ fused_node->set_name(node->name() + fusion_appendix);
++ fused_node->set_op(name());
++ fused_node->set_device(node->device());
++ fused_node->add_input(ss_reduce->input(0));
++ fused_node->add_input(ss_reduce->input(1));
++ fused_node->add_input(strided_slice->input(0));
++ fused_node->add_input(strided_slice->input(1));
++ AddNodeAttr("combiner", combiner, fused_node);
++
++ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
++
++ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
++ << fused_node->name() << "\n";
++ replace_all_users_with(cast_2, 0, fused_node, 0, graph);
++ replace_all_users_with(cast, 0, fused_node, 1, graph);
++ replace_all_users_with(node, 0, fused_node, 2, graph);
++ return true;
++ }
++};
++
++class KPFusedEmbeddingPaddingRewriter : public PatternRewriter {
++ public:
++ std::string name() const override { return "KPFusedEmbeddingPadding"; }
++
++ bool match_and_rewrite(
++ const NodeDef* node, GraphDef* graph,
++ std::unordered_map<std::string, int>& node_indexes) override {
++ graph_ = graph;
++ indexes_ = &node_indexes;
++ CHECK_NODE_OK(IsReshape(*node))
++ const NodeDef* user = get_user(node, 0, "ConcatV2");
++ CHECK_NODE_OK(user != nullptr)
++ const NodeDef* concat = get_node(node->input(0));
++ CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3)
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(concat->input(2)), {0}))
++ const NodeDef* fill = get_node(concat->input(1));
++ CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
++ NodeDef* pack = get_operand(fill, "Pack");
++ CHECK_NODE_OK(pack != nullptr && IsPack(*pack) && pack->input_size() == 2)
++ NodeDef* fill_const = get_operand(fill, "Const");
++ CHECK_NODE_OK(fill_const != nullptr &&
++ check_const_value<int>(fill_const, {0}))
++ NodeDef* sub = get_operand(pack, "Sub");
++ CHECK_NODE_OK(sub != nullptr && IsSub(*sub) && sub->input_size() == 2)
++ const NodeDef* strided_slice = get_node(sub->input(0));
++ CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
++ strided_slice->input_size() == 4)
++ const NodeDef* cast = get_node(strided_slice->input(0));
++ CHECK_NODE_OK(IsCast(*cast))
++
++ auto nodes = graph->mutable_node();
++ NodeDef* fused_node = nodes->Add();
++ fused_node->set_name(node->name() + fusion_appendix);
++ fused_node->set_op(name());
++ fused_node->set_device(node->device());
++ fused_node->add_input(cast->input(0));
++ fused_node->add_input(concat->input(0));
++ fused_node->add_input(sub->input(1));
++ fused_node->add_input(node->input(1));
++ const NodeDef* pack_left = get_node(pack->input(0));
++ const NodeDef* pack_right = get_node(pack->input(1));
++ if (IsConstant(*pack_left) || IsHostConstant(*pack_left)) {
++ fused_node->add_input(pack->input(0));
++ } else if (IsConstant(*pack_right) || IsHostConstant(*pack_right)) {
++ fused_node->add_input(pack->input(1));
++ } else {
++ return false;
++ }
++
++ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
++
++ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
++ << fused_node->name();
++ replace_all_users_with(sub, 0, fused_node, 0, graph);
++ replace_all_users_with(node, 0, fused_node, 1, graph);
++ return true;
++ }
++};
++
++class KPFusedEmbeddingPaddingFastRewriter : public PatternRewriter {
++ public:
++ std::string name() const override { return "KPFusedEmbeddingPaddingFast"; }
++
++ bool match_and_rewrite(
++ const NodeDef* node, GraphDef* graph,
++ std::unordered_map<std::string, int>& node_indexes) override {
++ graph_ = graph;
++ indexes_ = &node_indexes;
++ CHECK_NODE_OK(IsStridedSlice(*node) && node->input_size() == 4)
++ CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(1)), {0}))
++ CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(2)), {1}))
++ CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(3)), {1}))
++ CHECK_NODE_OK(check_int_attr(node, "shrink_axis_mask", 1))
++ const NodeDef* shape = get_node(node->input(0));
++ CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1)
++ const NodeDef* reshape = get_node(shape->input(0));
++ CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2)
++ const NodeDef* concat = get_node(reshape->input(0));
++ CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3)
++ const NodeDef* fill = get_node(concat->input(1));
++ CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
++ const NodeDef* pack = get_node(fill->input(0));
++ CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2)
++ const NodeDef* sub = get_node(pack->input(0));
++ CHECK_NODE_OK(IsSub(*sub) && sub->input_size() == 2)
++ const NodeDef* strided_slice = get_node(sub->input(0));
++ CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
++ strided_slice->input_size() == 4)
++ const NodeDef* cast = get_node(strided_slice->input(0));
++ CHECK_NODE_OK(IsCast(*cast))
++
++ auto nodes = graph->mutable_node();
++ NodeDef* fused_node = nodes->Add();
++ fused_node->set_name(node->name() + fusion_appendix);
++ fused_node->set_op(name());
++ fused_node->set_device(node->device());
++ fused_node->add_input(cast->input(0));
++ fused_node->add_input(concat->input(0));
++ fused_node->add_input(sub->input(1));
++ fused_node->add_input(reshape->input(1));
++ const NodeDef* pack_left = get_node(pack->input(0));
++ const NodeDef* pack_right = get_node(pack->input(1));
++ if (IsConstant(*pack_left) || IsHostConstant(*pack_left)) {
++ fused_node->add_input(pack->input(0));
++ } else if (IsConstant(*pack_right) || IsHostConstant(*pack_right)) {
++ fused_node->add_input(pack->input(1));
++ } else {
++ return false;
++ }
++ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
++
++ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
++ << fused_node->name();
++ replace_all_users_with(sub, 0, fused_node, 0, graph);
++ replace_all_users_with(node, 0, fused_node, 1, graph);
++ return true;
++ }
++};
++
++class KPFusedSparseSelectRewriter : public PatternRewriter {
++ public:
++ std::string name() const override { return "KPFusedSparseSelect"; }
++
++ bool match_and_rewrite(
++ const NodeDef* node, GraphDef* graph,
++ std::unordered_map<std::string, int>& node_indexes) override {
++ graph_ = graph;
++ indexes_ = &node_indexes;
++ CHECK_NODE_OK(IsConcat(*node) && node->input_size() == 3)
++ const NodeDef* select_0 = get_node(node->input(0));
++ CHECK_NODE_OK(IsSelect(*select_0) && select_0->input_size() == 3)
++ const NodeDef* select_1 = get_node(select_0->input(2));
++ CHECK_NODE_OK(IsSelect(*select_1) && select_1->input_size() == 3)
++ const NodeDef* fill = get_node(select_1->input(1));
++ CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
++ CHECK_NODE_OK(
++ check_const_value<float>(get_mutable_node(fill->input(1)), {1.0f}))
++ const NodeDef* cast = get_node(select_1->input(2));
++ CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
++ const NodeDef* equal = get_node(select_1->input(0));
++ CHECK_NODE_OK(IsEqual(*equal) && equal->input_size() == 2)
++ NodeDef* reshape = get_operand(equal, "Reshape");
++ CHECK_NODE_OK(reshape != nullptr && IsReshape(*reshape))
++ const NodeDef* greater = get_node(cast->input(0));
++ CHECK_NODE_OK(IsGreater(*greater) && greater->input_size() == 2)
++ const NodeDef* reshape_4 = get_node(greater->input(0));
++ CHECK_NODE_OK(IsReshape(*reshape_4) && reshape_4->input_size() == 2)
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(reshape_4->input(1)), {-1, 1}))
++ const NodeDef* equal_1 = get_node(select_0->input(0));
++ CHECK_NODE_OK(IsEqual(*equal_1) && equal_1->input_size() == 2)
++ NodeDef* reshape_1 = get_operand(equal_1, "Reshape");
++ CHECK_NODE_OK(reshape_1 != nullptr && IsReshape(*reshape_1))
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(reshape_1->input(1)), {-1, 1}))
++
++ // right branch
++ const NodeDef* select_2 = get_node(node->input(1));
++ CHECK_NODE_OK(IsSelect(*select_2) && select_2->input_size() == 3)
++ const NodeDef* equal_2 = get_node(select_2->input(0));
++ CHECK_NODE_OK(IsEqual(*equal_2) && equal_2->input_size() == 2)
++ const NodeDef* fill_1 = get_node(select_2->input(2));
++ CHECK_NODE_OK(IsFill(*fill_1) && fill_1->input_size() == 2)
++ CHECK_NODE_OK(
++ check_const_value<float>(get_mutable_node(fill_1->input(1)), {1.0f}))
++ const NodeDef* reshape_2 = get_operand(equal_2, "Reshape");
++ CHECK_NODE_OK(reshape_2 != nullptr && IsReshape(*reshape_2))
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(reshape_2->input(1)), {-1, 1}))
++
++ auto nodes = graph->mutable_node();
++ NodeDef* fused_node = nodes->Add();
++ fused_node->set_name(node->name() + fusion_appendix);
++ fused_node->set_op(name());
++ fused_node->set_device(node->device());
++ fused_node->add_input(reshape->input(0));
++ fused_node->add_input(reshape_1->input(0));
++ fused_node->add_input(reshape_2->input(0));
++ std::vector<const NodeDef*> const_inputs = {equal, equal_1, equal_2,
++ greater};
++ for (const NodeDef* const_node : const_inputs) {
++ const NodeDef* left = get_node(const_node->input(0));
++ const NodeDef* right = get_node(const_node->input(1));
++ if (IsConstant(*left) || IsHostConstant(*left)) {
++ fused_node->add_input(const_node->input(0));
++ } else if (IsConstant(*right) || IsHostConstant(*right)) {
++ fused_node->add_input(const_node->input(1));
++ } else {
++ return false;
++ }
++ }
++ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
++
++ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
++ << fused_node->name();
++ replace_all_users_with(reshape, 0, fused_node, 0, graph);
++ replace_all_users_with(select_0, 0, fused_node, 1, graph);
++ replace_all_users_with(node, 0, fused_node, 2, graph);
++ return true;
++ }
++};
++
++class KPFusedGatherRewriter : public PatternRewriter {
++ public:
++ std::string name() const override { return "KPFusedGather"; }
++
++ bool match_and_rewrite(
++ const NodeDef* node, GraphDef* graph,
++ std::unordered_map<std::string, int>& node_indexes) override {
++ graph_ = graph;
++ indexes_ = &node_indexes;
++ CHECK_NODE_OK(node->op() == "GatherV2" && node->input_size() == 3)
++ CHECK_NODE_OK(check_const_value<int>(get_mutable_node(node->input(2)), {0}))
++ const NodeDef* gather = get_node(node->input(0));
++ CHECK_NODE_OK(gather->op() == "GatherV2" &&
++ gather->input_size() == 3) // input:0
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(gather->input(2)), {0}))
++ const NodeDef* unique = get_node(node->input(1)); // output:1
++ CHECK_NODE_OK(unique->op() == "Unique" && unique->input_size() == 1)
++ const NodeDef* unique_1 = get_node(unique->input(0));
++ CHECK_NODE_OK(unique_1->op() == "Unique" && unique_1->input_size() == 1)
++ const NodeDef* strided_slice = get_node(unique_1->input(0));
++ CHECK_NODE_OK(IsStridedSlice(*strided_slice)) // input:1 2
++ CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
++ CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
++ CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
++
++ auto nodes = graph->mutable_node();
++ NodeDef* fused_node = nodes->Add();
++ fused_node->set_name(node->name() + fusion_appendix);
++ fused_node->set_op(name());
++ fused_node->set_device(node->device());
++ fused_node->add_input(gather->input(0));
++ fused_node->add_input(strided_slice->input(0));
++ fused_node->add_input(strided_slice->input(1));
++ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
++
++ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
++ << fused_node->name();
++ replace_all_users_with(unique_1, 0, fused_node, 0, graph);
++ replace_all_users_with(unique_1, 1, fused_node, 1, graph);
++ replace_all_users_with(node, 0, fused_node, 2, graph);
++ return true;
++ }
++};
++
++class KPFusedSparseReshapeRewriter : public PatternRewriter {
++ public:
++ std::string name() const override { return "KPFusedSparseReshape"; }
++
++ bool match_and_rewrite(
++ const NodeDef* node, GraphDef* graph,
++ std::unordered_map<std::string, int>& node_indexes) override {
++ graph_ = graph;
++ indexes_ = &node_indexes;
++ CHECK_NODE_OK(node->op() == "SparseReshape" && node->input_size() == 3)
++ const NodeDef* concat = get_node(node->input(0));
++ CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3)
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(concat->input(2)), {-1}))
++ const NodeDef* reshape = get_node(concat->input(1));
++ CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2)
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(reshape->input(1)), {-1, 1}))
++ const NodeDef* strided_slice = get_node(reshape->input(0));
++ CHECK_NODE_OK(IsStridedSlice(*strided_slice) &&
++ strided_slice->input_size() == 4)
++ CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2))
++ CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1))
++ CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1))
++ const NodeDef* cast_1 = get_node(concat->input(0));
++ CHECK_NODE_OK(IsCast(*cast_1) && cast_1->input_size() == 1)
++ const NodeDef* reshape_1 = get_node(cast_1->input(0));
++ CHECK_NODE_OK(IsReshape(*reshape_1) && reshape_1->input_size() == 2)
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(reshape_1->input(1)), {-1, 1}))
++ const NodeDef* range = get_node(reshape_1->input(0));
++ CHECK_NODE_OK(range->op() == "Range" && range->input_size() == 3)
++ // Range start=0, delta=1
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(range->input(0)), {0}))
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(range->input(2)), {1}))
++ const NodeDef* cast = get_node(node->input(1));
++ CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1)
++ const NodeDef* pack = get_node(cast->input(0));
++ CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2)
++ const NodeDef* strided_slice_1 = get_node(pack->input(0));
++ CHECK_NODE_OK(IsStridedSlice(*strided_slice_1) &&
++ strided_slice_1->input_size() == 4)
++ CHECK_NODE_OK(check_const_value<int>(
++ get_mutable_node(strided_slice_1->input(1)), {0}))
++ CHECK_NODE_OK(check_const_value<int>(
++ get_mutable_node(strided_slice_1->input(2)), {1}))
++ CHECK_NODE_OK(check_const_value<int>(
++ get_mutable_node(strided_slice_1->input(3)), {1}))
++ CHECK_NODE_OK(check_int_attr(strided_slice_1, "shrink_axis_mask", 1))
++ const NodeDef* shape = get_node(strided_slice_1->input(0));
++ CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1)
++
++ auto nodes = graph->mutable_node();
++ NodeDef* fused_node = nodes->Add();
++ fused_node->set_name(node->name() + fusion_appendix);
++ fused_node->set_op(name());
++ fused_node->set_device(node->device());
++ fused_node->add_input(shape->input(0));
++ fused_node->add_input(strided_slice->input(1));
++ fused_node->add_input(node->input(2));
++ fused_node->add_input(pack->input(1));
++ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
++
++ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
++ << fused_node->name();
++ replace_all_users_with(node, 0, fused_node, 0, graph);
++ replace_all_users_with(node, 1, fused_node, 1, graph);
++ return true;
++ }
++};
++
++class KPFusedEmbeddingActionIdGatherRewriter : public PatternRewriter {
++ public:
++ std::string name() const override { return "KPFusedEmbeddingActionIdGather"; }
++
++ bool match_and_rewrite(
++ const NodeDef* node, GraphDef* graph,
++ std::unordered_map<std::string, int>& node_indexes) override {
++ graph_ = graph;
++ indexes_ = &node_indexes;
++ CHECK_NODE_OK(IsConcat(*node) && node->input_size() == 3)
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(node->input(2)), {-1}))
++ const NodeDef* reshape = get_node(node->input(0));
++ CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2)
++ const NodeDef* gather = get_node(reshape->input(0));
++ const NodeDef* pack_1 = get_node(reshape->input(1));
++ CHECK_NODE_OK(IsPack(*pack_1) && pack_1->input_size() == 2)
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(pack_1->input(1)), {-1}))
++ CHECK_NODE_OK(gather->op() == "GatherV2" && gather->input_size() == 3)
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(gather->input(2)), {0}))
++ const NodeDef* gather_1 = get_node(gather->input(0));
++ CHECK_NODE_OK(gather_1->op() == "GatherV2" && gather_1->input_size() == 3)
++ CHECK_NODE_OK(
++ check_const_value<int>(get_mutable_node(gather_1->input(2)), {0}))
++ const NodeDef* fill = get_node(node->input(1));
++ CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2)
++ CHECK_NODE_OK(check_const_value<int>(get_mutable_node(fill->input(1)), {0}))
++ const NodeDef* pack = get_node(fill->input(0));
++ CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2)
++
++ auto nodes = graph->mutable_node();
++ NodeDef* fused_node = nodes->Add();
++ fused_node->set_name(node->name() + fusion_appendix);
++ fused_node->set_op(name());
++ fused_node->set_device(node->device());
++ fused_node->add_input(gather_1->input(1));
++ fused_node->add_input(gather_1->input(0));
++ fused_node->add_input(gather->input(1));
++ fused_node->add_input(pack->input(0));
++ fused_node->add_input(pack->input(1));
++ nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1);
++
++ VLOG(0) << "-- Add node: [" << fused_node->op() << "] "
++ << fused_node->name();
++ replace_all_users_with(node, 0, fused_node, 0, graph);
++ return true;
++ }
++};
++
++void run_graph_optimization(GraphDef* graph) {
++ GraphOptimizer optimizer(graph);
++
++ const char* annc_fused_all = getenv("ANNC_FUSED_ALL");
++ const char* annc_fused_sps_stitch = getenv("ANNC_FUSED_SPS_STITCH");
++ const char* annc_fused_sps_reduce = getenv("ANNC_FUSED_SPS_REDUCE");
++ const char* annc_fused_emb_padding = getenv("ANNC_FUSED_EMD_PADDING");
++ const char* annc_fused_emb_padding_fast =
++ getenv("ANNC_FUSED_EMD_PADDING_FAST");
++ const char* annc_fused_sps_select = getenv("ANNC_FUSED_SPS_SELECT");
++ const char* annc_fused_gather = getenv("ANNC_FUSED_GATHER");
++ const char* annc_fused_sps_reshape = getenv("ANNC_FUSED_SPS_RESHAPE");
++ const char* annc_fused_emb_actionid_gather =
++ getenv("ANNC_FUSED_EMB_ACTIONID_GATHER");
++ const char* annc_fused_sps_reduce_nonzero =
++ getenv("ANNC_FUSED_SPS_REDUCE_NONZERO");
++
++ bool enable_all =
++ (annc_fused_all != nullptr) && strcmp(annc_fused_all, "1") == 0;
++
++ // default enable all rewriters
++ if (enable_all || (annc_fused_sps_stitch != nullptr &&
++ strcmp(annc_fused_sps_stitch, "1") == 0))
++ optimizer.register_rewriter(
++ std::make_unique<KPFusedSparseDynamicStitchRewriter>());
++ if (enable_all || (annc_fused_sps_reduce != nullptr &&
++ strcmp(annc_fused_sps_reduce, "1") == 0))
++ optimizer.register_rewriter(
++ std::make_unique<KPFusedSparseSegmentReduceRewriter>());
++ if (enable_all || (annc_fused_emb_padding_fast != nullptr &&
++ strcmp(annc_fused_emb_padding_fast, "1") == 0))
++ optimizer.register_rewriter(
++ std::make_unique<KPFusedEmbeddingPaddingFastRewriter>());
++ if (enable_all || (annc_fused_emb_padding != nullptr &&
++ strcmp(annc_fused_emb_padding, "1") == 0))
++ optimizer.register_rewriter(
++ std::make_unique<KPFusedEmbeddingPaddingRewriter>());
++ if (enable_all || (annc_fused_sps_select != nullptr &&
++ strcmp(annc_fused_sps_select, "1") == 0))
++ optimizer.register_rewriter(
++ std::make_unique<KPFusedSparseSelectRewriter>());
++ if (enable_all ||
++ (annc_fused_gather != nullptr && strcmp(annc_fused_gather, "1") == 0))
++ optimizer.register_rewriter(std::make_unique<KPFusedGatherRewriter>());
++ if (enable_all || (annc_fused_sps_reshape != nullptr &&
++ strcmp(annc_fused_sps_reshape, "1") == 0))
++ optimizer.register_rewriter(std::make_unique<KPFusedSparseReshapeRewriter>());
++ if (annc_fused_emb_actionid_gather != nullptr &&
++ strcmp(annc_fused_emb_actionid_gather, "1") == 0)
++ optimizer.register_rewriter(
++ std::make_unique<KPFusedEmbeddingActionIdGatherRewriter>());
++ if (annc_fused_sps_reduce_nonzero != nullptr &&
++ strcmp(annc_fused_sps_reduce_nonzero, "1") == 0)
++ optimizer.register_rewriter(
++ std::make_unique<KPFusedSparseSegmentReduceNonzeroRewriter>());
++ optimizer.optimize();
++}
++} // namespace annc
+diff --git a/annc/tensorflow/graph_optimizer/graph_opt.h b/annc/tensorflow/graph_optimizer/graph_opt.h
+new file mode 100644
+index 0000000..0ac0273
+--- /dev/null
++++ b/annc/tensorflow/graph_optimizer/graph_opt.h
+@@ -0,0 +1,165 @@
++#ifndef ANNC_TF_GRAPH_OPT_H_
++#define ANNC_TF_GRAPH_OPT_H_
++#include <type_traits>
++#include <unordered_map>
++
++#include "tensorflow/core/grappler/graph_view.h"
++#include "tensorflow/core/grappler/grappler_item.h"
++#include "tensorflow/core/grappler/op_types.h"
++
++using namespace tensorflow;
++using namespace tensorflow::grappler;
++
++namespace annc {
++#define CHECK_NODE_OK(x) \
++ if (!(x)) { \
++ return false; \
++ }
++
++static const std::string fusion_appendix = "/kp_fused";
++
++void update_node_indexes(const GraphDef* graph,
++ std::unordered_map<std::string, int>& node_indexes);
++
++class PatternRewriter {
++ public:
++ PatternRewriter() {}
++ virtual ~PatternRewriter() = default;
++
++ virtual bool match_and_rewrite(
++ const NodeDef* node, GraphDef* graph,
++ std::unordered_map<std::string, int>& node_indexes) = 0;
++
++ virtual std::string name() const { return "PatternRewriter"; };
++
++ const NodeDef* get_node(const std::string& name);
++ NodeDef* get_mutable_node(const std::string& name);
++
++ NodeDef* get_operand(const NodeDef* node, std::string op_type);
++
++ const NodeDef* get_user(const NodeDef* node, int index,
++ const std::string& op_type);
++
++ void replace_all_users_with(const NodeDef* old_node, int old_index,
++ const NodeDef* new_node, int new_index,
++ GraphDef* graph);
++
++ bool check_input_dims(NodeDef* op, const std::string& output_name,
++ int dim_size) {
++ if (op->attr().count("_output_shapes")) {
++ int pos_index = 0;
++ size_t pos = output_name.find_last_of(':');
++ if (pos != std::string::npos) {
++ pos_index = std::stoi(output_name.substr(pos + 1));
++ }
++ const TensorShapeProto& shape =
++ op->attr().at("_output_shapes").list().shape(pos_index);
++ if (shape.dim_size() == dim_size) return true;
++ } else if (op->attr().count("shape")) {
++ const TensorShapeProto& shape = op->attr().at("shape").shape();
++ if (shape.dim_size() == dim_size) return true;
++ }
++ return false;
++ }
++
++ bool check_const_dims(NodeDef* op, int dim_size) {
++ if (!((IsConstant(*op) || IsHostConstant(*op)) &&
++ HasNodeAttr(*op, "value")))
++ return false;
++
++ TensorProto* tensor = (*op->mutable_attr())["value"].mutable_tensor();
++ const auto& shape = tensor->tensor_shape();
++ if (shape.dim_size() != static_cast<int>(dim_size)) return false;
++ return true;
++ }
++
++ bool check_const_shape(NodeDef* op, std::vector<int> dims) {
++ if (!((IsConstant(*op) || IsHostConstant(*op)) &&
++ HasNodeAttr(*op, "value")))
++ return false;
++
++ TensorProto* tensor = (*op->mutable_attr())["value"].mutable_tensor();
++ const auto& shape = tensor->tensor_shape();
++ if (shape.dim_size() != static_cast<int>(dims.size())) return false;
++ for (int i = 0; i < shape.dim_size(); ++i) {
++ if (shape.dim(i).size() != dims[i]) return false;
++ }
++ return true;
++ }
++
++ template <typename T>
++ bool check_const_value(NodeDef* op, std::vector<T> cmp) {
++ if (!((IsConstant(*op) || IsHostConstant(*op)) &&
++ HasNodeAttr(*op, "value")))
++ return false;
++
++ TensorProto* tensor = (*op->mutable_attr())["value"].mutable_tensor();
++ const auto& shape = tensor->tensor_shape();
++ int dim_size = 1;
++ for (int i = 0; i < shape.dim_size(); ++i) {
++ dim_size *= shape.dim(i).size();
++ }
++ if (dim_size < static_cast<int>(cmp.size())) return false;
++
++ if (std::is_same<T, float>::value) {
++ const float* data = tensor->mutable_float_val()->data();
++ if (data == nullptr)
++ data = reinterpret_cast<const float*>(tensor->tensor_content().data());
++ if (data == nullptr) return false;
++ for (int i = 0; i < static_cast<int>(cmp.size()); ++i) {
++ if (std::fabs(data[i] - cmp[i]) >= 1e-5f) return false;
++ }
++ } else if (std::is_same<T, int>::value) {
++ const int* data = tensor->mutable_int_val()->data();
++ if (data == nullptr)
++ data = reinterpret_cast<const int*>(tensor->tensor_content().data());
++ if (data == nullptr) return false;
++ for (int i = 0; i < static_cast<int>(cmp.size()); ++i) {
++ if (data[i] != cmp[i]) return false;
++ }
++ } else if (std::is_same<T, int64_t>::value) {
++ const int64_t* data = tensor->mutable_int64_val()->data();
++ if (data == nullptr)
++ data =
++ reinterpret_cast<const int64_t*>(tensor->tensor_content().data());
++ if (data == nullptr) return false;
++ for (int i = 0; i < static_cast<int>(cmp.size()); ++i) {
++ if (data[i] != cmp[i]) return false;
++ }
++ } else {
++ // data type do not support
++ return false;
++ }
++ return true;
++ }
++
++ bool check_int_attr(const NodeDef* op, std::string name, int value) {
++ if (HasNodeAttr(*op, name)) {
++ AttrValue attr = op->attr().at(name);
++ if (attr.value_case() == AttrValue::kI && attr.i() == value) return true;
++ }
++ return false;
++ }
++
++ GraphDef* graph_;
++ std::unordered_map<std::string, int>* indexes_;
++};
++
++class GraphOptimizer {
++ public:
++ GraphOptimizer(GraphDef* graph) : graph_(graph) {}
++ virtual ~GraphOptimizer() = default;
++
++ void register_rewriter(std::unique_ptr<PatternRewriter> rewriter);
++
++ void optimize();
++
++ private:
++ GraphDef* graph_;
++ std::unordered_map<std::string, int> node_indexes_;
++ std::vector<std::unique_ptr<PatternRewriter>> rewriters_;
++};
++
++void run_graph_optimization(GraphDef* graph);
++} // namespace annc
++#endif // ANNC_TF_GRAPH_OPT_H_
+diff --git a/annc/tensorflow/kernels/embedding_fused_action_id_gather.cc b/annc/tensorflow/kernels/embedding_fused_action_id_gather.cc
+new file mode 100644
+index 0000000..db7c92c
+--- /dev/null
++++ b/annc/tensorflow/kernels/embedding_fused_action_id_gather.cc
+@@ -0,0 +1,144 @@
++/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
++
++Licensed under the Apache License, Version 2.0 (the "License");
++you may not use this file except in compliance with the License.
++You may obtain a copy of the License at
++
++ http://www.apache.org/licenses/LICENSE-2.0
++
++Unless required by applicable law or agreed to in writing, software
++distributed under the License is distributed on an "AS IS" BASIS,
++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++See the License for the specific language governing permissions and
++limitations under the License.
++==============================================================================*/
++
++#include "tensorflow/core/framework/op_kernel.h"
++#include "tensorflow/core/framework/tensor.h"
++#include "tensorflow/core/util/work_sharder.h"
++
++namespace tensorflow {
++
++template <typename Tindices>
++static void GatherV2Impl(OpKernelContext* context, const float* params_data, const TensorShape& params_shape,
++ const Tindices* indices_data, const TensorShape& indices_shape, int axis, Tensor* temp) {
++ TensorShape temp_shape;
++ const int P0 = params_shape.dim_size(0);
++ int P1 = 1;
++ for (int d = 0; d < indices_shape.dims(); ++d) {
++ temp_shape.AddDim(indices_shape.dim_size(d));
++ }
++
++ for (int d = 1; d < params_shape.dims(); ++d) {
++ temp_shape.AddDim(params_shape.dim_size(d));
++ P1 *= params_shape.dim_size(d);
++ }
++ OP_REQUIRES_OK(context, context->allocate_temp(DT_FLOAT, temp_shape, temp));
++ VLOG(1) << "temp shape: " << temp->shape().DebugString();
++
++ const int num_indices = indices_shape.num_elements();
++ float* temp_data = temp->flat<float>().data();
++ VLOG(1) << "num_indices : " << num_indices;
++ OP_REQUIRES(context, axis == 0, errors::InvalidArgument("axis only support 0"));
++ const int slice_size = P1;
++ for (int i = 0; i < num_indices; ++i) {
++ Tindices idx = indices_data[i];
++ OP_REQUIRES(context, (idx >= 0 && idx < P0), errors::InvalidArgument("GatherV2 axis=0: index out of range"));
++ std::memcpy(
++ temp_data + i * slice_size, params_data + idx * slice_size, sizeof(float) * slice_size
++ );
++ }
++}
++
++
++template <typename Tindices1, typename Tindices2>
++class KPFusedEmbeddingActionIdGatherOp : public OpKernel {
++public:
++ explicit KPFusedEmbeddingActionIdGatherOp(OpKernelConstruction* context) : OpKernel(context) {}
++
++ void Compute(OpKernelContext* context) override {
++ // Grab the input tensor
++ const Tensor& indices1 = context->input(0);
++ const Tensor& params = context->input(1);
++ const Tensor& indices2 = context->input(2);
++ const Tensor& pack_dim = context->input(3);
++
++ const Tensor& pack = context->input(4);
++
++ VLOG(1) << "indices1 shape: " << indices1.shape().DebugString();
++ VLOG(1) << "params shape: " << params.shape().DebugString();
++ VLOG(1) << "indices2 shape: " << indices2.shape().DebugString();
++ OP_REQUIRES(
++ context,
++ TensorShapeUtils::IsMatrix(indices1.shape()),
++ errors::InvalidArgument("indices1 dims must = 2")
++ );
++ OP_REQUIRES(
++ context,
++ TensorShapeUtils::IsMatrix(indices2.shape()),
++ errors::InvalidArgument("indices2 dims must = 2")
++ );
++ OP_REQUIRES(
++ context,
++ TensorShapeUtils::IsMatrix(params.shape()),
++ errors::InvalidArgument("params dims must = 2")
++ );
++ OP_REQUIRES(
++ context,
++ TensorShapeUtils::IsScalar(pack_dim.shape()),
++ errors::InvalidArgument("pack_dim is scalar")
++ );
++ OP_REQUIRES(
++ context,
++ TensorShapeUtils::IsScalar(pack.shape()),
++ errors::InvalidArgument("pack const is scalar")
++ );
++
++ Tensor temp;
++ GatherV2Impl<Tindices1>(context, params.flat<float>().data(), params.shape(), indices1.flat<Tindices1>().data(),
++ indices1.shape(), 0, &temp);
++ Tensor temp1;
++ GatherV2Impl<Tindices2>(context, temp.flat<float>().data(), temp.shape(), indices2.flat<Tindices2>().data(),
++ indices2.shape(), 0, &temp1);
++ int pack_size = pack_dim.scalar<int32>()();
++ int pack_const = pack.scalar<int32>()();
++ OP_REQUIRES(context, pack_size > 0, errors::InvalidArgument("pack_size must > 0"));
++ int a_reshaped_cols = temp1.NumElements() / pack_size;
++ auto a_reshaped = temp1.shaped<float, 2>({pack_size, a_reshaped_cols});
++ Tensor* output;
++ int output_cols = a_reshaped_cols + pack_const;
++ OP_REQUIRES_OK(context,
++ context->allocate_output(0, TensorShape({pack_size, output_cols}), &output));
++ auto a_reshaped_data = a_reshaped.data();
++ auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
++ const int64 cost_per_unit = a_reshaped_cols + pack_const;
++ auto work = [&](int64 start_row, int64 end_row) {
++ float* base = output->matrix<float>().data();
++ for (int64 row = start_row; row < end_row; ++row) {
++ float* dst_row = base + row * (a_reshaped_cols + pack_const);
++ std::memcpy(
++ dst_row, a_reshaped_data + row * a_reshaped_cols, sizeof(float) * a_reshaped_cols
++ );
++ std::memset(
++ dst_row + a_reshaped_cols, 0, sizeof(float) * pack_const
++ );
++ }
++ };
++ Shard(worker_threads->num_threads, worker_threads->workers, pack_size,
++ cost_per_unit, work);
++ }
++};
++
++#define REGISTER_CPU_KERNEL(Tindices1, Tindices2) \
++ REGISTER_KERNEL_BUILDER(Name("KPFusedEmbeddingActionIdGather") \
++ .Device(DEVICE_CPU) \
++ .TypeConstraint<Tindices1>("Tindices1") \
++ .TypeConstraint<Tindices2>("Tindices2"), \
++ KPFusedEmbeddingActionIdGatherOp<Tindices1, Tindices2>);
++
++REGISTER_CPU_KERNEL(int64, int32)
++REGISTER_CPU_KERNEL(int32, int32)
++REGISTER_CPU_KERNEL(int64, int64)
++REGISTER_CPU_KERNEL(int32, int64)
++
++}
+\ No newline at end of file
+diff --git a/annc/tensorflow/kernels/embedding_fused_action_id_gather_test.cc b/annc/tensorflow/kernels/embedding_fused_action_id_gather_test.cc
+new file mode 100644
+index 0000000..11067bf
+--- /dev/null
++++ b/annc/tensorflow/kernels/embedding_fused_action_id_gather_test.cc
+@@ -0,0 +1,289 @@
++/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
++ *
++ * Licensed under the Apache License, Version 2.0 (the "License");
++ * you may not use this file except in compliance with the License.
++ * You may obtain a copy of the License at
++ *
++ * http://www.apache.org/licenses/LICENSE-2.0
++ *
++ * Unless required by applicable law or agreed to in writing, software
++ * distributed under the License is distributed on an "AS IS" BASIS,
++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++ * See the License for the specific language governing permissions and
++ * limitations under the License.
++ * ==============================================================================*/
++
++#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
++#include "tensorflow/core/framework/allocator.h"
++#include "tensorflow/core/framework/fake_input.h"
++#include "tensorflow/core/framework/node_def_builder.h"
++#include "tensorflow/core/framework/op_kernel.h"
++#include "tensorflow/core/framework/tensor.h"
++#include "tensorflow/core/framework/types.h"
++#include "tensorflow/core/framework/types.pb.h"
++#include "tensorflow/core/graph/testlib.h"
++#include "tensorflow/core/kernels/ops_testutil.h"
++#include "tensorflow/core/kernels/ops_util.h"
++#include "tensorflow/core/lib/core/status_test_util.h"
++#include "tensorflow/core/lib/gtl/array_slice.h"
++#include "tensorflow/core/lib/random/simple_philox.h"
++#include "tensorflow/core/lib/strings/str_util.h"
++#include "tensorflow/core/platform/test.h"
++#include "tensorflow/core/platform/test_benchmark.h"
++
++namespace tensorflow {
++
++class KPFusedEmbeddingActionIdGatherTest : public OpsTestBase {
++ protected:
++ void MakeOp(DataType indices1_type, DataType indices2_type) {
++ TF_ASSERT_OK(NodeDefBuilder("fused_embedding_action_id_gather",
++ "KPFusedEmbeddingActionIdGather")
++ .Input(FakeInput(indices1_type)) // indices1
++ .Input(FakeInput(DT_FLOAT)) // params
++ .Input(FakeInput(indices2_type)) // indices2
++ .Input(FakeInput(DT_INT32)) // pack_dim
++ .Input(FakeInput(DT_INT32)) // pack
++ .Finalize(node_def()));
++ TF_ASSERT_OK(InitOp());
++ }
++
++ template <typename Tindices1, typename Tindices2>
++ Status FeedAndRun(const std::vector<Tindices1>& indices1_data,
++ const TensorShape& indices1_shape,
++ const std::vector<float>& params_data,
++ const TensorShape& params_shape,
++ const std::vector<Tindices2>& indices2_data,
++ const TensorShape& indices2_shape, int pack_dim_value,
++ int pack_value) {
++ inputs_.clear();
++ input_types_.clear();
++
++ MakeOp(DataTypeToEnum<Tindices1>::v(), DataTypeToEnum<Tindices2>::v());
++ AddInputFromArray<Tindices1>(indices1_shape, indices1_data);
++ AddInputFromArray<float>(params_shape, params_data);
++ AddInputFromArray<Tindices2>(indices2_shape, indices2_data);
++ AddInputFromArray<int32>(TensorShape({}), {pack_dim_value});
++ AddInputFromArray<int32>(TensorShape({}), {pack_value});
++ return RunOpKernel();
++ }
++};
++
++TEST_F(KPFusedEmbeddingActionIdGatherTest, NormalCase) {
++ std::vector<int64> indices1_data = {0, 2};
++ TensorShape indices1_shape({2, 1});
++
++ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
++ TensorShape params_shape({3, 2});
++
++ std::vector<int32> indices2_data = {1, 0};
++ TensorShape indices2_shape({2, 1});
++
++ int pack_dim_value = 2;
++ int pack_value = 1;
++
++ TF_ASSERT_OK((FeedAndRun<int64, int32>(
++ indices1_data, indices1_shape, params_data, params_shape, indices2_data,
++ indices2_shape, pack_dim_value, pack_value)));
++
++ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
++ test::FillValues<float>(&expected, {5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f});
++ test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
++}
++
++TEST_F(KPFusedEmbeddingActionIdGatherTest, DifferentIndexTypes) {
++ // int64int32
++ {
++ std::vector<int64> indices1 = {0, 2};
++ std::vector<int32> indices2 = {1, 0};
++ TF_ASSERT_OK((FeedAndRun<int64, int32>(indices1, {2, 1},
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
++ {3, 2}, indices2, {2, 1}, 2, 1)));
++ test::ExpectTensorNear<float>(
++ *GetOutput(0),
++ test::AsTensor<float>({5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}, {2, 3}),
++ 1e-5);
++ }
++
++ // int32int32
++ {
++ std::vector<int32> indices1 = {0, 2};
++ std::vector<int32> indices2 = {1, 0};
++ TF_ASSERT_OK((FeedAndRun<int32, int32>(indices1, {2, 1},
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
++ {3, 2}, indices2, {2, 1}, 2, 1)));
++ test::ExpectTensorNear<float>(
++ *GetOutput(0),
++ test::AsTensor<float>({5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}, {2, 3}),
++ 1e-5);
++ }
++
++ // int64int64
++ {
++ std::vector<int64> indices1 = {0, 2};
++ std::vector<int64> indices2 = {1, 0};
++ TF_ASSERT_OK((FeedAndRun<int64, int64>(indices1, {2, 1},
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
++ {3, 2}, indices2, {2, 1}, 2, 1)));
++ test::ExpectTensorNear<float>(
++ *GetOutput(0),
++ test::AsTensor<float>({5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}, {2, 3}),
++ 1e-5);
++ }
++
++ // int32int64
++ {
++ std::vector<int32> indices1 = {0, 2};
++ std::vector<int64> indices2 = {1, 0};
++ TF_ASSERT_OK((FeedAndRun<int32, int64>(indices1, {2, 1},
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
++ {3, 2}, indices2, {2, 1}, 2, 1)));
++ test::ExpectTensorNear<float>(
++ *GetOutput(0),
++ test::AsTensor<float>({5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}, {2, 3}),
++ 1e-5);
++ }
++}
++
++TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidIndices1Dims) {
++ MakeOp(DT_INT64, DT_INT32);
++
++ std::vector<int64> indices1_data = {0, 2};
++ AddInputFromArray<int64>(TensorShape({2}), indices1_data);
++
++ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
++ AddInputFromArray<float>(TensorShape({3, 2}), params_data);
++
++ std::vector<int32> indices2_data = {1, 0};
++ AddInputFromArray<int32>(TensorShape({2, 1}), indices2_data);
++
++ AddInputFromArray<int32>(TensorShape({}), {2});
++ AddInputFromArray<int32>(TensorShape({}), {1});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.ToString(), "indices1 dims must = 2")) << s;
++}
++
++TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidIndices2Dims) {
++ MakeOp(DT_INT64, DT_INT32);
++
++ std::vector<int64> indices1_data = {0, 2};
++ AddInputFromArray<int64>(TensorShape({2, 1}), indices1_data);
++
++ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
++ AddInputFromArray<float>(TensorShape({3, 2}), params_data);
++
++ std::vector<int32> indices2_data = {1, 0};
++ AddInputFromArray<int32>(TensorShape({2}), indices2_data);
++
++ AddInputFromArray<int32>(TensorShape({}), {2});
++ AddInputFromArray<int32>(TensorShape({}), {1});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.ToString(), "indices2 dims must = 2")) << s;
++}
++
++TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidParamsDims) {
++ MakeOp(DT_INT64, DT_INT32);
++
++ std::vector<int64> indices1_data = {0, 2};
++ AddInputFromArray<int64>(TensorShape({2, 1}), indices1_data);
++
++ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f};
++ AddInputFromArray<float>(TensorShape({4}), params_data);
++
++ std::vector<int32> indices2_data = {1, 0};
++ AddInputFromArray<int32>(TensorShape({2, 1}), indices2_data);
++
++ AddInputFromArray<int32>(TensorShape({}), {2});
++ AddInputFromArray<int32>(TensorShape({}), {1});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.ToString(), "params dims must = 2")) << s;
++}
++
++TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidPackDimDims) {
++ MakeOp(DT_INT64, DT_INT32);
++
++ std::vector<int64> indices1_data = {0, 2};
++ AddInputFromArray<int64>(TensorShape({2, 1}), indices1_data);
++
++ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
++ AddInputFromArray<float>(TensorShape({3, 2}), params_data);
++
++ std::vector<int32> indices2_data = {1, 0};
++ AddInputFromArray<int32>(TensorShape({2, 1}), indices2_data);
++
++ AddInputFromArray<int32>(TensorShape({1}), {2});
++ AddInputFromArray<int32>(TensorShape({}), {1});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.ToString(), "pack_dim is scalar")) << s;
++}
++
++TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidPackDims) {
++ MakeOp(DT_INT64, DT_INT32);
++
++ std::vector<int64> indices1_data = {0, 2};
++ AddInputFromArray<int64>(TensorShape({2, 1}), indices1_data);
++
++ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
++ AddInputFromArray<float>(TensorShape({3, 2}), params_data);
++
++ std::vector<int32> indices2_data = {1, 0};
++ AddInputFromArray<int32>(TensorShape({2, 1}), indices2_data);
++
++ AddInputFromArray<int32>(TensorShape({}), {2});
++ AddInputFromArray<int32>(TensorShape({1}), {1});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.ToString(), "pack const is scalar")) << s;
++}
++
++TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidPackSize) {
++ MakeOp(DT_INT64, DT_INT32);
++
++ std::vector<int64> indices1_data = {0, 2};
++ AddInputFromArray<int64>(TensorShape({2, 1}), indices1_data);
++
++ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
++ AddInputFromArray<float>(TensorShape({3, 2}), params_data);
++
++ std::vector<int32> indices2_data = {1, 0};
++ AddInputFromArray<int32>(TensorShape({2, 1}), indices2_data);
++
++ AddInputFromArray<int32>(TensorShape({}), {0});
++ AddInputFromArray<int32>(TensorShape({}), {1});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.ToString(), "pack_size must > 0")) << s;
++}
++
++TEST_F(KPFusedEmbeddingActionIdGatherTest, IndexOutOfRange) {
++ MakeOp(DT_INT64, DT_INT32);
++
++ std::vector<int64> indices1_data = {0, 5};
++ AddInputFromArray<int64>(TensorShape({2, 1}), indices1_data);
++
++ std::vector<float> params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
++ AddInputFromArray<float>(TensorShape({3, 2}), params_data);
++
++ std::vector<int32> indices2_data = {1, 0};
++ AddInputFromArray<int32>(TensorShape({2, 1}), indices2_data);
++
++ AddInputFromArray<int32>(TensorShape({}), {2});
++ AddInputFromArray<int32>(TensorShape({}), {1});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(
++ absl::StrContains(s.ToString(), "GatherV2 axis=0: index out of range"))
++ << s;
++}
++
++} // namespace tensorflow
+diff --git a/annc/tensorflow/kernels/embedding_fused_gather.cc b/annc/tensorflow/kernels/embedding_fused_gather.cc
+new file mode 100644
+index 0000000..c09d1ce
+--- /dev/null
++++ b/annc/tensorflow/kernels/embedding_fused_gather.cc
+@@ -0,0 +1,90 @@
++/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
++
++Licensed under the Apache License, Version 2.0 (the "License");
++you may not use this file except in compliance with the License.
++You may obtain a copy of the License at
++
++ http://www.apache.org/licenses/LICENSE-2.0
++
++Unless required by applicable law or agreed to in writing, software
++distributed under the License is distributed on an "AS IS" BASIS,
++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++See the License for the specific language governing permissions and
++limitations under the License.
++==============================================================================*/
++
++#include "tensorflow/core/framework/op_kernel.h"
++#include "tensorflow/core/framework/tensor.h"
++#include "tensorflow/core/util/work_sharder.h"
++
++using namespace tensorflow;
++
++class KPFusedGather : public OpKernel {
++ public:
++ explicit KPFusedGather(OpKernelConstruction* context) : OpKernel(context) { }
++
++ void Compute(OpKernelContext* context) override {
++ const Tensor& data = context->input(0);
++ const Tensor& keys = context->input(1);
++ const Tensor& begin = context->input(2);
++ VLOG(1) << "Embedding table size: " << data.shape().DebugString();
++ VLOG(1) << "Input key shape: " << keys.shape().DebugString();
++ VLOG(1) << "Slice begin value: " << begin.DebugString();
++
++ OP_REQUIRES(context,
++ TensorShapeUtils::IsMatrix(keys.shape()),
++ errors::Internal("Input key must be 2D"));
++ OP_REQUIRES(context,
++ TensorShapeUtils::IsMatrix(data.shape()),
++ errors::Internal("Embedding table shape must be 2D"));
++ OP_REQUIRES(context, begin.NumElements() == 2, errors::Internal("begin must be same as keys rank"));
++ int32 col = begin.flat<int32>().data()[1];
++ OP_REQUIRES(context, col < keys.dim_size(1), errors::Internal("slice cols out of keys range"));
++
++ Tensor* out_indices = nullptr;
++ OP_REQUIRES_OK(context,
++ context->allocate_output(
++ 1, TensorShape({static_cast<int32>(keys.dim_size(0))}), &out_indices));
++ int32 *out_indices_data = out_indices->flat<int32>().data();
++
++ auto keys_mat = keys.matrix<int64>();
++ std::vector<int64_t> unique_values;
++ std::unordered_map<int64_t, int32_t> value_to_index;
++ int current_index = 0;
++ for (int64_t i = 0; i < keys.dim_size(0); ++i) {
++ auto it = value_to_index.find(keys_mat(i, col));
++ if (it == value_to_index.end()) {
++ value_to_index[keys_mat(i, col)] = current_index;
++ unique_values.push_back(keys_mat(i, col));
++ out_indices_data[i] = current_index;
++ ++current_index;
++ } else {
++ out_indices_data[i] = it->second;
++ }
++ }
++
++ Tensor* out_unique_value = nullptr;
++ OP_REQUIRES_OK(context,
++ context->allocate_output(
++ 0, TensorShape({static_cast<int32>(unique_values.size())}), &out_unique_value));
++ std::memcpy(out_unique_value->data(), unique_values.data(), unique_values.size() * sizeof(int64_t));
++
++ Tensor* out_data = nullptr;
++ int embedding_dims = data.dim_size(1);
++ OP_REQUIRES_OK(context,
++ context->allocate_output(
++ 2, TensorShape({static_cast<int32>(unique_values.size()), embedding_dims}), &out_data));
++
++ const float *data_mat = data.flat<float>().data();
++ for (int64_t cur_row = 0; cur_row < unique_values.size(); ++cur_row) {
++ int64_t idx = unique_values[cur_row];
++ OP_REQUIRES(context, idx < data.dim_size(0), errors::Internal("idx out of table range"));
++ const float* src = data_mat + idx * embedding_dims;
++ float* dst = out_data->flat<float>().data() + cur_row * embedding_dims;
++ std::memcpy(dst, src, embedding_dims * sizeof(float));
++ }
++ }
++};
++
++REGISTER_KERNEL_BUILDER(Name("KPFusedGather").Device(DEVICE_CPU),
++ KPFusedGather);
+\ No newline at end of file
+diff --git a/annc/tensorflow/kernels/embedding_fused_gather_test.cc b/annc/tensorflow/kernels/embedding_fused_gather_test.cc
+new file mode 100644
+index 0000000..c947709
+--- /dev/null
++++ b/annc/tensorflow/kernels/embedding_fused_gather_test.cc
+@@ -0,0 +1,186 @@
++/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
++
++Licensed under the Apache License, Version 2.0 (the "License");
++you may not use this file except in compliance with the License.
++You may obtain a copy of the License at
++
++ http://www.apache.org/licenses/LICENSE-2.0
++
++Unless required by applicable law or agreed to in writing, software
++distributed under the License is distributed on an "AS IS" BASIS,
++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++See the License for the specific language governing permissions and
++limitations under the License.
++==============================================================================*/
++
++#include "tensorflow/core/framework/fake_input.h"
++#include "tensorflow/core/framework/node_def_builder.h"
++#include "tensorflow/core/framework/tensor.h"
++#include "tensorflow/core/framework/tensor_testutil.h"
++#include "tensorflow/core/kernels/ops_testutil.h"
++#include "tensorflow/core/lib/core/status_test_util.h"
++
++namespace {
++using tensorflow::AllocatorAttributes;
++using tensorflow::DT_FLOAT;
++using tensorflow::DT_INT32;
++using tensorflow::DT_INT64;
++using tensorflow::int64;
++using tensorflow::int32;
++using tensorflow::NodeDefBuilder;
++using tensorflow::OpsTestBase;
++using tensorflow::Status;
++using tensorflow::Tensor;
++using tensorflow::TensorShape;
++using tensorflow::test::ExpectClose;
++using tensorflow::test::FillValues;
++using tensorflow::test::AsTensor;
++using tensorflow::test::ExpectTensorEqual;
++
++class KPFusedGatherTest : public OpsTestBase {
++ protected:
++ void RunValidCase(const TensorShape& data_shape,
++ const TensorShape& slice_shape,
++ const std::vector<int32>& begin_val,
++ const std::vector<int64>& slice_data,
++ const std::vector<float>& data_data,
++ const std::vector<int64>& expected_unique,
++ const std::vector<int32>& expected_indices,
++ const std::vector<float>& expected_output_data) {
++ TF_EXPECT_OK(NodeDefBuilder("kp_fused_gather", "KPFusedGather")
++ .Input(FakeInput(DT_FLOAT))
++ .Input(FakeInput(DT_INT64))
++ .Input(FakeInput(DT_INT32))
++ .Finalize(node_def()));
++ TF_EXPECT_OK(InitOp());
++
++ AddInputFromArray<float>(data_shape, data_data);
++ AddInputFromArray<int64>(slice_shape, slice_data);
++ AddInputFromArray<int32>(TensorShape({2}), begin_val);
++
++ TF_ASSERT_OK(RunOpKernel());
++
++ const Tensor& out_unique = *GetOutput(0);
++ const Tensor& out_indices = *GetOutput(1);
++ const Tensor& out_data = *GetOutput(2);
++
++ // 验证输出0: unique_values
++ Tensor expected_unique_tensor(
++ allocator(), DT_INT64,
++ TensorShape({static_cast<int64>(expected_unique.size())})
++ );
++ FillValues<int64>(&expected_unique_tensor, expected_unique);
++ ExpectTensorEqual<int64>(expected_unique_tensor, out_unique);
++
++ // 验证输出1: indices
++ Tensor expected_indices_tensor(
++ allocator(), DT_INT32,
++ TensorShape({static_cast<int64_t>(expected_indices.size())})
++ );
++ FillValues<int32>(&expected_indices_tensor, expected_indices);
++ ExpectTensorEqual<int32>(expected_indices_tensor, out_indices);
++
++ // 验证输出2: out_data
++ Tensor expected_data_tensor(allocator(), DT_FLOAT,
++ TensorShape({static_cast<int64>(expected_unique.size()), 12}));
++ FillValues<float>(&expected_data_tensor, expected_output_data);
++ ExpectClose(expected_data_tensor, out_data); // float 用 ExpectClose
++ }
++
++ Status RunOpExpectFailure(const TensorShape& data_shape,
++ const TensorShape& slice_shape,
++ const std::vector<int32>& begin_val,
++ const std::vector<int64>& slice_data,
++ const std::vector<float>& data_data) {
++ TF_CHECK_OK(NodeDefBuilder("kp_fused_gather", "KPFusedGather")
++ .Input(FakeInput(DT_FLOAT))
++ .Input(FakeInput(DT_INT64))
++ .Input(FakeInput(DT_INT32))
++ .Finalize(node_def()));
++ TF_CHECK_OK(InitOp());
++
++ AddInputFromArray<float>(data_shape, data_data);
++ AddInputFromArray<int64>(slice_shape, slice_data);
++ AddInputFromArray<int32>(TensorShape({2}), begin_val);
++
++ return RunOpKernel();
++ }
++};
++
++// 正向测试:正常输入
++TEST_F(KPFusedGatherTest, Valid_NormalInput) {
++ RunValidCase(
++ TensorShape({2, 12}), // data shape
++ TensorShape({4, 3}), // slice_input shape
++ {0, 1}, // begin[1] = 1 → 取第1列
++ {1, 1, 3,
++ 0, 1, 5,
++ 1, 0, 7,
++ 0, 1, 9}, // slice_input 数据
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
++ 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f},
++ {1, 0}, // unique values from col=1
++ {0, 0, 1, 0}, // indices mapping
++ {13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, // data[1]
++ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} // data[0]
++ );
++}
++
++// data不是2维
++TEST_F(KPFusedGatherTest, Invalid_DataDimsNot2) {
++ std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f};
++ Status s = RunOpExpectFailure(
++ TensorShape({4}), // data 不是二维
++ TensorShape({2, 2}),
++ {0, 0},
++ {0, 1, 2, 3},
++ data
++ );
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "Embedding table shape must be 2D"));
++}
++
++// key 不是2维
++TEST_F(KPFusedGatherTest, Invalid_SliceInputDimsNot2) {
++ std::vector<float> data(2 * 12, 1.0f);
++ Status s = RunOpExpectFailure(
++ TensorShape({2, 12}),
++ TensorShape({4}), // 1D slice_input
++ {0, 0},
++ {0, 1, 2, 3},
++ data
++ );
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "Input key must be 2D"));
++}
++
++// begin[1] 超出列范围
++TEST_F(KPFusedGatherTest, Invalid_BeginColOutOfRange) {
++ std::vector<float> data(2 * 12, 1.0f);
++ Status s = RunOpExpectFailure(
++ TensorShape({2, 12}),
++ TensorShape({2, 2}),
++ {0, 2}, // begin[1] = 2,但只有 2 列 → 索引 0,1
++ {0, 1, 2, 3},
++ data
++ );
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(),"slice cols out of keys range"));
++}
++
++// gather 索引超出 data 行数
++TEST_F(KPFusedGatherTest, Invalid_IndexOutOfRangeInData) {
++ std::vector<float> data(2 * 12, 1.0f);
++ Status s = RunOpExpectFailure(
++ TensorShape({2, 12}),
++ TensorShape({2, 2}),
++ {0, 0},
++ {0, 1,
++ 2, 3}, // 索引 2 超出 data 行数(只有 0,1)
++ data
++ );
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(),"idx out of table range"));
++}
++
++}
+\ No newline at end of file
+diff --git a/annc/tensorflow/kernels/embedding_fused_padding.cc b/annc/tensorflow/kernels/embedding_fused_padding.cc
+new file mode 100644
+index 0000000..a0f14b2
+--- /dev/null
++++ b/annc/tensorflow/kernels/embedding_fused_padding.cc
+@@ -0,0 +1,126 @@
++/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
++
++Licensed under the Apache License, Version 2.0 (the "License");
++you may not use this file except in compliance with the License.
++You may obtain a copy of the License at
++
++ http://www.apache.org/licenses/LICENSE-2.0
++
++Unless required by applicable law or agreed to in writing, software
++distributed under the License is distributed on an "AS IS" BASIS,
++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++See the License for the specific language governing permissions and
++limitations under the License.
++==============================================================================*/
++
++#include <vector>
++
++#include "tensorflow/core/framework/common_shape_fns.h"
++#include "tensorflow/core/framework/shape_inference.h"
++#include "tensorflow/core/framework/types.h"
++#include "tensorflow/core/framework/tensor.h"
++#include "tensorflow/core/framework/op_kernel.h"
++
++namespace tensorflow {
++
++using shape_inference::InferenceContext;
++using shape_inference::ShapeHandle;
++
++class KPFusedEmbeddingPaddingOp : public OpKernel {
++public:
++ explicit KPFusedEmbeddingPaddingOp(OpKernelConstruction* context) : OpKernel(context) {
++ fast_ = (type_string() == "KPFusedEmbeddingPaddingFast");
++ }
++
++ void Compute(OpKernelContext* context) override {
++ // Grab the input tensor
++ const Tensor& origin_shape = context->input(0);
++ const Tensor& input = context->input(1);
++ const Tensor& input_rows = context->input(2);
++ const Tensor& reshape_sizes = context->input(3);
++
++ const Tensor& pack = context->input(4);
++
++ VLOG(1) << "Input shape: " << input.shape().DebugString();
++ OP_REQUIRES(context,
++ TensorShapeUtils::IsVector(origin_shape.shape()),
++ errors::InvalidArgument("origin_shape dims must 1D, not ", origin_shape.shape().DebugString())
++ );
++ OP_REQUIRES(context,
++ origin_shape.NumElements() == 2,
++ errors::InvalidArgument("origin_shape NumElements must == 2, not ", origin_shape.NumElements())
++ );
++ OP_REQUIRES(context,
++ TensorShapeUtils::IsMatrix(input.shape()),
++ errors::InvalidArgument("input dims must 2D, not ", input.shape().DebugString()));
++ OP_REQUIRES(context,
++ TensorShapeUtils::IsScalar(input_rows.shape()),
++ errors::InvalidArgument("input_rows must be a scalar")
++ );
++ OP_REQUIRES(context,
++ TensorShapeUtils::IsVector(reshape_sizes.shape()),
++ errors::InvalidArgument("sizes input must be 1-D, not ", reshape_sizes.shape().DebugString())
++ );
++ OP_REQUIRES(context,
++ reshape_sizes.NumElements() == 2,
++ errors::InvalidArgument("reshape_sizes NumElements must == 2"));
++
++ int input_rows_value = input_rows.scalar<int32>()();
++ int padding_rows = static_cast<int32>(origin_shape.flat<int64>()(0)) - input_rows_value;
++ auto reshape_cols = reshape_sizes.flat<int32>()(1);
++ int output_rows = padding_rows + input.dim_size(0);
++ int output_cols = input.dim_size(1);
++ OP_REQUIRES(context,
++ padding_rows >= 0,
++ errors::InvalidArgument("Pooling size(", input_rows_value,
++ ") is greater than Input size(", static_cast<int32>(origin_shape.flat<int64>()(0)), ")"));
++ OP_REQUIRES(context,
++ reshape_cols > 0,
++ errors::InvalidArgument("reshape_cols must > 0"));
++ OP_REQUIRES(context,
++ reshape_sizes.flat<int32>()(0) == -1,
++ errors::InvalidArgument("reshape[0] is not -1"));
++ OP_REQUIRES(context,
++ pack.scalar<int32>()() == output_cols,
++ errors::InvalidArgument("pack(", pack.scalar<int32>()(), ") is not equal to embedding dims"));
++
++ Tensor* output0 = nullptr;
++ Tensor* output1 = nullptr;
++ OP_REQUIRES_OK(context,
++ context->allocate_output(0, TensorShape({}),
++ &output0));
++ output0->scalar<int32>()() = padding_rows;
++ OP_REQUIRES(context,
++ output_rows * output_cols % reshape_cols == 0,
++ errors::InvalidArgument("padding cannot reshape to [-1, ", reshape_cols, "]")
++ );
++ int reshape_rows = output_rows * output_cols / reshape_cols;
++ if (fast_) {
++ OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}), &output1));
++ output1->scalar<int32>()() = reshape_rows;
++ return;
++ }
++
++ TensorShape reshaped_shape({reshape_rows, reshape_cols});
++ OP_REQUIRES_OK(context,
++ context->allocate_output(1, reshaped_shape, &output1));
++ float* output_data = output1->flat<float>().data();
++ const float* input_data = input.flat<float>().data();
++ std::memcpy(output_data, input_data, input_rows_value * output_cols * sizeof(float));
++ std::memset(output_data + input_rows_value * output_cols,
++ 0.0f,
++ padding_rows * output_cols * sizeof(float));
++ }
++
++private:
++ bool fast_;
++};
++
++
++REGISTER_KERNEL_BUILDER(Name("KPFusedEmbeddingPadding").Device(DEVICE_CPU),
++ KPFusedEmbeddingPaddingOp);
++
++REGISTER_KERNEL_BUILDER(Name("KPFusedEmbeddingPaddingFast").Device(DEVICE_CPU),
++ KPFusedEmbeddingPaddingOp);
++
++}
+\ No newline at end of file
+diff --git a/annc/tensorflow/kernels/embedding_fused_padding_test.cc b/annc/tensorflow/kernels/embedding_fused_padding_test.cc
+new file mode 100644
+index 0000000..5137d51
+--- /dev/null
++++ b/annc/tensorflow/kernels/embedding_fused_padding_test.cc
+@@ -0,0 +1,307 @@
++/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
++
++Licensed under the Apache License, Version 2.0 (the "License");
++you may not use this file except in compliance with the License.
++You may obtain a copy of the License at
++
++ http://www.apache.org/licenses/LICENSE-2.0
++
++Unless required by applicable law or agreed to in writing, software
++distributed under the License is distributed on an "AS IS" BASIS,
++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++See the License for the specific language governing permissions and
++limitations under the License.
++==============================================================================*/
++
++#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
++#include "tensorflow/core/framework/allocator.h"
++#include "tensorflow/core/framework/fake_input.h"
++#include "tensorflow/core/framework/node_def_builder.h"
++#include "tensorflow/core/framework/op_kernel.h"
++#include "tensorflow/core/framework/tensor.h"
++#include "tensorflow/core/framework/types.h"
++#include "tensorflow/core/framework/types.pb.h"
++#include "tensorflow/core/graph/testlib.h"
++#include "tensorflow/core/kernels/ops_testutil.h"
++#include "tensorflow/core/kernels/ops_util.h"
++#include "tensorflow/core/lib/core/status_test_util.h"
++#include "tensorflow/core/lib/gtl/array_slice.h"
++#include "tensorflow/core/lib/random/simple_philox.h"
++#include "tensorflow/core/lib/strings/str_util.h"
++#include "tensorflow/core/platform/test.h"
++#include "tensorflow/core/platform/test_benchmark.h"
++
++namespace tensorflow {
++
++class KPFusedEmbeddingPaddingTest : public OpsTestBase {
++ protected:
++ void MakeOp(DataType input_shape_type, DataType pooling_type, DataType reshape_type, DataType const_type) {
++ TF_ASSERT_OK(NodeDefBuilder("fused_padding", "KPFusedEmbeddingPadding")
++ .Input(FakeInput(input_shape_type))
++ .Input(FakeInput(pooling_type))
++ .Input(FakeInput(const_type))
++ .Input(FakeInput(reshape_type))
++ .Input(FakeInput(const_type))
++ .Finalize(node_def()));
++ TF_ASSERT_OK(InitOp());
++ }
++
++ Status FeedAndRun(const int embedding_dims, const int table_size,
++ const int pooling_size, const int reshape_size) {
++ MakeOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32);
++ AddInputFromArray<int64>(TensorShape({2}), {table_size, embedding_dims});
++ AddInput<float>(TensorShape({pooling_size, embedding_dims}), [](int i) -> float {
++ return static_cast<float>(i + 1);
++ });
++ AddInputFromArray<int32>(TensorShape({}), {pooling_size});
++ AddInputFromArray<int32>(TensorShape({2}), {-1, reshape_size});
++ AddInputFromArray<int32>(TensorShape({}), {embedding_dims});
++ return RunOpKernel();
++ }
++
++ void MakeFastOp(DataType input_shape_type, DataType pooling_type, DataType reshape_type, DataType const_type) {
++ TF_ASSERT_OK(NodeDefBuilder("fused_padding_fast", "KPFusedEmbeddingPaddingFast")
++ .Input(FakeInput(input_shape_type))
++ .Input(FakeInput(pooling_type))
++ .Input(FakeInput(const_type))
++ .Input(FakeInput(reshape_type))
++ .Input(FakeInput(const_type))
++ .Finalize(node_def()));
++ TF_ASSERT_OK(InitOp());
++ }
++
++ Status FeedAndRunFast(const int embedding_dims, const int table_size,
++ const int pooling_size, const int reshape_size) {
++ MakeFastOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32);
++ AddInputFromArray<int64>(TensorShape({2}), {table_size, embedding_dims});
++ AddInput<float>(TensorShape({pooling_size, embedding_dims}), [](int i) -> float {
++ return static_cast<float>(i + 1);
++ });
++ AddInputFromArray<int32>(TensorShape({}), {pooling_size});
++ AddInputFromArray<int32>(TensorShape({2}), {-1, reshape_size});
++ AddInputFromArray<int32>(TensorShape({}), {embedding_dims});
++ return RunOpKernel();
++ }
++};
++
++TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithEmbeddingDims10_0) {
++ // Feed and run
++ const int embedding_dims = 10;
++ const int table_size = 151;
++ const int pooling_size = 151;
++ const int reshape_size = 1510;
++ TF_ASSERT_OK(FeedAndRun(embedding_dims, table_size, pooling_size, reshape_size));
++
++ // Check the output.
++ Tensor expected1(allocator(), DT_INT32, TensorShape({}));
++ Tensor expected2(allocator(), DT_FLOAT, TensorShape({table_size * embedding_dims / reshape_size, reshape_size}));
++ test::FillValues<int32>(&expected1, {table_size - pooling_size});
++ test::FillFn<float>(&expected2, [=](int i) -> float {
++ if (i < pooling_size * embedding_dims) {
++ return static_cast<float>(i + 1);
++ } else {
++ return 0.0f;
++ }
++ });
++ test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
++ test::ExpectTensorNear<float>(expected2, *GetOutput(1), 1e-5);
++}
++
++TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithEmbeddingDims10_1) {
++ // Feed and run
++ const int embedding_dims = 10;
++ const int table_size = 1510;
++ const int pooling_size = 151;
++ const int reshape_size = 1510;
++ TF_ASSERT_OK(FeedAndRun(embedding_dims, table_size, pooling_size, reshape_size));
++
++ // Check the output.
++ Tensor expected1(allocator(), DT_INT32, TensorShape({}));
++ Tensor expected2(allocator(), DT_FLOAT, TensorShape({table_size * embedding_dims / reshape_size, reshape_size}));
++ test::FillValues<int32>(&expected1, {table_size - pooling_size});
++ test::FillFn<float>(&expected2, [=](int i) -> float {
++ if (i < pooling_size * embedding_dims) {
++ return static_cast<float>(i + 1);
++ } else {
++ return 0.0f;
++ }
++ });
++ test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
++ test::ExpectTensorNear<float>(expected2, *GetOutput(1), 1e-5);
++}
++
++TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithEmbeddingDims12_0) {
++ // Feed and run
++ const int embedding_dims = 12;
++ const int table_size = 2;
++ const int pooling_size = 2;
++ const int reshape_size = 24;
++ TF_ASSERT_OK(FeedAndRun(embedding_dims, table_size, pooling_size, reshape_size));
++
++ // Check the output.
++ Tensor expected1(allocator(), DT_INT32, TensorShape({}));
++ Tensor expected2(allocator(), DT_FLOAT, TensorShape({table_size * embedding_dims / reshape_size, reshape_size}));
++ test::FillValues<int32>(&expected1, {table_size - pooling_size});
++ test::FillFn<float>(&expected2, [=](int i) -> float {
++ if (i < pooling_size * embedding_dims) {
++ return static_cast<float>(i + 1);
++ } else {
++ return 0.0f;
++ }
++ });
++ test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
++ test::ExpectTensorNear<float>(expected2, *GetOutput(1), 1e-5);
++}
++
++TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithEmbeddingDims12_1) {
++ // Feed and run
++ const int embedding_dims = 12;
++ const int table_size = 200;
++ const int pooling_size = 2;
++ const int reshape_size = 24;
++ TF_ASSERT_OK(FeedAndRun(embedding_dims, table_size, pooling_size, reshape_size));
++
++ // Check the output.
++ Tensor expected1(allocator(), DT_INT32, TensorShape({}));
++ Tensor expected2(allocator(), DT_FLOAT, TensorShape({table_size * embedding_dims / reshape_size, reshape_size}));
++ test::FillValues<int32>(&expected1, {table_size - pooling_size});
++ test::FillFn<float>(&expected2, [=](int i) -> float {
++ if (i < pooling_size * embedding_dims) {
++ return static_cast<float>(i + 1);
++ } else {
++ return 0.0f;
++ }
++ });
++ test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
++ test::ExpectTensorNear<float>(expected2, *GetOutput(1), 1e-5);
++}
++
++TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingFastWithEmbeddingDims10_0) {
++ // Feed and run
++ const int embedding_dims = 10;
++ const int table_size = 151;
++ const int pooling_size = 151;
++ const int reshape_size = 1510;
++ TF_ASSERT_OK(FeedAndRunFast(embedding_dims, table_size, pooling_size, reshape_size));
++
++ // Check the output.
++ Tensor expected1(allocator(), DT_INT32, TensorShape({}));
++ Tensor expected2(allocator(), DT_INT32, TensorShape({}));
++ test::FillValues<int32>(&expected1, {table_size - pooling_size});
++ test::FillValues<int32>(&expected2, {table_size * embedding_dims / reshape_size});
++ test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
++ test::ExpectTensorEqual<int32>(expected2, *GetOutput(1));
++}
++
++TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingFastWithEmbeddingDims10_1) {
++ // Feed and run
++ const int embedding_dims = 10;
++ const int table_size = 1510;
++ const int pooling_size = 151;
++ const int reshape_size = 1510;
++ TF_ASSERT_OK(FeedAndRunFast(embedding_dims, table_size, pooling_size, reshape_size));
++
++ // Check the output.
++ Tensor expected1(allocator(), DT_INT32, TensorShape({}));
++ Tensor expected2(allocator(), DT_INT32, TensorShape({}));
++ test::FillValues<int32>(&expected1, {table_size - pooling_size});
++ test::FillValues<int32>(&expected2, {table_size * embedding_dims / reshape_size});
++ test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
++ test::ExpectTensorEqual<int32>(expected2, *GetOutput(1));
++}
++
++TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingFastWithEmbeddingDims12_0) {
++ // Feed and run
++ const int embedding_dims = 12;
++ const int table_size = 2;
++ const int pooling_size = 2;
++ const int reshape_size = 24;
++ TF_ASSERT_OK(FeedAndRunFast(embedding_dims, table_size, pooling_size, reshape_size));
++
++ // Check the output.
++ Tensor expected1(allocator(), DT_INT32, TensorShape({}));
++ Tensor expected2(allocator(), DT_INT32, TensorShape({}));
++ test::FillValues<int32>(&expected1, {table_size - pooling_size});
++ test::FillValues<int32>(&expected2, {table_size * embedding_dims / reshape_size});
++ test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
++ test::ExpectTensorEqual<int32>(expected2, *GetOutput(1));
++}
++
++TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingFastWithEmbeddingDims12_1) {
++ // Feed and run
++ const int embedding_dims = 12;
++ const int table_size = 200;
++ const int pooling_size = 2;
++ const int reshape_size = 24;
++ TF_ASSERT_OK(FeedAndRunFast(embedding_dims, table_size, pooling_size, reshape_size));
++
++ // Check the output.
++ Tensor expected1(allocator(), DT_INT32, TensorShape({}));
++ Tensor expected2(allocator(), DT_INT32, TensorShape({}));
++ test::FillValues<int32>(&expected1, {table_size - pooling_size});
++ test::FillValues<int32>(&expected2, {table_size * embedding_dims / reshape_size});
++ test::ExpectTensorEqual<int32>(expected1, *GetOutput(0));
++ test::ExpectTensorEqual<int32>(expected2, *GetOutput(1));
++}
++
++TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithUnexpectReshape) {
++ // Feed and run
++ const int embedding_dims = 12;
++ const int table_size = 200;
++ const int pooling_size = 2;
++ const int reshape_size = 24;
++ MakeOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32);
++ AddInputFromArray<int64>(TensorShape({2}), {table_size, embedding_dims});
++ AddInput<float>(TensorShape({pooling_size, embedding_dims}), [](int i) -> float {
++ return static_cast<float>(i + 1);
++ });
++ AddInputFromArray<int32>(TensorShape({}), {pooling_size});
++ AddInputFromArray<int32>(TensorShape({2}), {10, reshape_size});
++ AddInputFromArray<int32>(TensorShape({}), {embedding_dims});
++ Status s = RunOpKernel();
++ EXPECT_TRUE(
++ absl::StrContains(s.ToString(), "reshape[0] is not -1"))
++ << s;
++}
++
++TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithUnexpectPack) {
++ // Feed and run
++ const int embedding_dims = 12;
++ const int table_size = 200;
++ const int pooling_size = 2;
++ const int reshape_size = 24;
++ MakeOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32);
++ AddInputFromArray<int64>(TensorShape({2}), {table_size, embedding_dims});
++ AddInput<float>(TensorShape({pooling_size, embedding_dims}), [](int i) -> float {
++ return static_cast<float>(i + 1);
++ });
++ AddInputFromArray<int32>(TensorShape({}), {pooling_size});
++ AddInputFromArray<int32>(TensorShape({2}), {-1, reshape_size});
++ AddInputFromArray<int32>(TensorShape({}), {10});
++ Status s = RunOpKernel();
++ EXPECT_TRUE(
++ absl::StrContains(s.ToString(), "pack(10) is not equal to embedding dims"))
++ << s;
++}
++
++TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithPoolingSizeGreaterInput) {
++ // Feed and run
++ const int embedding_dims = 12;
++ const int table_size = 200;
++ const int pooling_size = 201;
++ const int reshape_size = 24;
++ MakeOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32);
++ AddInputFromArray<int64>(TensorShape({2}), {table_size, embedding_dims});
++ AddInput<float>(TensorShape({pooling_size, embedding_dims}), [](int i) -> float {
++ return static_cast<float>(i + 1);
++ });
++ AddInputFromArray<int32>(TensorShape({}), {pooling_size});
++ AddInputFromArray<int32>(TensorShape({2}), {-1, reshape_size});
++ AddInputFromArray<int32>(TensorShape({}), {embedding_dims});
++ Status s = RunOpKernel();
++ EXPECT_TRUE(
++ absl::StrContains(s.ToString(), "Pooling size(201) is greater than Input size(200)"))
++ << s;
++}
++
++} // end namespace tensorflow
+diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_dynamic_stitch.cc b/annc/tensorflow/kernels/embedding_fused_sparse_dynamic_stitch.cc
+new file mode 100644
+index 0000000..0ccf599
+--- /dev/null
++++ b/annc/tensorflow/kernels/embedding_fused_sparse_dynamic_stitch.cc
+@@ -0,0 +1,87 @@
++/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
++
++Licensed under the Apache License, Version 2.0 (the "License");
++you may not use this file except in compliance with the License.
++You may obtain a copy of the License at
++
++ http://www.apache.org/licenses/LICENSE-2.0
++
++Unless required by applicable law or agreed to in writing, software
++distributed under the License is distributed on an "AS IS" BASIS,
++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++See the License for the specific language governing permissions and
++limitations under the License.
++==============================================================================*/
++
++#include <vector>
++
++#include "tensorflow/core/framework/common_shape_fns.h"
++#include "tensorflow/core/framework/op.h"
++#include "tensorflow/core/framework/op_kernel.h"
++#include "tensorflow/core/framework/tensor.h"
++#include "tensorflow/core/util/work_sharder.h"
++
++using namespace tensorflow;
++
++class KPFusedSparseDynamicStitchOp : public OpKernel {
++public:
++ explicit KPFusedSparseDynamicStitchOp(OpKernelConstruction* context)
++ : OpKernel(context) {}
++
++ void Compute(OpKernelContext* context) override {
++ const Tensor& x = context->input(0);
++ auto x_flat = x.flat<int64>();
++ int64_t num_elems = x_flat.size();
++
++ const int num_inputs = context->num_inputs();
++ const int num_partitions = num_inputs - 1;
++ OP_REQUIRES(context, num_partitions > 1, errors::InvalidArgument("num partitions must > 1"));
++ int64_t output_stride = 0;
++ std::vector<const float*> variables(num_partitions);
++ std::vector<int64_t> variable_rows(num_partitions);
++ for (int i = 1; i < num_inputs; ++i) {
++ const Tensor& input_tensor = context->input(i);
++ OP_REQUIRES(context, input_tensor.dims() == 2, errors::InvalidArgument("input dims must == 2"));
++ if (i == 1) {
++ output_stride = input_tensor.dim_size(1);
++ } else {
++ OP_REQUIRES(context, input_tensor.dim_size(1) == output_stride,
++ errors::InvalidArgument("All inputs must have same second dimension"));
++ }
++ variables[i - 1] = context->input(i).flat<float>().data();
++ variable_rows[i - 1] = input_tensor.dim_size(0);
++ }
++
++ OP_REQUIRES(context, output_stride > 0, errors::InvalidArgument("output_stride must > 0"));
++
++ Tensor* output_tensor = nullptr;
++ OP_REQUIRES_OK(context,
++ context->allocate_output(0, TensorShape({num_elems, output_stride}),
++ &output_tensor));
++ float* output = (float*)output_tensor->tensor_data().data();
++
++ const size_t copy_size = output_stride * sizeof(float);
++
++ auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
++ const int64 cost_per_unit = 120; // Actual single cycle execution time
++ auto work = [&](int start, int end) {
++ for (int i = start; i < end; ++i) {
++ const int64_t global_id = x_flat(i);
++ const int64_t table_id = global_id % num_partitions;
++ const int64_t row_id = global_id / num_partitions;
++
++ OP_REQUIRES(context, row_id < variable_rows[table_id], errors::InvalidArgument(
++ "row_id out of range."));
++
++ std::memcpy(output + i * output_stride,
++ variables[table_id] + row_id * output_stride, copy_size);
++ }
++ };
++
++ Shard(worker_threads->num_threads, worker_threads->workers, num_elems,
++ cost_per_unit, work);
++ }
++};
++
++REGISTER_KERNEL_BUILDER(Name("KPFusedSparseDynamicStitch").Device(DEVICE_CPU),
++ KPFusedSparseDynamicStitchOp);
+diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_dynamic_stitch_test.cc b/annc/tensorflow/kernels/embedding_fused_sparse_dynamic_stitch_test.cc
+new file mode 100644
+index 0000000..74fdd25
+--- /dev/null
++++ b/annc/tensorflow/kernels/embedding_fused_sparse_dynamic_stitch_test.cc
+@@ -0,0 +1,108 @@
++/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
++ *
++ * Licensed under the Apache License, Version 2.0 (the "License");
++ * you may not use this file except in compliance with the License.
++ * You may obtain a copy of the License at
++ *
++ * http://www.apache.org/licenses/LICENSE-2.0
++ *
++ * Unless required by applicable law or agreed to in writing, software
++ * distributed under the License is distributed on an "AS IS" BASIS,
++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++ * See the License for the specific language governing permissions and
++ * limitations under the License.
++ * ==============================================================================*/
++
++#include <functional>
++#include <memory>
++#include <vector>
++
++#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
++#include "tensorflow/core/framework/allocator.h"
++#include "tensorflow/core/framework/fake_input.h"
++#include "tensorflow/core/framework/node_def_builder.h"
++#include "tensorflow/core/framework/op_kernel.h"
++#include "tensorflow/core/framework/tensor.h"
++#include "tensorflow/core/framework/types.h"
++#include "tensorflow/core/framework/types.pb.h"
++#include "tensorflow/core/graph/testlib.h"
++#include "tensorflow/core/kernels/ops_testutil.h"
++#include "tensorflow/core/kernels/ops_util.h"
++#include "tensorflow/core/lib/core/status_test_util.h"
++#include "tensorflow/core/lib/gtl/array_slice.h"
++#include "tensorflow/core/lib/random/simple_philox.h"
++#include "tensorflow/core/lib/strings/str_util.h"
++#include "tensorflow/core/platform/test.h"
++#include "tensorflow/core/platform/test_benchmark.h"
++
++namespace tensorflow {
++namespace {
++
++class KPFusedSparseDynamicStitchOpTest : public OpsTestBase {
++ protected:
++ void MakeOp(int N) {
++ TF_ASSERT_OK(NodeDefBuilder("kp_fused_sparse_dynamic_stitch",
++ "KPFusedSparseDynamicStitch")
++ .Input(FakeInput(DT_INT64))
++ .Input(FakeInput(N, DT_FLOAT))
++ .Finalize(node_def()));
++ TF_ASSERT_OK(InitOp());
++ }
++};
++
++TEST_F(KPFusedSparseDynamicStitchOpTest, TestTwoTables) {
++ MakeOp(2); // num_partitions = 2
++
++ AddInputFromArray<int64>(TensorShape({4}), {0, 3, 2, 1});
++ AddInputFromArray<float>(TensorShape({3, 2}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
++ AddInputFromArray<float>(TensorShape({2, 2}), {7.0f, 8.0f, 9.0f, 10.0f});
++ TF_ASSERT_OK(RunOpKernel());
++
++ Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 2}));
++ test::FillValues<float>(&expected,
++ {1.0f, 2.0f, 9.0f, 10.0f, 3.0f, 4.0f, 7.0f, 8.0f});
++ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
++}
++
++TEST_F(KPFusedSparseDynamicStitchOpTest, TestDifferentStride) {
++ MakeOp(2);
++
++ AddInputFromArray<int64>(TensorShape({4}), {0, 3, 2, 1});
++ AddInputFromArray<float>(TensorShape({3, 2}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
++ AddInputFromArray<float>(TensorShape({1, 4}), {7.0f, 8.0f, 9.0f, 10.0f});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(),"All inputs must have same second dimension"));
++}
++
++TEST_F(KPFusedSparseDynamicStitchOpTest, TestIndicesOutOfBounds) {
++ MakeOp(2);
++
++ AddInputFromArray<int64>(TensorShape({4}), {0, 6, 2, 1});
++ AddInputFromArray<float>(TensorShape({3, 2}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
++ AddInputFromArray<float>(TensorShape({2, 2}), {7.0f, 8.0f, 9.0f, 10.0f});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(),"row_id out of range"));
++}
++
++TEST_F(KPFusedSparseDynamicStitchOpTest, TestInputDims) {
++ MakeOp(2);
++
++ AddInputFromArray<int64>(TensorShape({4}), {0, 6, 2, 1});
++ AddInputFromArray<float>(TensorShape({3, 2, 1}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
++ AddInputFromArray<float>(TensorShape({2, 2, 1}), {7.0f, 8.0f, 9.0f, 10.0f});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(),"input dims must == 2"));
++}
++
++} // namespace
++} // namespace tensorflow
+diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_reshape.cc b/annc/tensorflow/kernels/embedding_fused_sparse_reshape.cc
+new file mode 100644
+index 0000000..e4acad2
+--- /dev/null
++++ b/annc/tensorflow/kernels/embedding_fused_sparse_reshape.cc
+@@ -0,0 +1,195 @@
++/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
++
++Licensed under the Apache License, Version 2.0 (the "License");
++you may not use this file except in compliance with the License.
++You may obtain a copy of the License at
++
++ http://www.apache.org/licenses/LICENSE-2.0
++
++Unless required by applicable law or agreed to in writing, software
++distributed under the License is distributed on an "AS IS" BASIS,
++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++See the License for the specific language governing permissions and
++limitations under the License.
++==============================================================================*/
++
++#include "tensorflow/core/framework/common_shape_fns.h"
++#include "tensorflow/core/framework/op.h"
++#include "tensorflow/core/framework/op_kernel.h"
++#include "tensorflow/core/framework/tensor.h"
++#include "tensorflow/core/util/work_sharder.h"
++#include "tensorflow/core/kernels/reshape_util.h"
++#include "tensorflow/core/framework/register_types.h"
++#include "tensorflow/core/framework/tensor_util.h"
++#include "tensorflow/core/framework/types.h"
++#include "tensorflow/core/lib/gtl/inlined_vector.h"
++
++using namespace tensorflow;
++
++static void ReshapeKp(OpKernelContext *context, const Tensor &input_indices_in,
++ const Tensor &input_shape_in, const Tensor &target_shape_in,
++ int output_indices_idx, int output_shape_idx) {
++ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices_in.shape()),
++ errors::InvalidArgument(
++ "Input indices should be a matrix but received shape ",
++ input_indices_in.shape().DebugString()));
++ OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()),
++ errors::InvalidArgument(
++ "Input shape should be a vector but received shape ",
++ input_shape_in.shape().DebugString()));
++ OP_REQUIRES(context, TensorShapeUtils::IsVector(target_shape_in.shape()),
++ errors::InvalidArgument(
++ "Target shape should be a vector but received shape ",
++ target_shape_in.shape().DebugString()));
++
++ const int64 input_rank = input_shape_in.NumElements();
++ const int64 output_rank = target_shape_in.NumElements();
++ const TensorShape input_shape(input_shape_in.vec<int64>());
++ const int64 dense_size = input_shape.num_elements();
++ const int64 nnz = input_indices_in.shape().dim_size(0);
++
++ TensorShape output_shape;
++ int64 product = 1;
++ int unknown_index = -1;
++ auto target_shape = target_shape_in.vec<int64>();
++ for (int d = 0; d < output_rank; ++d) {
++ const int64 size = target_shape(d);
++ if (size == -1) {
++ OP_REQUIRES(
++ context, unknown_index == -1,
++ errors::InvalidArgument("only one output dimension may be -1, "
++ "not both ",
++ unknown_index, " and ", d));
++ unknown_index = d;
++ output_shape.AddDim(1);
++ } else {
++ OP_REQUIRES(context, size >= 0,
++ errors::InvalidArgument("size ", d,
++ " must be non-negative, not ", size));
++ product *= size;
++ output_shape.AddDim(size);
++ }
++ }
++ if (unknown_index != -1) {
++ OP_REQUIRES(
++ context, product > 0,
++ errors::InvalidArgument("reshape cannot infer the missing "
++ "input size for an empty tensor unless all "
++ "specified input sizes are non-zero"));
++ const int64 missing = dense_size / product;
++ OP_REQUIRES(
++ context, product * missing == dense_size,
++ errors::InvalidArgument(
++ "Input to reshape is a SparseTensor with ", dense_size,
++ " dense values, but the requested shape requires a multiple of ",
++ product, ". input_shape=", input_shape.DebugString(),
++ " output_shape=", output_shape.DebugString()));
++ output_shape.set_dim(unknown_index, missing);
++ }
++
++ OP_REQUIRES(
++ context, output_shape.num_elements() == dense_size,
++ errors::InvalidArgument("Input to reshape is a tensor with ", dense_size,
++ " dense values, but the requested shape has ",
++ output_shape.num_elements(),
++ ". input_shape=", input_shape.DebugString(),
++ " output_shape=", output_shape.DebugString()));
++
++ if (input_shape == output_shape) {
++ context->set_output(output_indices_idx, input_indices_in);
++ context->set_output(output_shape_idx, input_shape_in);
++ return;
++ }
++
++ gtl::InlinedVector<int64, 8> input_strides(input_rank);
++ if (input_rank > 0) {
++ input_strides[input_rank - 1] = 1;
++ for (int d = input_rank - 2; d >= 0; --d) {
++ input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
++ }
++ }
++
++ gtl::InlinedVector<int64, 8> output_strides(output_rank);
++ if (output_rank > 0) {
++ output_strides[output_rank - 1] = 1;
++ for (int d = output_rank - 2; d >= 0; --d) {
++ output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1);
++ }
++ }
++
++ Tensor *result_indices = nullptr;
++ OP_REQUIRES_OK(context,
++ context->allocate_output(output_indices_idx,
++ TensorShape({nnz, output_rank}),
++ &result_indices));
++ auto input_ind = input_indices_in.matrix<int64>();
++ auto output_ind = result_indices->matrix<int64>();
++ for (int i = 0; i < nnz; ++i) {
++ int64 id = 0;
++ for (int j = 0; j < input_rank; ++j) {
++ id += input_ind(i, j) * input_strides[j];
++ }
++ for (int j = 0; j < output_rank; ++j) {
++ output_ind(i, j) = id / output_strides[j];
++ id %= output_strides[j];
++ }
++ }
++
++ Tensor *result_shape = nullptr;
++ OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx,
++ TensorShape({output_rank}),
++ &result_shape));
++ auto output_shape_vec = result_shape->vec<int64>();
++ for (int j = 0; j < output_shape.dims(); ++j) {
++ output_shape_vec(j) = output_shape.dim_size(j);
++ }
++}
++
++class KPFusedSparseReshapeOp : public OpKernel {
++ public:
++ explicit KPFusedSparseReshapeOp(OpKernelConstruction* context) : OpKernel(context) { }
++
++ void Compute(OpKernelContext* context) override {
++ const Tensor& slice_input = context->input(0);
++ const Tensor& begin = context->input(1);
++ const Tensor& new_shape = context->input(2);
++ const Tensor& pack_const = context->input(3);
++
++ OP_REQUIRES(context, slice_input.dims() == 2, errors::Internal("slice_input dims must == 2"));
++ OP_REQUIRES(context, new_shape.dim_size(0) == 2, errors::Internal("new_shape dim size must == 2"));
++ OP_REQUIRES(context, pack_const.dims() == 0,
++ errors::InvalidArgument("pack_const must be a scalar"));
++ VLOG(1) << "Input slice_input shape: " << slice_input.shape().DebugString();
++ VLOG(1) << "Input begin value: " << begin.DebugString();
++ VLOG(1) << "Input new_shape value: " << new_shape.DebugString();
++
++ OP_REQUIRES(context, begin.dims() == 1 && begin.dim_size(0) == 2,
++ errors::InvalidArgument("begin must be 1D with at least 2 elements"));
++ int32 col = begin.flat<int32>().data()[1];
++ OP_REQUIRES(context, col < slice_input.dim_size(1), errors::Internal("begin[1] must < slice_input.dim_size(1)"));
++ int64_t num_rows = slice_input.dim_size(0);
++ auto slice_input_mat = slice_input.matrix<int64>();
++
++ VLOG(1) << "num_rows: " << num_rows;
++ VLOG(1) << "slice_input.dim_size(0): " << slice_input.dim_size(0);
++ VLOG(1) << "slice_input.dim_size(1): " << slice_input.dim_size(1);
++ VLOG(1) << "Column index from begin: " << col;
++
++ Tensor shape_in(DT_INT64, TensorShape({2}));
++ auto tensor_flat = shape_in.flat<int64>();
++ tensor_flat(0) = num_rows;
++ tensor_flat(1) = pack_const.scalar<int64>()();
++
++ Tensor indices_in(DT_INT64, TensorShape({num_rows, 2}));
++ auto indices_in_mat = indices_in.matrix<int64>();
++ for (int i = 0; i < num_rows; ++i) {
++ indices_in_mat(i, 0) = i;
++ indices_in_mat(i, 1) = slice_input_mat(i, col);
++ }
++
++ ReshapeKp(context, indices_in, shape_in, new_shape, 0, 1);
++ }
++};
++
++REGISTER_KERNEL_BUILDER(Name("KPFusedSparseReshape").Device(DEVICE_CPU),
++ KPFusedSparseReshapeOp);
+\ No newline at end of file
+diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_reshape_test.cc b/annc/tensorflow/kernels/embedding_fused_sparse_reshape_test.cc
+new file mode 100644
+index 0000000..874d48a
+--- /dev/null
++++ b/annc/tensorflow/kernels/embedding_fused_sparse_reshape_test.cc
+@@ -0,0 +1,281 @@
++/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
++
++Licensed under the Apache License, Version 2.0 (the "License");
++you may not use this file except in compliance with the License.
++You may obtain a copy of the License at
++
++ http://www.apache.org/licenses/LICENSE-2.0
++
++Unless required by applicable law or agreed to in writing, software
++distributed under the License is distributed on an "AS IS" BASIS,
++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++See the License for the specific language governing permissions and
++limitations under the License.
++==============================================================================*/
++
++#include "tensorflow/core/framework/fake_input.h"
++#include "tensorflow/core/framework/node_def_builder.h"
++#include "tensorflow/core/framework/tensor.h"
++#include "tensorflow/core/framework/tensor_testutil.h"
++#include "tensorflow/core/kernels/ops_testutil.h"
++#include "tensorflow/core/lib/core/status_test_util.h"
++
++namespace {
++using tensorflow::AllocatorAttributes;
++using tensorflow::DT_FLOAT;
++using tensorflow::DT_INT32;
++using tensorflow::DT_INT64;
++using tensorflow::int64;
++using tensorflow::int32;
++using tensorflow::NodeDefBuilder;
++using tensorflow::OpsTestBase;
++using tensorflow::Status;
++using tensorflow::Tensor;
++using tensorflow::TensorShape;
++using tensorflow::test::FillValues;
++using tensorflow::test::ExpectTensorEqual;
++
++class KPFusedSparseReshapeTest : public OpsTestBase {
++ protected:
++ void RunValidCase(const TensorShape& slice_shape,
++ const std::vector<int64>& slice_data,
++ const std::vector<int32>& begin_val,
++ const std::vector<int64>& new_shape_val,
++ const std::vector<int64>& pack_const_val,
++ const TensorShape& expected_indices_shape,
++ const std::vector<int64>& expected_shape_val) {
++ TF_EXPECT_OK(NodeDefBuilder("kp_fused_sparse_reshape", "KPFusedSparseReshape")
++ .Input(FakeInput(DT_INT64)) // slice_input
++ .Input(FakeInput(DT_INT32)) // begin
++ .Input(FakeInput(DT_INT64)) // new_shape
++ .Input(FakeInput(DT_INT64)) // pack_const
++ .Finalize(node_def()));
++ TF_EXPECT_OK(InitOp());
++
++ AddInputFromArray<int64>(slice_shape, slice_data);
++ AddInputFromArray<int32>(TensorShape({2}), begin_val);
++ AddInputFromArray<int64>(TensorShape({2}), new_shape_val);
++ AddInputFromArray<int64>(TensorShape({}), pack_const_val);
++
++ TF_ASSERT_OK(RunOpKernel());
++
++ // 输出0: result_indices
++ const Tensor& out_indices = *GetOutput(0);
++ EXPECT_EQ(out_indices.shape(), expected_indices_shape);
++
++ // 输出1: result_shape
++ const Tensor& out_shape = *GetOutput(1);
++ Tensor expected_shape_tensor(DT_INT64,
++ TensorShape({static_cast<int64>(expected_shape_val.size())}));
++ FillValues<int64>(&expected_shape_tensor, expected_shape_val);
++ ExpectTensorEqual<int64>(expected_shape_tensor, out_shape);
++ }
++
++ Status RunOpExpectFailure(const TensorShape& slice_shape,
++ const std::vector<int64>& slice_data,
++ const std::vector<int32>& begin_val,
++ const std::vector<int64>& new_shape_val,
++ const std::vector<int64>& pack_const_val) {
++ TF_CHECK_OK(NodeDefBuilder("kp_fused_sparse_reshape", "KPFusedSparseReshape")
++ .Input(FakeInput(DT_INT64)) // slice_input
++ .Input(FakeInput(DT_INT32)) // begin
++ .Input(FakeInput(DT_INT64)) // new_shape
++ .Input(FakeInput(DT_INT64)) // pack_const
++ .Finalize(node_def()));
++ TF_CHECK_OK(InitOp());
++
++ AddInputFromArray<int64>(slice_shape, slice_data);
++ AddInputFromArray<int32>(TensorShape({static_cast<int64>(begin_val.size())}), begin_val);
++ AddInputFromArray<int64>(TensorShape({static_cast<int64>(new_shape_val.size())}), new_shape_val);
++ AddInputFromArray<int64>(TensorShape({}), pack_const_val);
++
++ return RunOpKernel();
++ }
++};
++
++// ==================== 正向测试 ====================
++
++// 正常 reshape 案例
++// pack_const=2
++TEST_F(KPFusedSparseReshapeTest, Valid_NormalInput) {
++ RunValidCase(
++ TensorShape({4, 2}), // slice_input shape
++ {0, 1,
++ 1, 2,
++ 2, 3,
++ 3, 0}, // slice_input 数据
++ {0, 1}, // begin = (0,1),选第1列
++ {2, 4}, // new_shape = [2,4]
++ {2}, // pack_const = [2]
++ TensorShape({4, 2}), // 预期 indices 形状
++ {2, 4}); // 预期 shape
++}
++
++// pack_const = 1
++TEST_F(KPFusedSparseReshapeTest, Valid_PackConst1) {
++ RunValidCase(
++ TensorShape({1, 2}), // slice_input shape
++ {0, 1}, // slice_input 数据
++ {0, 1}, // begin = (0,1),选第1列
++ {-1, 1}, // new_shape = [-1,1]
++ {1}, // pack_const = [1]
++ TensorShape({1, 2}), // 预期 indices 形状
++ {1, 1}); // 预期 shape
++}
++
++// ==================== 反向测试 ====================
++
++// 反例1:slice_input 不是二维
++TEST_F(KPFusedSparseReshapeTest, Invalid_SliceInputNot2D) {
++ Status s = RunOpExpectFailure(
++ TensorShape({4}), {0, 1, 2, 3},
++ {0, 0},
++ {2, 2},
++ {4});
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "slice_input dims must == 2"));
++}
++
++// 反例2:new_shape dim size 不是 2
++TEST_F(KPFusedSparseReshapeTest, Invalid_NewShapeNotLen2) {
++ Status s = RunOpExpectFailure(
++ TensorShape({2, 2}), {0, 1, 1, 0},
++ {0, 0},
++ {4, 2, 1}, // new_shape 多了1个元素
++ {2});
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "new_shape dim size must == 2"));
++}
++
++// 反例3:begin[1] 超出 slice_input 列数
++TEST_F(KPFusedSparseReshapeTest, Invalid_BeginOutOfRange) {
++ Status s = RunOpExpectFailure(
++ TensorShape({2, 2}), {0, 1, 1, 0},
++ {0, 2}, // 超过列数
++ {2, 2},
++ {2});
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "begin[1] must < slice_input.dim_size(1)"));
++}
++
++// 反例4:target shape 有多个 -1
++TEST_F(KPFusedSparseReshapeTest, Invalid_MultipleUnknownDims) {
++ Status s = RunOpExpectFailure(
++ TensorShape({2, 2}), {0, 1, 1, 0},
++ {0, 1},
++ {-1, -1}, // 两个 -1
++ {2});
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "only one output dimension may be -1"));
++}
++
++// 反例5:reshape 推断维度时,总元素数不能整除,导致无法匹配 --> product * missing != dense_size
++TEST_F(KPFusedSparseReshapeTest, Invalid_InferredShapeDoesNotMatch) {
++ TensorShape input_indices_shape({6, 2}); // 6 个非零元素,rank=2
++ std::vector<int64> input_indices_data = {
++ 0, 0,
++ 0, 1,
++ 0, 2,
++ 1, 0,
++ 1, 1,
++ 1, 2
++ }; // 对应 2x3 的 dense tensor
++
++ std::vector<int32> begin_val = {0, 0}; // 假设的 begin 输入
++ std::vector<int64> new_shape_val = {-1, 4}; // reshape 到 ?x4
++ std::vector<int64> pack_const_val = {1};
++
++ Status s = RunOpExpectFailure(
++ input_indices_shape,
++ input_indices_data,
++ begin_val,
++ new_shape_val,
++ pack_const_val);
++
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "Input to reshape is a SparseTensor with"));
++}
++
++// 反例6:reshape 后元素数量不匹配 --> output_shape.num_elements() != dense_size
++TEST_F(KPFusedSparseReshapeTest, Invalid_SizeMismatch) {
++ Status s = RunOpExpectFailure(
++ TensorShape({2, 2}), {0, 1, 1, 0},
++ {0, 1},
++ {3, 3}, // 期望 9 元素,但输入 dense size = 4
++ {2});
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "Input to reshape is a tensor with"));
++}
++
++// 反例7:target_shape 包含负数但不是 -1
++TEST_F(KPFusedSparseReshapeTest, Invalid_NegativeDimNotMinusOne) {
++ Status s = RunOpExpectFailure(
++ TensorShape({2, 2}), {0, 1, 1, 0},
++ {0, 0},
++ {2, -2}, // -2 是非法的
++ {2});
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "size 1 must be non-negative, not -2"))
++ << "Actual error: " << s.message();
++}
++
++// 反例8:target_shape 有 -1,但其他维度乘积为 0
++TEST_F(KPFusedSparseReshapeTest, Invalid_ProductZeroWithUnknownDim) {
++ // dense_size = 0(空 SparseTensor),target_shape = [-1, 0]
++ // product = 0 → 不允许 infer
++ Status s = RunOpExpectFailure(
++ TensorShape({0, 2}), {}, // 空的 slice_input
++ {0, 0},
++ {-1, 0}, // product = 0
++ {2});
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "reshape cannot infer the missing input size for an empty tensor"))
++ << "Actual error: " << s.message();
++}
++
++// 反例9:begin 是 1D 但长度为 1(不够 2 个元素)
++TEST_F(KPFusedSparseReshapeTest, Invalid_BeginRank1ButSize1) {
++ Status s = RunOpExpectFailure(
++ TensorShape({2, 2}), {0, 1, 1, 0},
++ {0}, // begin = [0],长度为 1
++ {2, 2},
++ {2});
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "begin must be 1D with at least 2 elements"))
++ << "Actual error: " << s.message();
++}
++
++// 反例10:begin 是 1D 但长度为 3(超过 2)
++TEST_F(KPFusedSparseReshapeTest, Invalid_BeginRank1ButSize3) {
++ Status s = RunOpExpectFailure(
++ TensorShape({2, 2}), {0, 1, 1, 0},
++ {0, 1, 2}, // begin = [0,1,2],长度为 3
++ {2, 2},
++ {2});
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "begin must be 1D with at least 2 elements"))
++ << "Actual error: " << s.message();
++}
++
++// 反例11:pack_const 是标量(0维)
++TEST_F(KPFusedSparseReshapeTest, Invalid_PackConstIsScalarButExpect1D) {
++ TF_CHECK_OK(NodeDefBuilder("kp_fused_sparse_reshape", "KPFusedSparseReshape")
++ .Input(FakeInput(DT_INT64)) // slice_input
++ .Input(FakeInput(DT_INT32)) // begin
++ .Input(FakeInput(DT_INT64)) // new_shape
++ .Input(FakeInput(DT_INT64)) // pack_const
++ .Finalize(node_def()));
++ TF_CHECK_OK(InitOp());
++
++ AddInputFromArray<int64>(TensorShape({2, 2}), {0, 1, 1, 0});
++ AddInputFromArray<int32>(TensorShape({2}), {0, 1});
++ AddInputFromArray<int64>(TensorShape({2}), {2, 2});
++ AddInputFromArray<int64>(TensorShape({1}), {1}); // pack_const = 标量 1(0维)
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "pack_const must be a scalar"))
++ << "Actual error: " << s.message();
++}
++
++} // namespace
+diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce.cc b/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce.cc
+new file mode 100644
+index 0000000..33bbd31
+--- /dev/null
++++ b/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce.cc
+@@ -0,0 +1,165 @@
++/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
++
++Licensed under the Apache License, Version 2.0 (the "License");
++you may not use this file except in compliance with the License.
++You may obtain a copy of the License at
++
++ http://www.apache.org/licenses/LICENSE-2.0
++
++Unless required by applicable law or agreed to in writing, software
++distributed under the License is distributed on an "AS IS" BASIS,
++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++See the License for the specific language governing permissions and
++limitations under the License.
++==============================================================================*/
++
++#include <arm_neon.h>
++
++#include "tensorflow/core/framework/common_shape_fns.h"
++#include "tensorflow/core/framework/op.h"
++#include "tensorflow/core/framework/op_kernel.h"
++#include "tensorflow/core/framework/tensor.h"
++#include "tensorflow/core/util/work_sharder.h"
++
++using namespace tensorflow;
++
++template <typename Tidx>
++class KPFusedSparseSegmentReduceOp : public OpKernel {
++public:
++ explicit KPFusedSparseSegmentReduceOp(OpKernelConstruction* context)
++ : OpKernel(context) {
++ int combiner_mode;
++ OP_REQUIRES_OK(context, context->GetAttr("combiner", &combiner_mode));
++ OP_REQUIRES(context, combiner_mode == 0 || combiner_mode == 1,
++ errors::InvalidArgument("combiner must be 0 or 1"));
++ is_mean_ = (combiner_mode == 1);
++ }
++
++ void Compute(OpKernelContext* context) override {
++ const Tensor& input_tensor = context->input(0);
++ const Tensor& indices = context->input(1);
++ const Tensor& slice_input = context->input(2);
++ const Tensor& begin = context->input(3);
++ const Tensor& begin_1 = context->input(4);
++
++ OP_REQUIRES(context, input_tensor.dims() == 2, errors::InvalidArgument("input must be 2-D"));
++ OP_REQUIRES(context, slice_input.dims() == 2, errors::InvalidArgument("slice input must be 2-D"));
++ OP_REQUIRES(context, begin.NumElements() == 2, errors::InvalidArgument("begin must have 2 elements"));
++ OP_REQUIRES(context, begin_1.NumElements() == 1, errors::InvalidArgument("begin_1 must have 1 element"));
++ int64_t num_indices = indices.dim_size(0);
++ int64_t embedding_size = input_tensor.dim_size(1);
++ int32 col = begin.flat<int32>().data()[1];
++ int32 out_dim = static_cast<int32>(begin_1.flat<int32>()(0));
++
++ OP_REQUIRES(context, col >= 0 && col < slice_input.dim_size(1),
++ errors::InvalidArgument("Column index out of range"));
++ OP_REQUIRES(context, num_indices == slice_input.dim_size(0),
++ errors::InvalidArgument("indices and slice_input.dim_zie(0) should have same size"));
++
++ auto input_data = input_tensor.matrix<float>().data();
++ auto indices_vec = indices.vec<Tidx>();
++ auto slice_input_mat = slice_input.matrix<int64>();
++
++ // Calculate max segment_id
++ int64 max_seg_id = 0;
++ for (int32 i = 0; i < num_indices; ++i) {
++ int64 seg_id = slice_input_mat(i, col);
++ if (seg_id > max_seg_id) {
++ max_seg_id = seg_id;
++ }
++ }
++ const int64 batch_size = max_seg_id + 1;
++
++ Tensor* output = nullptr;
++ OP_REQUIRES_OK(context,
++ context->allocate_output(
++ 0, TensorShape({batch_size, embedding_size}), &output));
++ output->flat<float>().setZero();
++ Tensor* slice_out = nullptr;
++ OP_REQUIRES_OK(context,
++ context->allocate_output(1, TensorShape({}), &slice_out));
++ if (out_dim == 0)
++ slice_out->scalar<int32>()() = batch_size;
++ else slice_out->scalar<int32>()() = embedding_size;
++
++ auto output_data = output->matrix<float>().data();
++
++ if (is_mean_) {
++ Tensor counts(DT_INT32, TensorShape({batch_size}));
++ counts.flat<int32>().setZero();
++ auto counts_vec = counts.flat<int32>();
++
++ for (int64 i = 0; i < num_indices; ++i) {
++ const int64 seg_id = slice_input_mat(i, col);
++ const Tidx data_row = indices_vec(i);
++ counts_vec(seg_id) += 1;
++
++ float* output_row = output_data + seg_id * embedding_size;
++ const float* input_data_row = input_data + data_row * embedding_size;
++ int64 j = 0;
++ for (; j + 3 < embedding_size; j += 4) {
++ float32x4_t out = vld1q_f32(output_row + j);
++ float32x4_t data = vld1q_f32(input_data_row + j);
++ out = vaddq_f32(out, data);
++ vst1q_f32(output_row + j, out);
++ }
++
++ for (; j < embedding_size; ++j) {
++ output_row[j] += input_data_row[j];
++ }
++ }
++
++ for (int64_t seg = 0; seg < batch_size; ++seg) {
++ const int32_t count = counts_vec(seg);
++ if (count > 0) {
++ const float inv_count = 1.0f / static_cast<float>(count);
++ const float32x4_t inv_count_vec = vdupq_n_f32(inv_count);
++
++ float* row_start = output_data + seg * embedding_size;
++ int64_t j = 0;
++
++ for (; j + 3 < embedding_size; j += 4) {
++ float32x4_t val = vld1q_f32(row_start + j);
++ val = vmulq_f32(val, inv_count_vec);
++ vst1q_f32(row_start + j, val);
++ }
++
++ for (; j < embedding_size; ++j) {
++ row_start[j] *= inv_count;
++ }
++ }
++ }
++ } else {
++ for (int64 i = 0; i < num_indices; ++i) {
++ const int64 seg_id = slice_input_mat(i, col);
++ const Tidx data_row = indices_vec(i);
++
++ float* output_row = output_data + seg_id * embedding_size;
++ const float* input_data_row = input_data + data_row * embedding_size;
++ int64 j = 0;
++ for (; j + 3 < embedding_size; j += 4) {
++ float32x4_t out = vld1q_f32(output_row + j);
++ float32x4_t data = vld1q_f32(input_data_row + j);
++ out = vaddq_f32(out, data);
++ vst1q_f32(output_row + j, out);
++ }
++
++ for (; j < embedding_size; ++j) {
++ output_row[j] += input_data_row[j];
++ }
++ }
++ }
++ }
++
++private:
++ bool is_mean_;
++};
++
++#define REGISTER_KERNEL(Tidx) \
++ REGISTER_KERNEL_BUILDER(Name("KPFusedSparseSegmentReduce") \
++ .Device(DEVICE_CPU) \
++ .TypeConstraint<Tidx>("Tidx"), \
++ KPFusedSparseSegmentReduceOp<Tidx>);
++REGISTER_KERNEL(int64)
++REGISTER_KERNEL(int32)
++#undef REGISTER_KERNEL
+\ No newline at end of file
+diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_nonzero.cc b/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_nonzero.cc
+new file mode 100644
+index 0000000..5266476
+--- /dev/null
++++ b/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_nonzero.cc
+@@ -0,0 +1,159 @@
++/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
++
++Licensed under the Apache License, Version 2.0 (the "License");
++you may not use this file except in compliance with the License.
++You may obtain a copy of the License at
++
++ http://www.apache.org/licenses/LICENSE-2.0
++
++Unless required by applicable law or agreed to in writing, software
++distributed under the License is distributed on an "AS IS" BASIS,
++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++See the License for the specific language governing permissions and
++limitations under the License.
++==============================================================================*/
++
++#include <arm_neon.h>
++
++#include "tensorflow/core/framework/common_shape_fns.h"
++#include "tensorflow/core/framework/op.h"
++#include "tensorflow/core/framework/op_kernel.h"
++#include "tensorflow/core/framework/tensor.h"
++#include "tensorflow/core/util/work_sharder.h"
++#include "absl/container/flat_hash_map.h"
++
++using namespace tensorflow;
++
++template <typename Tidx>
++class KPFusedSparseSegmentReduceNonzeroOp : public OpKernel {
++public:
++ explicit KPFusedSparseSegmentReduceNonzeroOp(OpKernelConstruction* context)
++ : OpKernel(context) {
++ int combiner_mode;
++ OP_REQUIRES_OK(context, context->GetAttr("combiner", &combiner_mode));
++ OP_REQUIRES(context, combiner_mode == 0 || combiner_mode == 1,
++ errors::InvalidArgument("combiner must be 0 or 1"));
++ is_mean_ = (combiner_mode == 1);
++ }
++
++ void Compute(OpKernelContext* context) override {
++ const Tensor& input_tensor = context->input(0);
++ const Tensor& indices = context->input(1);
++ const Tensor& slice_input = context->input(2);
++ const Tensor& begin = context->input(3);
++
++ OP_REQUIRES(context, input_tensor.dims() == 1,
++ errors::InvalidArgument("Input data must be a vector"));
++ OP_REQUIRES(context, slice_input.dims() == 2, errors::InvalidArgument("slice input must be 2-D"));
++ OP_REQUIRES(context, begin.NumElements() == 2, errors::InvalidArgument("begin must have 2 elements"));
++
++ int64 num_indices = indices.dim_size(0);
++ int32 col = begin.flat<int32>().data()[1];
++
++ OP_REQUIRES(context, col >= 0 && col < slice_input.dim_size(1),
++ errors::InvalidArgument("Column index out of range"));
++ OP_REQUIRES(context, num_indices == slice_input.dim_size(0),
++ errors::InvalidArgument("indices and slice_input.dim_zie(0) should have same size"));
++
++ auto input_data = input_tensor.flat<float>();
++ auto indices_vec = indices.vec<Tidx>();
++ auto slice_input_mat = slice_input.matrix<int64>();
++
++ // Calculate max segment_id
++ std::vector<int64> segment_ids(num_indices);
++ int64 max_seg_id = 0;
++ for (int64 i = 0; i < num_indices; ++i) {
++ int64 seg_id = slice_input_mat(i, col);
++ segment_ids[i] = seg_id;
++ if (seg_id > max_seg_id) {
++ max_seg_id = seg_id;
++ }
++ }
++
++ const int64 batch_size = max_seg_id + 1;
++
++ Tensor* output_shape = nullptr;
++ OP_REQUIRES_OK(
++ context, context->allocate_output(0, TensorShape({1}), &output_shape));
++ output_shape->flat<int32>()(0) = static_cast<int32>(batch_size);
++
++ std::vector<std::pair<int64, float>> results;
++ int64 num_nonzero = 0;
++ absl::flat_hash_map<int64, float> segment_sums;
++ absl::flat_hash_map<int64, int32> segment_counts;
++ std::vector<int64> segment_order;
++
++ if (is_mean_) {
++ for (int64 i = 0; i < num_indices; ++i) {
++ const int64 seg_id = segment_ids[i];
++ const Tidx data_row = indices_vec(i);
++
++ if (segment_sums.find(seg_id) == segment_sums.end()) {
++ segment_order.push_back(seg_id);
++ }
++
++ segment_sums[seg_id] += input_data(data_row);
++ segment_counts[seg_id] += 1;
++ }
++
++ for (int64 seg_id : segment_order) {
++ const int32_t count = segment_counts[seg_id];
++ if (count > 0) {
++ const float inv_count = 1.0f / static_cast<float>(count);
++ float value = segment_sums[seg_id];
++ if (value != 0) {
++ results.push_back({seg_id, value * inv_count});
++ num_nonzero++;
++ }
++ }
++ }
++ } else {
++ for (int64 i = 0; i < num_indices; ++i) {
++ const int64 seg_id = segment_ids[i];
++ const Tidx data_row = indices_vec(i);
++
++ if (segment_sums.find(seg_id) == segment_sums.end()) {
++ segment_order.push_back(seg_id);
++ }
++
++ segment_sums[seg_id] += input_data(data_row);
++ }
++
++ for (int64 seg_id : segment_order) {
++ float value = segment_sums[seg_id];
++ if (value != 0) {
++ results.push_back({seg_id, value});
++ num_nonzero++;
++ }
++ }
++ }
++ Tensor* output_indices = nullptr;
++ OP_REQUIRES_OK(context,
++ context->allocate_output(1, TensorShape({num_nonzero, 1}),
++ &output_indices));
++ auto output_indices_data = output_indices->flat<int32>();
++
++ Tensor* output_nonzero = nullptr;
++ OP_REQUIRES_OK(context,
++ context->allocate_output(2, TensorShape({num_nonzero}),
++ &output_nonzero));
++ auto output_nonzero_data = output_nonzero->flat<float>();
++ for (int64 i = 0; i < num_nonzero; ++i) {
++ output_indices_data(i) = static_cast<int32>(results[i].first);
++ output_nonzero_data(i) = results[i].second;
++ }
++
++ }
++
++ private:
++ bool is_mean_;
++};
++
++#define REGISTER_KERNEL(Tidx) \
++ REGISTER_KERNEL_BUILDER(Name("KPFusedSparseSegmentReduceNonzero") \
++ .Device(DEVICE_CPU) \
++ .TypeConstraint<Tidx>("Tidx"), \
++ KPFusedSparseSegmentReduceNonzeroOp<Tidx>);
++REGISTER_KERNEL(int64)
++REGISTER_KERNEL(int32)
++#undef REGISTER_KERNEL
+\ No newline at end of file
+diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_nonzero_test.cc b/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_nonzero_test.cc
+new file mode 100644
+index 0000000..ac4a5c3
+--- /dev/null
++++ b/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_nonzero_test.cc
+@@ -0,0 +1,183 @@
++/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
++ *
++ * Licensed under the Apache License, Version 2.0 (the "License");
++ * you may not use this file except in compliance with the License.
++ * You may obtain a copy of the License at
++ *
++ * http://www.apache.org/licenses/LICENSE-2.0
++ *
++ * Unless required by applicable law or agreed to in writing, software
++ * distributed under the License is distributed on an "AS IS" BASIS,
++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++ * See the License for the specific language governing permissions and
++ * limitations under the License.
++ * ==============================================================================*/
++
++#include <functional>
++#include <memory>
++#include <vector>
++
++#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
++#include "tensorflow/core/framework/allocator.h"
++#include "tensorflow/core/framework/fake_input.h"
++#include "tensorflow/core/framework/node_def_builder.h"
++#include "tensorflow/core/framework/op_kernel.h"
++#include "tensorflow/core/framework/tensor.h"
++#include "tensorflow/core/framework/types.h"
++#include "tensorflow/core/framework/types.pb.h"
++#include "tensorflow/core/graph/testlib.h"
++#include "tensorflow/core/kernels/ops_testutil.h"
++#include "tensorflow/core/kernels/ops_util.h"
++#include "tensorflow/core/lib/core/status_test_util.h"
++#include "tensorflow/core/lib/gtl/array_slice.h"
++#include "tensorflow/core/lib/random/simple_philox.h"
++#include "tensorflow/core/lib/strings/str_util.h"
++#include "tensorflow/core/platform/test.h"
++#include "tensorflow/core/platform/test_benchmark.h"
++
++namespace tensorflow {
++namespace {
++
++class KPFusedSparseSegmentReduceNonzeroOpTest : public OpsTestBase {
++ protected:
++ void MakeOp(int combiner_mode) {
++ TF_ASSERT_OK(NodeDefBuilder("kp_fused_sparse_segment_reduce_nonzero",
++ "KPFusedSparseSegmentReduceNonzero")
++ .Input(FakeInput(DT_FLOAT)) // data
++ .Input(FakeInput(DT_INT32)) // indices
++ .Input(FakeInput(DT_INT64)) // slice_input
++ .Input(FakeInput(DT_INT32)) // begin
++ .Attr("combiner", combiner_mode)
++ .Finalize(node_def()));
++ TF_ASSERT_OK(InitOp());
++ }
++};
++
++TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestReduceMean) {
++ MakeOp(1);
++
++ AddInputFromArray<float>(TensorShape({8}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
++ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
++ AddInputFromArray<int64>(TensorShape({3, 4}),
++ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
++ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
++ TF_ASSERT_OK(RunOpKernel());
++
++ Tensor expected(allocator(), DT_INT32, TensorShape({1}));
++ test::FillValues<int32>(&expected, {4});
++ test::ExpectTensorEqual<int32>(expected, *GetOutput(0)); // output_shape
++
++ Tensor expected_1(allocator(), DT_INT32, TensorShape({2, 1}));
++ test::FillValues<int32>(&expected_1, {2, 3});
++ test::ExpectTensorEqual<int32>(expected_1, *GetOutput(1)); // output_indices
++
++ Tensor expected_2(allocator(), DT_FLOAT, TensorShape({2}));
++ test::FillValues<float>(&expected_2, {2, 2});
++ test::ExpectTensorEqual<float>(expected_2, *GetOutput(2)); // output_nonzero
++}
++
++TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestReduceSum) {
++ MakeOp(0);
++
++ AddInputFromArray<float>(TensorShape({8}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
++ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
++ AddInputFromArray<int64>(TensorShape({3, 4}),
++ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
++ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
++ TF_ASSERT_OK(RunOpKernel());
++
++ Tensor expected(allocator(), DT_INT32, TensorShape({1}));
++ test::FillValues<int32>(&expected, {4});
++ test::ExpectTensorEqual<int32>(expected, *GetOutput(0)); // output_shape
++
++ Tensor expected_1(allocator(), DT_INT32, TensorShape({2, 1}));
++ test::FillValues<int32>(&expected_1, {2, 3});
++ test::ExpectTensorEqual<int32>(expected_1, *GetOutput(1)); // output_indices
++
++ Tensor expected_2(allocator(), DT_FLOAT, TensorShape({2}));
++ test::FillValues<float>(&expected_2, {4, 2});
++ test::ExpectTensorEqual<float>(expected_2, *GetOutput(2)); // output_nonzero
++}
++
++TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestInvalidData) {
++ MakeOp(0);
++
++ AddInputFromArray<float>(TensorShape({4, 2}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
++ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
++ AddInputFromArray<int64>(TensorShape({3, 4}),
++ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
++ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "Input data must be a vector") !=
++ std::string::npos);
++}
++
++TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestInvalidSliceinput) {
++ MakeOp(0);
++
++ AddInputFromArray<float>(TensorShape({8}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
++ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
++ AddInputFromArray<int64>(TensorShape({3, 4, 1}),
++ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
++ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "slice input must be 2-D") !=
++ std::string::npos);
++}
++
++TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestInvalidbegin) {
++ MakeOp(0);
++
++ AddInputFromArray<float>(TensorShape({8}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
++ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
++ AddInputFromArray<int64>(TensorShape({3, 4}),
++ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
++ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "begin must have 2 elements"));
++}
++
++TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestColsOutOfBounds) {
++ MakeOp(0);
++
++ AddInputFromArray<float>(TensorShape({8}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
++ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
++ AddInputFromArray<int64>(TensorShape({3, 4}),
++ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
++ AddInputFromArray<int32>(TensorShape({2}), {0, 4});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "Column index out of range"));
++}
++
++TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestIndicesOutOfBounds) {
++ MakeOp(0);
++
++ AddInputFromArray<float>(TensorShape({8}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
++ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
++ AddInputFromArray<int64>(TensorShape({3, 4}),
++ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
++ AddInputFromArray<int32>(TensorShape({2}), {0, 1});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(),
++ "indices and slice_input.dim_zie(0) should have same size"));
++}
++
++} // namespace
++} // namespace tensorflow
+diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_test.cc b/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_test.cc
+new file mode 100644
+index 0000000..558ab85
+--- /dev/null
++++ b/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_test.cc
+@@ -0,0 +1,205 @@
++/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
++ *
++ * Licensed under the Apache License, Version 2.0 (the "License");
++ * you may not use this file except in compliance with the License.
++ * You may obtain a copy of the License at
++ *
++ * http://www.apache.org/licenses/LICENSE-2.0
++ *
++ * Unless required by applicable law or agreed to in writing, software
++ * distributed under the License is distributed on an "AS IS" BASIS,
++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++ * See the License for the specific language governing permissions and
++ * limitations under the License.
++ * ==============================================================================*/
++
++#include <functional>
++#include <memory>
++#include <vector>
++
++#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
++#include "tensorflow/core/framework/allocator.h"
++#include "tensorflow/core/framework/fake_input.h"
++#include "tensorflow/core/framework/node_def_builder.h"
++#include "tensorflow/core/framework/op_kernel.h"
++#include "tensorflow/core/framework/tensor.h"
++#include "tensorflow/core/framework/types.h"
++#include "tensorflow/core/framework/types.pb.h"
++#include "tensorflow/core/graph/testlib.h"
++#include "tensorflow/core/kernels/ops_testutil.h"
++#include "tensorflow/core/kernels/ops_util.h"
++#include "tensorflow/core/lib/core/status_test_util.h"
++#include "tensorflow/core/lib/gtl/array_slice.h"
++#include "tensorflow/core/lib/random/simple_philox.h"
++#include "tensorflow/core/lib/strings/str_util.h"
++#include "tensorflow/core/platform/test.h"
++#include "tensorflow/core/platform/test_benchmark.h"
++
++namespace tensorflow {
++namespace {
++
++class KPFusedSparseSegmentReduceOpTest : public OpsTestBase {
++ protected:
++ void MakeOp(int combiner_mode) {
++ TF_ASSERT_OK(NodeDefBuilder("kp_fused_sparse_segment_reduce",
++ "KPFusedSparseSegmentReduce")
++ .Input(FakeInput(DT_FLOAT)) // data
++ .Input(FakeInput(DT_INT32)) // indices
++ .Input(FakeInput(DT_INT64)) // slice_input
++ .Input(FakeInput(DT_INT32)) // begin
++ .Input(FakeInput(DT_INT32)) // begin_1
++ .Attr("combiner", combiner_mode)
++ .Finalize(node_def()));
++ TF_ASSERT_OK(InitOp());
++ }
++};
++
++TEST_F(KPFusedSparseSegmentReduceOpTest, TestReduceMean) {
++ MakeOp(1);
++
++ AddInputFromArray<float>(TensorShape({4, 2}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
++ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
++ AddInputFromArray<int64>(TensorShape({3, 4}),
++ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
++ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
++ AddInputFromArray<int32>(TensorShape({1}), {1});
++
++ TF_ASSERT_OK(RunOpKernel());
++
++ Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 2}));
++ test::FillValues<float>(&expected,
++ {0.0f, 0.0f, 0.0f, 0.0f, 3.0f, 4.0f, 3.0f, 4.0f});
++ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
++
++ Tensor expected_1(allocator(), DT_INT32, TensorShape({}));
++ test::FillValues<int32>(&expected_1, {2});
++ test::ExpectTensorEqual<int32>(expected_1, *GetOutput(1));
++}
++
++TEST_F(KPFusedSparseSegmentReduceOpTest, TestReduceSum) {
++ MakeOp(0);
++
++ AddInputFromArray<float>(TensorShape({4, 2}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
++ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
++ AddInputFromArray<int64>(TensorShape({3, 4}),
++ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
++ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
++ AddInputFromArray<int32>(TensorShape({1}), {0});
++
++ TF_ASSERT_OK(RunOpKernel());
++
++ Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 2}));
++ test::FillValues<float>(&expected,
++ {0.0f, 0.0f, 0.0f, 0.0f, 6.0f, 8.0f, 3.0f, 4.0f});
++ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
++
++ Tensor expected_1(allocator(), DT_INT32, TensorShape({}));
++ test::FillValues<int32>(&expected_1, {4});
++ test::ExpectTensorEqual<int32>(expected_1, *GetOutput(1));
++}
++
++TEST_F(KPFusedSparseSegmentReduceOpTest, TestColsOutOfBounds) {
++ MakeOp(0);
++
++ AddInputFromArray<float>(TensorShape({4, 2}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
++ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
++ AddInputFromArray<int64>(TensorShape({3, 4}),
++ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
++ AddInputFromArray<int32>(TensorShape({2}), {0, 5});
++ AddInputFromArray<int32>(TensorShape({1}), {0});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "Column index out of range"));
++}
++
++TEST_F(KPFusedSparseSegmentReduceOpTest, Test) {
++ MakeOp(0);
++
++ AddInputFromArray<float>(TensorShape({4, 2}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
++ AddInputFromArray<int32>(TensorShape({2}),
++ {0, 2}); // num_indices != slice_input.dim_size(0)
++ AddInputFromArray<int64>(TensorShape({3, 4}),
++ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
++ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
++ AddInputFromArray<int32>(TensorShape({1}), {0});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(),
++ "indices and slice_input.dim_zie(0) should have same size"));
++}
++
++TEST_F(KPFusedSparseSegmentReduceOpTest, TestInvalidData) {
++ MakeOp(0);
++
++ AddInputFromArray<float>(
++ TensorShape({4, 2, 1}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); // data.dims() > 2
++ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
++ AddInputFromArray<int64>(TensorShape({3, 4}),
++ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
++ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
++ AddInputFromArray<int32>(TensorShape({1}), {0});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "input must be 2-D"));
++}
++
++TEST_F(KPFusedSparseSegmentReduceOpTest, TestInvalidSliceinput) {
++ MakeOp(0);
++
++ AddInputFromArray<float>(TensorShape({4, 2}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
++ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
++ AddInputFromArray<int64>(
++ TensorShape({3, 4, 1}),
++ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); // slice_input.dims() > 2
++ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
++ AddInputFromArray<int32>(TensorShape({1}), {0});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "slice input must be 2-D"));
++}
++
++TEST_F(KPFusedSparseSegmentReduceOpTest, TestInvalidBegin) {
++ MakeOp(0);
++
++ AddInputFromArray<float>(TensorShape({4, 2}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
++ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
++ AddInputFromArray<int64>(TensorShape({3, 4}),
++ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
++ AddInputFromArray<int32>(TensorShape({3}),
++ {0, 2, 1}); // begin has 3 elements
++ AddInputFromArray<int32>(TensorShape({1}), {0});
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "begin must have 2 elements"));
++}
++
++TEST_F(KPFusedSparseSegmentReduceOpTest, TestInvalidBegin1) {
++ MakeOp(0);
++
++ AddInputFromArray<float>(TensorShape({4, 2}),
++ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
++ AddInputFromArray<int32>(TensorShape({3}), {0, 2, 1});
++ AddInputFromArray<int64>(TensorShape({3, 4}),
++ {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4});
++ AddInputFromArray<int32>(TensorShape({2}), {0, 2});
++ AddInputFromArray<int32>(TensorShape({2}), {0, 1}); // begin_1 has 2 elements
++
++ Status s = RunOpKernel();
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "begin_1 must have 1 element"));
++}
++
++} // namespace
++} // namespace tensorflow
+diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_select.cc b/annc/tensorflow/kernels/embedding_fused_sparse_select.cc
+new file mode 100644
+index 0000000..306a420
+--- /dev/null
++++ b/annc/tensorflow/kernels/embedding_fused_sparse_select.cc
+@@ -0,0 +1,111 @@
++/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
++
++Licensed under the Apache License, Version 2.0 (the "License");
++you may not use this file except in compliance with the License.
++You may obtain a copy of the License at
++
++ http://www.apache.org/licenses/LICENSE-2.0
++
++Unless required by applicable law or agreed to in writing, software
++distributed under the License is distributed on an "AS IS" BASIS,
++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++See the License for the specific language governing permissions and
++limitations under the License.
++==============================================================================*/
++
++#include <vector>
++#include <algorithm>
++
++#include "tensorflow/core/framework/op_kernel.h"
++#include "tensorflow/core/framework/tensor.h"
++#include "tensorflow/core/util/work_sharder.h"
++#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
++#include "tensorflow/core/platform/logging.h"
++
++using namespace tensorflow;
++
++class KPFusedSparseSelect : public OpKernel {
++public:
++ explicit KPFusedSparseSelect(OpKernelConstruction* context) : OpKernel(context) {
++ }
++
++ void Compute(OpKernelContext* context) override {
++ const Tensor& input_a = context->input(0);
++ const Tensor& input_b = context->input(1);
++ const Tensor& input_c = context->input(2);
++ const Tensor& greater = context->input(3);
++ const Tensor& equal1 = context->input(4);
++ const Tensor& equal2 = context->input(5);
++ const Tensor& equal3 = context->input(6);
++
++ int32_t equal1_val = equal1.flat<int32_t>()(0);
++ int32_t equal2_val = equal2.flat<int32_t>()(0);
++ int32_t equal3_val = equal3.flat<int32_t>()(0);
++ VLOG(1) << "equal1_val: " << equal1_val;
++ VLOG(1) << "equal2_val: " << equal2_val;
++ VLOG(1) << "equal3_val: " << equal3_val;
++
++ int32_t greater_val = greater.flat<int32_t>()(0);
++ auto a_flat = input_a.flat<int32_t>();
++ auto b_flat = input_b.flat<int32_t>();
++ auto c_flat = input_c.flat<int32_t>();
++ VLOG(1) << "input_a shape: " << input_a.shape().DebugString();
++ VLOG(1) << "input_b shape: " << input_b.shape().DebugString();
++ VLOG(1) << "input_c shape: " << input_c.shape().DebugString();
++ OP_REQUIRES(context, input_a.NumElements() == input_b.NumElements(),
++ errors::InvalidArgument("Input num elements of a and b must match"));
++ OP_REQUIRES(context, input_a.NumElements() == input_c.NumElements(),
++ errors::InvalidArgument("Input num elements of a and c must match"));
++ auto N = input_a.NumElements();
++
++ Eigen::TensorMap<Eigen::Tensor<const int32_t, 2, Eigen::RowMajor>> a_reshaped_tensor(a_flat.data(), N, 1);
++ Eigen::TensorMap<Eigen::Tensor<const int32_t, 2, Eigen::RowMajor>> b_reshaped_tensor(b_flat.data(), N, 1);
++ Eigen::TensorMap<Eigen::Tensor<const int32_t, 2, Eigen::RowMajor>> c_reshaped_tensor(c_flat.data(), N, 1);
++
++ Tensor* output_x = nullptr;
++ Tensor* output_y = nullptr;
++ Tensor* output_w = nullptr;
++
++ OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({N, 1}), &output_x));
++ OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({N, 1}), &output_y));
++ OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({N, 2}), &output_w));
++
++ Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor>> out_x(
++ output_x->flat<float>().data(),
++ output_x->dim_size(0),
++ output_x->dim_size(1)
++ );
++
++ Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor>> out_y(
++ output_y->flat<float>().data(),
++ output_y->dim_size(0),
++ output_y->dim_size(1)
++ );
++
++ Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor>> out_w(
++ output_w->flat<float>().data(),
++ output_w->dim_size(0),
++ output_w->dim_size(1)
++ );
++
++ auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
++ const int64 cost_per_unit = std::max(N / worker_threads->num_threads, int64(10));
++
++ auto work = [&](int64 start, int64 end) {
++ for (int64 i = start; i < end; i++) {
++ // Greater(bool)+Cast.2406(float) --> 1.0f / 0.0f
++ float a_greater = (a_reshaped_tensor(i, 0) > greater_val) ? 1.0f : 0.0f;
++ float res_equal1 = (b_reshaped_tensor(i, 0) == equal1_val) ? 1.0f : a_greater; // Fill.2409-->1.0f
++ float res_equal2 = (b_reshaped_tensor(i, 0) == equal2_val) ? 1.0f : res_equal1; // Fill.2409-->1.0f
++ out_x(i, 0) = a_reshaped_tensor(i, 0); // Reshape.2401
++ out_y(i, 0) = res_equal2;
++ out_w(i, 0) = res_equal2; // Mul.2419 硬编码 1.0f * input
++ out_w(i, 1) = 1.0f; // select_2427被消除,直接使用Fill.2422-->1.0f
++ }
++ };
++ Shard(worker_threads->num_threads, worker_threads->workers, N, cost_per_unit, work);
++ }
++};
++
++REGISTER_KERNEL_BUILDER(Name("KPFusedSparseSelect").Device(DEVICE_CPU),
++ KPFusedSparseSelect);
+\ No newline at end of file
+diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_select_test.cc b/annc/tensorflow/kernels/embedding_fused_sparse_select_test.cc
+new file mode 100644
+index 0000000..f956415
+--- /dev/null
++++ b/annc/tensorflow/kernels/embedding_fused_sparse_select_test.cc
+@@ -0,0 +1,182 @@
++/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved.
++
++Licensed under the Apache License, Version 2.0 (the "License");
++you may not use this file except in compliance with the License.
++You may obtain a copy of the License at
++
++ http://www.apache.org/licenses/LICENSE-2.0
++
++Unless required by applicable law or agreed to in writing, software
++distributed under the License is distributed on an "AS IS" BASIS,
++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++See the License for the specific language governing permissions and
++limitations under the License.
++==============================================================================*/
++
++#include "tensorflow/core/framework/fake_input.h"
++#include "tensorflow/core/framework/node_def_builder.h"
++#include "tensorflow/core/framework/tensor.h"
++#include "tensorflow/core/framework/tensor_testutil.h"
++#include "tensorflow/core/kernels/ops_testutil.h"
++#include "tensorflow/core/lib/core/status_test_util.h"
++
++namespace {
++using tensorflow::AllocatorAttributes;
++using tensorflow::DT_FLOAT;
++using tensorflow::DT_INT32;
++using tensorflow::DT_INT64;
++using tensorflow::int64;
++using tensorflow::int32;
++using tensorflow::NodeDefBuilder;
++using tensorflow::OpsTestBase;
++using tensorflow::Status;
++using tensorflow::Tensor;
++using tensorflow::TensorShape;
++using tensorflow::test::ExpectClose;
++using tensorflow::test::FillValues;
++using tensorflow::test::AsTensor;
++using tensorflow::test::ExpectTensorEqual;
++
++class KPFusedSparseSelectTest : public OpsTestBase {
++ protected:
++ void RunValidCase(
++ const TensorShape& shape,
++ const std::vector<int32>& a_data,
++ const std::vector<int32>& b_data,
++ const std::vector<int32>& c_data,
++ int32_t greater_val,
++ int32_t equal1_val,
++ int32_t equal2_val,
++ const std::vector<float>& expected_y,
++ const std::vector<float>& expected_w_col0) {
++
++ TF_EXPECT_OK(NodeDefBuilder("kp_fused_sparse_select", "KPFusedSparseSelect")
++ .Input(FakeInput(DT_INT32))
++ .Input(FakeInput(DT_INT32))
++ .Input(FakeInput(DT_INT32))
++ .Input(FakeInput(DT_INT32)) // greater
++ .Input(FakeInput(DT_INT32)) // equal1
++ .Input(FakeInput(DT_INT32)) // equal2
++ .Input(FakeInput(DT_INT32)) // equal3
++ .Finalize(node_def()));
++ TF_EXPECT_OK(InitOp());
++
++ AddInputFromArray<int32>(shape, a_data);
++ AddInputFromArray<int32>(shape, b_data);
++ AddInputFromArray<int32>(shape, c_data);
++ AddInputFromArray<int32>(TensorShape({}), {greater_val}); // scalar
++ AddInputFromArray<int32>(TensorShape({}), {equal1_val});
++ AddInputFromArray<int32>(TensorShape({}), {equal2_val});
++ AddInputFromArray<int32>(TensorShape({}), {0}); // equal3_val (未使用)
++
++ TF_ASSERT_OK(RunOpKernel());
++
++ const Tensor& out_x = *GetOutput(0);
++ const Tensor& out_y = *GetOutput(1);
++ const Tensor& out_w = *GetOutput(2);
++
++ int32 Num_elements = expected_y.size();
++ // 验证 output_x: 就是 input_a
++ std::vector<float> a_data_float(a_data.begin(), a_data.end());
++ ExpectTensorEqual<float>(out_x, AsTensor<float>(a_data_float, {Num_elements, 1}));
++
++ // 验证 output_y
++ ExpectTensorEqual<float>(out_y, AsTensor<float>(expected_y, {Num_elements, 1}));
++ // 验证 output_w 第一列
++ auto w_mat = out_w.matrix<float>();
++ for (int i = 0; i < w_mat.dimension(0); ++i) {
++ EXPECT_FLOAT_EQ(w_mat(i, 0), expected_w_col0[i]);
++ EXPECT_FLOAT_EQ(w_mat(i, 1), 1.0f); // 第二列必须是 1.0
++ }
++ }
++
++ Status RunOpExpectFailure(
++ const TensorShape& shape,
++ const std::vector<int32>& a_data,
++ const std::vector<int32>& b_data,
++ const std::vector<int32>& c_data,
++ int32_t greater_val,
++ int32_t equal1_val,
++ int32_t equal2_val) {
++
++ TF_CHECK_OK(NodeDefBuilder("kp_fused_sparse_select", "KPFusedSparseSelect")
++ .Input(FakeInput(DT_INT32))
++ .Input(FakeInput(DT_INT32))
++ .Input(FakeInput(DT_INT32))
++ .Input(FakeInput(DT_INT32))
++ .Input(FakeInput(DT_INT32))
++ .Input(FakeInput(DT_INT32))
++ .Input(FakeInput(DT_INT32))
++ .Finalize(node_def()));
++ TF_CHECK_OK(InitOp());
++ TensorShape b_shape({static_cast<int64>(b_data.size())});
++ TensorShape c_shape({static_cast<int64>(c_data.size())});
++ AddInputFromArray<int32>(shape, a_data);
++ AddInputFromArray<int32>(b_shape, b_data);
++ AddInputFromArray<int32>(c_shape, c_data);
++ AddInputFromArray<int32>(TensorShape({}), {greater_val});
++ AddInputFromArray<int32>(TensorShape({}), {equal1_val});
++ AddInputFromArray<int32>(TensorShape({}), {equal2_val});
++ AddInputFromArray<int32>(TensorShape({}), {0});
++
++ return RunOpKernel();
++ }
++};
++
++// ==================== 正向测试 ====================
++// 更多正向验证参考 fused_embedding_sparse_select_test.py
++TEST_F(KPFusedSparseSelectTest, Valid_NormalInput) {
++ RunValidCase(
++ TensorShape({3}), // shape
++ {5, 3, 8}, // input_a
++ {1, 2, 1}, // input_b
++ {9, 8, 7}, // input_c (未使用)
++ 4, // greater_val
++ 1, // equal1_val
++ 3, // equal2_val
++ {1.0f, 0.0f, 1.0f}, // expected_y
++ {1.0f, 0.0f, 1.0f} // expected_w_col0
++ );
++}
++
++TEST_F(KPFusedSparseSelectTest, Valid_2DInput) {
++ RunValidCase(
++ TensorShape({2, 2}),
++ {6, 3, 8, 2},
++ {2, 1, 3, 4},
++ {0, 0, 0, 0},
++ 5,
++ 2,
++ 3,
++ {1.0f, 0.0f, 1.0f, 0.0f},
++ {1.0f, 0.0f, 1.0f, 0.0f}
++ );
++}
++// ==================== 反向测试 ====================
++// 反例1:input_a 与 input_b 元素数不匹配
++TEST_F(KPFusedSparseSelectTest, Invalid_DimMismatch_AB) {
++ Status s = RunOpExpectFailure(
++ TensorShape({3}), // a 有 3 个元素
++ {1, 2, 3},
++ {4, 5}, // b 有 2 个元素 → 不匹配!
++ {6, 7, 8},
++ 0, 1, 2
++ );
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "Input num elements of a and b must match"));
++}
++
++// 反例2:input_a 与 input_c 元素数不匹配
++TEST_F(KPFusedSparseSelectTest, Invalid_DimMismatch_AC) {
++ Status s = RunOpExpectFailure(
++ TensorShape({2}),
++ {1, 2},
++ {3, 4},
++ {5}, // c 只有 1 个元素 → 不匹配!
++ 0, 1, 2
++ );
++ EXPECT_FALSE(s.ok());
++ EXPECT_TRUE(absl::StrContains(s.message(), "Input num elements of a and c must match"));
++}
++
++}
+\ No newline at end of file
+diff --git a/annc/tensorflow/ops/embedding_fused_ops.cc b/annc/tensorflow/ops/embedding_fused_ops.cc
+new file mode 100644
+index 0000000..bc0ce24
+--- /dev/null
++++ b/annc/tensorflow/ops/embedding_fused_ops.cc
+@@ -0,0 +1,133 @@
++/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
++
++Licensed under the Apache License, Version 2.0 (the "License");
++you may not use this file except in compliance with the License.
++You may obtain a copy of the License at
++
++ http://www.apache.org/licenses/LICENSE-2.0
++
++Unless required by applicable law or agreed to in writing, software
++distributed under the License is distributed on an "AS IS" BASIS,
++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
++See the License for the specific language governing permissions and
++limitations under the License.
++==============================================================================*/
++#include <stdio.h>
++
++#include "tensorflow/core/framework/op.h"
++#include "tensorflow/core/framework/types.h"
++#include "tensorflow/core/framework/shape_inference.h"
++#include "tensorflow/core/framework/common_shape_fns.h"
++
++namespace tensorflow {
++
++using shape_inference::DimensionHandle;
++using shape_inference::InferenceContext;
++using shape_inference::ShapeHandle;
++using shape_inference::UnchangedShape;
++
++REGISTER_OP("KPFusedSparseSegmentReduce")
++ .Input("data: float")
++ .Input("indices: Tidx")
++ .Input("slice_input: int64")
++ .Input("begin: int32")
++ .Input("begin_1: int32")
++ .Attr("combiner: int = 1") // 0 for SUM, 1 for MEAN
++ .Attr("Tidx: {int32, int64} = DT_INT32")
++ .Output("output: float")
++ .Output("slice_output: int32")
++ .SetShapeFn(shape_inference::UnknownShape);
++
++REGISTER_OP("KPFusedSparseSegmentReduceNonzero")
++ .Input("data: float")
++ .Input("indices: Tidx")
++ .Input("slice_input: int64")
++ .Input("begin: int32")
++ .Attr("combiner: int = 1") // 0 for SUM, 1 for MEAN
++ .Attr("Tidx: {int32, int64} = DT_INT32")
++ .Output("output_shape: int32")
++ .Output("output_indices: int32")
++ .Output("output_nonzero: float")
++ .SetShapeFn(shape_inference::UnknownShape);
++
++REGISTER_OP("KPFusedEmbeddingPaddingFast")
++ .Input("input0: int64")
++ .Input("input1: float")
++ .Input("input2: int32")
++ .Input("input3: int32")
++ .Input("pack: int32")
++ .Output("output0: int32")
++ .Output("output1: int32")
++ .SetShapeFn([](InferenceContext* c) {
++ ShapeHandle scalar_shape = c->Scalar();
++ c->set_output(0, scalar_shape);
++ c->set_output(1, scalar_shape);
++ return OkStatus();
++ });
++
++REGISTER_OP("KPFusedEmbeddingPadding")
++ .Input("input0: int64")
++ .Input("input1: float")
++ .Input("input2: int32")
++ .Input("input3: int32")
++ .Input("pack: int32")
++ .Output("output0: int32")
++ .Output("output1: float")
++ .SetShapeFn([](InferenceContext* c) {
++ ShapeHandle out;
++ ShapeHandle scalar_shape = c->Scalar();
++ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(3, &out));
++ c->set_output(0, scalar_shape);
++ c->set_output(1, out);
++ return OkStatus();
++ });
++
++REGISTER_OP("KPFusedSparseSelect")
++ .Input("input_a: int32")
++ .Input("input_b: int32")
++ .Input("input_c: int32")
++ .Input("greater: int32")
++ .Input("equal1: int32")
++ .Input("equal2: int32")
++ .Input("equal3: int32")
++ .Output("output_x: float")
++ .Output("output_y: float")
++ .Output("output_w: float")
++ .SetShapeFn(shape_inference::UnknownShape);
++
++REGISTER_OP("KPFusedSparseReshape")
++ .Input("slice_input: int64")
++ .Input("begin: int32")
++ .Input("new_shape: int64")
++ .Input("pack_const: int64")
++ .Output("out_indices: int64")
++ .Output("out_shape: int64")
++ .SetShapeFn(shape_inference::UnknownShape);
++
++REGISTER_OP("KPFusedSparseDynamicStitch")
++ .Input("x: int64")
++ .Input("variables: N * float")
++ .Output("output: float")
++ .Attr("N: int >= 1")
++ .SetShapeFn(shape_inference::UnknownShape);
++
++REGISTER_OP("KPFusedGather")
++ .Input("data: float")
++ .Input("slice_input: int64")
++ .Input("begin: int32")
++ .Output("out_shape: int64")
++ .Output("out_indices: int32")
++ .Output("out_data: float")
++ .SetShapeFn(shape_inference::UnknownShape);
++
++REGISTER_OP("KPFusedEmbeddingActionIdGather")
++ .Input("input0: Tindices1")
++ .Input("input1: float")
++ .Input("input2: Tindices2")
++ .Input("input3: int32")
++ .Input("pack: int32")
++ .Attr("Tindices1: {int32, int64} = DT_INT64")
++ .Attr("Tindices2: {int32, int64} = DT_INT32")
++ .Output("output0: float")
++ .SetShapeFn(shape_inference::UnknownShape);
++} // namespace tensorflow
+\ No newline at end of file
+diff --git a/annc/tensorflow/tf_annc_optimizer.patch b/annc/tensorflow/tf_annc_optimizer.patch
+new file mode 100644
+index 0000000..c92330a
+--- /dev/null
++++ b/annc/tensorflow/tf_annc_optimizer.patch
+@@ -0,0 +1,221 @@
++diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
++index 538574360ba..b43e0455802 100644
++--- a/tensorflow/core/BUILD
+++++ b/tensorflow/core/BUILD
++@@ -629,6 +629,7 @@ cc_library(
++ "//tensorflow/core/kernels/linalg",
++ "//tensorflow/core/kernels/sparse:kernels",
++ "//tensorflow/core/kernels/uniform_quant_ops:kernels",
+++ "//tensorflow/core/kernels:embedding_fused_ops",
++ ] + if_mkl([
++ "//tensorflow/core/kernels/mkl:mkl_concat_op",
++ "//tensorflow/core/kernels/mkl:mkl_dequantize_op",
++diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
++index ecd559734ea..97a918ead6d 100644
++--- a/tensorflow/core/grappler/optimizers/BUILD
+++++ b/tensorflow/core/grappler/optimizers/BUILD
++@@ -880,6 +880,18 @@ tf_cuda_cc_test(
++ ],
++ )
++
+++alias(
+++ name = "is_aarch64",
+++ actual = "@local_xla//xla/service/cpu:is_aarch64",
+++ visibility = ["//visibility:public"],
+++)
+++
+++alias(
+++ name = "aarch64_and_annc_disabled",
+++ actual = "@local_xla//xla/service/cpu:aarch64_and_annc_disabled",
+++ visibility = ["//visibility:public"],
+++)
+++
++ tf_kernel_library(
++ name = "remapper",
++ srcs = ["remapper.cc"],
++@@ -904,7 +916,12 @@ tf_kernel_library(
++ "//tensorflow/core/grappler/utils:symbolic_shapes",
++ "//tensorflow/core/grappler/utils:topological_sort",
++ "@com_google_absl//absl/container:flat_hash_set",
++- ] + if_mkl(["//tensorflow/core/graph:mkl_graph_util"]),
+++ ] + if_mkl(["//tensorflow/core/graph:mkl_graph_util"])
+++ + select({
+++ ":aarch64_and_annc_disabled": [],
+++ ":is_aarch64": ["//tensorflow/core/grappler/optimizers/graph_optimizer:annc_graph_opt"],
+++ "//conditions:default": [],
+++ })
++ )
++
++ tf_cuda_cc_test(
++diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
++index 3c37150f496..2fbde836d1d 100644
++--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++++ b/tensorflow/core/grappler/optimizers/remapper.cc
++@@ -52,6 +52,7 @@ limitations under the License.
++ #include "third_party/gpus/cudnn/cudnn.h"
++ #endif // GOOGLE_CUDA
++
+++#include "tensorflow/core/grappler/optimizers/graph_optimizer/graph_opt.h"
++ namespace tensorflow {
++ namespace grappler {
++
++@@ -4596,6 +4597,7 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index,
++ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
++ GraphDef* optimized_graph) {
++ GrapplerItem mutable_item = item;
+++ annc::run_graph_optimization(&mutable_item.graph);
++ Status status;
++ RemapperContext ctx(&mutable_item, &status, cpu_layout_conversion_,
++ xla_auto_clustering_on_);
++diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
++index 22617ef8ab5..f89a084f94c 100644
++--- a/tensorflow/core/kernels/BUILD
+++++ b/tensorflow/core/kernels/BUILD
++@@ -3655,6 +3655,97 @@ tf_kernel_library(
++ ]) + [":fft_impl"],
++ )
++
+++tf_kernel_library(
+++ name = "embedding_fused_action_id_gather_op",
+++ srcs = ["embedding_fused_action_id_gather.cc"],
+++ deps = MATH_DEPS,
+++)
+++
+++tf_kernel_library(
+++ name = "embedding_fused_gather_op",
+++ srcs = ["embedding_fused_gather.cc"],
+++ deps = MATH_DEPS,
+++)
+++
+++tf_kernel_library(
+++ name = "embedding_fused_padding_op",
+++ srcs = ["embedding_fused_padding.cc"],
+++ deps = MATH_DEPS,
+++)
+++
+++tf_kernel_library(
+++ name = "embedding_fused_sparse_dynamic_stitch_op",
+++ srcs = ["embedding_fused_sparse_dynamic_stitch.cc"],
+++ deps = MATH_DEPS,
+++)
+++
+++tf_kernel_library(
+++ name = "embedding_fused_reshape_op",
+++ srcs = ["embedding_fused_sparse_reshape.cc"],
+++ deps = MATH_DEPS + [
+++ ":reshape_util",
+++ ],
+++)
+++
+++tf_kernel_library(
+++ name = "embedding_fused_sparse_segment_reduce_op",
+++ srcs = ["embedding_fused_sparse_segment_reduce.cc"],
+++ deps = MATH_DEPS,
+++)
+++
+++tf_kernel_library(
+++ name = "embedding_fused_sparse_segment_reduce_nonzero_op",
+++ srcs = ["embedding_fused_sparse_segment_reduce_nonzero.cc"],
+++ deps = MATH_DEPS + ["@com_google_absl//absl/container:flat_hash_map"],
+++)
+++
+++tf_kernel_library(
+++ name = "embedding_fused_sparse_select_op",
+++ srcs = ["embedding_fused_sparse_select.cc"],
+++ deps = MATH_DEPS,
+++)
+++
+++cc_library(
+++ name = "embedding_fused_ops",
+++ deps = [
+++ ":embedding_fused_action_id_gather_op",
+++ ":embedding_fused_gather_op",
+++ ":embedding_fused_padding_op",
+++ ":embedding_fused_sparse_dynamic_stitch_op",
+++ ":embedding_fused_reshape_op",
+++ ":embedding_fused_sparse_segment_reduce_op",
+++ ":embedding_fused_sparse_segment_reduce_nonzero_op",
+++ ":embedding_fused_sparse_select_op",
+++ ],
+++)
+++
+++tf_cc_test(
+++ name = "embedding_fused_ops_test",
+++ size = "small",
+++ srcs = [
+++ "embedding_fused_action_id_gather_test.cc",
+++ "embedding_fused_sparse_dynamic_stitch_test.cc",
+++ "embedding_fused_sparse_segment_reduce_test.cc",
+++ "embedding_fused_sparse_segment_reduce_nonzero_test.cc",
+++ "embedding_fused_padding_test.cc",
+++ "embedding_fused_sparse_select_test.cc",
+++ "embedding_fused_gather_test.cc",
+++ "embedding_fused_sparse_reshape_test.cc",
+++ ],
+++ deps = [
+++ ":ops_testutil",
+++ ":ops_util",
+++ ":embedding_fused_ops",
+++ "//tensorflow/core:core_cpu",
+++ "//tensorflow/core:framework",
+++ "//tensorflow/core:lib",
+++ "//tensorflow/core:protos_all_cc",
+++ "//tensorflow/core:test",
+++ "//tensorflow/core:test_main",
+++ "//tensorflow/core:testlib",
+++ ],
+++)
+++
++ tf_kernel_library(
++ name = "reduction_ops",
++ gpu_srcs = ["reduction_gpu_kernels.cu.h"],
++diff --git a/tensorflow/core/ops/BUILD b/tensorflow/core/ops/BUILD
++index 91d80b6c2b5..b00b2a5d027 100644
++--- a/tensorflow/core/ops/BUILD
+++++ b/tensorflow/core/ops/BUILD
++@@ -55,6 +55,7 @@ tf_gen_op_libs(
++ "decode_proto_ops",
++ "encode_proto_ops",
++ "experimental_dataset_ops",
+++ "embedding_fused_ops",
++ "filesystem_ops",
++ "function_ops",
++ "functional_ops",
++@@ -323,6 +324,7 @@ cc_library(
++ ":training_ops_op_lib",
++ ":uniform_quant_ops_op_lib",
++ ":word2vec_ops",
+++ ":embedding_fused_ops_op_lib",
++ ] + select({
++ # Non-tpu platforms don't need tpu dependency.
++ "//tensorflow:chromiumos": [],
++diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD
++index 6e0ea613435..47c346b4e93 100644
++--- a/third_party/xla/xla/service/cpu/BUILD
+++++ b/third_party/xla/xla/service/cpu/BUILD
++@@ -198,6 +198,23 @@ cc_library(
++ ],
++ )
++
+++config_setting(
+++ name = "disable_annc",
+++ define_values = {"disable_annc": "false"},
+++ visibility = ["//visibility:public"],
+++)
+++
+++config_setting(
+++ name = "is_aarch64",
+++ constraint_values = ["@platforms//cpu:aarch64"],
+++)
+++
+++config_setting(
+++ name = "aarch64_and_annc_disabled",
+++ constraint_values = ["@platforms//cpu:aarch64"],
+++ define_values = {"disable_annc": "true"},
+++)
+++
++ cc_library(
++ name = "cpu_compiler_pure",
++ srcs = ["cpu_compiler.cc"],
++
++
+--
+2.33.0
+
diff --git a/ANNC.spec b/ANNC.spec
index fca393d..ca85e5a 100644
--- a/ANNC.spec
+++ b/ANNC.spec
@@ -7,24 +7,30 @@
%global max_jobs 16
-%define debug_package %{nil}
+%global debug_package %{nil}
Summary: %{name} is an AI compiler designed to optimize and compile ML model into high-performance executable code that can be executed on various targets.
Name: ANNC
Version: 0.0.2
-Release: 1
+Release: 3
# Package onnxruntime and SafeInt have MIT License.
# Package onnx has Apache License 2.0.
License: MIT and ASL 2.0 and Boost and BSD
-URL: https://gitee.com/openeuler/AI4C
+URL: https://gitee.com/openeuler/ANNC
Source0: %{pkg_version}.tar.gz
-Source1: annc_external.tar.gz.aa
-Source2: annc_external.tar.gz.ab
+Source1: external.tar.gz.aa
+Source2: external.tar.gz.ab
+Source3: external.tar.gz.ac
+Source4: XNNPACK.tar.gz
+Source5: v3.2.tar.gz
%ifarch x86_64
Patch0: x86_64_external_files.patch
%endif
+Patch1: 0001-fix-pattern-conflicts.patch
+Patch2: 0002-Add-graph-optimizer-and-embedding_fused-kernels.patch
BuildRequires: cmake >= 3.9.9
+BuildRequires: make
BuildRequires: gcc
BuildRequires: gcc-c++
BuildRequires: bzip2
@@ -43,21 +49,29 @@ BuildRequires: bazel
%{name} is is an AI compiler designed to optimize and compile ML model into high-performance executable code that can be executed on various targets.
%prep
-cat %{SOURCE1} %{SOURCE2} > annc_external.tar.gz
-tar -xzf annc_external.tar.gz -C .
+cat %{SOURCE1} %{SOURCE2} %{SOURCE3} > external.tar.gz
+tar xf external.tar.gz -C .
%ifarch x86_64
-%patch 0 -p0 -d .
+%patch0 -p1 -d .
%endif
-
-%autosetup -S git -n %{pkg_version}
+tar xf %{SOURCE4} -C .
+mkdir proxy
+cp %{SOURCE5} ./proxy
+%setup -q -n %{pkg_version}
+%patch1 -p1
+%patch2 -p1
%build
export ANNC=%{build_dir}
-cd %{build_dir}/annc/service/cpu/xla/libs
-bash xnnpack.sh
-
-export XNNPACK_BASE="$ANNC/annc/service/cpu/xla/libs"
-export XNNPACK_DIR="$XNNPACK_BASE/XNNPACK"
+cd %{_builddir}/XNNPACK/build
+CFLAGS="-fPIC" CXXFLAGS="-fPIC" cmake .. -DXNNPACK_BUILD_BENCHMARKS=OFF \
+ -DXNNPACK_BUILD_TESTS=OFF \
+ -DXNNPACK_LIBRARY_TYPE=shared \
+ -DCMAKE_BUILD_TYPE=Release
+make -j %{max_jobs}
+rm -rf $ANNC/annc/service/cpu/xla/libs/libXNNPACK.so
+cp %{_builddir}/XNNPACK/build/libXNNPACK.so $ANNC/annc/service/cpu/xla/libs
+export XNNPACK_DIR="%{_builddir}/XNNPACK"
CPLUS_INCLUDE_PATH+="$ANNC/annc/service/cpu/xla:"
CPLUS_INCLUDE_PATH+="$ANNC/annc/service/:"
@@ -75,9 +89,12 @@ run_bazel_build() {
--verbose_failures \
--action_env="baila=548" \
--define tflite_with_xnnpack=false \
- --copt="-g" \
- --copt="-DNDBUG" \
- annc/service/cpu:libannc.so
+ --local_ram_resources=512 \
+ --distdir=%{_builddir}/proxy \
+ annc/service/cpu:libannc.so
+ # --copt="-g" \
+ # --copt="-DNDBUG" \
+ # annc/service/cpu:libannc.so
}
fix_action() {
@@ -85,7 +102,7 @@ fix_action() {
external_path=$(find . -name "external" | head -n 1)
if [ -n "$external_path" ]; then
rm -rf $external_path/*
- cp -r %{_builddir}/external/* $external_path
+ cp -LR %{_builddir}/external/* $external_path
else
echo "Not find external directory."
fi
@@ -106,12 +123,17 @@ pushd %{build_dir}/python
%install
install -d %{install_includedir}
install %{build_dir}/annc/service/cpu/kdnn_rewriter.h -t %{install_includedir}
-install %{build_dir}/install/*.patch -t %{install_includedir}
+install %{build_dir}/annc/service/cpu/annc_flags.h -t %{install_includedir}
+install %{build_dir}/annc/service/cpu/xla/*.h -t %{install_includedir}
+install -d %{install_includedir}/bisheng-cpu
+install %{build_dir}/annc/service/bisheng-cpu/*.h -t %{install_includedir}/bisheng-cpu
+cp -r %{build_dir}/annc/tensorflow %{install_includedir}
+cp -r %{build_dir}/install/* %{install_includedir}
install %{build_dir}/python/tensorflow/kernels/* -t %{install_includedir}
install -d %{install_libdir}
output_path=$(find %{build_dir} -type f -name "libannc.so")
install ${output_path} -t %{install_libdir}
-install %{build_dir}/annc/service/cpu/xla/libs/XNNPACK/build/*.so -t %{install_libdir}
+install %{build_dir}/annc/service/cpu/xla/libs/libXNNPACK.so -t %{install_libdir}
pushd %{build_dir}/python
%py3_install
@@ -121,9 +143,26 @@ pushd %{build_dir}/python
%{_libdir}/*
%{python3_sitelib}/*
/usr/bin/annc-opt
-/usr/bin/annc-apply-tf
%changelog
+* Tue Nov 11 2025 Chenhui Zheng <zhengchenhui1@huawei.com> - 0.0.2-3
+- Type:Update
+- ID:NA
+- SUG:NA
+- DEC: Add graph optimize and embedding_fused kernels.
+
+* Thu Sep 11 2025 Chenhui Zheng <zhengchenhui1@huawei.com> - 0.0.2-2
+- Type:Fix
+- ID:NA
+- SUG:NA
+- DEC:Fix pattern conflict & missing header files.
+
+* Fri Aug 22 2025 Chenhui Zheng <zhengchenhui1@huawei.com> - 0.0.2-1
+- Type:Update
+- ID:NA
+- SUG:NA
+- DEC:Release v0.0.2
+
* Mon May 12 2025 Chenhui Zheng <zhengchenhui1@huawei.com> - 0.0.1-1
- Type:Init
- ID:NA
diff --git a/sources b/sources
index 38f805f..5f8a86f 100644
--- a/sources
+++ b/sources
@@ -1,3 +1,6 @@
-871ddedbfbb9aa75c2db497584667f61 ANNC-v0.0.2.tar.gz
-a3f0ec5120fa9b65af527332299c9d46 annc_external.tar.gz.aa
-f548d6ba0ad0163c0aa3df33250e97c6 annc_external.tar.gz.ab
+c727cf97cec102b5399d3f43bcf1dd5a ANNC-v0.0.2.tar.gz
+20e5d643ae5e8981686f54d7838a959a XNNPACK.tar.gz
+2851e140b6c1c07b44cd4db060f17fbd external.tar.gz.aa
+8955b6806f170bbe9a5d6a864e75d2d3 external.tar.gz.ab
+88966e263f2f840215f9df1ac34a0389 external.tar.gz.ac
+19c62a338990388a31cd2ecb918af855 v3.2.tar.gz
diff --git a/x86_64_external_files.patch b/x86_64_external_files.patch
index 1f77ebc..7bf4fb3 100644
--- a/x86_64_external_files.patch
+++ b/x86_64_external_files.patch
@@ -1,3 +1,73 @@
+diff --git a/external/local_config_cc/BUILD b/external/local_config_cc/BUILD
+index 6c19fae..a0ea3f8 100755
+--- a/external/local_config_cc/BUILD
++++ b/external/local_config_cc/BUILD
+@@ -47,15 +47,15 @@ filegroup(
+ cc_toolchain_suite(
+ name = "toolchain",
+ toolchains = {
+- "aarch64|compiler": ":cc-compiler-aarch64",
+- "aarch64": ":cc-compiler-aarch64",
++ "k8|compiler": ":cc-compiler-k8",
++ "k8": ":cc-compiler-k8",
+ "armeabi-v7a|compiler": ":cc-compiler-armeabi-v7a",
+ "armeabi-v7a": ":cc-compiler-armeabi-v7a",
+ },
+ )
+
+ cc_toolchain(
+- name = "cc-compiler-aarch64",
++ name = "cc-compiler-k8",
+ toolchain_identifier = "local",
+ toolchain_config = ":local",
+ all_files = ":compiler_deps",
+@@ -72,7 +72,7 @@ cc_toolchain(
+
+ cc_toolchain_config(
+ name = "local",
+- cpu = "aarch64",
++ cpu = "k8",
+ compiler = "compiler",
+ toolchain_identifier = "local",
+ host_system_name = "local",
+@@ -80,7 +80,7 @@ cc_toolchain_config(
+ target_libc = "local",
+ abi_version = "local",
+ abi_libc_version = "local",
+- cxx_builtin_include_directories = ["/usr/lib/gcc/aarch64-linux-gnu/10.3.1/include",
++ cxx_builtin_include_directories = ["/usr/lib/gcc/x86_64-linux-gnu/10.3.1/include",
+ "/usr/local/include",
+ "/usr/include",
+ "/root/rpmbuild/BUILD/ANNC-v0.0.2/annc/service/cpu/xla",
+@@ -91,7 +91,7 @@ cc_toolchain_config(
+ "/root/rpmbuild/BUILD/XNNPACK/build/pthreadpool-source/include",
+ "/root/rpmbuild/BUILD/ANNC-v0.0.2/output/e7b069029cc648c50e1b8083cef52b4f/external/local_config_cc",
+ "/usr/include/c++/10.3.1",
+- "/usr/include/c++/10.3.1/aarch64-linux-gnu",
++ "/usr/include/c++/10.3.1/x86_64-linux-gnu",
+ "/usr/include/c++/10.3.1/backward"],
+ tool_paths = {"ar": "/usr/bin/ar",
+ "ld": "/usr/bin/ld",
+diff --git a/external/local_config_cc/builtin_include_directory_paths b/external/local_config_cc/builtin_include_directory_paths
+index 188e6c9..2c3bcc1 100755
+--- a/external/local_config_cc/builtin_include_directory_paths
++++ b/external/local_config_cc/builtin_include_directory_paths
+@@ -4,7 +4,7 @@ changes to it will be reflected in the action cache key. When some of these
+ paths change, Bazel will make sure to rerun the action, even though none of
+ declared action inputs or the action commandline changes.
+
+-/usr/lib/gcc/aarch64-linux-gnu/10.3.1/include
++/usr/lib/gcc/x86_64-linux-gnu/10.3.1/include
+ /usr/local/include
+ /usr/include
+ /root/rpmbuild/BUILD/ANNC-v0.0.2/annc/service/cpu/xla
+@@ -15,5 +15,5 @@ declared action inputs or the action commandline changes.
+ /root/rpmbuild/BUILD/XNNPACK/build/pthreadpool-source/include
+ /root/rpmbuild/BUILD/ANNC-v0.0.2/output/e7b069029cc648c50e1b8083cef52b4f/external/local_config_cc
+ /usr/include/c++/10.3.1
+-/usr/include/c++/10.3.1/aarch64-linux-gnu
++/usr/include/c++/10.3.1/x86_64-linux-gnu
+ /usr/include/c++/10.3.1/backward
diff --git a/external/go_sdk/BUILD.bazel b/x86_64/external/go_sdk/BUILD.bazel
index 9cf6add..511ddbc 100644
--- a/external/go_sdk/BUILD.bazel
@@ -44,68 +114,6 @@ index 81bd76b..560f94f 100644
const version = `go1.18.4`
const defaultGOOS = runtime.GOOS
const defaultGOARCH = runtime.GOARCH
-diff --git a/external/local_config_cc/BUILD b/x86_64/external/local_config_cc/BUILD
-index 51949c4..89f56ce 100755
---- a/external/local_config_cc/BUILD
-+++ b/x86_64/external/local_config_cc/BUILD
-@@ -47,15 +47,15 @@ filegroup(
- cc_toolchain_suite(
- name = "toolchain",
- toolchains = {
-- "aarch64|compiler": ":cc-compiler-aarch64",
-- "aarch64": ":cc-compiler-aarch64",
-+ "k8|compiler": ":cc-compiler-k8",
-+ "k8": ":cc-compiler-k8",
- "armeabi-v7a|compiler": ":cc-compiler-armeabi-v7a",
- "armeabi-v7a": ":cc-compiler-armeabi-v7a",
- },
- )
-
- cc_toolchain(
-- name = "cc-compiler-aarch64",
-+ name = "cc-compiler-k8",
- toolchain_identifier = "local",
- toolchain_config = ":local",
- all_files = ":compiler_deps",
-@@ -72,7 +72,7 @@ cc_toolchain(
-
- cc_toolchain_config(
- name = "local",
-- cpu = "aarch64",
-+ cpu = "k8",
- compiler = "compiler",
- toolchain_identifier = "local",
- host_system_name = "local",
-@@ -80,11 +80,11 @@ cc_toolchain_config(
- target_libc = "local",
- abi_version = "local",
- abi_libc_version = "local",
-- cxx_builtin_include_directories = ["/usr/lib/gcc/aarch64-openEuler-linux/12/include",
-+ cxx_builtin_include_directories = ["/usr/lib/gcc/x86_64-openEuler-linux/12/include",
- "/usr/local/include",
- "/usr/include",
- "/usr/include/c++/12",
-- "/usr/include/c++/12/aarch64-openEuler-linux",
-+ "/usr/include/c++/12/x86_64-openEuler-linux",
- "/usr/include/c++/12/backward"],
- tool_paths = {"ar": "/usr/bin/ar",
- "ld": "/usr/bin/ld",
-diff --git a/external/local_config_cc/builtin_include_directory_paths b/x86_64/external/local_config_cc/builtin_include_directory_paths
-index 711ac34..50fea54 100755
---- a/external/local_config_cc/builtin_include_directory_paths
-+++ b/x86_64/external/local_config_cc/builtin_include_directory_paths
-@@ -4,9 +4,9 @@ changes to it will be reflected in the action cache key. When some of these
- paths change, Bazel will make sure to rerun the action, even though none of
- declared action inputs or the action commandline changes.
-
--/usr/lib/gcc/aarch64-openEuler-linux/12/include
-+/usr/lib/gcc/x86_64-openEuler-linux/12/include
- /usr/local/include
- /usr/include
- /usr/include/c++/12
--/usr/include/c++/12/aarch64-openEuler-linux
-+/usr/include/c++/12/x86_64-openEuler-linux
- /usr/include/c++/12/backward
diff --git a/external/local_config_cc_toolchains/BUILD b/x86_64/external/local_config_cc_toolchains/BUILD
index db5234f..f9c0875 100755
--- a/external/local_config_cc_toolchains/BUILD