Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize QLinearSoftmax Transpose #22849

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1654,14 +1654,14 @@ static bool HandleSplit(HandlerArgs& args) {

constexpr HandlerInfo split_handler = {&FirstInput, &HandleSplit};

static bool HandleConcat(HandlerArgs& args) {
bool HandleConcat(HandlerArgs& args) {
return HandleSimpleNodeWithAxis(args);
}

constexpr HandlerInfo concat_handler = {&AllInputs, &HandleConcat};

// Handles Softmax, Hardmax, and LogSoftmax
static bool HandleSoftHardMax(HandlerArgs& args) {
bool HandleSoftHardMax(HandlerArgs& args) {
if (args.ctx.opset >= 13) {
return HandleSimpleNodeWithAxis(args, /*default_axis*/ -1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ bool HandleSimpleNodeBroadcast(HandlerArgs& args);
// Transposes all inputs and all outputs. Updates axis attribute.
bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional<int64_t> default_axis = std::nullopt);

bool HandleConcat(HandlerArgs& args);
bool HandleSoftHardMax(HandlerArgs& args);

// base handlers that are used by extended handlers. add from transpose_optimizer.cc as needed.
bool HandleReduceOps(HandlerArgs& args);
bool HandleResize([[maybe_unused]] HandlerArgs& args);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ static bool EPAwareHandleResize(HandlerArgs& args) {

constexpr HandlerInfo ep_aware_resize_handler = {&FirstInput, &EPAwareHandleResize};

static bool HandleQLinearConcat(HandlerArgs& args) {
return HandleSimpleNodeWithAxis(args);
}

std::vector<size_t> QLinearConcatInputs(OptimizerCtx& ctx, api::NodeRef& node) {
(void)ctx;
std::vector<size_t> indices;
Expand All @@ -48,19 +44,15 @@ std::vector<size_t> QLinearConcatInputs(OptimizerCtx& ctx, api::NodeRef& node) {
return indices;
}

constexpr HandlerInfo q_linear_concat_handler = {&QLinearConcatInputs, &HandleQLinearConcat};

static bool HandleQLinearBinaryOp(HandlerArgs& args) {
return HandleSimpleNodeBroadcast(args);
}
constexpr HandlerInfo q_linear_concat_handler = {&QLinearConcatInputs, &HandleConcat};

std::vector<size_t> QLinearBinaryOpInputs(OptimizerCtx&, api::NodeRef&) {
// Inputs are: [A, A_scale, A_zero_point, B, B_scale, B_zero_point, C_scale, C_zero_point],
// we want [A, B].
return {0, 3};
}

constexpr HandlerInfo q_linear_binary_op_handler = {&QLinearBinaryOpInputs, &HandleQLinearBinaryOp};
constexpr HandlerInfo q_linear_binary_op_handler = {&QLinearBinaryOpInputs, &HandleSimpleNodeBroadcast};

static bool HandleQLinearPoolOp(HandlerArgs& args) {
// Swap between channel first/last variants. Only works for applicable values of perm.
Expand Down Expand Up @@ -129,6 +121,7 @@ constexpr HandlerInfo max_pool_op_handler = {&FirstInput, &HandleMaxPool};

constexpr HandlerInfo node_1_inp_handler = {&FirstInput, &HandleSimpleNode};
constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOps};
constexpr HandlerInfo soft_hard_max_handler = {&FirstInput, &HandleSoftHardMax};
constexpr HandlerInfo contrib_quantize_dequantize_linear_handler = {&FirstInput,
&HandleContribQuantizeDequantizeLinear};

Expand All @@ -148,6 +141,7 @@ const HandlerMap& OrtExtendedHandlers() {
{"com.microsoft.QLinearMul", q_linear_binary_op_handler},
{"com.microsoft.QLinearReduceMean", reduce_op_handler},
{"com.microsoft.QLinearSigmoid", node_1_inp_handler},
{"com.microsoft.QLinearSoftmax", soft_hard_max_handler},
};

return map;
Expand Down
41 changes: 41 additions & 0 deletions onnxruntime/test/optimizer/transpose_optimizer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "test/optimizer/graph_transform_test_builder.h"
#include "test/providers/internal_testing/internal_testing_execution_provider.h"
#include "test/util/include/asserts.h"
#include "test/util/include/default_providers.h"
#include "test/util/include/inference_session_wrapper.h"
#include "test/util/include/test_utils.h"

Expand Down Expand Up @@ -3800,6 +3801,46 @@ TEST(TransposeOptimizerTests, TestCast) {
/*opset_version*/ {15, 18});
}

TEST(TransposeOptimizerTests, TestQLinearSoftmax) {
auto build_test_case_1 = [&](ModelTestBuilder& builder) {
auto* input0_arg = MakeInput<uint8_t>(builder, std::nullopt, {1, 384, 384, 21}, 0, 255);
auto* transpose_1_out_0 = builder.MakeIntermediate();
auto* input_x_scale = builder.MakeScalarInitializer<float>(0.5086354613304138);
auto* input_x_zero_point = builder.MakeScalarInitializer<uint8_t>(74);
auto* input_y_scale = builder.MakeScalarInitializer<float>(0.003921568859368563);
auto* input_y_zero_point = builder.MakeScalarInitializer<uint8_t>(0);
auto* qlinearsoftmax_1_out_0 = builder.MakeIntermediate();
auto* transpose_2_out_0 = builder.MakeOutput();

auto& transpose_1 = builder.AddNode("Transpose", {input0_arg}, {transpose_1_out_0});
transpose_1.AddAttribute("perm", std::vector<int64_t>{0, 3, 1, 2});
auto& qlinearsoftmax_1 = builder.AddNode("QLinearSoftmax",
{transpose_1_out_0, input_x_scale, input_x_zero_point, input_y_scale, input_y_zero_point},
{qlinearsoftmax_1_out_0}, kMSDomain);
qlinearsoftmax_1.AddAttribute("axis", static_cast<int64_t>(1));
yihonglyu marked this conversation as resolved.
Show resolved Hide resolved
qlinearsoftmax_1.AddAttribute("opset", static_cast<int64_t>(13));
auto& transpose_2 = builder.AddNode("Transpose", {qlinearsoftmax_1_out_0}, {transpose_2_out_0});
transpose_2.AddAttribute("perm", std::vector<int64_t>{0, 2, 3, 1});
};

auto check_optimized_graph_1 = [&](InferenceSessionWrapper& session) {
int transpose_cost = EstimateTransposeCost(session.GetGraph());
EXPECT_EQ(transpose_cost, 0);
};

TransformerTester(build_test_case_1,
check_optimized_graph_1,
TransformerLevel::Level2,
TransformerLevel::Level3,
/*opset_version*/ 13,
/*per_sample_tolerance*/ 0.0,
/*relative_per_sample_tolerance*/ 0.0,
/*transformer*/ nullptr,
/*add_session_options*/ {},
/*disabled_optimizers*/ {},
/*ep*/ DefaultCpuExecutionProvider());
}

TEST(TransposeOptimizerTests, TestBroadcastReusedInputs) {
auto build_test_case_1 = [&](ModelTestBuilder& builder) {
auto* input0_arg = MakeInput<float>(builder, {{-1, -1, 3, 4}}, {1, 2, 3, 4}, 0.0, 1.0);
Expand Down
Loading