From bf43770a4193c11fb2a34a5cdb31842ec4982bf1 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 20 Feb 2025 17:44:28 +0100 Subject: [PATCH 1/3] ENH: torch: add type promotion for (uintN, uintM) --- array_api_compat/torch/_aliases.py | 47 ++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 5b20aabc..ed47af78 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -96,6 +96,53 @@ } +try: + # torch >=2.3 + _uint_promotion_table = { + # uints + (torch.uint8, torch.uint16): torch.uint16, + (torch.uint8, torch.uint32): torch.uint32, + (torch.uint8, torch.uint64): torch.uint64, + (torch.uint16, torch.uint8): torch.uint16, + (torch.uint16, torch.uint16): torch.uint16, + (torch.uint16, torch.uint32): torch.uint32, + (torch.uint16, torch.uint64): torch.uint64, + (torch.uint32, torch.uint8): torch.uint32, + (torch.uint32, torch.uint16): torch.uint32, + (torch.uint32, torch.uint32): torch.uint32, + (torch.uint32, torch.uint64): torch.uint64, + (torch.uint64, torch.uint8): torch.uint64, + (torch.uint64, torch.uint16): torch.uint64, + (torch.uint64, torch.uint32): torch.uint64, + (torch.uint64, torch.uint64): torch.uint64, + # ints and uints (mixed sign) + (torch.int8, torch.uint16): torch.int32, + (torch.int8, torch.uint32): torch.int64, + (torch.int16, torch.uint8): torch.int16, + (torch.int16, torch.uint16): torch.int32, + (torch.int16, torch.uint32): torch.int64, + (torch.int32, torch.uint8): torch.int32, + (torch.int32, torch.uint16): torch.int32, + (torch.int32, torch.uint32): torch.int64, + (torch.int64, torch.uint8): torch.int64, + (torch.int64, torch.uint16): torch.int64, + (torch.int64, torch.uint32): torch.int64, + (torch.uint16, torch.int8): torch.int32, + (torch.uint16, torch.int16): torch.int32, + (torch.uint16, torch.int32): torch.int32, + (torch.uint16, torch.int64): torch.int64, + (torch.uint32, torch.int8): torch.int64, + (torch.uint32, torch.int16): torch.int64, + (torch.uint32, torch.int32): torch.int64, + (torch.uint32, torch.int64): torch.int64, + } +except AttributeError: + _uint_promotion_table = {} + pass + +_promotion_table.update(_uint_promotion_table) + + def _two_arg(f): @_wraps(f) def _f(x1, x2, /, **kwargs): From e2dc3ad7053db20e18439c71d0e01adcd997039e Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Feb 2025 17:41:26 +0100 Subject: [PATCH 2/3] ENH: torch: add uintN type to __array_namespace_info__ --- array_api_compat/torch/_info.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index 34fbcb21..d1cce55b 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -169,16 +169,26 @@ def _dtypes(self, kind): int32 = torch.int32 int64 = torch.int64 uint8 = torch.uint8 - # uint16, uint32, and uint64 are present in newer versions of pytorch, - # but they aren't generally supported by the array API functions, so - # we omit them from this function. + try: + # pytorch >= 2.3 + uint16 = torch.uint16 + uint32 = torch.uint32 + uint64 = torch.uint64 + uint_kinds = { + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + except AttributeError: + uint_kinds = {} + float32 = torch.float32 float64 = torch.float64 complex64 = torch.complex64 complex128 = torch.complex128 if kind is None: - return { + kinds = { "bool": bool, "int8": int8, "int16": int16, @@ -190,6 +200,8 @@ def _dtypes(self, kind): "complex64": complex64, "complex128": complex128, } + kinds.update(uint_kinds) + return kinds if kind == "bool": return {"bool": bool} if kind == "signed integer": @@ -200,17 +212,21 @@ def _dtypes(self, kind): "int64": int64, } if kind == "unsigned integer": - return { + kinds= { "uint8": uint8, } + kinds.update(uint_kinds) + return kinds if kind == "integral": - return { + kinds= { "int8": int8, "int16": int16, "int32": int32, "int64": int64, "uint8": uint8, } + kinds.update(uint_kinds) + return kinds if kind == "real floating": return { "float32": float32, @@ -222,7 +238,7 @@ def _dtypes(self, kind): "complex128": complex128, } if kind == "numeric": - return { + kinds = { "int8": int8, "int16": int16, "int32": int32, @@ -233,6 +249,9 @@ def _dtypes(self, kind): "complex64": complex64, "complex128": complex128, } + kinds.update(uint_kinds) + return kinds + if isinstance(kind, tuple): res = {} for k in kind: From a3649910804142c0c037bea846e8a0fcdefe4fa0 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 26 Feb 2025 10:23:31 +0100 Subject: [PATCH 3/3] CI: test torch uints on CI --- .github/workflows/array-api-tests-torch.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml index 56ab81a3..1b01f755 100644 --- a/.github/workflows/array-api-tests-torch.yml +++ b/.github/workflows/array-api-tests-torch.yml @@ -7,5 +7,3 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: torch - extra-env-vars: | - ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64