diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index d9fdabd5..0b633ce7 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -10,7 +10,7 @@ from hypothesis import assume, reject from hypothesis.strategies import (SearchStrategy, booleans, composite, floats, integers, just, lists, none, one_of, - sampled_from, shared, builds) + sampled_from, shared, builds, permutations) from . import _array_module as xp, api_version from . import array_helpers as ah @@ -148,6 +148,25 @@ def mutually_promotable_dtypes( return one_of(strats).map(tuple) +@composite +def mutually_non_promotable_dtypes( + draw, + max_size: Optional[int] = 2, +) -> Sequence[Tuple[DataType, ...]]: + """Generate a pair of dtypes which cannot be promoted.""" + assert max_size == 2 + + _categories = [ + (xp.bool,), + dh.uint_dtypes + dh.int_dtypes, + dh.real_float_dtypes + dh.complex_dtypes + ] + cat_st = permutations(_categories).map(lambda s: s[:2]) + cat_from, cat_to = draw(cat_st) + from_, to = draw(sampled_from(cat_from)), draw(sampled_from(cat_to)) + return from_, to + + class OnewayPromotableDtypes(NamedTuple): input_dtype: DataType result_dtype: DataType diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 34c40024..2ae98329 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -204,3 +204,13 @@ def test_isdtype(dtype, kind): def test_result_type(dtypes): out = xp.result_type(*dtypes) ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out") + + +@given(hh.mutually_non_promotable_dtypes(2)) +def test_result_type_false(dtypes): + """Test _very_ strict promotion rules according to the spec. + Conforming array libraries may extend the promotion rules, and + then they'll need to xfail this test. + """ + with pytest.raises(TypeError): + xp.result_type(*dtypes)