Skip to content

Commit 76d3c55

Browse files
Solaryeeguizili0
authored andcommitted
[XLA] Add more info for docs (#2063)
1 parent 5b9f729 commit 76d3c55

File tree

3 files changed

+58
-35
lines changed

3 files changed

+58
-35
lines changed

docs/guide/OpenXLA_Support_on_GPU.md

+50-27
Original file line numberDiff line numberDiff line change
@@ -38,41 +38,64 @@ Then we can get the library with xla extension **./bazel-bin/itex/libitex_xla_
3838
$ export PJRT_NAMES_AND_LIBRARY_PATHS='xpu:Your_itex_path/bazel-bin/itex/libitex_xla_extension.so'
3939
$ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:Your_Python_site-packages/jaxlib # Some functions defined in xla_extension.so are needed by libitex_xla_extension.so
4040
41-
$ export ONEDNN_VERBOSE=1 # Optional variable setting. Enable onednn verbose to check if it runs on GPU.
41+
$ export ITEX_VERBOSE=1 # Optional variable setting. It shows detailed optimization/compilation/execution info.
4242
```
4343
* **Run the below jax python code.**
4444
```python
45+
import jax
4546
import jax.numpy as jnp
46-
from jax import random
47-
key = random.PRNGKey(0)
48-
size = 3000
49-
x = random.normal(key, (size, size), dtype=jnp.float32)
50-
y = jnp.dot(x, x.T).block_until_ready()
51-
print(y)
47+
48+
@jax.jit
49+
def lax_conv():
50+
key = jax.random.PRNGKey(0)
51+
lhs = jax.random.uniform(key, (2,1,9,9), jnp.float32)
52+
rhs = jax.random.uniform(key, (1,1,4,4), jnp.float32)
53+
side = jax.random.uniform(key, (1,1,1,1), jnp.float32)
54+
out = jax.lax.conv_with_general_padding(lhs, rhs, (1,1), ((0,0),(0,0)), (1,1), (1,1))
55+
out = jax.nn.relu(out)
56+
out = jnp.multiply(out, side)
57+
return out
58+
59+
print(lax_conv())
5260
```
5361
* **Reference result:**
5462
```
55-
onednn_verbose,info,oneDNN v3.1.0 (commit xxxx)
56-
onednn_verbose,info,cpu,runtime:DPC++,nthr:1
57-
onednn_verbose,info,cpu,isa:Intel 64
58-
onednn_verbose,info,gpu,runtime:DPC++
59-
onednn_verbose,info,cpu,engine,0,backend:OpenCL,name:Intel(R) Xeon(R) Gold 6346 CPU @ 3.10GHz,driver_version:2022.15.12,binary_kernels:disabled
60-
onednn_verbose,info,gpu,engine,0,backend:Level Zero,name:Intel(R) Data Center GPU Flex Series 170 [0x56c0],driver_version:1.3.25018,binary_kernels:enabled
61-
onednn_verbose,info,gpu,engine,1,backend:Level Zero,name:Intel(R) Data Center GPU Flex Series 170 [0x56c0],driver_version:1.3.25018,binary_kernels:enabled
62-
onednn_verbose,info,experimental features are enabled
63-
onednn_verbose,info,use batch_normalization stats one pass is enabled
64-
onednn_verbose,info,experimental functionality for sparse domain is enabled
65-
onednn_verbose,info,prim_template:operation,engine,primitive,implementation,prop_kind,memory_descriptors,attributes,auxiliary,problem_desc,exec_time
66-
onednn_verbose,exec,gpu:0,matmul,jit:gemm:any,undef,src_f32::blocked:abc:f0 wei_f32::blocked:abc:f0 dst_f32::blocked:abc:f0,attr-scratchpad:user ,,1x3000x3000:1x3000x3000:1x3000x3000,xxxxxxxx
67-
[[2938.1716 17.388428 36.508217 ... 32.315964 51.31904 -34.432026]
68-
[17.388428 3031.179 41.194576 ... 47.248768 58.077858 -13.371612]
69-
[36.508217 41.194576 3000.4697 ... 8.10901 -42.501842 26.495111]
70-
...
71-
[32.315964 47.248768 8.10901 ... 2916.339 34.38107 39.404522]
72-
[51.31904 58.077858 -42.501842 ... 34.38107 3032.2844 63.69183 ]
73-
[-34.432026 -13.371612 26.495111 ... 39.404522 63.69183 3033.4866 ]]
63+
I itex/core/devices/gpu/itex_gpu_runtime.cc:129] Selected platform: Intel(R) Level-Zero
64+
I itex/core/compiler/xla/service/service.cc:176] XLA service 0x56060b5ae740 initialized for platform sycl (this does not guarantee that XLA will be used). Devices:
65+
I itex/core/compiler/xla/service/service.cc:184] StreamExecutor device (0): <undefined>, <undefined>
66+
I itex/core/compiler/xla/service/service.cc:184] StreamExecutor device (1): <undefined>, <undefined>
67+
[[[[2.0449753 2.093208 2.1844783 1.9769732 1.5857391 1.6942389]
68+
[1.9218378 2.2862523 2.1549542 1.8367321 1.3978379 1.3860377]
69+
[1.9456574 2.062028 2.0365305 1.901286 1.5255247 1.1421617]
70+
[2.0621 2.2933435 2.1257985 2.1095486 1.5584903 1.1229166]
71+
[1.7746235 2.2446113 1.7870374 1.8216239 1.557919 0.9832508]
72+
[2.0887792 2.5433128 1.9749291 2.2580051 1.6096935 1.264905 ]]]
73+
74+
75+
[[[2.175818 2.0094342 2.005763 1.6559253 1.3896458 1.4036925]
76+
[2.1342552 1.8239582 1.6091168 1.434404 1.671778 1.7397764]
77+
[1.930626 1.659667 1.6508744 1.3305787 1.4061482 2.0829628]
78+
[2.130649 1.6637266 1.594426 1.2636002 1.7168686 1.8598001]
79+
[1.9009514 1.7938274 1.4870623 1.6193901 1.5297288 2.0247464]
80+
[2.0905268 1.7598859 1.9362347 1.9513799 1.9403584 2.1483061]]]]
81+
```
82+
If `ITEX_VERBOSE=1` is set, the log looks like this:
83+
```
84+
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:301] Running HLO pass pipeline on module jit_lax_conv: optimization
85+
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass fusion
86+
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass fusion_merger
87+
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass multi_output_fusion
88+
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass gpu-conv-rewriter
89+
I itex/core/compiler/xla/service/hlo_pass_pipeline.cc:181] HLO pass onednn-fused-convolution-rewriter
90+
91+
I itex/core/compiler/xla/service/gpu/gpu_compiler.cc:1221] Build kernel via LLVM kernel compilation.
92+
I itex/core/compiler/xla/service/gpu/spir_compiler.cc:255] CompileTargetBinary - CompileToSpir time: 11 us (cumulative: 99.2 ms, max: 74.9 ms, #called: 8)
93+
94+
I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2201] Executing computation jit_lax_conv; num_replicas=1 num_partitions=1 num_addressable_devices=1
95+
I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2268] Replicated execution complete.
96+
I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1208] PjRtStreamExecutorBuffer::Delete
97+
I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1299] PjRtStreamExecutorBuffer::ToLiteral
7498
```
75-
Check it runs on GPU but not CPU. For example, "onednn_verbose,exec,**gpu**:0,matmul, ..." means "matmul" runs on GPU.
7699
77100
**4. More JAX examples.**
78101
Get examples from [https://github.com/google/jax](https://github.com/google/jax/tree/jaxlib-v0.4.4/examples) to run.

itex/core/compiler/xla/service/gpu/onednn_fused_conv_rewriter.cc

+7-7
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ StatusOr<bool> FuseConvertToFloat(HloComputation* comp) {
179179
if (!Match(instr, pattern)) {
180180
continue;
181181
}
182-
if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
182+
if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] {
183183
return absl::StrCat("FuseConvertToFloat: ", conv->ToString());
184184
})) {
185185
continue;
@@ -229,7 +229,7 @@ StatusOr<bool> FuseConvAlpha(HloComputation* comp) {
229229
if (config.conv_result_scale() != 1) {
230230
continue;
231231
}
232-
if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
232+
if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] {
233233
return absl::StrCat("FuseConvAlpha: ", conv->ToString());
234234
})) {
235235
continue;
@@ -327,7 +327,7 @@ StatusOr<bool> FuseBiasOrSideInput(HloComputation* comp) {
327327
continue;
328328
}
329329

330-
if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
330+
if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] {
331331
return absl::StrCat("FuseBiasOrSideInput: ", conv->ToString());
332332
})) {
333333
continue;
@@ -401,7 +401,7 @@ StatusOr<bool> FuseSideInputAlpha(HloComputation* comp) {
401401
}))))) {
402402
continue;
403403
}
404-
if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
404+
if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] {
405405
return absl::StrCat("FuseSideInputAlpha: ", conv->ToString());
406406
})) {
407407
continue;
@@ -481,7 +481,7 @@ StatusOr<bool> FuseRelu(HloComputation* comp) {
481481
continue;
482482
}
483483

484-
if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
484+
if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] {
485485
return absl::StrCat("FuseRelu: ", conv->ToString());
486486
})) {
487487
continue;
@@ -524,7 +524,7 @@ StatusOr<bool> FuseConvertToF16(HloComputation* comp) {
524524
0, m::GetTupleElement(m::Op().WithPredicate(IsConvCustomCall))));
525525
continue;
526526
}
527-
if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
527+
if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] {
528528
return absl::StrCat("FuseConvertToF16: ", conv->ToString());
529529
})) {
530530
continue;
@@ -609,7 +609,7 @@ StatusOr<bool> FuseConvertToS8(HloComputation* comp) {
609609
} else {
610610
continue;
611611
}
612-
if (!ConsumeFuel("cudnn-fused-convolution-rewriter", [&] {
612+
if (!ConsumeFuel("onednn-fused-convolution-rewriter", [&] {
613613
return absl::StrCat("FuseConvertToS8: ", conv->ToString());
614614
})) {
615615
continue;

itex/core/compiler/xla/service/gpu/onednn_fused_conv_rewriter.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ namespace gpu {
9797
class CudnnFusedConvRewriter : public HloModulePass {
9898
public:
9999
absl::string_view name() const override {
100-
return "cudnn-fused-convolution-rewriter";
100+
return "onednn-fused-convolution-rewriter";
101101
}
102102

103103
StatusOr<bool> Run(HloModule* module) override;

0 commit comments

Comments
 (0)