Skip to content

Commit

Permalink
Add FakeQuantize op support in TS transformations (#17243)
Browse files Browse the repository at this point in the history
* Add FQ op support in TS transformations

* codestyle

* Mark FQ as supported op in the TS ops list
  • Loading branch information
itikhono authored Apr 27, 2023
1 parent 22bb3af commit 40bf400
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "itt.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/fake_quantize.hpp"
#include "openvino/op/prelu.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/util/op_types.hpp"
Expand All @@ -25,7 +26,8 @@ TSBinaryForward::TSBinaryForward() {
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
op::util::BinaryElementwiseComparison,
op::util::BinaryElementwiseLogical,
ov::op::v0::PRelu>([](const Output<Node>& output) -> bool {
ov::op::v0::PRelu,
ov::op::v0::FakeQuantize>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && IfNodeHasTransposeInputs(output);
});

Expand Down Expand Up @@ -62,7 +64,8 @@ TSBinaryBackward::TSBinaryBackward() {
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
op::util::BinaryElementwiseComparison,
op::util::BinaryElementwiseLogical,
ov::op::v0::PRelu>([](const Output<Node>& output) -> bool {
ov::op::v0::PRelu,
ov::op::v0::FakeQuantize>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ bool CanPropagateForwardThrough(Node* node) {
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Reshape, node)
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Unsqueeze, node)
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Transpose, node)
CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::FakeQuantize, node)

return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,22 @@ FactoryPtr CreateReshapeFactory(const std::string& type_name) {
return std::make_shared<ReshapeFactory>(type_name);
}

class FakeQuantizeFactory : public IFactory {
public:
explicit FakeQuantizeFactory(const std::string& type_name) : IFactory(type_name) {}
NodePtr create(const OutputVector& parent_nodes) const override {
return std::make_shared<FakeQuantize>(parent_nodes[0],
parent_nodes[1],
parent_nodes[2],
parent_nodes[3],
parent_nodes[4],
128);
}
};

FactoryPtr CreateFakeQuantizeFactory(const std::string& type_name) {
return std::make_shared<FakeQuantizeFactory>(type_name);
}
// ----------------------------------------------------------------------------

#undef CREATE_UNARY_FACTORY
Expand Down Expand Up @@ -255,6 +271,9 @@ FactoryPtr CreateReshapeFactory(const std::string& type_name) {
#undef CREATE_RESHAPE_FACTORY
#define CREATE_RESHAPE_FACTORY(type_name) CreateReshapeFactory(#type_name)

#undef CREATE_FQ_FACTORY
#define CREATE_FQ_FACTORY(type_name) common::CreateFakeQuantizeFactory(#type_name)

// ----------------------------------------------------------------------------

vector<FactoryPtr> unary_factories = {
Expand Down Expand Up @@ -393,6 +412,42 @@ auto test_forward_binary = []() {

INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryForward, TSTestFixture, test_forward_binary());

auto test_forward_fq = []() {
TestCase test_case;

// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSBinaryForward);
test_case.num_main_ops = {1, 10};
test_case.inputs_to_main = {
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {55, 55, 96, 1}),
parameter(element::f32, {1}),
parameter(element::f32, {55, 1, 1, 1}),
parameter(element::f32, {55, 55, 1, 1}),
};

// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_FQ_FACTORY(FakeQuantize)};
test_case.model.model_template = create_model;

// Reference model description:
auto set_unsqueeze_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec = out_vec;
auto indices = make_shared<Constant>(element::i64, Shape{3}, std::vector<int64_t>{0, 1, 2});
new_out_vec[2] = make_shared<Unsqueeze>(out_vec[2], indices);
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_unsqueeze_for, set_transpose_for}, {{2}, {1, 2, 3, 4}}};
test_case.model_ref.main_op = {CREATE_FQ_FACTORY(FakeQuantize)};
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = create_model;

return wrapper(test_case);
};

INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonFQForward, TSTestFixture, test_forward_fq());

auto test_forward_concat = []() {
TestCase test_case;

Expand Down Expand Up @@ -867,6 +922,42 @@ auto test_backward_binary = []() {

INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryBackward, TSTestFixture, test_backward_binary());

auto test_backward_fq = []() {
TestCase test_case;

// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSBinaryBackward);
test_case.num_main_ops = {1, 10};
test_case.inputs_to_main = {
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {1}),
parameter(element::f32, {1, 96, 55, 1}),
parameter(element::f32, {1, 96, 1, 1}),
};

// Test model description:
test_case.model.main_op = {CREATE_FQ_FACTORY(FakeQuantize)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = create_model;

auto set_unsqueeze_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec = out_vec;
auto indices = make_shared<Constant>(element::i64, Shape{3}, std::vector<int64_t>{0, 1, 2});
new_out_vec[2] = make_shared<Unsqueeze>(out_vec[2], indices);
return new_out_vec;
};

// Reference model description:
test_case.model_ref.preprocess_inputs_to_main = {{set_unsqueeze_for, set_transpose_for}, {{2}, {0, 1, 2, 3, 4}}};
test_case.model_ref.main_op = {CREATE_FQ_FACTORY(FakeQuantize)};
test_case.model_ref.model_template = create_model;

return wrapper(test_case);
};

INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonFQBackward, TSTestFixture, test_backward_fq());

auto test_backward_concat = []() {
TestCase test_case;

Expand Down

0 comments on commit 40bf400

Please sign in to comment.