From 90ac46231be919e1c07b7d41bc0a8c4b1f1ba41a Mon Sep 17 00:00:00 2001 From: zhengchenhui Date: Mon, 10 Nov 2025 16:49:10 +0800 Subject: [PATCH] Add graph optimizer and embedding_fused kernels. --- annc/tensorflow/graph_optimizer/BUILD | 32 + annc/tensorflow/graph_optimizer/graph_opt.cc | 738 ++++++++++++++++++ annc/tensorflow/graph_optimizer/graph_opt.h | 165 ++++ .../embedding_fused_action_id_gather.cc | 144 ++++ .../embedding_fused_action_id_gather_test.cc | 289 +++++++ .../kernels/embedding_fused_gather.cc | 90 +++ .../kernels/embedding_fused_gather_test.cc | 186 +++++ .../kernels/embedding_fused_padding.cc | 126 +++ .../kernels/embedding_fused_padding_test.cc | 307 ++++++++ .../embedding_fused_sparse_dynamic_stitch.cc | 87 +++ ...edding_fused_sparse_dynamic_stitch_test.cc | 108 +++ .../kernels/embedding_fused_sparse_reshape.cc | 195 +++++ .../embedding_fused_sparse_reshape_test.cc | 281 +++++++ .../embedding_fused_sparse_segment_reduce.cc | 165 ++++ ...ing_fused_sparse_segment_reduce_nonzero.cc | 159 ++++ ...used_sparse_segment_reduce_nonzero_test.cc | 183 +++++ ...edding_fused_sparse_segment_reduce_test.cc | 205 +++++ .../kernels/embedding_fused_sparse_select.cc | 111 +++ .../embedding_fused_sparse_select_test.cc | 182 +++++ annc/tensorflow/ops/embedding_fused_ops.cc | 133 ++++ annc/tensorflow/tf_annc_optimizer.patch | 221 ++++++ 21 files changed, 4107 insertions(+) create mode 100644 annc/tensorflow/graph_optimizer/BUILD create mode 100644 annc/tensorflow/graph_optimizer/graph_opt.cc create mode 100644 annc/tensorflow/graph_optimizer/graph_opt.h create mode 100644 annc/tensorflow/kernels/embedding_fused_action_id_gather.cc create mode 100644 annc/tensorflow/kernels/embedding_fused_action_id_gather_test.cc create mode 100644 annc/tensorflow/kernels/embedding_fused_gather.cc create mode 100644 annc/tensorflow/kernels/embedding_fused_gather_test.cc create mode 100644 annc/tensorflow/kernels/embedding_fused_padding.cc create mode 100644 annc/tensorflow/kernels/embedding_fused_padding_test.cc create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_dynamic_stitch.cc create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_dynamic_stitch_test.cc create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_reshape.cc create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_reshape_test.cc create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce.cc create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_nonzero.cc create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_nonzero_test.cc create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_test.cc create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_select.cc create mode 100644 annc/tensorflow/kernels/embedding_fused_sparse_select_test.cc create mode 100644 annc/tensorflow/ops/embedding_fused_ops.cc create mode 100644 annc/tensorflow/tf_annc_optimizer.patch diff --git a/annc/tensorflow/graph_optimizer/BUILD b/annc/tensorflow/graph_optimizer/BUILD new file mode 100644 index 0000000..39a37f8 --- /dev/null +++ b/annc/tensorflow/graph_optimizer/BUILD @@ -0,0 +1,32 @@ +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +cc_library( + name = "annc_graph_opt", + srcs = glob(["*.cc"]), + hdrs = glob(["*.h"]), + linkstatic = True, + alwayslink = True, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core/grappler:graph_view", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + ], +) + +cc_binary( + name = "libannc_graph_opt.so", + srcs = glob(["*.cc", "*.h"]), + linkshared = True, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core/grappler:graph_view", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + ], +) \ No newline at end of file diff --git a/annc/tensorflow/graph_optimizer/graph_opt.cc b/annc/tensorflow/graph_optimizer/graph_opt.cc new file mode 100644 index 0000000..9c74489 --- /dev/null +++ b/annc/tensorflow/graph_optimizer/graph_opt.cc @@ -0,0 +1,738 @@ +#include "graph_opt.h" + +using namespace tensorflow; +using namespace tensorflow::grappler; + +namespace annc { +void update_node_indexes(const GraphDef* graph, + std::unordered_map& node_indexes) { + for (int i = 0; i < graph->node_size(); ++i) { + node_indexes[graph->node(i).name()] = i; + } +} + +void GraphOptimizer::register_rewriter( + std::unique_ptr rewriter) { + rewriters_.push_back(std::move(rewriter)); +} + +void GraphOptimizer::optimize() { + update_node_indexes(graph_, node_indexes_); + int node_index = 0; + const int node_size = graph_->node_size(); + while (node_index < node_size) { + const NodeDef& node = graph_->node(node_index); + const std::string& node_name = node.name(); + for (auto& rewriter : rewriters_) { + if (rewriter->match_and_rewrite(&node, graph_, node_indexes_)) { + update_node_indexes(graph_, node_indexes_); + const std::string new_node_name = node_name + fusion_appendix; + node_index = node_indexes_.at(new_node_name); + break; + } + } + node_index++; + } +} + +std::string get_node_name(const std::string& name) { + size_t colon_pos = name.find_last_of(':'); + std::string node_name = name; + if (colon_pos != std::string::npos) { + node_name = name.substr(0, colon_pos); + } + return node_name; +} + +void set_fusedop_attributes(NodeDef* fused, + const absl::Span fused_ops, + int num_args = 1, float epsilon = 0.0) { + auto* attr = fused->mutable_attr(); + SetAttrValue(fused_ops, &(*attr)["fused_ops"]); + SetAttrValue(num_args, &(*attr)["num_args"]); + SetAttrValue(epsilon, &(*attr)["epsilon"]); // required only for BatchNorm +} + +const NodeDef* PatternRewriter::get_node(const std::string& name) { + const std::string node_name = get_node_name(name); + const int node_index = indexes_->at(node_name); + return &graph_->node(node_index); +} + +NodeDef* PatternRewriter::get_mutable_node(const std::string& name) { + const std::string node_name = get_node_name(name); + if (indexes_->find(node_name) == indexes_->end()) return nullptr; + const int node_index = indexes_->at(node_name); + return graph_->mutable_node(node_index); +} + +NodeDef* PatternRewriter::get_operand(const NodeDef* node, + std::string op_type) { + for (int i = 0; i < node->input_size(); ++i) { + NodeDef* operand = get_mutable_node(node->input(i)); + if (operand != nullptr && operand->op() == op_type) return operand; + } + return nullptr; +} + +const NodeDef* PatternRewriter::get_user(const NodeDef* node, int index, + const std::string& op_type) { + std::string node_name = node->name(); + if (index) std::string node_name = node_name + ":" + std::to_string(index); + for (int i = 0; i < graph_->node_size(); ++i) { + const NodeDef* node = graph_->mutable_node(i); + for (int j = 0; j < node->input_size(); ++j) { + if (node->input(j) == node_name && node->op() == op_type) { + return node; + } + } + } + return nullptr; +} + +void PatternRewriter::replace_all_users_with(const NodeDef* old_node, + int old_index, + const NodeDef* new_node, + int new_index, GraphDef* graph) { + std::string old_name = old_node->name(); + if (old_index) old_name = old_name + ":" + std::to_string(old_index); + std::string new_name = new_node->name(); + if (new_index) new_name = new_name + ":" + std::to_string(new_index); + for (int i = 0; i < graph->node_size(); ++i) { + NodeDef* node = graph->mutable_node(i); + for (int j = 0; j < node->input_size(); ++j) { + if (node->input(j) == old_name) { + node->set_input(j, new_name); + } + } + } +} + +class KPFusedSparseDynamicStitchRewriter : public PatternRewriter { + public: + std::string name() const override { return "KPFusedSparseDynamicStitch"; } + + bool match_and_rewrite( + const NodeDef* node, GraphDef* graph, + std::unordered_map& node_indexes) override { + graph_ = graph; + indexes_ = &node_indexes; + CHECK_NODE_OK(node->op() == "ParallelDynamicStitch" && + node->input_size() % 2 == 0) + int num_inputs = node->input_size(); + int num_partitions = num_inputs / 2; + // left branch + const NodeDef* partition = get_node(node->input(0)); + CHECK_NODE_OK(partition->op() == "DynamicPartition" && + partition->input_size() == 2) + const NodeDef* range = get_node(partition->input(0)); + CHECK_NODE_OK(range->op() == "Range" && range->input_size() == 3) + // Range start=0, delta=1 + CHECK_NODE_OK( + check_const_value(get_mutable_node(range->input(0)), {0})) + CHECK_NODE_OK( + check_const_value(get_mutable_node(range->input(2)), {1})) + const NodeDef* size = get_node(range->input(1)); + CHECK_NODE_OK(IsSize(*size) && size->input_size() == 1) + const NodeDef* cast = get_node(partition->input(1)); + CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1) + const NodeDef* floor_mod = get_node(cast->input(0)); + CHECK_NODE_OK(floor_mod->op() == "FloorMod" && floor_mod->input_size() == 2) + CHECK_NODE_OK(check_const_value( + get_mutable_node(floor_mod->input(1)), {num_partitions})) + + CHECK_NODE_OK(check_int_attr(node, "N", {num_partitions})) + CHECK_NODE_OK(check_int_attr(partition, "num_partitions", {num_partitions})) + + auto nodes = graph->mutable_node(); + NodeDef* fused_node = nodes->Add(); + fused_node->set_name(node->name() + fusion_appendix); + fused_node->set_op(name()); + fused_node->set_device(node->device()); + fused_node->add_input(size->input(0)); + // right branch + for (int i = num_partitions; i < num_inputs; ++i) { + const NodeDef* gather = get_node(node->input(i)); + CHECK_NODE_OK(gather->op() == "GatherV2" && gather->input_size() == 3) + // Gather axis=0 + CHECK_NODE_OK( + check_const_value(get_mutable_node(gather->input(2)), {0})) + const NodeDef* partition_1 = get_node(gather->input(1)); + CHECK_NODE_OK(partition_1->op() == "DynamicPartition" && + partition_1->input_size() == 2) + CHECK_NODE_OK( + check_int_attr(partition_1, "num_partitions", {num_partitions})) + const NodeDef* floor_div = get_node(partition_1->input(0)); + CHECK_NODE_OK(floor_div->op() == "FloorDiv" && + floor_div->input_size() == 2) + CHECK_NODE_OK(check_const_value( + get_mutable_node(floor_div->input(1)), {num_partitions})) + fused_node->add_input(gather->input(0)); + } + (*fused_node->mutable_attr())["N"].set_i(num_partitions); + nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1); + + VLOG(0) << "-- Add node: [" << fused_node->op() << "] " + << fused_node->name(); + replace_all_users_with(node, 0, fused_node, 0, graph); + return true; + } +}; + +class KPFusedSparseSegmentReduceRewriter : public PatternRewriter { + public: + std::string name() const override { return "KPFusedSparseSegmentReduce"; } + + bool match_and_rewrite( + const NodeDef* node, GraphDef* graph, + std::unordered_map& node_indexes) override { + graph_ = graph; + indexes_ = &node_indexes; + CHECK_NODE_OK(IsStridedSlice(*node) && node->input_size() == 4) + const NodeDef* shape = get_node(node->input(0)); + CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1) + const NodeDef* ss_reduce = get_node(shape->input(0)); + CHECK_NODE_OK(ss_reduce->input_size() == 3) + AttrValue combiner; + if (ss_reduce->op() == "SparseSegmentMean") + combiner.set_i(1); + else if (ss_reduce->op() == "SparseSegmentSum") + combiner.set_i(0); + else + return false; + const NodeDef* cast = get_node(ss_reduce->input(2)); + CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1) + const NodeDef* strided_slice = get_node(cast->input(0)); + CHECK_NODE_OK(IsStridedSlice(*strided_slice) && + strided_slice->input_size() == 4) + + // check fusion conditions + CHECK_NODE_OK( + check_const_shape(get_mutable_node(strided_slice->input(1)), {2})) + CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2)) + CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1)) + CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1)) + CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(1)), {1})) + CHECK_NODE_OK(check_int_attr(node, "shrink_axis_mask", 1)) + + auto nodes = graph->mutable_node(); + NodeDef* fused_node = nodes->Add(); + fused_node->set_name(node->name() + fusion_appendix); + fused_node->set_op(name()); + fused_node->set_device(node->device()); + fused_node->add_input(ss_reduce->input(0)); + fused_node->add_input(ss_reduce->input(1)); + fused_node->add_input(strided_slice->input(0)); + fused_node->add_input(strided_slice->input(1)); + fused_node->add_input(node->input(1)); + AddNodeAttr("combiner", combiner, fused_node); + + nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1); + + VLOG(0) << "-- Add node: [" << fused_node->op() << "] " + << fused_node->name(); + replace_all_users_with(ss_reduce, 0, fused_node, 0, graph); + replace_all_users_with(node, 0, fused_node, 1, graph); + return true; + } +}; + +class KPFusedSparseSegmentReduceNonzeroRewriter : public PatternRewriter { + public: + std::string name() const override { + return "KPFusedSparseSegmentReduceNonzero"; + } + + bool match_and_rewrite( + const NodeDef* node, GraphDef* graph, + std::unordered_map& node_indexes) override { + graph_ = graph; + indexes_ = &node_indexes; + CHECK_NODE_OK(node->op() == "GatherND" && + node->input_size() == 2) // output:2 + const NodeDef* ss_reduce = get_node(node->input(0)); + CHECK_NODE_OK(ss_reduce->input_size() == 3) + AttrValue combiner; + if (ss_reduce->op() == "SparseSegmentMean") { + combiner.set_i(1); + } else if (ss_reduce->op() == "SparseSegmentSum") { + combiner.set_i(0); + } else { + return false; + } + const NodeDef* where = get_node(node->input(1)); + CHECK_NODE_OK(where->op() == "Where" && where->input_size() == 1) + const NodeDef* cast = get_user(where, 0, "Cast"); + CHECK_NODE_OK(cast != nullptr) // output: 1 + const NodeDef* notequal = get_node(where->input(0)); + CHECK_NODE_OK(IsNotEqual(*notequal) && notequal->input_size() == 2); + const NodeDef* zerolike = get_node(notequal->input(1)); + CHECK_NODE_OK(IsZerosLike(*zerolike) && zerolike->input_size() == 1) + const NodeDef* cast_1 = get_node(ss_reduce->input(2)); + CHECK_NODE_OK(IsCast(*cast_1) && cast_1->input_size() == 1) + const NodeDef* strided_slice = get_node(cast->input(0)); + CHECK_NODE_OK(IsStridedSlice(*strided_slice) && + strided_slice->input_size() == 4) + const NodeDef* shape = get_user(ss_reduce, 0, "Shape"); + CHECK_NODE_OK(shape != nullptr) + const NodeDef* cast_2 = get_user(shape, 0, "Cast"); // output: 0 + CHECK_NODE_OK(cast_2 != nullptr) + + CHECK_NODE_OK( + check_const_shape(get_mutable_node(strided_slice->input(1)), {2})) + CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2)) + CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1)) + CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1)) + + auto nodes = graph->mutable_node(); + NodeDef* fused_node = nodes->Add(); + fused_node->set_name(node->name() + fusion_appendix); + fused_node->set_op(name()); + fused_node->set_device(node->device()); + fused_node->add_input(ss_reduce->input(0)); + fused_node->add_input(ss_reduce->input(1)); + fused_node->add_input(strided_slice->input(0)); + fused_node->add_input(strided_slice->input(1)); + AddNodeAttr("combiner", combiner, fused_node); + + nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1); + + VLOG(0) << "-- Add node: [" << fused_node->op() << "] " + << fused_node->name() << "\n"; + replace_all_users_with(cast_2, 0, fused_node, 0, graph); + replace_all_users_with(cast, 0, fused_node, 1, graph); + replace_all_users_with(node, 0, fused_node, 2, graph); + return true; + } +}; + +class KPFusedEmbeddingPaddingRewriter : public PatternRewriter { + public: + std::string name() const override { return "KPFusedEmbeddingPadding"; } + + bool match_and_rewrite( + const NodeDef* node, GraphDef* graph, + std::unordered_map& node_indexes) override { + graph_ = graph; + indexes_ = &node_indexes; + CHECK_NODE_OK(IsReshape(*node)) + const NodeDef* user = get_user(node, 0, "ConcatV2"); + CHECK_NODE_OK(user != nullptr) + const NodeDef* concat = get_node(node->input(0)); + CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3) + CHECK_NODE_OK( + check_const_value(get_mutable_node(concat->input(2)), {0})) + const NodeDef* fill = get_node(concat->input(1)); + CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2) + NodeDef* pack = get_operand(fill, "Pack"); + CHECK_NODE_OK(pack != nullptr && IsPack(*pack) && pack->input_size() == 2) + NodeDef* fill_const = get_operand(fill, "Const"); + CHECK_NODE_OK(fill_const != nullptr && + check_const_value(fill_const, {0})) + NodeDef* sub = get_operand(pack, "Sub"); + CHECK_NODE_OK(sub != nullptr && IsSub(*sub) && sub->input_size() == 2) + const NodeDef* strided_slice = get_node(sub->input(0)); + CHECK_NODE_OK(IsStridedSlice(*strided_slice) && + strided_slice->input_size() == 4) + const NodeDef* cast = get_node(strided_slice->input(0)); + CHECK_NODE_OK(IsCast(*cast)) + + auto nodes = graph->mutable_node(); + NodeDef* fused_node = nodes->Add(); + fused_node->set_name(node->name() + fusion_appendix); + fused_node->set_op(name()); + fused_node->set_device(node->device()); + fused_node->add_input(cast->input(0)); + fused_node->add_input(concat->input(0)); + fused_node->add_input(sub->input(1)); + fused_node->add_input(node->input(1)); + const NodeDef* pack_left = get_node(pack->input(0)); + const NodeDef* pack_right = get_node(pack->input(1)); + if (IsConstant(*pack_left) || IsHostConstant(*pack_left)) { + fused_node->add_input(pack->input(0)); + } else if (IsConstant(*pack_right) || IsHostConstant(*pack_right)) { + fused_node->add_input(pack->input(1)); + } else { + return false; + } + + nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1); + + VLOG(0) << "-- Add node: [" << fused_node->op() << "] " + << fused_node->name(); + replace_all_users_with(sub, 0, fused_node, 0, graph); + replace_all_users_with(node, 0, fused_node, 1, graph); + return true; + } +}; + +class KPFusedEmbeddingPaddingFastRewriter : public PatternRewriter { + public: + std::string name() const override { return "KPFusedEmbeddingPaddingFast"; } + + bool match_and_rewrite( + const NodeDef* node, GraphDef* graph, + std::unordered_map& node_indexes) override { + graph_ = graph; + indexes_ = &node_indexes; + CHECK_NODE_OK(IsStridedSlice(*node) && node->input_size() == 4) + CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(1)), {0})) + CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(2)), {1})) + CHECK_NODE_OK(check_const_shape(get_mutable_node(node->input(3)), {1})) + CHECK_NODE_OK(check_int_attr(node, "shrink_axis_mask", 1)) + const NodeDef* shape = get_node(node->input(0)); + CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1) + const NodeDef* reshape = get_node(shape->input(0)); + CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2) + const NodeDef* concat = get_node(reshape->input(0)); + CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3) + const NodeDef* fill = get_node(concat->input(1)); + CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2) + const NodeDef* pack = get_node(fill->input(0)); + CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2) + const NodeDef* sub = get_node(pack->input(0)); + CHECK_NODE_OK(IsSub(*sub) && sub->input_size() == 2) + const NodeDef* strided_slice = get_node(sub->input(0)); + CHECK_NODE_OK(IsStridedSlice(*strided_slice) && + strided_slice->input_size() == 4) + const NodeDef* cast = get_node(strided_slice->input(0)); + CHECK_NODE_OK(IsCast(*cast)) + + auto nodes = graph->mutable_node(); + NodeDef* fused_node = nodes->Add(); + fused_node->set_name(node->name() + fusion_appendix); + fused_node->set_op(name()); + fused_node->set_device(node->device()); + fused_node->add_input(cast->input(0)); + fused_node->add_input(concat->input(0)); + fused_node->add_input(sub->input(1)); + fused_node->add_input(reshape->input(1)); + const NodeDef* pack_left = get_node(pack->input(0)); + const NodeDef* pack_right = get_node(pack->input(1)); + if (IsConstant(*pack_left) || IsHostConstant(*pack_left)) { + fused_node->add_input(pack->input(0)); + } else if (IsConstant(*pack_right) || IsHostConstant(*pack_right)) { + fused_node->add_input(pack->input(1)); + } else { + return false; + } + nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1); + + VLOG(0) << "-- Add node: [" << fused_node->op() << "] " + << fused_node->name(); + replace_all_users_with(sub, 0, fused_node, 0, graph); + replace_all_users_with(node, 0, fused_node, 1, graph); + return true; + } +}; + +class KPFusedSparseSelectRewriter : public PatternRewriter { + public: + std::string name() const override { return "KPFusedSparseSelect"; } + + bool match_and_rewrite( + const NodeDef* node, GraphDef* graph, + std::unordered_map& node_indexes) override { + graph_ = graph; + indexes_ = &node_indexes; + CHECK_NODE_OK(IsConcat(*node) && node->input_size() == 3) + const NodeDef* select_0 = get_node(node->input(0)); + CHECK_NODE_OK(IsSelect(*select_0) && select_0->input_size() == 3) + const NodeDef* select_1 = get_node(select_0->input(2)); + CHECK_NODE_OK(IsSelect(*select_1) && select_1->input_size() == 3) + const NodeDef* fill = get_node(select_1->input(1)); + CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2) + CHECK_NODE_OK( + check_const_value(get_mutable_node(fill->input(1)), {1.0f})) + const NodeDef* cast = get_node(select_1->input(2)); + CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1) + const NodeDef* equal = get_node(select_1->input(0)); + CHECK_NODE_OK(IsEqual(*equal) && equal->input_size() == 2) + NodeDef* reshape = get_operand(equal, "Reshape"); + CHECK_NODE_OK(reshape != nullptr && IsReshape(*reshape)) + const NodeDef* greater = get_node(cast->input(0)); + CHECK_NODE_OK(IsGreater(*greater) && greater->input_size() == 2) + const NodeDef* reshape_4 = get_node(greater->input(0)); + CHECK_NODE_OK(IsReshape(*reshape_4) && reshape_4->input_size() == 2) + CHECK_NODE_OK( + check_const_value(get_mutable_node(reshape_4->input(1)), {-1, 1})) + const NodeDef* equal_1 = get_node(select_0->input(0)); + CHECK_NODE_OK(IsEqual(*equal_1) && equal_1->input_size() == 2) + NodeDef* reshape_1 = get_operand(equal_1, "Reshape"); + CHECK_NODE_OK(reshape_1 != nullptr && IsReshape(*reshape_1)) + CHECK_NODE_OK( + check_const_value(get_mutable_node(reshape_1->input(1)), {-1, 1})) + + // right branch + const NodeDef* select_2 = get_node(node->input(1)); + CHECK_NODE_OK(IsSelect(*select_2) && select_2->input_size() == 3) + const NodeDef* equal_2 = get_node(select_2->input(0)); + CHECK_NODE_OK(IsEqual(*equal_2) && equal_2->input_size() == 2) + const NodeDef* fill_1 = get_node(select_2->input(2)); + CHECK_NODE_OK(IsFill(*fill_1) && fill_1->input_size() == 2) + CHECK_NODE_OK( + check_const_value(get_mutable_node(fill_1->input(1)), {1.0f})) + const NodeDef* reshape_2 = get_operand(equal_2, "Reshape"); + CHECK_NODE_OK(reshape_2 != nullptr && IsReshape(*reshape_2)) + CHECK_NODE_OK( + check_const_value(get_mutable_node(reshape_2->input(1)), {-1, 1})) + + auto nodes = graph->mutable_node(); + NodeDef* fused_node = nodes->Add(); + fused_node->set_name(node->name() + fusion_appendix); + fused_node->set_op(name()); + fused_node->set_device(node->device()); + fused_node->add_input(reshape->input(0)); + fused_node->add_input(reshape_1->input(0)); + fused_node->add_input(reshape_2->input(0)); + std::vector const_inputs = {equal, equal_1, equal_2, + greater}; + for (const NodeDef* const_node : const_inputs) { + const NodeDef* left = get_node(const_node->input(0)); + const NodeDef* right = get_node(const_node->input(1)); + if (IsConstant(*left) || IsHostConstant(*left)) { + fused_node->add_input(const_node->input(0)); + } else if (IsConstant(*right) || IsHostConstant(*right)) { + fused_node->add_input(const_node->input(1)); + } else { + return false; + } + } + nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1); + + VLOG(0) << "-- Add node: [" << fused_node->op() << "] " + << fused_node->name(); + replace_all_users_with(reshape, 0, fused_node, 0, graph); + replace_all_users_with(select_0, 0, fused_node, 1, graph); + replace_all_users_with(node, 0, fused_node, 2, graph); + return true; + } +}; + +class KPFusedGatherRewriter : public PatternRewriter { + public: + std::string name() const override { return "KPFusedGather"; } + + bool match_and_rewrite( + const NodeDef* node, GraphDef* graph, + std::unordered_map& node_indexes) override { + graph_ = graph; + indexes_ = &node_indexes; + CHECK_NODE_OK(node->op() == "GatherV2" && node->input_size() == 3) + CHECK_NODE_OK(check_const_value(get_mutable_node(node->input(2)), {0})) + const NodeDef* gather = get_node(node->input(0)); + CHECK_NODE_OK(gather->op() == "GatherV2" && + gather->input_size() == 3) // input:0 + CHECK_NODE_OK( + check_const_value(get_mutable_node(gather->input(2)), {0})) + const NodeDef* unique = get_node(node->input(1)); // output:1 + CHECK_NODE_OK(unique->op() == "Unique" && unique->input_size() == 1) + const NodeDef* unique_1 = get_node(unique->input(0)); + CHECK_NODE_OK(unique_1->op() == "Unique" && unique_1->input_size() == 1) + const NodeDef* strided_slice = get_node(unique_1->input(0)); + CHECK_NODE_OK(IsStridedSlice(*strided_slice)) // input:1 2 + CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2)) + CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1)) + CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1)) + + auto nodes = graph->mutable_node(); + NodeDef* fused_node = nodes->Add(); + fused_node->set_name(node->name() + fusion_appendix); + fused_node->set_op(name()); + fused_node->set_device(node->device()); + fused_node->add_input(gather->input(0)); + fused_node->add_input(strided_slice->input(0)); + fused_node->add_input(strided_slice->input(1)); + nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1); + + VLOG(0) << "-- Add node: [" << fused_node->op() << "] " + << fused_node->name(); + replace_all_users_with(unique_1, 0, fused_node, 0, graph); + replace_all_users_with(unique_1, 1, fused_node, 1, graph); + replace_all_users_with(node, 0, fused_node, 2, graph); + return true; + } +}; + +class KPFusedSparseReshapeRewriter : public PatternRewriter { + public: + std::string name() const override { return "KPFusedSparseReshape"; } + + bool match_and_rewrite( + const NodeDef* node, GraphDef* graph, + std::unordered_map& node_indexes) override { + graph_ = graph; + indexes_ = &node_indexes; + CHECK_NODE_OK(node->op() == "SparseReshape" && node->input_size() == 3) + const NodeDef* concat = get_node(node->input(0)); + CHECK_NODE_OK(IsConcat(*concat) && concat->input_size() == 3) + CHECK_NODE_OK( + check_const_value(get_mutable_node(concat->input(2)), {-1})) + const NodeDef* reshape = get_node(concat->input(1)); + CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2) + CHECK_NODE_OK( + check_const_value(get_mutable_node(reshape->input(1)), {-1, 1})) + const NodeDef* strided_slice = get_node(reshape->input(0)); + CHECK_NODE_OK(IsStridedSlice(*strided_slice) && + strided_slice->input_size() == 4) + CHECK_NODE_OK(check_int_attr(strided_slice, "shrink_axis_mask", 2)) + CHECK_NODE_OK(check_int_attr(strided_slice, "begin_mask", 1)) + CHECK_NODE_OK(check_int_attr(strided_slice, "end_mask", 1)) + const NodeDef* cast_1 = get_node(concat->input(0)); + CHECK_NODE_OK(IsCast(*cast_1) && cast_1->input_size() == 1) + const NodeDef* reshape_1 = get_node(cast_1->input(0)); + CHECK_NODE_OK(IsReshape(*reshape_1) && reshape_1->input_size() == 2) + CHECK_NODE_OK( + check_const_value(get_mutable_node(reshape_1->input(1)), {-1, 1})) + const NodeDef* range = get_node(reshape_1->input(0)); + CHECK_NODE_OK(range->op() == "Range" && range->input_size() == 3) + // Range start=0, delta=1 + CHECK_NODE_OK( + check_const_value(get_mutable_node(range->input(0)), {0})) + CHECK_NODE_OK( + check_const_value(get_mutable_node(range->input(2)), {1})) + const NodeDef* cast = get_node(node->input(1)); + CHECK_NODE_OK(IsCast(*cast) && cast->input_size() == 1) + const NodeDef* pack = get_node(cast->input(0)); + CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2) + const NodeDef* strided_slice_1 = get_node(pack->input(0)); + CHECK_NODE_OK(IsStridedSlice(*strided_slice_1) && + strided_slice_1->input_size() == 4) + CHECK_NODE_OK(check_const_value( + get_mutable_node(strided_slice_1->input(1)), {0})) + CHECK_NODE_OK(check_const_value( + get_mutable_node(strided_slice_1->input(2)), {1})) + CHECK_NODE_OK(check_const_value( + get_mutable_node(strided_slice_1->input(3)), {1})) + CHECK_NODE_OK(check_int_attr(strided_slice_1, "shrink_axis_mask", 1)) + const NodeDef* shape = get_node(strided_slice_1->input(0)); + CHECK_NODE_OK(IsShape(*shape) && shape->input_size() == 1) + + auto nodes = graph->mutable_node(); + NodeDef* fused_node = nodes->Add(); + fused_node->set_name(node->name() + fusion_appendix); + fused_node->set_op(name()); + fused_node->set_device(node->device()); + fused_node->add_input(shape->input(0)); + fused_node->add_input(strided_slice->input(1)); + fused_node->add_input(node->input(2)); + fused_node->add_input(pack->input(1)); + nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1); + + VLOG(0) << "-- Add node: [" << fused_node->op() << "] " + << fused_node->name(); + replace_all_users_with(node, 0, fused_node, 0, graph); + replace_all_users_with(node, 1, fused_node, 1, graph); + return true; + } +}; + +class KPFusedEmbeddingActionIdGatherRewriter : public PatternRewriter { + public: + std::string name() const override { return "KPFusedEmbeddingActionIdGather"; } + + bool match_and_rewrite( + const NodeDef* node, GraphDef* graph, + std::unordered_map& node_indexes) override { + graph_ = graph; + indexes_ = &node_indexes; + CHECK_NODE_OK(IsConcat(*node) && node->input_size() == 3) + CHECK_NODE_OK( + check_const_value(get_mutable_node(node->input(2)), {-1})) + const NodeDef* reshape = get_node(node->input(0)); + CHECK_NODE_OK(IsReshape(*reshape) && reshape->input_size() == 2) + const NodeDef* gather = get_node(reshape->input(0)); + const NodeDef* pack_1 = get_node(reshape->input(1)); + CHECK_NODE_OK(IsPack(*pack_1) && pack_1->input_size() == 2) + CHECK_NODE_OK( + check_const_value(get_mutable_node(pack_1->input(1)), {-1})) + CHECK_NODE_OK(gather->op() == "GatherV2" && gather->input_size() == 3) + CHECK_NODE_OK( + check_const_value(get_mutable_node(gather->input(2)), {0})) + const NodeDef* gather_1 = get_node(gather->input(0)); + CHECK_NODE_OK(gather_1->op() == "GatherV2" && gather_1->input_size() == 3) + CHECK_NODE_OK( + check_const_value(get_mutable_node(gather_1->input(2)), {0})) + const NodeDef* fill = get_node(node->input(1)); + CHECK_NODE_OK(IsFill(*fill) && fill->input_size() == 2) + CHECK_NODE_OK(check_const_value(get_mutable_node(fill->input(1)), {0})) + const NodeDef* pack = get_node(fill->input(0)); + CHECK_NODE_OK(IsPack(*pack) && pack->input_size() == 2) + + auto nodes = graph->mutable_node(); + NodeDef* fused_node = nodes->Add(); + fused_node->set_name(node->name() + fusion_appendix); + fused_node->set_op(name()); + fused_node->set_device(node->device()); + fused_node->add_input(gather_1->input(1)); + fused_node->add_input(gather_1->input(0)); + fused_node->add_input(gather->input(1)); + fused_node->add_input(pack->input(0)); + fused_node->add_input(pack->input(1)); + nodes->SwapElements(node_indexes.at(node->name()), nodes->size() - 1); + + VLOG(0) << "-- Add node: [" << fused_node->op() << "] " + << fused_node->name(); + replace_all_users_with(node, 0, fused_node, 0, graph); + return true; + } +}; + +void run_graph_optimization(GraphDef* graph) { + GraphOptimizer optimizer(graph); + + const char* annc_fused_all = getenv("ANNC_FUSED_ALL"); + const char* annc_fused_sps_stitch = getenv("ANNC_FUSED_SPS_STITCH"); + const char* annc_fused_sps_reduce = getenv("ANNC_FUSED_SPS_REDUCE"); + const char* annc_fused_emb_padding = getenv("ANNC_FUSED_EMD_PADDING"); + const char* annc_fused_emb_padding_fast = + getenv("ANNC_FUSED_EMD_PADDING_FAST"); + const char* annc_fused_sps_select = getenv("ANNC_FUSED_SPS_SELECT"); + const char* annc_fused_gather = getenv("ANNC_FUSED_GATHER"); + const char* annc_fused_sps_reshape = getenv("ANNC_FUSED_SPS_RESHAPE"); + const char* annc_fused_emb_actionid_gather = + getenv("ANNC_FUSED_EMB_ACTIONID_GATHER"); + const char* annc_fused_sps_reduce_nonzero = + getenv("ANNC_FUSED_SPS_REDUCE_NONZERO"); + + bool enable_all = + (annc_fused_all != nullptr) && strcmp(annc_fused_all, "1") == 0; + + // default enable all rewriters + if (enable_all || (annc_fused_sps_stitch != nullptr && + strcmp(annc_fused_sps_stitch, "1") == 0)) + optimizer.register_rewriter( + std::make_unique()); + if (enable_all || (annc_fused_sps_reduce != nullptr && + strcmp(annc_fused_sps_reduce, "1") == 0)) + optimizer.register_rewriter( + std::make_unique()); + if (enable_all || (annc_fused_emb_padding_fast != nullptr && + strcmp(annc_fused_emb_padding_fast, "1") == 0)) + optimizer.register_rewriter( + std::make_unique()); + if (enable_all || (annc_fused_emb_padding != nullptr && + strcmp(annc_fused_emb_padding, "1") == 0)) + optimizer.register_rewriter( + std::make_unique()); + if (enable_all || (annc_fused_sps_select != nullptr && + strcmp(annc_fused_sps_select, "1") == 0)) + optimizer.register_rewriter( + std::make_unique()); + if (enable_all || + (annc_fused_gather != nullptr && strcmp(annc_fused_gather, "1") == 0)) + optimizer.register_rewriter(std::make_unique()); + if (enable_all || (annc_fused_sps_reshape != nullptr && + strcmp(annc_fused_sps_reshape, "1") == 0)) + optimizer.register_rewriter(std::make_unique()); + if (annc_fused_emb_actionid_gather != nullptr && + strcmp(annc_fused_emb_actionid_gather, "1") == 0) + optimizer.register_rewriter( + std::make_unique()); + if (annc_fused_sps_reduce_nonzero != nullptr && + strcmp(annc_fused_sps_reduce_nonzero, "1") == 0) + optimizer.register_rewriter( + std::make_unique()); + optimizer.optimize(); +} +} // namespace annc diff --git a/annc/tensorflow/graph_optimizer/graph_opt.h b/annc/tensorflow/graph_optimizer/graph_opt.h new file mode 100644 index 0000000..0ac0273 --- /dev/null +++ b/annc/tensorflow/graph_optimizer/graph_opt.h @@ -0,0 +1,165 @@ +#ifndef ANNC_TF_GRAPH_OPT_H_ +#define ANNC_TF_GRAPH_OPT_H_ +#include +#include + +#include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" + +using namespace tensorflow; +using namespace tensorflow::grappler; + +namespace annc { +#define CHECK_NODE_OK(x) \ + if (!(x)) { \ + return false; \ + } + +static const std::string fusion_appendix = "/kp_fused"; + +void update_node_indexes(const GraphDef* graph, + std::unordered_map& node_indexes); + +class PatternRewriter { + public: + PatternRewriter() {} + virtual ~PatternRewriter() = default; + + virtual bool match_and_rewrite( + const NodeDef* node, GraphDef* graph, + std::unordered_map& node_indexes) = 0; + + virtual std::string name() const { return "PatternRewriter"; }; + + const NodeDef* get_node(const std::string& name); + NodeDef* get_mutable_node(const std::string& name); + + NodeDef* get_operand(const NodeDef* node, std::string op_type); + + const NodeDef* get_user(const NodeDef* node, int index, + const std::string& op_type); + + void replace_all_users_with(const NodeDef* old_node, int old_index, + const NodeDef* new_node, int new_index, + GraphDef* graph); + + bool check_input_dims(NodeDef* op, const std::string& output_name, + int dim_size) { + if (op->attr().count("_output_shapes")) { + int pos_index = 0; + size_t pos = output_name.find_last_of(':'); + if (pos != std::string::npos) { + pos_index = std::stoi(output_name.substr(pos + 1)); + } + const TensorShapeProto& shape = + op->attr().at("_output_shapes").list().shape(pos_index); + if (shape.dim_size() == dim_size) return true; + } else if (op->attr().count("shape")) { + const TensorShapeProto& shape = op->attr().at("shape").shape(); + if (shape.dim_size() == dim_size) return true; + } + return false; + } + + bool check_const_dims(NodeDef* op, int dim_size) { + if (!((IsConstant(*op) || IsHostConstant(*op)) && + HasNodeAttr(*op, "value"))) + return false; + + TensorProto* tensor = (*op->mutable_attr())["value"].mutable_tensor(); + const auto& shape = tensor->tensor_shape(); + if (shape.dim_size() != static_cast(dim_size)) return false; + return true; + } + + bool check_const_shape(NodeDef* op, std::vector dims) { + if (!((IsConstant(*op) || IsHostConstant(*op)) && + HasNodeAttr(*op, "value"))) + return false; + + TensorProto* tensor = (*op->mutable_attr())["value"].mutable_tensor(); + const auto& shape = tensor->tensor_shape(); + if (shape.dim_size() != static_cast(dims.size())) return false; + for (int i = 0; i < shape.dim_size(); ++i) { + if (shape.dim(i).size() != dims[i]) return false; + } + return true; + } + + template + bool check_const_value(NodeDef* op, std::vector cmp) { + if (!((IsConstant(*op) || IsHostConstant(*op)) && + HasNodeAttr(*op, "value"))) + return false; + + TensorProto* tensor = (*op->mutable_attr())["value"].mutable_tensor(); + const auto& shape = tensor->tensor_shape(); + int dim_size = 1; + for (int i = 0; i < shape.dim_size(); ++i) { + dim_size *= shape.dim(i).size(); + } + if (dim_size < static_cast(cmp.size())) return false; + + if (std::is_same::value) { + const float* data = tensor->mutable_float_val()->data(); + if (data == nullptr) + data = reinterpret_cast(tensor->tensor_content().data()); + if (data == nullptr) return false; + for (int i = 0; i < static_cast(cmp.size()); ++i) { + if (std::fabs(data[i] - cmp[i]) >= 1e-5f) return false; + } + } else if (std::is_same::value) { + const int* data = tensor->mutable_int_val()->data(); + if (data == nullptr) + data = reinterpret_cast(tensor->tensor_content().data()); + if (data == nullptr) return false; + for (int i = 0; i < static_cast(cmp.size()); ++i) { + if (data[i] != cmp[i]) return false; + } + } else if (std::is_same::value) { + const int64_t* data = tensor->mutable_int64_val()->data(); + if (data == nullptr) + data = + reinterpret_cast(tensor->tensor_content().data()); + if (data == nullptr) return false; + for (int i = 0; i < static_cast(cmp.size()); ++i) { + if (data[i] != cmp[i]) return false; + } + } else { + // data type do not support + return false; + } + return true; + } + + bool check_int_attr(const NodeDef* op, std::string name, int value) { + if (HasNodeAttr(*op, name)) { + AttrValue attr = op->attr().at(name); + if (attr.value_case() == AttrValue::kI && attr.i() == value) return true; + } + return false; + } + + GraphDef* graph_; + std::unordered_map* indexes_; +}; + +class GraphOptimizer { + public: + GraphOptimizer(GraphDef* graph) : graph_(graph) {} + virtual ~GraphOptimizer() = default; + + void register_rewriter(std::unique_ptr rewriter); + + void optimize(); + + private: + GraphDef* graph_; + std::unordered_map node_indexes_; + std::vector> rewriters_; +}; + +void run_graph_optimization(GraphDef* graph); +} // namespace annc +#endif // ANNC_TF_GRAPH_OPT_H_ diff --git a/annc/tensorflow/kernels/embedding_fused_action_id_gather.cc b/annc/tensorflow/kernels/embedding_fused_action_id_gather.cc new file mode 100644 index 0000000..db7c92c --- /dev/null +++ b/annc/tensorflow/kernels/embedding_fused_action_id_gather.cc @@ -0,0 +1,144 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +template +static void GatherV2Impl(OpKernelContext* context, const float* params_data, const TensorShape& params_shape, + const Tindices* indices_data, const TensorShape& indices_shape, int axis, Tensor* temp) { + TensorShape temp_shape; + const int P0 = params_shape.dim_size(0); + int P1 = 1; + for (int d = 0; d < indices_shape.dims(); ++d) { + temp_shape.AddDim(indices_shape.dim_size(d)); + } + + for (int d = 1; d < params_shape.dims(); ++d) { + temp_shape.AddDim(params_shape.dim_size(d)); + P1 *= params_shape.dim_size(d); + } + OP_REQUIRES_OK(context, context->allocate_temp(DT_FLOAT, temp_shape, temp)); + VLOG(1) << "temp shape: " << temp->shape().DebugString(); + + const int num_indices = indices_shape.num_elements(); + float* temp_data = temp->flat().data(); + VLOG(1) << "num_indices : " << num_indices; + OP_REQUIRES(context, axis == 0, errors::InvalidArgument("axis only support 0")); + const int slice_size = P1; + for (int i = 0; i < num_indices; ++i) { + Tindices idx = indices_data[i]; + OP_REQUIRES(context, (idx >= 0 && idx < P0), errors::InvalidArgument("GatherV2 axis=0: index out of range")); + std::memcpy( + temp_data + i * slice_size, params_data + idx * slice_size, sizeof(float) * slice_size + ); + } +} + + +template +class KPFusedEmbeddingActionIdGatherOp : public OpKernel { +public: + explicit KPFusedEmbeddingActionIdGatherOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& indices1 = context->input(0); + const Tensor& params = context->input(1); + const Tensor& indices2 = context->input(2); + const Tensor& pack_dim = context->input(3); + + const Tensor& pack = context->input(4); + + VLOG(1) << "indices1 shape: " << indices1.shape().DebugString(); + VLOG(1) << "params shape: " << params.shape().DebugString(); + VLOG(1) << "indices2 shape: " << indices2.shape().DebugString(); + OP_REQUIRES( + context, + TensorShapeUtils::IsMatrix(indices1.shape()), + errors::InvalidArgument("indices1 dims must = 2") + ); + OP_REQUIRES( + context, + TensorShapeUtils::IsMatrix(indices2.shape()), + errors::InvalidArgument("indices2 dims must = 2") + ); + OP_REQUIRES( + context, + TensorShapeUtils::IsMatrix(params.shape()), + errors::InvalidArgument("params dims must = 2") + ); + OP_REQUIRES( + context, + TensorShapeUtils::IsScalar(pack_dim.shape()), + errors::InvalidArgument("pack_dim is scalar") + ); + OP_REQUIRES( + context, + TensorShapeUtils::IsScalar(pack.shape()), + errors::InvalidArgument("pack const is scalar") + ); + + Tensor temp; + GatherV2Impl(context, params.flat().data(), params.shape(), indices1.flat().data(), + indices1.shape(), 0, &temp); + Tensor temp1; + GatherV2Impl(context, temp.flat().data(), temp.shape(), indices2.flat().data(), + indices2.shape(), 0, &temp1); + int pack_size = pack_dim.scalar()(); + int pack_const = pack.scalar()(); + OP_REQUIRES(context, pack_size > 0, errors::InvalidArgument("pack_size must > 0")); + int a_reshaped_cols = temp1.NumElements() / pack_size; + auto a_reshaped = temp1.shaped({pack_size, a_reshaped_cols}); + Tensor* output; + int output_cols = a_reshaped_cols + pack_const; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({pack_size, output_cols}), &output)); + auto a_reshaped_data = a_reshaped.data(); + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); + const int64 cost_per_unit = a_reshaped_cols + pack_const; + auto work = [&](int64 start_row, int64 end_row) { + float* base = output->matrix().data(); + for (int64 row = start_row; row < end_row; ++row) { + float* dst_row = base + row * (a_reshaped_cols + pack_const); + std::memcpy( + dst_row, a_reshaped_data + row * a_reshaped_cols, sizeof(float) * a_reshaped_cols + ); + std::memset( + dst_row + a_reshaped_cols, 0, sizeof(float) * pack_const + ); + } + }; + Shard(worker_threads->num_threads, worker_threads->workers, pack_size, + cost_per_unit, work); + } +}; + +#define REGISTER_CPU_KERNEL(Tindices1, Tindices2) \ + REGISTER_KERNEL_BUILDER(Name("KPFusedEmbeddingActionIdGather") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Tindices1") \ + .TypeConstraint("Tindices2"), \ + KPFusedEmbeddingActionIdGatherOp); + +REGISTER_CPU_KERNEL(int64, int32) +REGISTER_CPU_KERNEL(int32, int32) +REGISTER_CPU_KERNEL(int64, int64) +REGISTER_CPU_KERNEL(int32, int64) + +} \ No newline at end of file diff --git a/annc/tensorflow/kernels/embedding_fused_action_id_gather_test.cc b/annc/tensorflow/kernels/embedding_fused_action_id_gather_test.cc new file mode 100644 index 0000000..11067bf --- /dev/null +++ b/annc/tensorflow/kernels/embedding_fused_action_id_gather_test.cc @@ -0,0 +1,289 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ==============================================================================*/ + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +class KPFusedEmbeddingActionIdGatherTest : public OpsTestBase { + protected: + void MakeOp(DataType indices1_type, DataType indices2_type) { + TF_ASSERT_OK(NodeDefBuilder("fused_embedding_action_id_gather", + "KPFusedEmbeddingActionIdGather") + .Input(FakeInput(indices1_type)) // indices1 + .Input(FakeInput(DT_FLOAT)) // params + .Input(FakeInput(indices2_type)) // indices2 + .Input(FakeInput(DT_INT32)) // pack_dim + .Input(FakeInput(DT_INT32)) // pack + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + } + + template + Status FeedAndRun(const std::vector& indices1_data, + const TensorShape& indices1_shape, + const std::vector& params_data, + const TensorShape& params_shape, + const std::vector& indices2_data, + const TensorShape& indices2_shape, int pack_dim_value, + int pack_value) { + inputs_.clear(); + input_types_.clear(); + + MakeOp(DataTypeToEnum::v(), DataTypeToEnum::v()); + AddInputFromArray(indices1_shape, indices1_data); + AddInputFromArray(params_shape, params_data); + AddInputFromArray(indices2_shape, indices2_data); + AddInputFromArray(TensorShape({}), {pack_dim_value}); + AddInputFromArray(TensorShape({}), {pack_value}); + return RunOpKernel(); + } +}; + +TEST_F(KPFusedEmbeddingActionIdGatherTest, NormalCase) { + std::vector indices1_data = {0, 2}; + TensorShape indices1_shape({2, 1}); + + std::vector params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + TensorShape params_shape({3, 2}); + + std::vector indices2_data = {1, 0}; + TensorShape indices2_shape({2, 1}); + + int pack_dim_value = 2; + int pack_value = 1; + + TF_ASSERT_OK((FeedAndRun( + indices1_data, indices1_shape, params_data, params_shape, indices2_data, + indices2_shape, pack_dim_value, pack_value))); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); + test::FillValues(&expected, {5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}); + test::ExpectTensorNear(expected, *GetOutput(0), 1e-5); +} + +TEST_F(KPFusedEmbeddingActionIdGatherTest, DifferentIndexTypes) { + // int64int32 + { + std::vector indices1 = {0, 2}; + std::vector indices2 = {1, 0}; + TF_ASSERT_OK((FeedAndRun(indices1, {2, 1}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, + {3, 2}, indices2, {2, 1}, 2, 1))); + test::ExpectTensorNear( + *GetOutput(0), + test::AsTensor({5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}, {2, 3}), + 1e-5); + } + + // int32int32 + { + std::vector indices1 = {0, 2}; + std::vector indices2 = {1, 0}; + TF_ASSERT_OK((FeedAndRun(indices1, {2, 1}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, + {3, 2}, indices2, {2, 1}, 2, 1))); + test::ExpectTensorNear( + *GetOutput(0), + test::AsTensor({5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}, {2, 3}), + 1e-5); + } + + // int64int64 + { + std::vector indices1 = {0, 2}; + std::vector indices2 = {1, 0}; + TF_ASSERT_OK((FeedAndRun(indices1, {2, 1}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, + {3, 2}, indices2, {2, 1}, 2, 1))); + test::ExpectTensorNear( + *GetOutput(0), + test::AsTensor({5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}, {2, 3}), + 1e-5); + } + + // int32int64 + { + std::vector indices1 = {0, 2}; + std::vector indices2 = {1, 0}; + TF_ASSERT_OK((FeedAndRun(indices1, {2, 1}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, + {3, 2}, indices2, {2, 1}, 2, 1))); + test::ExpectTensorNear( + *GetOutput(0), + test::AsTensor({5.0f, 6.0f, 0.0f, 1.0f, 2.0f, 0.0f}, {2, 3}), + 1e-5); + } +} + +TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidIndices1Dims) { + MakeOp(DT_INT64, DT_INT32); + + std::vector indices1_data = {0, 2}; + AddInputFromArray(TensorShape({2}), indices1_data); + + std::vector params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + AddInputFromArray(TensorShape({3, 2}), params_data); + + std::vector indices2_data = {1, 0}; + AddInputFromArray(TensorShape({2, 1}), indices2_data); + + AddInputFromArray(TensorShape({}), {2}); + AddInputFromArray(TensorShape({}), {1}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.ToString(), "indices1 dims must = 2")) << s; +} + +TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidIndices2Dims) { + MakeOp(DT_INT64, DT_INT32); + + std::vector indices1_data = {0, 2}; + AddInputFromArray(TensorShape({2, 1}), indices1_data); + + std::vector params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + AddInputFromArray(TensorShape({3, 2}), params_data); + + std::vector indices2_data = {1, 0}; + AddInputFromArray(TensorShape({2}), indices2_data); + + AddInputFromArray(TensorShape({}), {2}); + AddInputFromArray(TensorShape({}), {1}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.ToString(), "indices2 dims must = 2")) << s; +} + +TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidParamsDims) { + MakeOp(DT_INT64, DT_INT32); + + std::vector indices1_data = {0, 2}; + AddInputFromArray(TensorShape({2, 1}), indices1_data); + + std::vector params_data = {1.0f, 2.0f, 3.0f, 4.0f}; + AddInputFromArray(TensorShape({4}), params_data); + + std::vector indices2_data = {1, 0}; + AddInputFromArray(TensorShape({2, 1}), indices2_data); + + AddInputFromArray(TensorShape({}), {2}); + AddInputFromArray(TensorShape({}), {1}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.ToString(), "params dims must = 2")) << s; +} + +TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidPackDimDims) { + MakeOp(DT_INT64, DT_INT32); + + std::vector indices1_data = {0, 2}; + AddInputFromArray(TensorShape({2, 1}), indices1_data); + + std::vector params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + AddInputFromArray(TensorShape({3, 2}), params_data); + + std::vector indices2_data = {1, 0}; + AddInputFromArray(TensorShape({2, 1}), indices2_data); + + AddInputFromArray(TensorShape({1}), {2}); + AddInputFromArray(TensorShape({}), {1}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.ToString(), "pack_dim is scalar")) << s; +} + +TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidPackDims) { + MakeOp(DT_INT64, DT_INT32); + + std::vector indices1_data = {0, 2}; + AddInputFromArray(TensorShape({2, 1}), indices1_data); + + std::vector params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + AddInputFromArray(TensorShape({3, 2}), params_data); + + std::vector indices2_data = {1, 0}; + AddInputFromArray(TensorShape({2, 1}), indices2_data); + + AddInputFromArray(TensorShape({}), {2}); + AddInputFromArray(TensorShape({1}), {1}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.ToString(), "pack const is scalar")) << s; +} + +TEST_F(KPFusedEmbeddingActionIdGatherTest, InvalidPackSize) { + MakeOp(DT_INT64, DT_INT32); + + std::vector indices1_data = {0, 2}; + AddInputFromArray(TensorShape({2, 1}), indices1_data); + + std::vector params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + AddInputFromArray(TensorShape({3, 2}), params_data); + + std::vector indices2_data = {1, 0}; + AddInputFromArray(TensorShape({2, 1}), indices2_data); + + AddInputFromArray(TensorShape({}), {0}); + AddInputFromArray(TensorShape({}), {1}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.ToString(), "pack_size must > 0")) << s; +} + +TEST_F(KPFusedEmbeddingActionIdGatherTest, IndexOutOfRange) { + MakeOp(DT_INT64, DT_INT32); + + std::vector indices1_data = {0, 5}; + AddInputFromArray(TensorShape({2, 1}), indices1_data); + + std::vector params_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + AddInputFromArray(TensorShape({3, 2}), params_data); + + std::vector indices2_data = {1, 0}; + AddInputFromArray(TensorShape({2, 1}), indices2_data); + + AddInputFromArray(TensorShape({}), {2}); + AddInputFromArray(TensorShape({}), {1}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE( + absl::StrContains(s.ToString(), "GatherV2 axis=0: index out of range")) + << s; +} + +} // namespace tensorflow diff --git a/annc/tensorflow/kernels/embedding_fused_gather.cc b/annc/tensorflow/kernels/embedding_fused_gather.cc new file mode 100644 index 0000000..c09d1ce --- /dev/null +++ b/annc/tensorflow/kernels/embedding_fused_gather.cc @@ -0,0 +1,90 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/util/work_sharder.h" + +using namespace tensorflow; + +class KPFusedGather : public OpKernel { + public: + explicit KPFusedGather(OpKernelConstruction* context) : OpKernel(context) { } + + void Compute(OpKernelContext* context) override { + const Tensor& data = context->input(0); + const Tensor& keys = context->input(1); + const Tensor& begin = context->input(2); + VLOG(1) << "Embedding table size: " << data.shape().DebugString(); + VLOG(1) << "Input key shape: " << keys.shape().DebugString(); + VLOG(1) << "Slice begin value: " << begin.DebugString(); + + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(keys.shape()), + errors::Internal("Input key must be 2D")); + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(data.shape()), + errors::Internal("Embedding table shape must be 2D")); + OP_REQUIRES(context, begin.NumElements() == 2, errors::Internal("begin must be same as keys rank")); + int32 col = begin.flat().data()[1]; + OP_REQUIRES(context, col < keys.dim_size(1), errors::Internal("slice cols out of keys range")); + + Tensor* out_indices = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 1, TensorShape({static_cast(keys.dim_size(0))}), &out_indices)); + int32 *out_indices_data = out_indices->flat().data(); + + auto keys_mat = keys.matrix(); + std::vector unique_values; + std::unordered_map value_to_index; + int current_index = 0; + for (int64_t i = 0; i < keys.dim_size(0); ++i) { + auto it = value_to_index.find(keys_mat(i, col)); + if (it == value_to_index.end()) { + value_to_index[keys_mat(i, col)] = current_index; + unique_values.push_back(keys_mat(i, col)); + out_indices_data[i] = current_index; + ++current_index; + } else { + out_indices_data[i] = it->second; + } + } + + Tensor* out_unique_value = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({static_cast(unique_values.size())}), &out_unique_value)); + std::memcpy(out_unique_value->data(), unique_values.data(), unique_values.size() * sizeof(int64_t)); + + Tensor* out_data = nullptr; + int embedding_dims = data.dim_size(1); + OP_REQUIRES_OK(context, + context->allocate_output( + 2, TensorShape({static_cast(unique_values.size()), embedding_dims}), &out_data)); + + const float *data_mat = data.flat().data(); + for (int64_t cur_row = 0; cur_row < unique_values.size(); ++cur_row) { + int64_t idx = unique_values[cur_row]; + OP_REQUIRES(context, idx < data.dim_size(0), errors::Internal("idx out of table range")); + const float* src = data_mat + idx * embedding_dims; + float* dst = out_data->flat().data() + cur_row * embedding_dims; + std::memcpy(dst, src, embedding_dims * sizeof(float)); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("KPFusedGather").Device(DEVICE_CPU), + KPFusedGather); \ No newline at end of file diff --git a/annc/tensorflow/kernels/embedding_fused_gather_test.cc b/annc/tensorflow/kernels/embedding_fused_gather_test.cc new file mode 100644 index 0000000..c947709 --- /dev/null +++ b/annc/tensorflow/kernels/embedding_fused_gather_test.cc @@ -0,0 +1,186 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { +using tensorflow::AllocatorAttributes; +using tensorflow::DT_FLOAT; +using tensorflow::DT_INT32; +using tensorflow::DT_INT64; +using tensorflow::int64; +using tensorflow::int32; +using tensorflow::NodeDefBuilder; +using tensorflow::OpsTestBase; +using tensorflow::Status; +using tensorflow::Tensor; +using tensorflow::TensorShape; +using tensorflow::test::ExpectClose; +using tensorflow::test::FillValues; +using tensorflow::test::AsTensor; +using tensorflow::test::ExpectTensorEqual; + +class KPFusedGatherTest : public OpsTestBase { + protected: + void RunValidCase(const TensorShape& data_shape, + const TensorShape& slice_shape, + const std::vector& begin_val, + const std::vector& slice_data, + const std::vector& data_data, + const std::vector& expected_unique, + const std::vector& expected_indices, + const std::vector& expected_output_data) { + TF_EXPECT_OK(NodeDefBuilder("kp_fused_gather", "KPFusedGather") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_INT64)) + .Input(FakeInput(DT_INT32)) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + + AddInputFromArray(data_shape, data_data); + AddInputFromArray(slice_shape, slice_data); + AddInputFromArray(TensorShape({2}), begin_val); + + TF_ASSERT_OK(RunOpKernel()); + + const Tensor& out_unique = *GetOutput(0); + const Tensor& out_indices = *GetOutput(1); + const Tensor& out_data = *GetOutput(2); + + // 验证输出0: unique_values + Tensor expected_unique_tensor( + allocator(), DT_INT64, + TensorShape({static_cast(expected_unique.size())}) + ); + FillValues(&expected_unique_tensor, expected_unique); + ExpectTensorEqual(expected_unique_tensor, out_unique); + + // 验证输出1: indices + Tensor expected_indices_tensor( + allocator(), DT_INT32, + TensorShape({static_cast(expected_indices.size())}) + ); + FillValues(&expected_indices_tensor, expected_indices); + ExpectTensorEqual(expected_indices_tensor, out_indices); + + // 验证输出2: out_data + Tensor expected_data_tensor(allocator(), DT_FLOAT, + TensorShape({static_cast(expected_unique.size()), 12})); + FillValues(&expected_data_tensor, expected_output_data); + ExpectClose(expected_data_tensor, out_data); // float 用 ExpectClose + } + + Status RunOpExpectFailure(const TensorShape& data_shape, + const TensorShape& slice_shape, + const std::vector& begin_val, + const std::vector& slice_data, + const std::vector& data_data) { + TF_CHECK_OK(NodeDefBuilder("kp_fused_gather", "KPFusedGather") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_INT64)) + .Input(FakeInput(DT_INT32)) + .Finalize(node_def())); + TF_CHECK_OK(InitOp()); + + AddInputFromArray(data_shape, data_data); + AddInputFromArray(slice_shape, slice_data); + AddInputFromArray(TensorShape({2}), begin_val); + + return RunOpKernel(); + } +}; + +// 正向测试:正常输入 +TEST_F(KPFusedGatherTest, Valid_NormalInput) { + RunValidCase( + TensorShape({2, 12}), // data shape + TensorShape({4, 3}), // slice_input shape + {0, 1}, // begin[1] = 1 → 取第1列 + {1, 1, 3, + 0, 1, 5, + 1, 0, 7, + 0, 1, 9}, // slice_input 数据 + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f}, + {1, 0}, // unique values from col=1 + {0, 0, 1, 0}, // indices mapping + {13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, // data[1] + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} // data[0] + ); +} + +// data不是2维 +TEST_F(KPFusedGatherTest, Invalid_DataDimsNot2) { + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + Status s = RunOpExpectFailure( + TensorShape({4}), // data 不是二维 + TensorShape({2, 2}), + {0, 0}, + {0, 1, 2, 3}, + data + ); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "Embedding table shape must be 2D")); +} + +// key 不是2维 +TEST_F(KPFusedGatherTest, Invalid_SliceInputDimsNot2) { + std::vector data(2 * 12, 1.0f); + Status s = RunOpExpectFailure( + TensorShape({2, 12}), + TensorShape({4}), // 1D slice_input + {0, 0}, + {0, 1, 2, 3}, + data + ); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "Input key must be 2D")); +} + +// begin[1] 超出列范围 +TEST_F(KPFusedGatherTest, Invalid_BeginColOutOfRange) { + std::vector data(2 * 12, 1.0f); + Status s = RunOpExpectFailure( + TensorShape({2, 12}), + TensorShape({2, 2}), + {0, 2}, // begin[1] = 2,但只有 2 列 → 索引 0,1 + {0, 1, 2, 3}, + data + ); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(),"slice cols out of keys range")); +} + +// gather 索引超出 data 行数 +TEST_F(KPFusedGatherTest, Invalid_IndexOutOfRangeInData) { + std::vector data(2 * 12, 1.0f); + Status s = RunOpExpectFailure( + TensorShape({2, 12}), + TensorShape({2, 2}), + {0, 0}, + {0, 1, + 2, 3}, // 索引 2 超出 data 行数(只有 0,1) + data + ); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(),"idx out of table range")); +} + +} \ No newline at end of file diff --git a/annc/tensorflow/kernels/embedding_fused_padding.cc b/annc/tensorflow/kernels/embedding_fused_padding.cc new file mode 100644 index 0000000..a0f14b2 --- /dev/null +++ b/annc/tensorflow/kernels/embedding_fused_padding.cc @@ -0,0 +1,126 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +class KPFusedEmbeddingPaddingOp : public OpKernel { +public: + explicit KPFusedEmbeddingPaddingOp(OpKernelConstruction* context) : OpKernel(context) { + fast_ = (type_string() == "KPFusedEmbeddingPaddingFast"); + } + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& origin_shape = context->input(0); + const Tensor& input = context->input(1); + const Tensor& input_rows = context->input(2); + const Tensor& reshape_sizes = context->input(3); + + const Tensor& pack = context->input(4); + + VLOG(1) << "Input shape: " << input.shape().DebugString(); + OP_REQUIRES(context, + TensorShapeUtils::IsVector(origin_shape.shape()), + errors::InvalidArgument("origin_shape dims must 1D, not ", origin_shape.shape().DebugString()) + ); + OP_REQUIRES(context, + origin_shape.NumElements() == 2, + errors::InvalidArgument("origin_shape NumElements must == 2, not ", origin_shape.NumElements()) + ); + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(input.shape()), + errors::InvalidArgument("input dims must 2D, not ", input.shape().DebugString())); + OP_REQUIRES(context, + TensorShapeUtils::IsScalar(input_rows.shape()), + errors::InvalidArgument("input_rows must be a scalar") + ); + OP_REQUIRES(context, + TensorShapeUtils::IsVector(reshape_sizes.shape()), + errors::InvalidArgument("sizes input must be 1-D, not ", reshape_sizes.shape().DebugString()) + ); + OP_REQUIRES(context, + reshape_sizes.NumElements() == 2, + errors::InvalidArgument("reshape_sizes NumElements must == 2")); + + int input_rows_value = input_rows.scalar()(); + int padding_rows = static_cast(origin_shape.flat()(0)) - input_rows_value; + auto reshape_cols = reshape_sizes.flat()(1); + int output_rows = padding_rows + input.dim_size(0); + int output_cols = input.dim_size(1); + OP_REQUIRES(context, + padding_rows >= 0, + errors::InvalidArgument("Pooling size(", input_rows_value, + ") is greater than Input size(", static_cast(origin_shape.flat()(0)), ")")); + OP_REQUIRES(context, + reshape_cols > 0, + errors::InvalidArgument("reshape_cols must > 0")); + OP_REQUIRES(context, + reshape_sizes.flat()(0) == -1, + errors::InvalidArgument("reshape[0] is not -1")); + OP_REQUIRES(context, + pack.scalar()() == output_cols, + errors::InvalidArgument("pack(", pack.scalar()(), ") is not equal to embedding dims")); + + Tensor* output0 = nullptr; + Tensor* output1 = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), + &output0)); + output0->scalar()() = padding_rows; + OP_REQUIRES(context, + output_rows * output_cols % reshape_cols == 0, + errors::InvalidArgument("padding cannot reshape to [-1, ", reshape_cols, "]") + ); + int reshape_rows = output_rows * output_cols / reshape_cols; + if (fast_) { + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}), &output1)); + output1->scalar()() = reshape_rows; + return; + } + + TensorShape reshaped_shape({reshape_rows, reshape_cols}); + OP_REQUIRES_OK(context, + context->allocate_output(1, reshaped_shape, &output1)); + float* output_data = output1->flat().data(); + const float* input_data = input.flat().data(); + std::memcpy(output_data, input_data, input_rows_value * output_cols * sizeof(float)); + std::memset(output_data + input_rows_value * output_cols, + 0.0f, + padding_rows * output_cols * sizeof(float)); + } + +private: + bool fast_; +}; + + +REGISTER_KERNEL_BUILDER(Name("KPFusedEmbeddingPadding").Device(DEVICE_CPU), + KPFusedEmbeddingPaddingOp); + +REGISTER_KERNEL_BUILDER(Name("KPFusedEmbeddingPaddingFast").Device(DEVICE_CPU), + KPFusedEmbeddingPaddingOp); + +} \ No newline at end of file diff --git a/annc/tensorflow/kernels/embedding_fused_padding_test.cc b/annc/tensorflow/kernels/embedding_fused_padding_test.cc new file mode 100644 index 0000000..5137d51 --- /dev/null +++ b/annc/tensorflow/kernels/embedding_fused_padding_test.cc @@ -0,0 +1,307 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +class KPFusedEmbeddingPaddingTest : public OpsTestBase { + protected: + void MakeOp(DataType input_shape_type, DataType pooling_type, DataType reshape_type, DataType const_type) { + TF_ASSERT_OK(NodeDefBuilder("fused_padding", "KPFusedEmbeddingPadding") + .Input(FakeInput(input_shape_type)) + .Input(FakeInput(pooling_type)) + .Input(FakeInput(const_type)) + .Input(FakeInput(reshape_type)) + .Input(FakeInput(const_type)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + } + + Status FeedAndRun(const int embedding_dims, const int table_size, + const int pooling_size, const int reshape_size) { + MakeOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32); + AddInputFromArray(TensorShape({2}), {table_size, embedding_dims}); + AddInput(TensorShape({pooling_size, embedding_dims}), [](int i) -> float { + return static_cast(i + 1); + }); + AddInputFromArray(TensorShape({}), {pooling_size}); + AddInputFromArray(TensorShape({2}), {-1, reshape_size}); + AddInputFromArray(TensorShape({}), {embedding_dims}); + return RunOpKernel(); + } + + void MakeFastOp(DataType input_shape_type, DataType pooling_type, DataType reshape_type, DataType const_type) { + TF_ASSERT_OK(NodeDefBuilder("fused_padding_fast", "KPFusedEmbeddingPaddingFast") + .Input(FakeInput(input_shape_type)) + .Input(FakeInput(pooling_type)) + .Input(FakeInput(const_type)) + .Input(FakeInput(reshape_type)) + .Input(FakeInput(const_type)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + } + + Status FeedAndRunFast(const int embedding_dims, const int table_size, + const int pooling_size, const int reshape_size) { + MakeFastOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32); + AddInputFromArray(TensorShape({2}), {table_size, embedding_dims}); + AddInput(TensorShape({pooling_size, embedding_dims}), [](int i) -> float { + return static_cast(i + 1); + }); + AddInputFromArray(TensorShape({}), {pooling_size}); + AddInputFromArray(TensorShape({2}), {-1, reshape_size}); + AddInputFromArray(TensorShape({}), {embedding_dims}); + return RunOpKernel(); + } +}; + +TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithEmbeddingDims10_0) { + // Feed and run + const int embedding_dims = 10; + const int table_size = 151; + const int pooling_size = 151; + const int reshape_size = 1510; + TF_ASSERT_OK(FeedAndRun(embedding_dims, table_size, pooling_size, reshape_size)); + + // Check the output. + Tensor expected1(allocator(), DT_INT32, TensorShape({})); + Tensor expected2(allocator(), DT_FLOAT, TensorShape({table_size * embedding_dims / reshape_size, reshape_size})); + test::FillValues(&expected1, {table_size - pooling_size}); + test::FillFn(&expected2, [=](int i) -> float { + if (i < pooling_size * embedding_dims) { + return static_cast(i + 1); + } else { + return 0.0f; + } + }); + test::ExpectTensorEqual(expected1, *GetOutput(0)); + test::ExpectTensorNear(expected2, *GetOutput(1), 1e-5); +} + +TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithEmbeddingDims10_1) { + // Feed and run + const int embedding_dims = 10; + const int table_size = 1510; + const int pooling_size = 151; + const int reshape_size = 1510; + TF_ASSERT_OK(FeedAndRun(embedding_dims, table_size, pooling_size, reshape_size)); + + // Check the output. + Tensor expected1(allocator(), DT_INT32, TensorShape({})); + Tensor expected2(allocator(), DT_FLOAT, TensorShape({table_size * embedding_dims / reshape_size, reshape_size})); + test::FillValues(&expected1, {table_size - pooling_size}); + test::FillFn(&expected2, [=](int i) -> float { + if (i < pooling_size * embedding_dims) { + return static_cast(i + 1); + } else { + return 0.0f; + } + }); + test::ExpectTensorEqual(expected1, *GetOutput(0)); + test::ExpectTensorNear(expected2, *GetOutput(1), 1e-5); +} + +TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithEmbeddingDims12_0) { + // Feed and run + const int embedding_dims = 12; + const int table_size = 2; + const int pooling_size = 2; + const int reshape_size = 24; + TF_ASSERT_OK(FeedAndRun(embedding_dims, table_size, pooling_size, reshape_size)); + + // Check the output. + Tensor expected1(allocator(), DT_INT32, TensorShape({})); + Tensor expected2(allocator(), DT_FLOAT, TensorShape({table_size * embedding_dims / reshape_size, reshape_size})); + test::FillValues(&expected1, {table_size - pooling_size}); + test::FillFn(&expected2, [=](int i) -> float { + if (i < pooling_size * embedding_dims) { + return static_cast(i + 1); + } else { + return 0.0f; + } + }); + test::ExpectTensorEqual(expected1, *GetOutput(0)); + test::ExpectTensorNear(expected2, *GetOutput(1), 1e-5); +} + +TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithEmbeddingDims12_1) { + // Feed and run + const int embedding_dims = 12; + const int table_size = 200; + const int pooling_size = 2; + const int reshape_size = 24; + TF_ASSERT_OK(FeedAndRun(embedding_dims, table_size, pooling_size, reshape_size)); + + // Check the output. + Tensor expected1(allocator(), DT_INT32, TensorShape({})); + Tensor expected2(allocator(), DT_FLOAT, TensorShape({table_size * embedding_dims / reshape_size, reshape_size})); + test::FillValues(&expected1, {table_size - pooling_size}); + test::FillFn(&expected2, [=](int i) -> float { + if (i < pooling_size * embedding_dims) { + return static_cast(i + 1); + } else { + return 0.0f; + } + }); + test::ExpectTensorEqual(expected1, *GetOutput(0)); + test::ExpectTensorNear(expected2, *GetOutput(1), 1e-5); +} + +TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingFastWithEmbeddingDims10_0) { + // Feed and run + const int embedding_dims = 10; + const int table_size = 151; + const int pooling_size = 151; + const int reshape_size = 1510; + TF_ASSERT_OK(FeedAndRunFast(embedding_dims, table_size, pooling_size, reshape_size)); + + // Check the output. + Tensor expected1(allocator(), DT_INT32, TensorShape({})); + Tensor expected2(allocator(), DT_INT32, TensorShape({})); + test::FillValues(&expected1, {table_size - pooling_size}); + test::FillValues(&expected2, {table_size * embedding_dims / reshape_size}); + test::ExpectTensorEqual(expected1, *GetOutput(0)); + test::ExpectTensorEqual(expected2, *GetOutput(1)); +} + +TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingFastWithEmbeddingDims10_1) { + // Feed and run + const int embedding_dims = 10; + const int table_size = 1510; + const int pooling_size = 151; + const int reshape_size = 1510; + TF_ASSERT_OK(FeedAndRunFast(embedding_dims, table_size, pooling_size, reshape_size)); + + // Check the output. + Tensor expected1(allocator(), DT_INT32, TensorShape({})); + Tensor expected2(allocator(), DT_INT32, TensorShape({})); + test::FillValues(&expected1, {table_size - pooling_size}); + test::FillValues(&expected2, {table_size * embedding_dims / reshape_size}); + test::ExpectTensorEqual(expected1, *GetOutput(0)); + test::ExpectTensorEqual(expected2, *GetOutput(1)); +} + +TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingFastWithEmbeddingDims12_0) { + // Feed and run + const int embedding_dims = 12; + const int table_size = 2; + const int pooling_size = 2; + const int reshape_size = 24; + TF_ASSERT_OK(FeedAndRunFast(embedding_dims, table_size, pooling_size, reshape_size)); + + // Check the output. + Tensor expected1(allocator(), DT_INT32, TensorShape({})); + Tensor expected2(allocator(), DT_INT32, TensorShape({})); + test::FillValues(&expected1, {table_size - pooling_size}); + test::FillValues(&expected2, {table_size * embedding_dims / reshape_size}); + test::ExpectTensorEqual(expected1, *GetOutput(0)); + test::ExpectTensorEqual(expected2, *GetOutput(1)); +} + +TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingFastWithEmbeddingDims12_1) { + // Feed and run + const int embedding_dims = 12; + const int table_size = 200; + const int pooling_size = 2; + const int reshape_size = 24; + TF_ASSERT_OK(FeedAndRunFast(embedding_dims, table_size, pooling_size, reshape_size)); + + // Check the output. + Tensor expected1(allocator(), DT_INT32, TensorShape({})); + Tensor expected2(allocator(), DT_INT32, TensorShape({})); + test::FillValues(&expected1, {table_size - pooling_size}); + test::FillValues(&expected2, {table_size * embedding_dims / reshape_size}); + test::ExpectTensorEqual(expected1, *GetOutput(0)); + test::ExpectTensorEqual(expected2, *GetOutput(1)); +} + +TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithUnexpectReshape) { + // Feed and run + const int embedding_dims = 12; + const int table_size = 200; + const int pooling_size = 2; + const int reshape_size = 24; + MakeOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32); + AddInputFromArray(TensorShape({2}), {table_size, embedding_dims}); + AddInput(TensorShape({pooling_size, embedding_dims}), [](int i) -> float { + return static_cast(i + 1); + }); + AddInputFromArray(TensorShape({}), {pooling_size}); + AddInputFromArray(TensorShape({2}), {10, reshape_size}); + AddInputFromArray(TensorShape({}), {embedding_dims}); + Status s = RunOpKernel(); + EXPECT_TRUE( + absl::StrContains(s.ToString(), "reshape[0] is not -1")) + << s; +} + +TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithUnexpectPack) { + // Feed and run + const int embedding_dims = 12; + const int table_size = 200; + const int pooling_size = 2; + const int reshape_size = 24; + MakeOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32); + AddInputFromArray(TensorShape({2}), {table_size, embedding_dims}); + AddInput(TensorShape({pooling_size, embedding_dims}), [](int i) -> float { + return static_cast(i + 1); + }); + AddInputFromArray(TensorShape({}), {pooling_size}); + AddInputFromArray(TensorShape({2}), {-1, reshape_size}); + AddInputFromArray(TensorShape({}), {10}); + Status s = RunOpKernel(); + EXPECT_TRUE( + absl::StrContains(s.ToString(), "pack(10) is not equal to embedding dims")) + << s; +} + +TEST_F(KPFusedEmbeddingPaddingTest, FusedPaddingWithPoolingSizeGreaterInput) { + // Feed and run + const int embedding_dims = 12; + const int table_size = 200; + const int pooling_size = 201; + const int reshape_size = 24; + MakeOp(DT_INT64, DT_FLOAT, DT_INT32, DT_INT32); + AddInputFromArray(TensorShape({2}), {table_size, embedding_dims}); + AddInput(TensorShape({pooling_size, embedding_dims}), [](int i) -> float { + return static_cast(i + 1); + }); + AddInputFromArray(TensorShape({}), {pooling_size}); + AddInputFromArray(TensorShape({2}), {-1, reshape_size}); + AddInputFromArray(TensorShape({}), {embedding_dims}); + Status s = RunOpKernel(); + EXPECT_TRUE( + absl::StrContains(s.ToString(), "Pooling size(201) is greater than Input size(200)")) + << s; +} + +} // end namespace tensorflow diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_dynamic_stitch.cc b/annc/tensorflow/kernels/embedding_fused_sparse_dynamic_stitch.cc new file mode 100644 index 0000000..0ccf599 --- /dev/null +++ b/annc/tensorflow/kernels/embedding_fused_sparse_dynamic_stitch.cc @@ -0,0 +1,87 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/util/work_sharder.h" + +using namespace tensorflow; + +class KPFusedSparseDynamicStitchOp : public OpKernel { +public: + explicit KPFusedSparseDynamicStitchOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& x = context->input(0); + auto x_flat = x.flat(); + int64_t num_elems = x_flat.size(); + + const int num_inputs = context->num_inputs(); + const int num_partitions = num_inputs - 1; + OP_REQUIRES(context, num_partitions > 1, errors::InvalidArgument("num partitions must > 1")); + int64_t output_stride = 0; + std::vector variables(num_partitions); + std::vector variable_rows(num_partitions); + for (int i = 1; i < num_inputs; ++i) { + const Tensor& input_tensor = context->input(i); + OP_REQUIRES(context, input_tensor.dims() == 2, errors::InvalidArgument("input dims must == 2")); + if (i == 1) { + output_stride = input_tensor.dim_size(1); + } else { + OP_REQUIRES(context, input_tensor.dim_size(1) == output_stride, + errors::InvalidArgument("All inputs must have same second dimension")); + } + variables[i - 1] = context->input(i).flat().data(); + variable_rows[i - 1] = input_tensor.dim_size(0); + } + + OP_REQUIRES(context, output_stride > 0, errors::InvalidArgument("output_stride must > 0")); + + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({num_elems, output_stride}), + &output_tensor)); + float* output = (float*)output_tensor->tensor_data().data(); + + const size_t copy_size = output_stride * sizeof(float); + + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); + const int64 cost_per_unit = 120; // Actual single cycle execution time + auto work = [&](int start, int end) { + for (int i = start; i < end; ++i) { + const int64_t global_id = x_flat(i); + const int64_t table_id = global_id % num_partitions; + const int64_t row_id = global_id / num_partitions; + + OP_REQUIRES(context, row_id < variable_rows[table_id], errors::InvalidArgument( + "row_id out of range.")); + + std::memcpy(output + i * output_stride, + variables[table_id] + row_id * output_stride, copy_size); + } + }; + + Shard(worker_threads->num_threads, worker_threads->workers, num_elems, + cost_per_unit, work); + } +}; + +REGISTER_KERNEL_BUILDER(Name("KPFusedSparseDynamicStitch").Device(DEVICE_CPU), + KPFusedSparseDynamicStitchOp); diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_dynamic_stitch_test.cc b/annc/tensorflow/kernels/embedding_fused_sparse_dynamic_stitch_test.cc new file mode 100644 index 0000000..74fdd25 --- /dev/null +++ b/annc/tensorflow/kernels/embedding_fused_sparse_dynamic_stitch_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +class KPFusedSparseDynamicStitchOpTest : public OpsTestBase { + protected: + void MakeOp(int N) { + TF_ASSERT_OK(NodeDefBuilder("kp_fused_sparse_dynamic_stitch", + "KPFusedSparseDynamicStitch") + .Input(FakeInput(DT_INT64)) + .Input(FakeInput(N, DT_FLOAT)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + } +}; + +TEST_F(KPFusedSparseDynamicStitchOpTest, TestTwoTables) { + MakeOp(2); // num_partitions = 2 + + AddInputFromArray(TensorShape({4}), {0, 3, 2, 1}); + AddInputFromArray(TensorShape({3, 2}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + AddInputFromArray(TensorShape({2, 2}), {7.0f, 8.0f, 9.0f, 10.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 2})); + test::FillValues(&expected, + {1.0f, 2.0f, 9.0f, 10.0f, 3.0f, 4.0f, 7.0f, 8.0f}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(KPFusedSparseDynamicStitchOpTest, TestDifferentStride) { + MakeOp(2); + + AddInputFromArray(TensorShape({4}), {0, 3, 2, 1}); + AddInputFromArray(TensorShape({3, 2}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + AddInputFromArray(TensorShape({1, 4}), {7.0f, 8.0f, 9.0f, 10.0f}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(),"All inputs must have same second dimension")); +} + +TEST_F(KPFusedSparseDynamicStitchOpTest, TestIndicesOutOfBounds) { + MakeOp(2); + + AddInputFromArray(TensorShape({4}), {0, 6, 2, 1}); + AddInputFromArray(TensorShape({3, 2}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + AddInputFromArray(TensorShape({2, 2}), {7.0f, 8.0f, 9.0f, 10.0f}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(),"row_id out of range")); +} + +TEST_F(KPFusedSparseDynamicStitchOpTest, TestInputDims) { + MakeOp(2); + + AddInputFromArray(TensorShape({4}), {0, 6, 2, 1}); + AddInputFromArray(TensorShape({3, 2, 1}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + AddInputFromArray(TensorShape({2, 2, 1}), {7.0f, 8.0f, 9.0f, 10.0f}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(),"input dims must == 2")); +} + +} // namespace +} // namespace tensorflow diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_reshape.cc b/annc/tensorflow/kernels/embedding_fused_sparse_reshape.cc new file mode 100644 index 0000000..e4acad2 --- /dev/null +++ b/annc/tensorflow/kernels/embedding_fused_sparse_reshape.cc @@ -0,0 +1,195 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/util/work_sharder.h" +#include "tensorflow/core/kernels/reshape_util.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +using namespace tensorflow; + +static void ReshapeKp(OpKernelContext *context, const Tensor &input_indices_in, + const Tensor &input_shape_in, const Tensor &target_shape_in, + int output_indices_idx, int output_shape_idx) { + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices_in.shape()), + errors::InvalidArgument( + "Input indices should be a matrix but received shape ", + input_indices_in.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()), + errors::InvalidArgument( + "Input shape should be a vector but received shape ", + input_shape_in.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVector(target_shape_in.shape()), + errors::InvalidArgument( + "Target shape should be a vector but received shape ", + target_shape_in.shape().DebugString())); + + const int64 input_rank = input_shape_in.NumElements(); + const int64 output_rank = target_shape_in.NumElements(); + const TensorShape input_shape(input_shape_in.vec()); + const int64 dense_size = input_shape.num_elements(); + const int64 nnz = input_indices_in.shape().dim_size(0); + + TensorShape output_shape; + int64 product = 1; + int unknown_index = -1; + auto target_shape = target_shape_in.vec(); + for (int d = 0; d < output_rank; ++d) { + const int64 size = target_shape(d); + if (size == -1) { + OP_REQUIRES( + context, unknown_index == -1, + errors::InvalidArgument("only one output dimension may be -1, " + "not both ", + unknown_index, " and ", d)); + unknown_index = d; + output_shape.AddDim(1); + } else { + OP_REQUIRES(context, size >= 0, + errors::InvalidArgument("size ", d, + " must be non-negative, not ", size)); + product *= size; + output_shape.AddDim(size); + } + } + if (unknown_index != -1) { + OP_REQUIRES( + context, product > 0, + errors::InvalidArgument("reshape cannot infer the missing " + "input size for an empty tensor unless all " + "specified input sizes are non-zero")); + const int64 missing = dense_size / product; + OP_REQUIRES( + context, product * missing == dense_size, + errors::InvalidArgument( + "Input to reshape is a SparseTensor with ", dense_size, + " dense values, but the requested shape requires a multiple of ", + product, ". input_shape=", input_shape.DebugString(), + " output_shape=", output_shape.DebugString())); + output_shape.set_dim(unknown_index, missing); + } + + OP_REQUIRES( + context, output_shape.num_elements() == dense_size, + errors::InvalidArgument("Input to reshape is a tensor with ", dense_size, + " dense values, but the requested shape has ", + output_shape.num_elements(), + ". input_shape=", input_shape.DebugString(), + " output_shape=", output_shape.DebugString())); + + if (input_shape == output_shape) { + context->set_output(output_indices_idx, input_indices_in); + context->set_output(output_shape_idx, input_shape_in); + return; + } + + gtl::InlinedVector input_strides(input_rank); + if (input_rank > 0) { + input_strides[input_rank - 1] = 1; + for (int d = input_rank - 2; d >= 0; --d) { + input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1); + } + } + + gtl::InlinedVector output_strides(output_rank); + if (output_rank > 0) { + output_strides[output_rank - 1] = 1; + for (int d = output_rank - 2; d >= 0; --d) { + output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1); + } + } + + Tensor *result_indices = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(output_indices_idx, + TensorShape({nnz, output_rank}), + &result_indices)); + auto input_ind = input_indices_in.matrix(); + auto output_ind = result_indices->matrix(); + for (int i = 0; i < nnz; ++i) { + int64 id = 0; + for (int j = 0; j < input_rank; ++j) { + id += input_ind(i, j) * input_strides[j]; + } + for (int j = 0; j < output_rank; ++j) { + output_ind(i, j) = id / output_strides[j]; + id %= output_strides[j]; + } + } + + Tensor *result_shape = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx, + TensorShape({output_rank}), + &result_shape)); + auto output_shape_vec = result_shape->vec(); + for (int j = 0; j < output_shape.dims(); ++j) { + output_shape_vec(j) = output_shape.dim_size(j); + } +} + +class KPFusedSparseReshapeOp : public OpKernel { + public: + explicit KPFusedSparseReshapeOp(OpKernelConstruction* context) : OpKernel(context) { } + + void Compute(OpKernelContext* context) override { + const Tensor& slice_input = context->input(0); + const Tensor& begin = context->input(1); + const Tensor& new_shape = context->input(2); + const Tensor& pack_const = context->input(3); + + OP_REQUIRES(context, slice_input.dims() == 2, errors::Internal("slice_input dims must == 2")); + OP_REQUIRES(context, new_shape.dim_size(0) == 2, errors::Internal("new_shape dim size must == 2")); + OP_REQUIRES(context, pack_const.dims() == 0, + errors::InvalidArgument("pack_const must be a scalar")); + VLOG(1) << "Input slice_input shape: " << slice_input.shape().DebugString(); + VLOG(1) << "Input begin value: " << begin.DebugString(); + VLOG(1) << "Input new_shape value: " << new_shape.DebugString(); + + OP_REQUIRES(context, begin.dims() == 1 && begin.dim_size(0) == 2, + errors::InvalidArgument("begin must be 1D with at least 2 elements")); + int32 col = begin.flat().data()[1]; + OP_REQUIRES(context, col < slice_input.dim_size(1), errors::Internal("begin[1] must < slice_input.dim_size(1)")); + int64_t num_rows = slice_input.dim_size(0); + auto slice_input_mat = slice_input.matrix(); + + VLOG(1) << "num_rows: " << num_rows; + VLOG(1) << "slice_input.dim_size(0): " << slice_input.dim_size(0); + VLOG(1) << "slice_input.dim_size(1): " << slice_input.dim_size(1); + VLOG(1) << "Column index from begin: " << col; + + Tensor shape_in(DT_INT64, TensorShape({2})); + auto tensor_flat = shape_in.flat(); + tensor_flat(0) = num_rows; + tensor_flat(1) = pack_const.scalar()(); + + Tensor indices_in(DT_INT64, TensorShape({num_rows, 2})); + auto indices_in_mat = indices_in.matrix(); + for (int i = 0; i < num_rows; ++i) { + indices_in_mat(i, 0) = i; + indices_in_mat(i, 1) = slice_input_mat(i, col); + } + + ReshapeKp(context, indices_in, shape_in, new_shape, 0, 1); + } +}; + +REGISTER_KERNEL_BUILDER(Name("KPFusedSparseReshape").Device(DEVICE_CPU), + KPFusedSparseReshapeOp); \ No newline at end of file diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_reshape_test.cc b/annc/tensorflow/kernels/embedding_fused_sparse_reshape_test.cc new file mode 100644 index 0000000..874d48a --- /dev/null +++ b/annc/tensorflow/kernels/embedding_fused_sparse_reshape_test.cc @@ -0,0 +1,281 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { +using tensorflow::AllocatorAttributes; +using tensorflow::DT_FLOAT; +using tensorflow::DT_INT32; +using tensorflow::DT_INT64; +using tensorflow::int64; +using tensorflow::int32; +using tensorflow::NodeDefBuilder; +using tensorflow::OpsTestBase; +using tensorflow::Status; +using tensorflow::Tensor; +using tensorflow::TensorShape; +using tensorflow::test::FillValues; +using tensorflow::test::ExpectTensorEqual; + +class KPFusedSparseReshapeTest : public OpsTestBase { + protected: + void RunValidCase(const TensorShape& slice_shape, + const std::vector& slice_data, + const std::vector& begin_val, + const std::vector& new_shape_val, + const std::vector& pack_const_val, + const TensorShape& expected_indices_shape, + const std::vector& expected_shape_val) { + TF_EXPECT_OK(NodeDefBuilder("kp_fused_sparse_reshape", "KPFusedSparseReshape") + .Input(FakeInput(DT_INT64)) // slice_input + .Input(FakeInput(DT_INT32)) // begin + .Input(FakeInput(DT_INT64)) // new_shape + .Input(FakeInput(DT_INT64)) // pack_const + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + + AddInputFromArray(slice_shape, slice_data); + AddInputFromArray(TensorShape({2}), begin_val); + AddInputFromArray(TensorShape({2}), new_shape_val); + AddInputFromArray(TensorShape({}), pack_const_val); + + TF_ASSERT_OK(RunOpKernel()); + + // 输出0: result_indices + const Tensor& out_indices = *GetOutput(0); + EXPECT_EQ(out_indices.shape(), expected_indices_shape); + + // 输出1: result_shape + const Tensor& out_shape = *GetOutput(1); + Tensor expected_shape_tensor(DT_INT64, + TensorShape({static_cast(expected_shape_val.size())})); + FillValues(&expected_shape_tensor, expected_shape_val); + ExpectTensorEqual(expected_shape_tensor, out_shape); + } + + Status RunOpExpectFailure(const TensorShape& slice_shape, + const std::vector& slice_data, + const std::vector& begin_val, + const std::vector& new_shape_val, + const std::vector& pack_const_val) { + TF_CHECK_OK(NodeDefBuilder("kp_fused_sparse_reshape", "KPFusedSparseReshape") + .Input(FakeInput(DT_INT64)) // slice_input + .Input(FakeInput(DT_INT32)) // begin + .Input(FakeInput(DT_INT64)) // new_shape + .Input(FakeInput(DT_INT64)) // pack_const + .Finalize(node_def())); + TF_CHECK_OK(InitOp()); + + AddInputFromArray(slice_shape, slice_data); + AddInputFromArray(TensorShape({static_cast(begin_val.size())}), begin_val); + AddInputFromArray(TensorShape({static_cast(new_shape_val.size())}), new_shape_val); + AddInputFromArray(TensorShape({}), pack_const_val); + + return RunOpKernel(); + } +}; + +// ==================== 正向测试 ==================== + +// 正常 reshape 案例 +// pack_const=2 +TEST_F(KPFusedSparseReshapeTest, Valid_NormalInput) { + RunValidCase( + TensorShape({4, 2}), // slice_input shape + {0, 1, + 1, 2, + 2, 3, + 3, 0}, // slice_input 数据 + {0, 1}, // begin = (0,1),选第1列 + {2, 4}, // new_shape = [2,4] + {2}, // pack_const = [2] + TensorShape({4, 2}), // 预期 indices 形状 + {2, 4}); // 预期 shape +} + +// pack_const = 1 +TEST_F(KPFusedSparseReshapeTest, Valid_PackConst1) { + RunValidCase( + TensorShape({1, 2}), // slice_input shape + {0, 1}, // slice_input 数据 + {0, 1}, // begin = (0,1),选第1列 + {-1, 1}, // new_shape = [-1,1] + {1}, // pack_const = [1] + TensorShape({1, 2}), // 预期 indices 形状 + {1, 1}); // 预期 shape +} + +// ==================== 反向测试 ==================== + +// 反例1:slice_input 不是二维 +TEST_F(KPFusedSparseReshapeTest, Invalid_SliceInputNot2D) { + Status s = RunOpExpectFailure( + TensorShape({4}), {0, 1, 2, 3}, + {0, 0}, + {2, 2}, + {4}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "slice_input dims must == 2")); +} + +// 反例2:new_shape dim size 不是 2 +TEST_F(KPFusedSparseReshapeTest, Invalid_NewShapeNotLen2) { + Status s = RunOpExpectFailure( + TensorShape({2, 2}), {0, 1, 1, 0}, + {0, 0}, + {4, 2, 1}, // new_shape 多了1个元素 + {2}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "new_shape dim size must == 2")); +} + +// 反例3:begin[1] 超出 slice_input 列数 +TEST_F(KPFusedSparseReshapeTest, Invalid_BeginOutOfRange) { + Status s = RunOpExpectFailure( + TensorShape({2, 2}), {0, 1, 1, 0}, + {0, 2}, // 超过列数 + {2, 2}, + {2}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "begin[1] must < slice_input.dim_size(1)")); +} + +// 反例4:target shape 有多个 -1 +TEST_F(KPFusedSparseReshapeTest, Invalid_MultipleUnknownDims) { + Status s = RunOpExpectFailure( + TensorShape({2, 2}), {0, 1, 1, 0}, + {0, 1}, + {-1, -1}, // 两个 -1 + {2}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "only one output dimension may be -1")); +} + +// 反例5:reshape 推断维度时,总元素数不能整除,导致无法匹配 --> product * missing != dense_size +TEST_F(KPFusedSparseReshapeTest, Invalid_InferredShapeDoesNotMatch) { + TensorShape input_indices_shape({6, 2}); // 6 个非零元素,rank=2 + std::vector input_indices_data = { + 0, 0, + 0, 1, + 0, 2, + 1, 0, + 1, 1, + 1, 2 + }; // 对应 2x3 的 dense tensor + + std::vector begin_val = {0, 0}; // 假设的 begin 输入 + std::vector new_shape_val = {-1, 4}; // reshape 到 ?x4 + std::vector pack_const_val = {1}; + + Status s = RunOpExpectFailure( + input_indices_shape, + input_indices_data, + begin_val, + new_shape_val, + pack_const_val); + + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "Input to reshape is a SparseTensor with")); +} + +// 反例6:reshape 后元素数量不匹配 --> output_shape.num_elements() != dense_size +TEST_F(KPFusedSparseReshapeTest, Invalid_SizeMismatch) { + Status s = RunOpExpectFailure( + TensorShape({2, 2}), {0, 1, 1, 0}, + {0, 1}, + {3, 3}, // 期望 9 元素,但输入 dense size = 4 + {2}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "Input to reshape is a tensor with")); +} + +// 反例7:target_shape 包含负数但不是 -1 +TEST_F(KPFusedSparseReshapeTest, Invalid_NegativeDimNotMinusOne) { + Status s = RunOpExpectFailure( + TensorShape({2, 2}), {0, 1, 1, 0}, + {0, 0}, + {2, -2}, // -2 是非法的 + {2}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "size 1 must be non-negative, not -2")) + << "Actual error: " << s.message(); +} + +// 反例8:target_shape 有 -1,但其他维度乘积为 0 +TEST_F(KPFusedSparseReshapeTest, Invalid_ProductZeroWithUnknownDim) { + // dense_size = 0(空 SparseTensor),target_shape = [-1, 0] + // product = 0 → 不允许 infer + Status s = RunOpExpectFailure( + TensorShape({0, 2}), {}, // 空的 slice_input + {0, 0}, + {-1, 0}, // product = 0 + {2}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "reshape cannot infer the missing input size for an empty tensor")) + << "Actual error: " << s.message(); +} + +// 反例9:begin 是 1D 但长度为 1(不够 2 个元素) +TEST_F(KPFusedSparseReshapeTest, Invalid_BeginRank1ButSize1) { + Status s = RunOpExpectFailure( + TensorShape({2, 2}), {0, 1, 1, 0}, + {0}, // begin = [0],长度为 1 + {2, 2}, + {2}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "begin must be 1D with at least 2 elements")) + << "Actual error: " << s.message(); +} + +// 反例10:begin 是 1D 但长度为 3(超过 2) +TEST_F(KPFusedSparseReshapeTest, Invalid_BeginRank1ButSize3) { + Status s = RunOpExpectFailure( + TensorShape({2, 2}), {0, 1, 1, 0}, + {0, 1, 2}, // begin = [0,1,2],长度为 3 + {2, 2}, + {2}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "begin must be 1D with at least 2 elements")) + << "Actual error: " << s.message(); +} + +// 反例11:pack_const 是标量(0维) +TEST_F(KPFusedSparseReshapeTest, Invalid_PackConstIsScalarButExpect1D) { + TF_CHECK_OK(NodeDefBuilder("kp_fused_sparse_reshape", "KPFusedSparseReshape") + .Input(FakeInput(DT_INT64)) // slice_input + .Input(FakeInput(DT_INT32)) // begin + .Input(FakeInput(DT_INT64)) // new_shape + .Input(FakeInput(DT_INT64)) // pack_const + .Finalize(node_def())); + TF_CHECK_OK(InitOp()); + + AddInputFromArray(TensorShape({2, 2}), {0, 1, 1, 0}); + AddInputFromArray(TensorShape({2}), {0, 1}); + AddInputFromArray(TensorShape({2}), {2, 2}); + AddInputFromArray(TensorShape({1}), {1}); // pack_const = 标量 1(0维) + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "pack_const must be a scalar")) + << "Actual error: " << s.message(); +} + +} // namespace diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce.cc b/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce.cc new file mode 100644 index 0000000..33bbd31 --- /dev/null +++ b/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce.cc @@ -0,0 +1,165 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/util/work_sharder.h" + +using namespace tensorflow; + +template +class KPFusedSparseSegmentReduceOp : public OpKernel { +public: + explicit KPFusedSparseSegmentReduceOp(OpKernelConstruction* context) + : OpKernel(context) { + int combiner_mode; + OP_REQUIRES_OK(context, context->GetAttr("combiner", &combiner_mode)); + OP_REQUIRES(context, combiner_mode == 0 || combiner_mode == 1, + errors::InvalidArgument("combiner must be 0 or 1")); + is_mean_ = (combiner_mode == 1); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_tensor = context->input(0); + const Tensor& indices = context->input(1); + const Tensor& slice_input = context->input(2); + const Tensor& begin = context->input(3); + const Tensor& begin_1 = context->input(4); + + OP_REQUIRES(context, input_tensor.dims() == 2, errors::InvalidArgument("input must be 2-D")); + OP_REQUIRES(context, slice_input.dims() == 2, errors::InvalidArgument("slice input must be 2-D")); + OP_REQUIRES(context, begin.NumElements() == 2, errors::InvalidArgument("begin must have 2 elements")); + OP_REQUIRES(context, begin_1.NumElements() == 1, errors::InvalidArgument("begin_1 must have 1 element")); + int64_t num_indices = indices.dim_size(0); + int64_t embedding_size = input_tensor.dim_size(1); + int32 col = begin.flat().data()[1]; + int32 out_dim = static_cast(begin_1.flat()(0)); + + OP_REQUIRES(context, col >= 0 && col < slice_input.dim_size(1), + errors::InvalidArgument("Column index out of range")); + OP_REQUIRES(context, num_indices == slice_input.dim_size(0), + errors::InvalidArgument("indices and slice_input.dim_zie(0) should have same size")); + + auto input_data = input_tensor.matrix().data(); + auto indices_vec = indices.vec(); + auto slice_input_mat = slice_input.matrix(); + + // Calculate max segment_id + int64 max_seg_id = 0; + for (int32 i = 0; i < num_indices; ++i) { + int64 seg_id = slice_input_mat(i, col); + if (seg_id > max_seg_id) { + max_seg_id = seg_id; + } + } + const int64 batch_size = max_seg_id + 1; + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({batch_size, embedding_size}), &output)); + output->flat().setZero(); + Tensor* slice_out = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(1, TensorShape({}), &slice_out)); + if (out_dim == 0) + slice_out->scalar()() = batch_size; + else slice_out->scalar()() = embedding_size; + + auto output_data = output->matrix().data(); + + if (is_mean_) { + Tensor counts(DT_INT32, TensorShape({batch_size})); + counts.flat().setZero(); + auto counts_vec = counts.flat(); + + for (int64 i = 0; i < num_indices; ++i) { + const int64 seg_id = slice_input_mat(i, col); + const Tidx data_row = indices_vec(i); + counts_vec(seg_id) += 1; + + float* output_row = output_data + seg_id * embedding_size; + const float* input_data_row = input_data + data_row * embedding_size; + int64 j = 0; + for (; j + 3 < embedding_size; j += 4) { + float32x4_t out = vld1q_f32(output_row + j); + float32x4_t data = vld1q_f32(input_data_row + j); + out = vaddq_f32(out, data); + vst1q_f32(output_row + j, out); + } + + for (; j < embedding_size; ++j) { + output_row[j] += input_data_row[j]; + } + } + + for (int64_t seg = 0; seg < batch_size; ++seg) { + const int32_t count = counts_vec(seg); + if (count > 0) { + const float inv_count = 1.0f / static_cast(count); + const float32x4_t inv_count_vec = vdupq_n_f32(inv_count); + + float* row_start = output_data + seg * embedding_size; + int64_t j = 0; + + for (; j + 3 < embedding_size; j += 4) { + float32x4_t val = vld1q_f32(row_start + j); + val = vmulq_f32(val, inv_count_vec); + vst1q_f32(row_start + j, val); + } + + for (; j < embedding_size; ++j) { + row_start[j] *= inv_count; + } + } + } + } else { + for (int64 i = 0; i < num_indices; ++i) { + const int64 seg_id = slice_input_mat(i, col); + const Tidx data_row = indices_vec(i); + + float* output_row = output_data + seg_id * embedding_size; + const float* input_data_row = input_data + data_row * embedding_size; + int64 j = 0; + for (; j + 3 < embedding_size; j += 4) { + float32x4_t out = vld1q_f32(output_row + j); + float32x4_t data = vld1q_f32(input_data_row + j); + out = vaddq_f32(out, data); + vst1q_f32(output_row + j, out); + } + + for (; j < embedding_size; ++j) { + output_row[j] += input_data_row[j]; + } + } + } + } + +private: + bool is_mean_; +}; + +#define REGISTER_KERNEL(Tidx) \ + REGISTER_KERNEL_BUILDER(Name("KPFusedSparseSegmentReduce") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Tidx"), \ + KPFusedSparseSegmentReduceOp); +REGISTER_KERNEL(int64) +REGISTER_KERNEL(int32) +#undef REGISTER_KERNEL \ No newline at end of file diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_nonzero.cc b/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_nonzero.cc new file mode 100644 index 0000000..5266476 --- /dev/null +++ b/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_nonzero.cc @@ -0,0 +1,159 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/util/work_sharder.h" +#include "absl/container/flat_hash_map.h" + +using namespace tensorflow; + +template +class KPFusedSparseSegmentReduceNonzeroOp : public OpKernel { +public: + explicit KPFusedSparseSegmentReduceNonzeroOp(OpKernelConstruction* context) + : OpKernel(context) { + int combiner_mode; + OP_REQUIRES_OK(context, context->GetAttr("combiner", &combiner_mode)); + OP_REQUIRES(context, combiner_mode == 0 || combiner_mode == 1, + errors::InvalidArgument("combiner must be 0 or 1")); + is_mean_ = (combiner_mode == 1); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_tensor = context->input(0); + const Tensor& indices = context->input(1); + const Tensor& slice_input = context->input(2); + const Tensor& begin = context->input(3); + + OP_REQUIRES(context, input_tensor.dims() == 1, + errors::InvalidArgument("Input data must be a vector")); + OP_REQUIRES(context, slice_input.dims() == 2, errors::InvalidArgument("slice input must be 2-D")); + OP_REQUIRES(context, begin.NumElements() == 2, errors::InvalidArgument("begin must have 2 elements")); + + int64 num_indices = indices.dim_size(0); + int32 col = begin.flat().data()[1]; + + OP_REQUIRES(context, col >= 0 && col < slice_input.dim_size(1), + errors::InvalidArgument("Column index out of range")); + OP_REQUIRES(context, num_indices == slice_input.dim_size(0), + errors::InvalidArgument("indices and slice_input.dim_zie(0) should have same size")); + + auto input_data = input_tensor.flat(); + auto indices_vec = indices.vec(); + auto slice_input_mat = slice_input.matrix(); + + // Calculate max segment_id + std::vector segment_ids(num_indices); + int64 max_seg_id = 0; + for (int64 i = 0; i < num_indices; ++i) { + int64 seg_id = slice_input_mat(i, col); + segment_ids[i] = seg_id; + if (seg_id > max_seg_id) { + max_seg_id = seg_id; + } + } + + const int64 batch_size = max_seg_id + 1; + + Tensor* output_shape = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output(0, TensorShape({1}), &output_shape)); + output_shape->flat()(0) = static_cast(batch_size); + + std::vector> results; + int64 num_nonzero = 0; + absl::flat_hash_map segment_sums; + absl::flat_hash_map segment_counts; + std::vector segment_order; + + if (is_mean_) { + for (int64 i = 0; i < num_indices; ++i) { + const int64 seg_id = segment_ids[i]; + const Tidx data_row = indices_vec(i); + + if (segment_sums.find(seg_id) == segment_sums.end()) { + segment_order.push_back(seg_id); + } + + segment_sums[seg_id] += input_data(data_row); + segment_counts[seg_id] += 1; + } + + for (int64 seg_id : segment_order) { + const int32_t count = segment_counts[seg_id]; + if (count > 0) { + const float inv_count = 1.0f / static_cast(count); + float value = segment_sums[seg_id]; + if (value != 0) { + results.push_back({seg_id, value * inv_count}); + num_nonzero++; + } + } + } + } else { + for (int64 i = 0; i < num_indices; ++i) { + const int64 seg_id = segment_ids[i]; + const Tidx data_row = indices_vec(i); + + if (segment_sums.find(seg_id) == segment_sums.end()) { + segment_order.push_back(seg_id); + } + + segment_sums[seg_id] += input_data(data_row); + } + + for (int64 seg_id : segment_order) { + float value = segment_sums[seg_id]; + if (value != 0) { + results.push_back({seg_id, value}); + num_nonzero++; + } + } + } + Tensor* output_indices = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(1, TensorShape({num_nonzero, 1}), + &output_indices)); + auto output_indices_data = output_indices->flat(); + + Tensor* output_nonzero = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(2, TensorShape({num_nonzero}), + &output_nonzero)); + auto output_nonzero_data = output_nonzero->flat(); + for (int64 i = 0; i < num_nonzero; ++i) { + output_indices_data(i) = static_cast(results[i].first); + output_nonzero_data(i) = results[i].second; + } + + } + + private: + bool is_mean_; +}; + +#define REGISTER_KERNEL(Tidx) \ + REGISTER_KERNEL_BUILDER(Name("KPFusedSparseSegmentReduceNonzero") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Tidx"), \ + KPFusedSparseSegmentReduceNonzeroOp); +REGISTER_KERNEL(int64) +REGISTER_KERNEL(int32) +#undef REGISTER_KERNEL \ No newline at end of file diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_nonzero_test.cc b/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_nonzero_test.cc new file mode 100644 index 0000000..ac4a5c3 --- /dev/null +++ b/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_nonzero_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +class KPFusedSparseSegmentReduceNonzeroOpTest : public OpsTestBase { + protected: + void MakeOp(int combiner_mode) { + TF_ASSERT_OK(NodeDefBuilder("kp_fused_sparse_segment_reduce_nonzero", + "KPFusedSparseSegmentReduceNonzero") + .Input(FakeInput(DT_FLOAT)) // data + .Input(FakeInput(DT_INT32)) // indices + .Input(FakeInput(DT_INT64)) // slice_input + .Input(FakeInput(DT_INT32)) // begin + .Attr("combiner", combiner_mode) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + } +}; + +TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestReduceMean) { + MakeOp(1); + + AddInputFromArray(TensorShape({8}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); + AddInputFromArray(TensorShape({3}), {0, 2, 1}); + AddInputFromArray(TensorShape({3, 4}), + {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {0, 2}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({1})); + test::FillValues(&expected, {4}); + test::ExpectTensorEqual(expected, *GetOutput(0)); // output_shape + + Tensor expected_1(allocator(), DT_INT32, TensorShape({2, 1})); + test::FillValues(&expected_1, {2, 3}); + test::ExpectTensorEqual(expected_1, *GetOutput(1)); // output_indices + + Tensor expected_2(allocator(), DT_FLOAT, TensorShape({2})); + test::FillValues(&expected_2, {2, 2}); + test::ExpectTensorEqual(expected_2, *GetOutput(2)); // output_nonzero +} + +TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestReduceSum) { + MakeOp(0); + + AddInputFromArray(TensorShape({8}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); + AddInputFromArray(TensorShape({3}), {0, 2, 1}); + AddInputFromArray(TensorShape({3, 4}), + {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {0, 2}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({1})); + test::FillValues(&expected, {4}); + test::ExpectTensorEqual(expected, *GetOutput(0)); // output_shape + + Tensor expected_1(allocator(), DT_INT32, TensorShape({2, 1})); + test::FillValues(&expected_1, {2, 3}); + test::ExpectTensorEqual(expected_1, *GetOutput(1)); // output_indices + + Tensor expected_2(allocator(), DT_FLOAT, TensorShape({2})); + test::FillValues(&expected_2, {4, 2}); + test::ExpectTensorEqual(expected_2, *GetOutput(2)); // output_nonzero +} + +TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestInvalidData) { + MakeOp(0); + + AddInputFromArray(TensorShape({4, 2}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); + AddInputFromArray(TensorShape({3}), {0, 2, 1}); + AddInputFromArray(TensorShape({3, 4}), + {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {0, 2}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "Input data must be a vector") != + std::string::npos); +} + +TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestInvalidSliceinput) { + MakeOp(0); + + AddInputFromArray(TensorShape({8}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); + AddInputFromArray(TensorShape({3}), {0, 2, 1}); + AddInputFromArray(TensorShape({3, 4, 1}), + {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {0, 2}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "slice input must be 2-D") != + std::string::npos); +} + +TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestInvalidbegin) { + MakeOp(0); + + AddInputFromArray(TensorShape({8}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); + AddInputFromArray(TensorShape({3}), {0, 2, 1}); + AddInputFromArray(TensorShape({3, 4}), + {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); + AddInputFromArray(TensorShape({3}), {0, 2, 1}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "begin must have 2 elements")); +} + +TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestColsOutOfBounds) { + MakeOp(0); + + AddInputFromArray(TensorShape({8}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); + AddInputFromArray(TensorShape({3}), {0, 2, 1}); + AddInputFromArray(TensorShape({3, 4}), + {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {0, 4}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "Column index out of range")); +} + +TEST_F(KPFusedSparseSegmentReduceNonzeroOpTest, TestIndicesOutOfBounds) { + MakeOp(0); + + AddInputFromArray(TensorShape({8}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); + AddInputFromArray(TensorShape({2}), {0, 2}); + AddInputFromArray(TensorShape({3, 4}), + {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {0, 1}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), + "indices and slice_input.dim_zie(0) should have same size")); +} + +} // namespace +} // namespace tensorflow diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_test.cc b/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_test.cc new file mode 100644 index 0000000..558ab85 --- /dev/null +++ b/annc/tensorflow/kernels/embedding_fused_sparse_segment_reduce_test.cc @@ -0,0 +1,205 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +class KPFusedSparseSegmentReduceOpTest : public OpsTestBase { + protected: + void MakeOp(int combiner_mode) { + TF_ASSERT_OK(NodeDefBuilder("kp_fused_sparse_segment_reduce", + "KPFusedSparseSegmentReduce") + .Input(FakeInput(DT_FLOAT)) // data + .Input(FakeInput(DT_INT32)) // indices + .Input(FakeInput(DT_INT64)) // slice_input + .Input(FakeInput(DT_INT32)) // begin + .Input(FakeInput(DT_INT32)) // begin_1 + .Attr("combiner", combiner_mode) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + } +}; + +TEST_F(KPFusedSparseSegmentReduceOpTest, TestReduceMean) { + MakeOp(1); + + AddInputFromArray(TensorShape({4, 2}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); + AddInputFromArray(TensorShape({3}), {0, 2, 1}); + AddInputFromArray(TensorShape({3, 4}), + {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {0, 2}); + AddInputFromArray(TensorShape({1}), {1}); + + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 2})); + test::FillValues(&expected, + {0.0f, 0.0f, 0.0f, 0.0f, 3.0f, 4.0f, 3.0f, 4.0f}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + + Tensor expected_1(allocator(), DT_INT32, TensorShape({})); + test::FillValues(&expected_1, {2}); + test::ExpectTensorEqual(expected_1, *GetOutput(1)); +} + +TEST_F(KPFusedSparseSegmentReduceOpTest, TestReduceSum) { + MakeOp(0); + + AddInputFromArray(TensorShape({4, 2}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); + AddInputFromArray(TensorShape({3}), {0, 2, 1}); + AddInputFromArray(TensorShape({3, 4}), + {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {0, 2}); + AddInputFromArray(TensorShape({1}), {0}); + + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 2})); + test::FillValues(&expected, + {0.0f, 0.0f, 0.0f, 0.0f, 6.0f, 8.0f, 3.0f, 4.0f}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + + Tensor expected_1(allocator(), DT_INT32, TensorShape({})); + test::FillValues(&expected_1, {4}); + test::ExpectTensorEqual(expected_1, *GetOutput(1)); +} + +TEST_F(KPFusedSparseSegmentReduceOpTest, TestColsOutOfBounds) { + MakeOp(0); + + AddInputFromArray(TensorShape({4, 2}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); + AddInputFromArray(TensorShape({3}), {0, 2, 1}); + AddInputFromArray(TensorShape({3, 4}), + {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {0, 5}); + AddInputFromArray(TensorShape({1}), {0}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "Column index out of range")); +} + +TEST_F(KPFusedSparseSegmentReduceOpTest, Test) { + MakeOp(0); + + AddInputFromArray(TensorShape({4, 2}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); + AddInputFromArray(TensorShape({2}), + {0, 2}); // num_indices != slice_input.dim_size(0) + AddInputFromArray(TensorShape({3, 4}), + {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {0, 2}); + AddInputFromArray(TensorShape({1}), {0}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), + "indices and slice_input.dim_zie(0) should have same size")); +} + +TEST_F(KPFusedSparseSegmentReduceOpTest, TestInvalidData) { + MakeOp(0); + + AddInputFromArray( + TensorShape({4, 2, 1}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); // data.dims() > 2 + AddInputFromArray(TensorShape({3}), {0, 2, 1}); + AddInputFromArray(TensorShape({3, 4}), + {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {0, 2}); + AddInputFromArray(TensorShape({1}), {0}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "input must be 2-D")); +} + +TEST_F(KPFusedSparseSegmentReduceOpTest, TestInvalidSliceinput) { + MakeOp(0); + + AddInputFromArray(TensorShape({4, 2}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); + AddInputFromArray(TensorShape({3}), {0, 2, 1}); + AddInputFromArray( + TensorShape({3, 4, 1}), + {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); // slice_input.dims() > 2 + AddInputFromArray(TensorShape({2}), {0, 2}); + AddInputFromArray(TensorShape({1}), {0}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "slice input must be 2-D")); +} + +TEST_F(KPFusedSparseSegmentReduceOpTest, TestInvalidBegin) { + MakeOp(0); + + AddInputFromArray(TensorShape({4, 2}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); + AddInputFromArray(TensorShape({3}), {0, 2, 1}); + AddInputFromArray(TensorShape({3, 4}), + {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); + AddInputFromArray(TensorShape({3}), + {0, 2, 1}); // begin has 3 elements + AddInputFromArray(TensorShape({1}), {0}); + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "begin must have 2 elements")); +} + +TEST_F(KPFusedSparseSegmentReduceOpTest, TestInvalidBegin1) { + MakeOp(0); + + AddInputFromArray(TensorShape({4, 2}), + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}); + AddInputFromArray(TensorShape({3}), {0, 2, 1}); + AddInputFromArray(TensorShape({3, 4}), + {1, 2, 2, 2, 1, 1, 2, 3, 2, 2, 3, 4}); + AddInputFromArray(TensorShape({2}), {0, 2}); + AddInputFromArray(TensorShape({2}), {0, 1}); // begin_1 has 2 elements + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "begin_1 must have 1 element")); +} + +} // namespace +} // namespace tensorflow diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_select.cc b/annc/tensorflow/kernels/embedding_fused_sparse_select.cc new file mode 100644 index 0000000..306a420 --- /dev/null +++ b/annc/tensorflow/kernels/embedding_fused_sparse_select.cc @@ -0,0 +1,111 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/util/work_sharder.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/platform/logging.h" + +using namespace tensorflow; + +class KPFusedSparseSelect : public OpKernel { +public: + explicit KPFusedSparseSelect(OpKernelConstruction* context) : OpKernel(context) { + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_a = context->input(0); + const Tensor& input_b = context->input(1); + const Tensor& input_c = context->input(2); + const Tensor& greater = context->input(3); + const Tensor& equal1 = context->input(4); + const Tensor& equal2 = context->input(5); + const Tensor& equal3 = context->input(6); + + int32_t equal1_val = equal1.flat()(0); + int32_t equal2_val = equal2.flat()(0); + int32_t equal3_val = equal3.flat()(0); + VLOG(1) << "equal1_val: " << equal1_val; + VLOG(1) << "equal2_val: " << equal2_val; + VLOG(1) << "equal3_val: " << equal3_val; + + int32_t greater_val = greater.flat()(0); + auto a_flat = input_a.flat(); + auto b_flat = input_b.flat(); + auto c_flat = input_c.flat(); + VLOG(1) << "input_a shape: " << input_a.shape().DebugString(); + VLOG(1) << "input_b shape: " << input_b.shape().DebugString(); + VLOG(1) << "input_c shape: " << input_c.shape().DebugString(); + OP_REQUIRES(context, input_a.NumElements() == input_b.NumElements(), + errors::InvalidArgument("Input num elements of a and b must match")); + OP_REQUIRES(context, input_a.NumElements() == input_c.NumElements(), + errors::InvalidArgument("Input num elements of a and c must match")); + auto N = input_a.NumElements(); + + Eigen::TensorMap> a_reshaped_tensor(a_flat.data(), N, 1); + Eigen::TensorMap> b_reshaped_tensor(b_flat.data(), N, 1); + Eigen::TensorMap> c_reshaped_tensor(c_flat.data(), N, 1); + + Tensor* output_x = nullptr; + Tensor* output_y = nullptr; + Tensor* output_w = nullptr; + + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({N, 1}), &output_x)); + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({N, 1}), &output_y)); + OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({N, 2}), &output_w)); + + Eigen::TensorMap> out_x( + output_x->flat().data(), + output_x->dim_size(0), + output_x->dim_size(1) + ); + + Eigen::TensorMap> out_y( + output_y->flat().data(), + output_y->dim_size(0), + output_y->dim_size(1) + ); + + Eigen::TensorMap> out_w( + output_w->flat().data(), + output_w->dim_size(0), + output_w->dim_size(1) + ); + + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); + const int64 cost_per_unit = std::max(N / worker_threads->num_threads, int64(10)); + + auto work = [&](int64 start, int64 end) { + for (int64 i = start; i < end; i++) { + // Greater(bool)+Cast.2406(float) --> 1.0f / 0.0f + float a_greater = (a_reshaped_tensor(i, 0) > greater_val) ? 1.0f : 0.0f; + float res_equal1 = (b_reshaped_tensor(i, 0) == equal1_val) ? 1.0f : a_greater; // Fill.2409-->1.0f + float res_equal2 = (b_reshaped_tensor(i, 0) == equal2_val) ? 1.0f : res_equal1; // Fill.2409-->1.0f + out_x(i, 0) = a_reshaped_tensor(i, 0); // Reshape.2401 + out_y(i, 0) = res_equal2; + out_w(i, 0) = res_equal2; // Mul.2419 硬编码 1.0f * input + out_w(i, 1) = 1.0f; // select_2427被消除,直接使用Fill.2422-->1.0f + } + }; + Shard(worker_threads->num_threads, worker_threads->workers, N, cost_per_unit, work); + } +}; + +REGISTER_KERNEL_BUILDER(Name("KPFusedSparseSelect").Device(DEVICE_CPU), + KPFusedSparseSelect); \ No newline at end of file diff --git a/annc/tensorflow/kernels/embedding_fused_sparse_select_test.cc b/annc/tensorflow/kernels/embedding_fused_sparse_select_test.cc new file mode 100644 index 0000000..f956415 --- /dev/null +++ b/annc/tensorflow/kernels/embedding_fused_sparse_select_test.cc @@ -0,0 +1,182 @@ +/* Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { +using tensorflow::AllocatorAttributes; +using tensorflow::DT_FLOAT; +using tensorflow::DT_INT32; +using tensorflow::DT_INT64; +using tensorflow::int64; +using tensorflow::int32; +using tensorflow::NodeDefBuilder; +using tensorflow::OpsTestBase; +using tensorflow::Status; +using tensorflow::Tensor; +using tensorflow::TensorShape; +using tensorflow::test::ExpectClose; +using tensorflow::test::FillValues; +using tensorflow::test::AsTensor; +using tensorflow::test::ExpectTensorEqual; + +class KPFusedSparseSelectTest : public OpsTestBase { + protected: + void RunValidCase( + const TensorShape& shape, + const std::vector& a_data, + const std::vector& b_data, + const std::vector& c_data, + int32_t greater_val, + int32_t equal1_val, + int32_t equal2_val, + const std::vector& expected_y, + const std::vector& expected_w_col0) { + + TF_EXPECT_OK(NodeDefBuilder("kp_fused_sparse_select", "KPFusedSparseSelect") + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) // greater + .Input(FakeInput(DT_INT32)) // equal1 + .Input(FakeInput(DT_INT32)) // equal2 + .Input(FakeInput(DT_INT32)) // equal3 + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + + AddInputFromArray(shape, a_data); + AddInputFromArray(shape, b_data); + AddInputFromArray(shape, c_data); + AddInputFromArray(TensorShape({}), {greater_val}); // scalar + AddInputFromArray(TensorShape({}), {equal1_val}); + AddInputFromArray(TensorShape({}), {equal2_val}); + AddInputFromArray(TensorShape({}), {0}); // equal3_val (未使用) + + TF_ASSERT_OK(RunOpKernel()); + + const Tensor& out_x = *GetOutput(0); + const Tensor& out_y = *GetOutput(1); + const Tensor& out_w = *GetOutput(2); + + int32 Num_elements = expected_y.size(); + // 验证 output_x: 就是 input_a + std::vector a_data_float(a_data.begin(), a_data.end()); + ExpectTensorEqual(out_x, AsTensor(a_data_float, {Num_elements, 1})); + + // 验证 output_y + ExpectTensorEqual(out_y, AsTensor(expected_y, {Num_elements, 1})); + // 验证 output_w 第一列 + auto w_mat = out_w.matrix(); + for (int i = 0; i < w_mat.dimension(0); ++i) { + EXPECT_FLOAT_EQ(w_mat(i, 0), expected_w_col0[i]); + EXPECT_FLOAT_EQ(w_mat(i, 1), 1.0f); // 第二列必须是 1.0 + } + } + + Status RunOpExpectFailure( + const TensorShape& shape, + const std::vector& a_data, + const std::vector& b_data, + const std::vector& c_data, + int32_t greater_val, + int32_t equal1_val, + int32_t equal2_val) { + + TF_CHECK_OK(NodeDefBuilder("kp_fused_sparse_select", "KPFusedSparseSelect") + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Finalize(node_def())); + TF_CHECK_OK(InitOp()); + TensorShape b_shape({static_cast(b_data.size())}); + TensorShape c_shape({static_cast(c_data.size())}); + AddInputFromArray(shape, a_data); + AddInputFromArray(b_shape, b_data); + AddInputFromArray(c_shape, c_data); + AddInputFromArray(TensorShape({}), {greater_val}); + AddInputFromArray(TensorShape({}), {equal1_val}); + AddInputFromArray(TensorShape({}), {equal2_val}); + AddInputFromArray(TensorShape({}), {0}); + + return RunOpKernel(); + } +}; + +// ==================== 正向测试 ==================== +// 更多正向验证参考 fused_embedding_sparse_select_test.py +TEST_F(KPFusedSparseSelectTest, Valid_NormalInput) { + RunValidCase( + TensorShape({3}), // shape + {5, 3, 8}, // input_a + {1, 2, 1}, // input_b + {9, 8, 7}, // input_c (未使用) + 4, // greater_val + 1, // equal1_val + 3, // equal2_val + {1.0f, 0.0f, 1.0f}, // expected_y + {1.0f, 0.0f, 1.0f} // expected_w_col0 + ); +} + +TEST_F(KPFusedSparseSelectTest, Valid_2DInput) { + RunValidCase( + TensorShape({2, 2}), + {6, 3, 8, 2}, + {2, 1, 3, 4}, + {0, 0, 0, 0}, + 5, + 2, + 3, + {1.0f, 0.0f, 1.0f, 0.0f}, + {1.0f, 0.0f, 1.0f, 0.0f} + ); +} +// ==================== 反向测试 ==================== +// 反例1:input_a 与 input_b 元素数不匹配 +TEST_F(KPFusedSparseSelectTest, Invalid_DimMismatch_AB) { + Status s = RunOpExpectFailure( + TensorShape({3}), // a 有 3 个元素 + {1, 2, 3}, + {4, 5}, // b 有 2 个元素 → 不匹配! + {6, 7, 8}, + 0, 1, 2 + ); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "Input num elements of a and b must match")); +} + +// 反例2:input_a 与 input_c 元素数不匹配 +TEST_F(KPFusedSparseSelectTest, Invalid_DimMismatch_AC) { + Status s = RunOpExpectFailure( + TensorShape({2}), + {1, 2}, + {3, 4}, + {5}, // c 只有 1 个元素 → 不匹配! + 0, 1, 2 + ); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.message(), "Input num elements of a and c must match")); +} + +} \ No newline at end of file diff --git a/annc/tensorflow/ops/embedding_fused_ops.cc b/annc/tensorflow/ops/embedding_fused_ops.cc new file mode 100644 index 0000000..bc0ce24 --- /dev/null +++ b/annc/tensorflow/ops/embedding_fused_ops.cc @@ -0,0 +1,133 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/common_shape_fns.h" + +namespace tensorflow { + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; +using shape_inference::UnchangedShape; + +REGISTER_OP("KPFusedSparseSegmentReduce") + .Input("data: float") + .Input("indices: Tidx") + .Input("slice_input: int64") + .Input("begin: int32") + .Input("begin_1: int32") + .Attr("combiner: int = 1") // 0 for SUM, 1 for MEAN + .Attr("Tidx: {int32, int64} = DT_INT32") + .Output("output: float") + .Output("slice_output: int32") + .SetShapeFn(shape_inference::UnknownShape); + +REGISTER_OP("KPFusedSparseSegmentReduceNonzero") + .Input("data: float") + .Input("indices: Tidx") + .Input("slice_input: int64") + .Input("begin: int32") + .Attr("combiner: int = 1") // 0 for SUM, 1 for MEAN + .Attr("Tidx: {int32, int64} = DT_INT32") + .Output("output_shape: int32") + .Output("output_indices: int32") + .Output("output_nonzero: float") + .SetShapeFn(shape_inference::UnknownShape); + +REGISTER_OP("KPFusedEmbeddingPaddingFast") + .Input("input0: int64") + .Input("input1: float") + .Input("input2: int32") + .Input("input3: int32") + .Input("pack: int32") + .Output("output0: int32") + .Output("output1: int32") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle scalar_shape = c->Scalar(); + c->set_output(0, scalar_shape); + c->set_output(1, scalar_shape); + return OkStatus(); + }); + +REGISTER_OP("KPFusedEmbeddingPadding") + .Input("input0: int64") + .Input("input1: float") + .Input("input2: int32") + .Input("input3: int32") + .Input("pack: int32") + .Output("output0: int32") + .Output("output1: float") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle out; + ShapeHandle scalar_shape = c->Scalar(); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(3, &out)); + c->set_output(0, scalar_shape); + c->set_output(1, out); + return OkStatus(); + }); + +REGISTER_OP("KPFusedSparseSelect") + .Input("input_a: int32") + .Input("input_b: int32") + .Input("input_c: int32") + .Input("greater: int32") + .Input("equal1: int32") + .Input("equal2: int32") + .Input("equal3: int32") + .Output("output_x: float") + .Output("output_y: float") + .Output("output_w: float") + .SetShapeFn(shape_inference::UnknownShape); + +REGISTER_OP("KPFusedSparseReshape") + .Input("slice_input: int64") + .Input("begin: int32") + .Input("new_shape: int64") + .Input("pack_const: int64") + .Output("out_indices: int64") + .Output("out_shape: int64") + .SetShapeFn(shape_inference::UnknownShape); + +REGISTER_OP("KPFusedSparseDynamicStitch") + .Input("x: int64") + .Input("variables: N * float") + .Output("output: float") + .Attr("N: int >= 1") + .SetShapeFn(shape_inference::UnknownShape); + +REGISTER_OP("KPFusedGather") + .Input("data: float") + .Input("slice_input: int64") + .Input("begin: int32") + .Output("out_shape: int64") + .Output("out_indices: int32") + .Output("out_data: float") + .SetShapeFn(shape_inference::UnknownShape); + +REGISTER_OP("KPFusedEmbeddingActionIdGather") + .Input("input0: Tindices1") + .Input("input1: float") + .Input("input2: Tindices2") + .Input("input3: int32") + .Input("pack: int32") + .Attr("Tindices1: {int32, int64} = DT_INT64") + .Attr("Tindices2: {int32, int64} = DT_INT32") + .Output("output0: float") + .SetShapeFn(shape_inference::UnknownShape); +} // namespace tensorflow \ No newline at end of file diff --git a/annc/tensorflow/tf_annc_optimizer.patch b/annc/tensorflow/tf_annc_optimizer.patch new file mode 100644 index 0000000..c92330a --- /dev/null +++ b/annc/tensorflow/tf_annc_optimizer.patch @@ -0,0 +1,221 @@ +diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD +index 538574360ba..b43e0455802 100644 +--- a/tensorflow/core/BUILD ++++ b/tensorflow/core/BUILD +@@ -629,6 +629,7 @@ cc_library( + "//tensorflow/core/kernels/linalg", + "//tensorflow/core/kernels/sparse:kernels", + "//tensorflow/core/kernels/uniform_quant_ops:kernels", ++ "//tensorflow/core/kernels:embedding_fused_ops", + ] + if_mkl([ + "//tensorflow/core/kernels/mkl:mkl_concat_op", + "//tensorflow/core/kernels/mkl:mkl_dequantize_op", +diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD +index ecd559734ea..97a918ead6d 100644 +--- a/tensorflow/core/grappler/optimizers/BUILD ++++ b/tensorflow/core/grappler/optimizers/BUILD +@@ -880,6 +880,18 @@ tf_cuda_cc_test( + ], + ) + ++alias( ++ name = "is_aarch64", ++ actual = "@local_xla//xla/service/cpu:is_aarch64", ++ visibility = ["//visibility:public"], ++) ++ ++alias( ++ name = "aarch64_and_annc_disabled", ++ actual = "@local_xla//xla/service/cpu:aarch64_and_annc_disabled", ++ visibility = ["//visibility:public"], ++) ++ + tf_kernel_library( + name = "remapper", + srcs = ["remapper.cc"], +@@ -904,7 +916,12 @@ tf_kernel_library( + "//tensorflow/core/grappler/utils:symbolic_shapes", + "//tensorflow/core/grappler/utils:topological_sort", + "@com_google_absl//absl/container:flat_hash_set", +- ] + if_mkl(["//tensorflow/core/graph:mkl_graph_util"]), ++ ] + if_mkl(["//tensorflow/core/graph:mkl_graph_util"]) ++ + select({ ++ ":aarch64_and_annc_disabled": [], ++ ":is_aarch64": ["//tensorflow/core/grappler/optimizers/graph_optimizer:annc_graph_opt"], ++ "//conditions:default": [], ++ }) + ) + + tf_cuda_cc_test( +diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc +index 3c37150f496..2fbde836d1d 100644 +--- a/tensorflow/core/grappler/optimizers/remapper.cc ++++ b/tensorflow/core/grappler/optimizers/remapper.cc +@@ -52,6 +52,7 @@ limitations under the License. + #include "third_party/gpus/cudnn/cudnn.h" + #endif // GOOGLE_CUDA + ++#include "tensorflow/core/grappler/optimizers/graph_optimizer/graph_opt.h" + namespace tensorflow { + namespace grappler { + +@@ -4596,6 +4597,7 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index, + Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) { + GrapplerItem mutable_item = item; ++ annc::run_graph_optimization(&mutable_item.graph); + Status status; + RemapperContext ctx(&mutable_item, &status, cpu_layout_conversion_, + xla_auto_clustering_on_); +diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD +index 22617ef8ab5..f89a084f94c 100644 +--- a/tensorflow/core/kernels/BUILD ++++ b/tensorflow/core/kernels/BUILD +@@ -3655,6 +3655,97 @@ tf_kernel_library( + ]) + [":fft_impl"], + ) + ++tf_kernel_library( ++ name = "embedding_fused_action_id_gather_op", ++ srcs = ["embedding_fused_action_id_gather.cc"], ++ deps = MATH_DEPS, ++) ++ ++tf_kernel_library( ++ name = "embedding_fused_gather_op", ++ srcs = ["embedding_fused_gather.cc"], ++ deps = MATH_DEPS, ++) ++ ++tf_kernel_library( ++ name = "embedding_fused_padding_op", ++ srcs = ["embedding_fused_padding.cc"], ++ deps = MATH_DEPS, ++) ++ ++tf_kernel_library( ++ name = "embedding_fused_sparse_dynamic_stitch_op", ++ srcs = ["embedding_fused_sparse_dynamic_stitch.cc"], ++ deps = MATH_DEPS, ++) ++ ++tf_kernel_library( ++ name = "embedding_fused_reshape_op", ++ srcs = ["embedding_fused_sparse_reshape.cc"], ++ deps = MATH_DEPS + [ ++ ":reshape_util", ++ ], ++) ++ ++tf_kernel_library( ++ name = "embedding_fused_sparse_segment_reduce_op", ++ srcs = ["embedding_fused_sparse_segment_reduce.cc"], ++ deps = MATH_DEPS, ++) ++ ++tf_kernel_library( ++ name = "embedding_fused_sparse_segment_reduce_nonzero_op", ++ srcs = ["embedding_fused_sparse_segment_reduce_nonzero.cc"], ++ deps = MATH_DEPS + ["@com_google_absl//absl/container:flat_hash_map"], ++) ++ ++tf_kernel_library( ++ name = "embedding_fused_sparse_select_op", ++ srcs = ["embedding_fused_sparse_select.cc"], ++ deps = MATH_DEPS, ++) ++ ++cc_library( ++ name = "embedding_fused_ops", ++ deps = [ ++ ":embedding_fused_action_id_gather_op", ++ ":embedding_fused_gather_op", ++ ":embedding_fused_padding_op", ++ ":embedding_fused_sparse_dynamic_stitch_op", ++ ":embedding_fused_reshape_op", ++ ":embedding_fused_sparse_segment_reduce_op", ++ ":embedding_fused_sparse_segment_reduce_nonzero_op", ++ ":embedding_fused_sparse_select_op", ++ ], ++) ++ ++tf_cc_test( ++ name = "embedding_fused_ops_test", ++ size = "small", ++ srcs = [ ++ "embedding_fused_action_id_gather_test.cc", ++ "embedding_fused_sparse_dynamic_stitch_test.cc", ++ "embedding_fused_sparse_segment_reduce_test.cc", ++ "embedding_fused_sparse_segment_reduce_nonzero_test.cc", ++ "embedding_fused_padding_test.cc", ++ "embedding_fused_sparse_select_test.cc", ++ "embedding_fused_gather_test.cc", ++ "embedding_fused_sparse_reshape_test.cc", ++ ], ++ deps = [ ++ ":ops_testutil", ++ ":ops_util", ++ ":embedding_fused_ops", ++ "//tensorflow/core:core_cpu", ++ "//tensorflow/core:framework", ++ "//tensorflow/core:lib", ++ "//tensorflow/core:protos_all_cc", ++ "//tensorflow/core:test", ++ "//tensorflow/core:test_main", ++ "//tensorflow/core:testlib", ++ ], ++) ++ + tf_kernel_library( + name = "reduction_ops", + gpu_srcs = ["reduction_gpu_kernels.cu.h"], +diff --git a/tensorflow/core/ops/BUILD b/tensorflow/core/ops/BUILD +index 91d80b6c2b5..b00b2a5d027 100644 +--- a/tensorflow/core/ops/BUILD ++++ b/tensorflow/core/ops/BUILD +@@ -55,6 +55,7 @@ tf_gen_op_libs( + "decode_proto_ops", + "encode_proto_ops", + "experimental_dataset_ops", ++ "embedding_fused_ops", + "filesystem_ops", + "function_ops", + "functional_ops", +@@ -323,6 +324,7 @@ cc_library( + ":training_ops_op_lib", + ":uniform_quant_ops_op_lib", + ":word2vec_ops", ++ ":embedding_fused_ops_op_lib", + ] + select({ + # Non-tpu platforms don't need tpu dependency. + "//tensorflow:chromiumos": [], +diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD +index 6e0ea613435..47c346b4e93 100644 +--- a/third_party/xla/xla/service/cpu/BUILD ++++ b/third_party/xla/xla/service/cpu/BUILD +@@ -198,6 +198,23 @@ cc_library( + ], + ) + ++config_setting( ++ name = "disable_annc", ++ define_values = {"disable_annc": "false"}, ++ visibility = ["//visibility:public"], ++) ++ ++config_setting( ++ name = "is_aarch64", ++ constraint_values = ["@platforms//cpu:aarch64"], ++) ++ ++config_setting( ++ name = "aarch64_and_annc_disabled", ++ constraint_values = ["@platforms//cpu:aarch64"], ++ define_values = {"disable_annc": "true"}, ++) ++ + cc_library( + name = "cpu_compiler_pure", + srcs = ["cpu_compiler.cc"], + + -- 2.33.0