You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am writing to inquiry about what's the expected behavior of XLA kernel selection / fusion. My interest primarily lies in XLA<>GPU but I am also eager to learn about XLA<>TPU kernel selection process.
What I have done is to start with a simple torch model, leverage torch_xla and torch.autocast to export the model to stablehlo with fp16 operations (with saving the model weights as a numpy array, as they are treated as model inputs). Then, we leveraged XLA AOT compile with a pre-processing step where I freezed all the model weights as constants. And I have some questions on the compiled kernels.
From what I can see (see the fused report below), the convolutions are consistently fused as cudnn-conv-bias-activation. However, for the linear layers, some are fused as gemm_fusion_dot but some become fused_reduce. Taking a closer look at the fused_reduce, the reduce op after matrix multiplication happens actually in FP32, which inserts two casts (converts) before and after the reduce, which slows down the inference.
Therefore, my question is: what is the mechanism on how XLA selects which kernel to be used after compilation? How can we modify the behavior so the better kernel will be selected? Is there any documentations describing the fusion / kernel selection logics in the different optimization passes of XLA that I can read about?
Thank you advance.
Attached here is the torch -> torch_xla -> hlo export flow
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo
from torch_xla.stablehlo import VariableType
xla_flags = os.environ.get("XLA_FLAGS", "")
def torch2hlo(model: torch.nn.Module, sample_input: Any, output_dir: str,
input_names: List[str], output_names: List[str]):
"""Export a torch model to stableHLO.
Args:
model: The torch model to export.
sample_input: A sample input to the model.
output_dir: The directory to save the HLO files to.
inputs: The names of the input tensors.
outputs: The names of the output tensors.
Returns:
The stableHLO program. Useful for validation.
"""
exported = export(model, sample_input)
# freeze the weights / attribute params as tensors and save them to disk
params_to_freeze = {}
input_specs = exported.graph_signature.input_specs
state_dict = exported.state_dict
for idx, input_spec in enumerate(input_specs):
# TODO(yixzhou): I believe these are all the types that corresponsd to weights
# but we should be able to easily support other types if needed.
if input_spec.kind in INPUTKINDS_TO_FREEZE:
params_to_freeze[input_spec.target] = state_dict[
input_spec.target].detach().cpu().numpy()
output_data_dir = os.path.join(output_dir, constants.MLIR_DATA_DIR)
os.makedirs(output_data_dir, exist_ok=True)
for k, v in params_to_freeze.items():
tensor = TensorProto(
tensor_info=TensorInfoProto(
name=k,
dtype=numpy_dtype_to_enum_dtype[v.dtype.type],
shape=v.shape,
),
data=v.tobytes(),
)
tensor_path = os.path.join(output_data_dir, f"{k}.npy")
np.save(tensor_path, v)
glog.info(f"Saved tensor to {tensor_path}")
stablehlo_program = exported_program_to_stablehlo(exported)
# save the mapping from position arguments to the names of the arguments
func = stablehlo_program._name_to_stablehlo[
constants.MLIR_DEFAULT_FUNCTION_NAME]
meta = func.meta
arg_position_to_name_mapping = {}
for idx, loc in enumerate(meta.input_locations):
if loc.type_ in POSITIONAL_PARAM_TYPES_TO_FREEZE:
arg_position_to_name_mapping[idx] = loc.name
json_file = os.path.join(output_data_dir,
constants.MLIR_POSITION_TO_ARG_NAME_MAP_FILENAME)
json_str = json.dumps(arg_position_to_name_mapping)
json_file.write_text(json_str)
glog.info(f"Saved arg position to name mapping to {json_file}")
mlir_binary_path = os.path.join(output_dir, constants.MLIR_BINARY_FILENAME)
with open(mlir_binary_path, "wb+") as f:
f.write(stablehlo_program.get_stablehlo_bytecode('forward'))
glog.info(f"Saved MLIR binary to {mlir_binary_path}")
mlir_text_path = os.path.join(output_dir, constants.MLIR_TEXT_FILENAME)
with open(mlir_text_path, "w+") as f:
f.write(stablehlo_program.get_stablehlo_text('forward'))
glog.info(f"Saved MLIR debug text to {mlir_text_path}")
return stablehlo_program
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
self.linear = nn.Linear(16 * 32 * 32, 2**12)
self.linear2 = nn.Linear(2**12, 2**12)
self.linear3 = nn.Linear(2**12, 2**10)
self.linear4 = nn.Linear(2**10, 2**10)
self.linear5 = nn.Linear(2**10, 10)
def forward(self, x):
x = self.conv(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = self.conv3(x)
x = F.relu(x)
x = self.conv4(x)
x = F.relu(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
x = F.relu(x)
x = self.linear4(x)
x = F.relu(x)
x = self.linear5(x)
return x
# Initialize model
model = SimpleModel().eval()
# Create dummy input (batch_size=1, channels=3, height=32, width=32)
dummy_input = torch.randn(1, 3, 32, 32, dtype=torch.float32)
# Run inference
with torch.no_grad():
output = model(dummy_input)
args = (dummy_input,)
with torch.no_grad():
with torch.autocast(device_type="cpu", dtype=torch.float16, enabled=False):
torch2hlo(model, args, output_dir=workdir, input_names=["input"], output_names=["output"])
The text was updated successfully, but these errors were encountered:
Hi team,
I am writing to inquiry about what's the expected behavior of XLA kernel selection / fusion. My interest primarily lies in XLA<>GPU but I am also eager to learn about XLA<>TPU kernel selection process.
What I have done is to start with a simple torch model, leverage torch_xla and torch.autocast to export the model to stablehlo with fp16 operations (with saving the model weights as a numpy array, as they are treated as model inputs). Then, we leveraged XLA AOT compile with a pre-processing step where I freezed all the model weights as constants. And I have some questions on the compiled kernels.
From what I can see (see the fused report below), the convolutions are consistently fused as
cudnn-conv-bias-activation
. However, for the linear layers, some are fused asgemm_fusion_dot
but some becomefused_reduce
. Taking a closer look at thefused_reduce
, thereduce
op after matrix multiplication happens actually in FP32, which inserts two casts (convert
s) before and after the reduce, which slows down the inference.module_0001.IrToHlo.106.sm_8.6_gpu_after_optimizations-memory-usage-report.txt
Therefore, my question is: what is the mechanism on how XLA selects which kernel to be used after compilation? How can we modify the behavior so the better kernel will be selected? Is there any documentations describing the fusion / kernel selection logics in the different optimization passes of XLA that I can read about?
Thank you advance.
Attached here is the torch -> torch_xla -> hlo export flow
The text was updated successfully, but these errors were encountered: