@@ -957,6 +957,66 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
957
957
torch .testing .assert_close (expected_quantized , quantized )
958
958
torch .testing .assert_close (expected_dequantized , dequantized )
959
959
960
+ @parameterized .expand (
961
+ [
962
+ torch .float64 ,
963
+ torch .float32 ,
964
+ torch .bfloat16 ,
965
+ torch .float16 ,
966
+ ]
967
+ )
968
+ def test_choose_qparams_affine_for_inf_scale_reciprocal (self , hp_dtype ):
969
+ # Fixed by #1770, the test will fail for all the variants
970
+ # before that fix, and will pass afterwards.
971
+ #
972
+ # The scale value must be forcefully clamped, within
973
+ # _choose_qparams_affine() function, (that
974
+ # choose_qparams_affine() and others call into) to a large
975
+ # enough number so that its reciprocal does not become Inf.
976
+ # Otherwise during the quantization, by multiplying with scale
977
+ # reciprocal, all the values will be quantized to Inf value,
978
+ # except from zero value that would produce NaN (0*Inf) as
979
+ # quantized value.
980
+ #
981
+ # The minimal normalized value for given floating point data
982
+ # type is given by torch.finfo(hp_dtype).tiny - let's call
983
+ # this value "tiny". It could be seen by checking, that for
984
+ # all of torch.float64, torch.float32, torch.float16 and
985
+ # torch.floatb16, denormalized number that is equal to tiny/4
986
+ # will produce Inf as its reciprocal.
987
+ #
988
+ # Thus, to reproduce the problem, one would create a tensor
989
+ # with such values that their absolute maximum, after being
990
+ # divided with the range of quantized data (that is 57344 for
991
+ # torch.float8_e5m2), would produce scale smaller than tiny/4.
992
+ # Also, eps parameter should be set to value no greater than
993
+ # tiny/4, as scale is clamped from below to that value. With
994
+ # such inpujts, choose_qparams_affine() will produce Inf as
995
+ # scale value.
996
+ #
997
+ # Note that this may seem as contrieved reproduces. However,
998
+ # there are cases with existing code that would pass
999
+ # torch.finfo(torch.float32).eps as eps value, no matters of
1000
+ # scale_dtype. The float16 has rather small range, so this
1001
+ # value is well bellow torch.finfo(torch.float32).eps, and for
1002
+ # such eps value, the code bellow would produce Inf scale even
1003
+ # for float16 tensor that has 0.5 as its maximum value.
1004
+ float8_dtype = torch .float8_e5m2
1005
+ tiny = torch .finfo (hp_dtype ).tiny
1006
+ x = torch .tensor ([[0 , 100 * tiny ]], dtype = hp_dtype )
1007
+ scale , _ = choose_qparams_affine (
1008
+ input = x ,
1009
+ mapping_type = MappingType .SYMMETRIC ,
1010
+ block_size = [1 , 2 ],
1011
+ target_dtype = float8_dtype ,
1012
+ eps = tiny / 4 ,
1013
+ scale_dtype = hp_dtype ,
1014
+ preserve_zero = True ,
1015
+ zero_point_domain = ZeroPointDomain .NONE ,
1016
+ )
1017
+ scale_reciprocal = scale .reciprocal ()
1018
+ assert not torch .any (torch .isinf (scale_reciprocal )).item ()
1019
+
960
1020
961
1021
if __name__ == "__main__" :
962
1022
unittest .main ()
0 commit comments