diff --git a/tensorflow/lite/micro/kernels/activations.cc b/tensorflow/lite/micro/kernels/activations.cc index 6772e4765af..f9f0f6afa5e 100644 --- a/tensorflow/lite/micro/kernels/activations.cc +++ b/tensorflow/lite/micro/kernels/activations.cc @@ -54,14 +54,23 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } case kTfLiteInt8: { - tflite::ReluQuantized(data, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorData(output)); + tflite::ReluQuantized( + data, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorData(output)); + return kTfLiteOk; + } + case kTfLiteInt16: { + tflite::ReluQuantized( + data, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorData(output)); return kTfLiteOk; } default: { - MicroPrintf("Only float32 is supported currently, got %s", + MicroPrintf("Only float32/int8/int16 is supported currently, got %s", TfLiteTypeGetName(input->type)); return kTfLiteError; } @@ -109,7 +118,7 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } default: { - MicroPrintf("Only float32 is supported currently, got %s", + MicroPrintf("Only float32/int8/int16 is supported currently, got %s", TfLiteTypeGetName(input->type)); return kTfLiteError; } diff --git a/tensorflow/lite/micro/kernels/activations.h b/tensorflow/lite/micro/kernels/activations.h index eaf93c2df26..b0b4fd6e8b8 100644 --- a/tensorflow/lite/micro/kernels/activations.h +++ b/tensorflow/lite/micro/kernels/activations.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { @@ -36,9 +37,23 @@ struct Relu6OpData { int32_t zero; }; +template void ReluQuantized(const ReluOpData& data, const RuntimeShape& input_shape, - const RuntimeShape& output_shape, const int8_t* input_data, - int8_t* output_data); + const RuntimeShape& output_shape, const T* input_data, + T* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); + for (int i = 0; i < flat_size; ++i) { + const int32_t val = static_cast(input_data[i]); + int32_t clamped = + data.params.output_offset + + MultiplyByQuantizedMultiplier(val - data.params.input_offset, + data.params.output_multiplier, + data.params.output_shift); + clamped = std::max(data.params.quantized_activation_min, clamped); + clamped = std::min(data.params.quantized_activation_max, clamped); + output_data[i] = static_cast(clamped); + } +} template void CalculateReluOpData(const TfLiteTensor* input, TfLiteTensor* output, diff --git a/tensorflow/lite/micro/kernels/activations_common.cc b/tensorflow/lite/micro/kernels/activations_common.cc index 90062447791..05920e1d0f9 100644 --- a/tensorflow/lite/micro/kernels/activations_common.cc +++ b/tensorflow/lite/micro/kernels/activations_common.cc @@ -33,23 +33,6 @@ namespace tflite { const int kActivationsInputTensor = 0; const int kActivationsOutputTensor = 0; -void ReluQuantized(const ReluOpData& data, const RuntimeShape& input_shape, - const RuntimeShape& output_shape, const int8_t* input_data, - int8_t* output_data) { - const int flat_size = MatchingFlatSize(input_shape, output_shape); - for (int i = 0; i < flat_size; ++i) { - const int32_t val = static_cast(input_data[i]); - int32_t clamped = - data.params.output_offset + - MultiplyByQuantizedMultiplier(val - data.params.input_offset, - data.params.output_multiplier, - data.params.output_shift); - clamped = std::max(data.params.quantized_activation_min, clamped); - clamped = std::min(data.params.quantized_activation_max, clamped); - output_data[i] = static_cast(clamped); - } -} - template void CalculateReluOpData(const TfLiteTensor* input, TfLiteTensor* output, ReluOpData* data) { @@ -116,6 +99,10 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) { if (input->type == kTfLiteInt8) { CalculateReluOpData(input, output, data); + } else if (input->type == kTfLiteInt16) { + TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + CalculateReluOpData(input, output, data); } micro_context->DeallocateTempTfLiteTensor(input); diff --git a/tensorflow/lite/micro/kernels/activations_test.cc b/tensorflow/lite/micro/kernels/activations_test.cc index 479668a164b..bdcfabf6d10 100644 --- a/tensorflow/lite/micro/kernels/activations_test.cc +++ b/tensorflow/lite/micro/kernels/activations_test.cc @@ -129,6 +129,46 @@ void TestReluInt8(int* input_dims_data, const float* input_data, } } +void TestReluInt16(int* input_dims_data, const float* input_data, + int16_t* input_data_quantized, const float input_scale, + const int input_zero_point, const float* golden, + int16_t* golden_quantized, int* output_dims_data, + const float output_scale, const int output_zero_point, + int16_t* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + const int output_elements_count = ElementCount(*output_dims); + constexpr int inputs_size = 1; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_data_quantized, input_dims, + input_scale, input_zero_point), + CreateQuantizedTensor(output_data, output_dims, output_scale, + output_zero_point), + }; + + int inputs_array_data[] = {1, 0}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 1}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + const TFLMRegistration registration = Register_RELU(); + micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, + outputs_array, + /*builtin_data=*/nullptr); + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); + + Quantize(golden, golden_quantized, output_elements_count, output_scale, + output_zero_point); + + for (int i = 0; i < output_elements_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(golden_quantized[i], output_data[i]); + } +} + void TestRelu6Int8(int* input_dims_data, const float* input_data, int8_t* input_data_quantized, const float input_scale, const int input_zero_point, const float* golden, @@ -265,6 +305,29 @@ TF_LITE_MICRO_TEST(SimpleReluTestInt8) { output_zero_point, output_data); } +TF_LITE_MICRO_TEST(SimpleReluTestInt16) { + const int elements_count = 10; + + int input_shape[] = {2, 2, 5}; + const float input_data[] = {256, 257, 258, 259, 260, + -256, -257, -258, -259, -260}; + int16_t input_quantized[elements_count]; + int output_shape[] = {2, 2, 5}; + const float golden[] = {256, 257, 258, 259, 260, 0, 0, 0, 0, 0}; + int16_t golden_quantized[elements_count]; + int16_t output_data[elements_count]; + + const float input_scale = 0.5f; + const int input_zero_point = 0; + const float output_scale = 0.5f; + const int output_zero_point = 0; + + tflite::testing::TestReluInt16(input_shape, input_data, input_quantized, + input_scale, input_zero_point, golden, + golden_quantized, output_shape, output_scale, + output_zero_point, output_data); +} + TF_LITE_MICRO_TEST(SimpleRelu6TestInt8) { const int elements_count = 10;