Skip to content

Commit a7ac730

Browse files
committed
xla: update patches for current master, add ROCm cherry-picks
based on openxla/xla#23574
1 parent d100d6d commit a7ac730

7 files changed

+276
-64
lines changed

openxla/patches/20240605-001-Added-FFI-handler-registration-API-to-the-FFI-PjRt-extension.patch openxla/patches/0001-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
From c79f202be6fde802b4e5d697a5925d7eccea3d25 Mon Sep 17 00:00:00 2001
1+
From e888ca450bbc58331c81e7537dd3f2b933f92df7 Mon Sep 17 00:00:00 2001
22
From: Hugo Mano <[email protected]>
33
Date: Wed, 5 Feb 2025 19:25:03 +0100
4-
Subject: [PATCH] Added FFI handler registration API to the FFI PjRt
4+
Subject: [PATCH 1/6] Added FFI handler registration API to the FFI PjRt
55

66
PR: https://github.com/openxla/xla/pull/13420
77
---
@@ -11,7 +11,7 @@ PR: https://github.com/openxla/xla/pull/13420
1111
3 files changed, 54 insertions(+), 2 deletions(-)
1212

1313
diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD
14-
index ad1b3987fe..0598281ad1 100644
14+
index ad2ed95bce..0e7c35c30f 100644
1515
--- a/xla/pjrt/c/BUILD
1616
+++ b/xla/pjrt/c/BUILD
1717
@@ -69,7 +69,12 @@ cc_library(
@@ -28,10 +28,10 @@ index ad1b3987fe..0598281ad1 100644
2828
)
2929

3030
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
31-
index c5766f2a19..3d74e7cbf3 100644
31+
index a33bd4aa9c..3309194538 100644
3232
--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h
3333
+++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
34-
@@ -67,12 +67,28 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data);
34+
@@ -66,12 +66,28 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data);
3535
// Adds a user data to the execute context.
3636
typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args);
3737

@@ -128,4 +128,5 @@ index 0375b39d0b..3527a0756e 100644
128128
}
129129

130130
--
131-
2.39.5 (Apple Git-154)
131+
2.43.0
132+

openxla/patches/20240901-001-Various-macOS-QOL-enchancements.patch openxla/patches/0002-Various-macOS-QOL-enchancements.patch

+8-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
From f59ff33447ea8312ed8c4bf7e87c7d5409f0d2b9 Mon Sep 17 00:00:00 2001
1+
From ba424272d5e5e0b139d05d530faad7ff1fbb6af5 Mon Sep 17 00:00:00 2001
22
From: Hugo Mano <[email protected]>
33
Date: Wed, 12 Feb 2025 13:10:04 +0100
4-
Subject: [PATCH] Various macOS QOL enchancements
4+
Subject: [PATCH 2/6] Various macOS QOL enchancements
55

66
This PR adds various small quality of life improvements to macOS builds:
77

@@ -18,10 +18,10 @@ Co-authored-by: Steeve Morin <[email protected]>
1818
1 file changed, 10 insertions(+), 8 deletions(-)
1919

