Skip to content

Commit beab4c1

Browse files
committed
Fix wrong scale eps applied
1 parent 7963f9c commit beab4c1

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

Diff for: test/quantization/test_quant_primitives.py

+60
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,66 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
957957
torch.testing.assert_close(expected_quantized, quantized)
958958
torch.testing.assert_close(expected_dequantized, dequantized)
959959

960+
@parameterized.expand(
961+
[
962+
torch.float64,
963+
torch.float32,
964+
torch.bfloat16,
965+
torch.float16,
966+
]
967+
)
968+
def test_choose_qparams_affine_for_inf_scale_reciprocal(self, hp_dtype):
969+
# Fixed by #1770, the test will fail for all the variants
970+
# before that fix, and will pass afterwards.
971+
#
972+
# The scale value must be forcefully clamped, within
973+
# _choose_qparams_affine() function, (that
974+
# choose_qparams_affine() and others call into) to a large
975+
# enough number so that its reciprocal does not become Inf.
976+
# Otherwise during the quantization, by multiplying with scale
977+
# reciprocal, all the values will be quantized to Inf value,
978+
# except from zero value that would produce NaN (0*Inf) as
979+
# quantized value.
980+
#
981+
# The minimal normalized value for given floating point data
982+
# type is given by torch.finfo(hp_dtype).tiny - let's call
983+
# this value "tiny". It could be seen by checking, that for
984+
# all of torch.float64, torch.float32, torch.float16 and
985+
# torch.floatb16, denormalized number that is equal to tiny/4
986+
# will produce Inf as its reciprocal.
987+
#
988+
# Thus, to reproduce the problem, one would create a tensor
989+
# with such values that their absolute maximum, after being
990+
# divided with the range of quantized data (that is 57344 for
991+
# torch.float8_e5m2), would produce scale smaller than tiny/4.
992+
# Also, eps parameter should be set to value no greater than
993+
# tiny/4, as scale is clamped from below to that value. With
994+
# such inpujts, choose_qparams_affine() will produce Inf as
995+
# scale value.
996+
#
997+
# Note that this may seem as contrieved reproduces. However,
998+
# there are cases with existing code that would pass
999+
# torch.finfo(torch.float32).eps as eps value, no matters of
1000+
# scale_dtype. The float16 has rather small range, so this
1001+
# value is well bellow torch.finfo(torch.float32).eps, and for
1002+
# such eps value, the code bellow would produce Inf scale even
1003+
# for float16 tensor that has 0.5 as its maximum value.
1004+
float8_dtype = torch.float8_e5m2
1005+
tiny = torch.finfo(hp_dtype).tiny
1006+
x = torch.tensor([[0, 100 * tiny]], dtype=hp_dtype)
1007+
scale, _ = choose_qparams_affine(
1008+
input=x,
1009+
mapping_type=MappingType.SYMMETRIC,
1010+
block_size=[1, 2],
1011+
target_dtype=float8_dtype,
1012+
eps=tiny / 4,
1013+
scale_dtype=hp_dtype,
1014+
preserve_zero=True,
1015+
zero_point_domain=ZeroPointDomain.NONE,
1016+
)
1017+
scale_reciprocal = scale.reciprocal()
1018+
assert not torch.any(torch.isinf(scale_reciprocal)).item()
1019+
9601020

9611021
if __name__ == "__main__":
9621022
unittest.main()

Diff for: torchao/quantization/quant_primitives.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,7 @@ def _choose_qparams_affine(
856856
3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero`
857857
and `zero_point_domain`
858858
"""
859+
859860
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
860861
assert mapping_type in [
861862
MappingType.SYMMETRIC.name,
@@ -907,6 +908,16 @@ def _choose_qparams_affine(
907908
min_val_neg = min_val
908909
max_val_pos = max_val
909910

911+
# Prevent reciprocal of scale, calculated below, to become Inf.
912+
if torch.is_floating_point(max_val):
913+
# In this case, scale will be calculated below in
914+
# max_val.dtype.
915+
eps = max(eps, torch.finfo(max_val.dtype).tiny)
916+
else:
917+
# In this case, scale will be calculated below in
918+
# torch.float32 dtype.
919+
eps = max(eps, torch.finfo(torch.float32).tiny)
920+
910921
if (
911922
mapping_type == MappingType.SYMMETRIC.name
912923
or mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name
@@ -966,7 +977,13 @@ def _choose_qparams_affine(
966977

967978
if zero_point is not None:
968979
zero_point = zero_point.to(dtype=zero_point_dtype)
969-
return scale.to(dtype=scale_dtype), zero_point
980+
scale = scale.to(dtype=scale_dtype)
981+
if torch.is_floating_point(scale):
982+
# Again, prevent scale reciprocal to become Inf.
983+
scale = scale.clamp(
984+
min=torch.finfo(scale_dtype).tiny, max=torch.finfo(scale_dtype).max
985+
)
986+
return scale, zero_point
970987

971988

972989
def choose_qparams_and_quantize_affine_qqq(

0 commit comments

Comments
 (0)