From 67d280a4d478c9857bc6016a4840abb3ef467648 Mon Sep 17 00:00:00 2001 From: Jiaxu Zhu Date: Mon, 18 Sep 2023 13:14:22 -0700 Subject: [PATCH] support `torch.sub/sign/abs` in eager mode Summary: As title Achieve by `vizard.quantization.functional. FloatFunctional` and `vizard.quantization.prepare.prepare_eager` Differential Revision: D48377683 --- d2go/quantization/fx.py | 9 ++++++--- d2go/quantization/modeling.py | 7 ++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/d2go/quantization/fx.py b/d2go/quantization/fx.py index 4f3d4166..493a5dd0 100644 --- a/d2go/quantization/fx.py +++ b/d2go/quantization/fx.py @@ -10,16 +10,19 @@ TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) if TORCH_VERSION > (1, 10): - from torch.ao.quantization.quantize import convert + from torch.ao.quantization.quantize import convert, prepare, prepare_qat from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx else: - from torch.quantization.quantize import convert + from torch.quantization.quantize import convert, prepare, prepare_qat from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx @fb_overwritable() def get_prepare_fx_fn(cfg, is_qat): - return prepare_qat_fx if is_qat else prepare_fx + if cfg.QUANTIZATION.EAGER_MODE: + return prepare_qat if is_qat else prepare + else: + return prepare_qat_fx if is_qat else prepare_fx @fb_overwritable() diff --git a/d2go/quantization/modeling.py b/d2go/quantization/modeling.py index bce13345..3a0e96d2 100644 --- a/d2go/quantization/modeling.py +++ b/d2go/quantization/modeling.py @@ -352,11 +352,8 @@ def prepare_fake_quant_model(cfg, model, is_qat, example_input=None): ) model = default_prepare_for_quant(cfg, model) # NOTE: eager model needs to call prepare after `prepare_for_quant` - if is_qat: - torch.ao.quantization.prepare_qat(model, inplace=True) - else: - torch.ao.quantization.prepare(model, inplace=True) - + prepare_fn = get_prepare_fx_fn(cfg, is_qat) + prepare_fn(model, inplace=True) else: # FX graph mode requires the model to be symbolically traceable, swap common # modules like SyncBN to FX-friendly version.