2020
diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD
21-
index a0485b6a43..6f67ee6b78 100644
21+
index 0e7c35c30f..b3de80a5e7 100644
2222
--- a/xla/pjrt/c/BUILD
2323
+++ b/xla/pjrt/c/BUILD
24-
@@ -321,9 +321,14 @@ cc_library(
24+
@@ -326,9 +326,14 @@ cc_library(
2525

2626
# PJRT CPU plugin.
2727
xla_cc_binary(
@@ -38,7 +38,7 @@ index a0485b6a43..6f67ee6b78 100644
3838
[
3939
"-Wl,--version-script,$(location :pjrt_c_api_cpu_version_script.lds)",
4040
"-Wl,--no-undefined",
41-
@@ -336,10 +341,7 @@ xla_cc_binary(
41+
@@ -341,10 +346,7 @@ xla_cc_binary(
4242
"notsan",
4343
],
4444
visibility = ["//visibility:public"],
@@ -50,7 +50,7 @@ index a0485b6a43..6f67ee6b78 100644
5050
)
5151

5252
cc_library(
53-
@@ -408,7 +410,8 @@ cc_library(
53+
@@ -413,7 +415,8 @@ cc_library(
5454

5555
# PJRT GPU plugin. Can be configured to be built for CUDA or ROCM.
5656
xla_cc_binary(
@@ -60,7 +60,7 @@ index a0485b6a43..6f67ee6b78 100644
6060
linkopts = [
6161
"-Wl,--version-script,$(location :pjrt_c_api_gpu_version_script.lds)",
6262
"-Wl,--no-undefined",
63-
@@ -422,7 +425,6 @@ xla_cc_binary(
63+
@@ -427,7 +430,6 @@ xla_cc_binary(
6464
],
6565
deps = [
6666
":pjrt_c_api_gpu",
@@ -69,5 +69,5 @@ index a0485b6a43..6f67ee6b78 100644
6969
] + if_cuda_is_configured([
7070
"//xla/stream_executor:cuda_platform",
7171
--
72-
2.39.5 (Apple Git-154)
72+
2.43.0
7373

openxla/patches/20250120-001-Enable-nvptxcompiler-with-nvjitlink.patch openxla/patches/0003-Expose-nvptxcompiler-to-link-against-in-XLA-if-enabl.patch

+11-11
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
From a13ee6b68e951f3fa95f26fa3a4a9d0f8e9ab17d Mon Sep 17 00:00:00 2001
1+
From 87a2f9bdec76d447bfb7f1c379f3ccab93324824 Mon Sep 17 00:00:00 2001
22
From: Hugo Mano <[email protected]>
33
Date: Tue, 21 Jan 2025 14:41:42 +0100
4-
Subject: [PATCH] Expose nvptxcompiler to link against in XLA if
4+
Subject: [PATCH 3/6] Expose nvptxcompiler to link against in XLA if
55
enable_libnvptxcompiler_support is set
66

77
Only for ZML, no PR on XLA side.
88
---
9-
.../gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl | 8 ++++++++
10-
xla/stream_executor/cuda/BUILD | 12 +++++++++++-
9+
third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl | 8 ++++++++
10+
xla/stream_executor/cuda/BUILD | 12 +++++++++++-
1111
2 files changed, 19 insertions(+), 1 deletion(-)
1212

13-
diff --git a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl
13+
diff --git a/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl
1414
index 16ff3c8bea..d27832bb2e 100644
15-
--- a/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl
16-
+++ b/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl
15+
--- a/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl
16+
+++ b/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl
1717
@@ -45,6 +45,14 @@ filegroup(
1818
visibility = ["//visibility:public"],
1919
)
@@ -30,10 +30,10 @@ index 16ff3c8bea..d27832bb2e 100644
3030
name = "bin",
3131
srcs = glob([
3232
diff --git a/xla/stream_executor/cuda/BUILD b/xla/stream_executor/cuda/BUILD
33-
index 4ccfcbfc72..87d93f5f9c 100644
33+
index 62c8e4dbb2..d2f40db277 100644
3434
--- a/xla/stream_executor/cuda/BUILD
3535
+++ b/xla/stream_executor/cuda/BUILD
36-
@@ -78,6 +78,11 @@ config_setting(
36+
@@ -79,6 +79,11 @@ config_setting(
3737
},
3838
)
3939

@@ -45,7 +45,7 @@ index 4ccfcbfc72..87d93f5f9c 100644
4545
cc_library(
4646
name = "cuda_platform_id",
4747
srcs = ["cuda_platform_id.cc"],
48-
@@ -122,7 +127,12 @@ cc_library(
48+
@@ -123,7 +128,12 @@ cc_library(
4949
"@tsl//tsl/platform:errors",
5050
"@tsl//tsl/platform:status",
5151
"@tsl//tsl/platform:statusor",
@@ -60,5 +60,5 @@ index 4ccfcbfc72..87d93f5f9c 100644
6060
)
6161

6262
--
63-
2.39.5 (Apple Git-154)
63+
2.43.0
6464

openxla/patches/20250205-001-Use-hermetic-CC-toolchain-glibc-2.31.patch openxla/patches/0004-build-use-hermetic-cc-toolchain-for-Linux-CPU-use-gl.patch

+9-7
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1-
From 86de695ba9579ace447bc6c4f5c54bc03d467f85 Mon Sep 17 00:00:00 2001
1+
From dfdd3a3cd7f33e8c9febf787cc18dd5f38977f9e Mon Sep 17 00:00:00 2001
22
From: Hugo Mano <[email protected]>
33
Date: Wed, 5 Feb 2025 16:28:27 +0100
4-
Subject: [PATCH] build: use hermetic cc toolchain for Linux CPU (use glibc 2.31)
4+
Subject: [PATCH 4/6] build: use hermetic cc toolchain for Linux CPU (use glibc
5+
2.31)
56

67
Only for ZML, no PR on XLA side.
78
---
89
WORKSPACE | 24 ++++++++++++++++++++++++
9-
1 files changed, 24 insertions(+)
10+
1 file changed, 24 insertions(+)
1011

1112
diff --git a/WORKSPACE b/WORKSPACE
12-
index 028dcdc7ef..55b6ed691f 100644
13+
index fb250a66da..3671de7c06 100644
1314
--- a/WORKSPACE
1415
+++ b/WORKSPACE
1516
@@ -99,3 +99,27 @@ load(
1617
)
17-
18+
1819
nccl_configure(name = "local_config_nccl")
1920
+
2021
+load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
@@ -40,5 +41,6 @@ index 028dcdc7ef..55b6ed691f 100644
4041
+register_toolchains(
4142
+ "@zig_sdk//toolchain:linux_amd64_gnu.2.31",
4243
+)
43-
--
44-
2.39.5 (Apple Git-154)
44+
--
45+
2.43.0
46+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
From 3c5a3dc9cae6552fc6c58659d75077055f899126 Mon Sep 17 00:00:00 2001
2+
From: Dragan Mladjenovic <[email protected]>
3+
Date: Wed, 19 Feb 2025 10:36:32 -0600
4+
Subject: [PATCH 5/6] [ROCm] Pass correct warp size to Triton pipeline
5+
6+
---
7+
xla/backends/gpu/codegen/triton/compilation_pipeline.h | 2 +-
8+
.../gpu/codegen/triton/compilation_pipeline_cuda.cc | 6 ++----
9+
.../gpu/codegen/triton/compilation_pipeline_rocm.cc | 7 +++----
10+
.../gpu/codegen/triton/compilation_pipeline_stub.cc | 2 +-
11+
xla/backends/gpu/codegen/triton/fusion_emitter.cc | 4 +---
12+
.../gpu/codegen/triton/fusion_emitter_stub_test.cc | 2 +-
13+
xla/service/gpu/ir_emitter_unnested.cc | 7 ++++---
14+
7 files changed, 13 insertions(+), 17 deletions(-)
15+
16+
diff --git a/xla/backends/gpu/codegen/triton/compilation_pipeline.h b/xla/backends/gpu/codegen/triton/compilation_pipeline.h
17+
index 9acd6fee99..c9e65798a5 100644
18+
--- a/xla/backends/gpu/codegen/triton/compilation_pipeline.h
19+
+++ b/xla/backends/gpu/codegen/triton/compilation_pipeline.h
20+
@@ -41,7 +41,7 @@ namespace gpu {
21+
// parameter which would give a hint to Triton which cluster dims we prefer to
22+
// use, but that's not the case currently.
23+
absl::Status CreateTritonPipeline(
24+
- mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas,
25+
+ mlir::OpPassManager* pm, const se::DeviceDescription& device_info, int num_warps, int num_ctas,
26+
int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info,
27+
bool is_xla_fusion);
28+
29+
diff --git a/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc b/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc
30+
index b57300ea88..e0fcf5bfd1 100644
31+
--- a/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc
32+
+++ b/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc
33+
@@ -43,13 +43,11 @@ namespace mt = ::mlir::triton;
34+
namespace mt_xla = ::mlir::triton::xla;
35+
36+
absl::Status CreateTritonPipeline(mlir::OpPassManager* pm,
37+
- std::string arch_name, int num_warps,
38+
+ const se::DeviceDescription& device_info, int num_warps,
39+
int num_ctas, int num_stages,
40+
mt::nvidia_gpu::ClusterInfo& out_cluster_info,
41+
bool is_xla_fusion) {
42+
- TF_ASSIGN_OR_RETURN(
43+
- const stream_executor::CudaComputeCapability cc,
44+
- stream_executor::CudaComputeCapability::FromString(arch_name));
45+
+ auto cc = device_info.cuda_compute_capability();
46+
const int ccAsInt = cc.major * 10 + cc.minor;
47+
const int threadsPerWarp = 32;
48+
49+
diff --git a/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc b/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc
50+
index 03fc4bb230..64a493ed2b 100644
51+
--- a/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc
52+
+++ b/xla/backends/gpu/codegen/triton/compilation_pipeline_rocm.cc
53+
@@ -58,13 +58,12 @@ using ::mlir::Value;
54+
using mlir::ValueRange;
55+
56+
absl::Status CreateTritonPipeline(mlir::OpPassManager* pm,
57+
- std::string arch_name, int num_warps,
58+
+ const se::DeviceDescription& device_info, int num_warps,
59+
int num_ctas, int num_stages,
60+
mt::nvidia_gpu::ClusterInfo& out_cluster_info,
61+
bool is_xla_fusion) {
62+
- // TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64.
63+
- const int threadsPerWarp = 32;
64+
- auto cc = se::RocmComputeCapability(std::move(arch_name));
65+
+ const int threadsPerWarp = device_info.threads_per_warp();
66+
+ auto cc = device_info.rocm_compute_capability();
67+
68+
if (is_xla_fusion) {
69+
pm->addPass(mt_xla::CreateInt4ToPackedInt4RewritePass());
70+
diff --git a/xla/backends/gpu/codegen/triton/compilation_pipeline_stub.cc b/xla/backends/gpu/codegen/triton/compilation_pipeline_stub.cc
71+
index d91acda7f5..ce7517a6b5 100644
72+
--- a/xla/backends/gpu/codegen/triton/compilation_pipeline_stub.cc
73+
+++ b/xla/backends/gpu/codegen/triton/compilation_pipeline_stub.cc
74+
@@ -23,7 +23,7 @@ namespace xla {
75+
namespace gpu {
76+
77+
absl::Status CreateTritonPipeline(
78+
- mlir::OpPassManager* pm, std::string arch_name, int num_warps, int num_ctas,
79+
+ mlir::OpPassManager* pm, const se::DeviceDescription& device_info, int num_warps, int num_ctas,
80+
int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info,
81+
bool is_xla_fusion) {
82+
return absl::UnimplementedError("not supported for this build configuration");
83+
diff --git a/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/xla/backends/gpu/codegen/triton/fusion_emitter.cc
84+
index 02644b9dc4..d164ffa9e4 100644
85+
--- a/xla/backends/gpu/codegen/triton/fusion_emitter.cc
86+
+++ b/xla/backends/gpu/codegen/triton/fusion_emitter.cc
87+
@@ -1544,8 +1544,6 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
88+
mlir::ModuleOp triton_module, llvm::Module* llvm_module,
89+
mlir::MLIRContext& mlir_context, bool is_xla_fusion, bool emit_kernel) {
90+
const auto& cc = device_info.gpu_compute_capability();
91+
- std::string arch_name =
92+
- std::visit([](auto& cc) { return cc.ToString(); }, cc);
93+
if (std::holds_alternative<se::CudaComputeCapability>(cc)) {
94+
auto ccCuda = std::get<se::CudaComputeCapability>(cc);
95+
if (!ccCuda.IsAtLeastAmpere()) {
96+
@@ -1606,7 +1604,7 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
97+
pm.addPass(CreateConvertIndexTypePass());
98+
99+
mlir::triton::nvidia_gpu::ClusterInfo cluster_info;
100+
- if (!CreateTritonPipeline(&pm, arch_name, block_level_parameters.num_warps,
101+
+ if (!CreateTritonPipeline(&pm, device_info, block_level_parameters.num_warps,
102+
block_level_parameters.num_ctas,
103+
block_level_parameters.num_stages, cluster_info,
104+
is_xla_fusion)
105+
diff --git a/xla/backends/gpu/codegen/triton/fusion_emitter_stub_test.cc b/xla/backends/gpu/codegen/triton/fusion_emitter_stub_test.cc
106+
index 20accf012b..26b0d91fee 100644
107+
--- a/xla/backends/gpu/codegen/triton/fusion_emitter_stub_test.cc
108+
+++ b/xla/backends/gpu/codegen/triton/fusion_emitter_stub_test.cc
109+
@@ -51,7 +51,7 @@ TEST(TritonStub, CallStubApi) {
110+
mlir::OpPassManager pm;
111+
::mlir::triton::nvidia_gpu::ClusterInfo cluster_info;
112+
113+
- EXPECT_FALSE(CreateTritonPipeline(&pm, "", 1, 1, 1, cluster_info,
114+
+ EXPECT_FALSE(CreateTritonPipeline(&pm, {}, 1, 1, 1, cluster_info,
115+
/*is_xla_fusion=*/true)
116+
.ok());
117+
EXPECT_EQ(GetLibdevicePath({}, {}), "");
118+
diff --git a/xla/service/gpu/ir_emitter_unnested.cc b/xla/service/gpu/ir_emitter_unnested.cc
119+
index fcedefa8f3..75d970b7ae 100644
120+
--- a/xla/service/gpu/ir_emitter_unnested.cc
121+
+++ b/xla/service/gpu/ir_emitter_unnested.cc
122+
@@ -1434,9 +1434,10 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall(
123+
KernelArguments::Create(ir_emitter_context_->buffer_assignment(), instr,
124+
instr->operands(),
125+
/*dedup=*/false));
126+
- auto launch_dimensions =
127+
- LaunchDimensions(se::BlockDim(call.grid_x, call.grid_y, call.grid_z),
128+
- se::ThreadDim(call.num_warps * 32));
129+
+ auto launch_dimensions = LaunchDimensions(
130+
+ se::BlockDim(call.grid_x, call.grid_y, call.grid_z),
131+
+ se::ThreadDim(call.num_warps *
132+
+ ir_emitter_context_->gpu_device_info().threads_per_warp()));
133+
134+
std::string sanitized_kernel_name =
135+
GetSanitizedUniqueName(*ir_emitter_context_, kernel_name);
136+
--
137+
2.43.0
138+

0 commit comments

Comments
 (0)