From 2ef228328f327cb4bddbdbfcdb5476481c8a55b8 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 25 Jan 2024 16:40:21 -0800 Subject: [PATCH] [torch] `torch.dequantize` for per channel tensors to` linalg` (#2769) Support a lowering for dequantization for per channel tensors from `torch` dialect to a linalg decomposition. Tested via a numerical `torch` test. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 53 ++++++++ .../TorchToLinalg/Uncategorized.cpp | 114 ++++++++++++++++-- .../Transforms/AbstractInterpLibrary.cpp | 32 +++++ .../base_lazy_backend/shape_inference.cpp | 20 +++ projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/abstract_interp_lib_gen.py | 17 +++ .../build_tools/torch_ods_gen.py | 2 + .../test_suite/elementwise.py | 27 +++++ 8 files changed, 258 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index a46c79acb941..c09900ce8ecc 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -14465,6 +14465,33 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [ }]; } +def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$scales, + AnyTorchTensorType:$zero_points, + Torch_IntType:$axis, + Torch_IntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenQuantizePerChannelOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenQuantizePerChannelOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenQuantizePerTensorOp : Torch_Op<"aten.quantize_per_tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -14560,6 +14587,32 @@ def Torch_AtenIntReprOp : Torch_Op<"aten.int_repr", [ }]; } +def Torch_Aten_MakePerChannelQuantizedTensorOp : Torch_Op<"aten._make_per_channel_quantized_tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$scale, + AnyTorchTensorType:$zero_point, + Torch_IntType:$axis + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_MakePerChannelQuantizedTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void Aten_MakePerChannelQuantizedTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_Aten_MakePerTensorQuantizedTensorOp : Torch_Op<"aten._make_per_tensor_quantized_tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 593afeb1aa84..9ff4c63741b2 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1344,7 +1344,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( auto makeQTensor = qtensor.getDefiningOp(); if (!makeQTensor) { - op->emitError( + op->emitWarning( "unimplemented: dequantizing tensor of unknown scale / zero-point"); return nullptr; } @@ -2221,16 +2221,109 @@ class ConvertAtenIntReprOp : public OpConversionPattern { } // namespace namespace { -class ConvertMakePerTensorQuantizedTensorOp - : public OpConversionPattern { +class ConvertDequantizePerChannel + : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(Aten_MakePerTensorQuantizedTensorOp op, OpAdaptor adaptor, + matchAndRewrite(AtenDequantizeSelfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto loc = op.getLoc(); + auto qoperand = op.getOperand(); + auto make = qoperand.getDefiningOp(); + if (!make) { + llvm::errs() << "Did not find make per channel\n"; + return rewriter.notifyMatchFailure(op, "did not find per channel qint"); + } + + auto converter = getTypeConverter(); + auto operand = make.getOperand(0); + auto scale = make.getScale(); + auto zeropoint = make.getZeroPoint(); + auto axis = make.getAxis(); + + IntegerAttr axisAttr; + if (!matchPattern(axis, m_Constant(&axisAttr))) { + return failure(); + } + + auto operandDTy = operand.getType().cast().getDtype(); + auto zeropointDTy = zeropoint.getType().cast().getDtype(); + operand = converter->materializeTargetConversion( + rewriter, loc, converter->convertType(operand.getType()), operand); + scale = converter->materializeTargetConversion( + rewriter, loc, converter->convertType(scale.getType()), scale); + zeropoint = converter->materializeTargetConversion( + rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint); + + auto resultType = converter->convertType(op->getResult(0).getType()) + .cast(); + + llvm::SmallVector dynSizes; + for (auto [index, dim] : llvm::enumerate(resultType.getShape())) { + if (ShapedType::isDynamic(dim)) { + dynSizes.push_back(rewriter.create(loc, operand, index)); + } + } + + llvm::SmallVector iterators( + resultType.getRank(), utils::IteratorType::parallel); + llvm::SmallVector maps( + 4, {rewriter.getMultiDimIdentityMap(resultType.getRank())}); + auto broadcastMap = AffineMap::get( + resultType.getRank(), /*symbolCount=*/0, + {rewriter.getAffineDimExpr(axisAttr.getInt())}, rewriter.getContext()); + maps[1] = broadcastMap; + maps[2] = broadcastMap; + + auto empty = + rewriter.create(op.getLoc(), resultType, dynSizes); + auto linalgOp = rewriter.create( + loc, resultType, ValueRange{operand, scale, zeropoint}, + ValueRange{empty}, maps, iterators, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value operand = args[0]; + Value scale = args[1]; + Value zeropoint = args[2]; + if (operandDTy.isUnsignedInteger(8)) { + operand = b.create(loc, b.getI32Type(), operand); + } else if (operandDTy.isSignedInteger(8)) { + operand = b.create(loc, b.getI32Type(), operand); + } + + if (zeropointDTy.isUnsignedInteger(8)) { + zeropoint = + b.create(loc, b.getI32Type(), zeropoint); + } else if (zeropointDTy.isSignedInteger(8)) { + zeropoint = + b.create(loc, b.getI32Type(), zeropoint); + } + + Value sub = rewriter.create(loc, operand, zeropoint); + Value fp = + rewriter.create(loc, args[3].getType(), sub); + Value mul = rewriter.create(loc, fp, scale); + b.create(loc, mul); + }); + rewriter.replaceOp(op, linalgOp.getResults()); + return success(); + } +}; +} // namespace + +namespace { + +template +class ConvertCastEquivalentOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpTy::Adaptor; + + LogicalResult + matchAndRewrite(OpTy op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = this->getTypeConverter(); + RankedTensorType resultType = cast( + converter->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf()); return success(); @@ -2283,6 +2376,11 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); - patterns.add(typeConverter, context); + patterns.add>( + typeConverter, context); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 590bea8d7176..bb9717303e6b 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6549,6 +6549,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_channel\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.int, %arg4: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.quantize_per_tensor\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6565,6 +6569,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten._make_per_channel_quantized_tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -12632,6 +12640,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_channel\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n" +" return %arg4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_tensor\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " return %arg3 : !torch.int\n" " }\n" @@ -12664,6 +12675,27 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._make_per_channel_quantized_tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.int) -> !torch.int {\n" +" %int14 = torch.constant.int 14\n" +" %int12 = torch.constant.int 12\n" +" %int1 = torch.constant.int 1\n" +" %int13 = torch.constant.int 13\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int13 : !torch.int\n" +" } else {\n" +" %3 = torch.aten.eq.int %0#1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int12 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int14 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.int) -> !torch.int {\n" " %int14 = torch.constant.int 14\n" " %int12 = torch.constant.int 12\n" diff --git a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp index ff43359ebe80..325e89e14d5e 100644 --- a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp +++ b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp @@ -39,6 +39,20 @@ std::vector compute_shape_div(const at::Tensor& self, return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector +compute_shape__make_per_channel_quantized_tensor(const at::Tensor &self, + const at::Tensor &scale, + const at::Tensor &zero_point, + int64_t axis) { + if (self.scalar_type() == at::kChar) + return {Shape(at::kQInt8, self.sizes().vec())}; + if (self.scalar_type() == at::kByte) + return {Shape(at::kQUInt8, self.sizes().vec())}; + if (self.scalar_type() == at::kInt) + return {Shape(at::kQInt32, self.sizes().vec())}; + assert(false); +} + std::vector compute_shape__make_per_tensor_quantized_tensor( const at::Tensor &self, double scale, int64_t zero_point) { if (self.scalar_type() == at::kChar) @@ -75,6 +89,12 @@ std::vector compute_shape_isinf(const at::Tensor& self) { return {Shape(at::kBool, self.sizes().vec())}; } +std::vector compute_shape_quantize_per_channel( + const at::Tensor &self, const at::Tensor &scales, + const at::Tensor &zero_points, int64_t axis, at::ScalarType dtype) { + return {Shape(dtype, self.sizes().vec())}; +} + std::vector compute_shape_max_pool3d_with_indices( const at::Tensor& self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index f0261b16f6af..f43c325069ce 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -313,6 +313,7 @@ "GroupNormNoWeightAndBiasModule_basic", # Dynamo does not support tracing quantized tensors + "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseQuantizePerTensorModule_basic", "AtenMmQuint8_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 28e87cc60990..91e98d99c9ff 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -251,6 +251,9 @@ def aten〇clamp_max〡shape(self: List[int], max: float) -> List[int]: def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇quantize_per_channel〡shape(self: List[int], scales: List[int], zero_points: List[int], axis: int, dtype: int) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇quantize_per_tensor〡shape(self: List[int], scale: float, zero_point: int, dtype: int) -> List[int]: return upstream_shape_functions.unary(self) @@ -263,6 +266,9 @@ def aten〇dequantize〇tensor〡shape(qtensor: List[int]) -> List[int]: def aten〇int_repr〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇_make_per_channel_quantized_tensor〡shape(self: List[int], scale: List[int], zero_point: List[int], axis: int) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇_make_per_tensor_quantized_tensor〡shape(self: List[int], scale: float, zero_point: int) -> List[int]: return upstream_shape_functions.unary(self) @@ -4280,6 +4286,9 @@ def prims〇collapse〡dtype(a_rank_dtype: Tuple[int, int], start: int, end: int return a_dtype +def aten〇quantize_per_channel〡dtype(self_rank_dtype: Tuple[int, int], scales_rank_dtype: Tuple[int, int], zero_points_rank_dtype: Tuple[int, int], axis: int, dtype: int) -> int: + return dtype + def aten〇quantize_per_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int, dtype: int) -> int: return dtype @@ -4297,6 +4306,14 @@ def aten〇int_repr〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.int8 return torch.int32 +def aten〇_make_per_channel_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], axis: int) -> int: + self_rank, self_dtype = self_rank_dtype + if (self_dtype == torch.uint8): + return torch.quint8 + if (self_dtype == torch.int8): + return torch.qint8 + return torch.qint32 + def aten〇_make_per_tensor_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int) -> int: self_rank, self_dtype = self_rank_dtype if (self_dtype == torch.uint8): diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index ae4c608c6de7..3b930c20e79d 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -820,10 +820,12 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)") # quantized ops + emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)") emit("aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)") emit("aten::dequantize.self : (Tensor) -> (Tensor)") emit("aten::dequantize.tensor : (Tensor) -> (Tensor)") emit("aten::int_repr : (Tensor) -> (Tensor)") + emit("aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)") emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)") # ========================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index a422772fc298..26eac617a4a6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -4328,6 +4328,33 @@ def ElementwiseDequantizePerTensorModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseDequantizePerChannelModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.int8, True), + ([4], torch.int8, True), + ([4], torch.float, True), + ]) + def forward(self, x, zeropoint, scale): + qx = torch._make_per_channel_quantized_tensor(x, scale, zeropoint, axis=1) + qx = torch.dequantize(qx) + return qx + +@register_test_case(module_factory=lambda: ElementwiseDequantizePerChannelModule()) +def ElementwiseDequantizePerChannelModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, low=-128, high=127).to(torch.int8), + tu.randint(4, low=-128, high=127).to(torch.int8), + tu.rand(4) + ) + +# ============================================================================== + class GluStaticModule(torch.nn.Module): def __init__(self): super().__init__()