Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLA Kernel Fusion / Selection Expectations #23630

Open
YixuanSeanZhou opened this issue Mar 12, 2025 · 0 comments
Open

XLA Kernel Fusion / Selection Expectations #23630

YixuanSeanZhou opened this issue Mar 12, 2025 · 0 comments

Comments

@YixuanSeanZhou
Copy link

Hi team,

I am writing to inquiry about what's the expected behavior of XLA kernel selection / fusion. My interest primarily lies in XLA<>GPU but I am also eager to learn about XLA<>TPU kernel selection process.

What I have done is to start with a simple torch model, leverage torch_xla and torch.autocast to export the model to stablehlo with fp16 operations (with saving the model weights as a numpy array, as they are treated as model inputs). Then, we leveraged XLA AOT compile with a pre-processing step where I freezed all the model weights as constants. And I have some questions on the compiled kernels.

From what I can see (see the fused report below), the convolutions are consistently fused as cudnn-conv-bias-activation. However, for the linear layers, some are fused as gemm_fusion_dot but some become fused_reduce. Taking a closer look at the fused_reduce, the reduce op after matrix multiplication happens actually in FP32, which inserts two casts (converts) before and after the reduce, which slows down the inference.

 convert.20.3 = f32[10,1024]{1,0} convert(multiply.7.3)
  constant_28_1 = f32[] constant(0)
  reduce.1.3 = f32[10]{0} reduce(convert.20.3, constant_28_1), dimensions={1}, to_apply=scalar_add_computation
  ROOT convert.21.1 = f16[10]{0} convert(reduce.1.3), metadata={op_name="dot.89"}

module_0001.IrToHlo.106.sm_8.6_gpu_after_optimizations-memory-usage-report.txt

Therefore, my question is: what is the mechanism on how XLA selects which kernel to be used after compilation? How can we modify the behavior so the better kernel will be selected? Is there any documentations describing the fusion / kernel selection logics in the different optimization passes of XLA that I can read about?

Thank you advance.

Attached here is the torch -> torch_xla -> hlo export flow

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo
from torch_xla.stablehlo import VariableType

xla_flags = os.environ.get("XLA_FLAGS", "")

def torch2hlo(model: torch.nn.Module, sample_input: Any, output_dir: str,
              input_names: List[str], output_names: List[str]):
    """Export a torch model to stableHLO.

    Args:
        model: The torch model to export.
        sample_input: A sample input to the model.
        output_dir: The directory to save the HLO files to.
        inputs: The names of the input tensors.
        outputs: The names of the output tensors.

    Returns:
        The stableHLO program. Useful for validation.
    """
    exported = export(model, sample_input)

    # freeze the weights / attribute params as tensors and save them to disk
    params_to_freeze = {}
    input_specs = exported.graph_signature.input_specs
    state_dict = exported.state_dict
    for idx, input_spec in enumerate(input_specs):
        # TODO(yixzhou): I believe these are all the types that corresponsd to weights
        # but we should be able to easily support other types if needed.
        if input_spec.kind in INPUTKINDS_TO_FREEZE:
            params_to_freeze[input_spec.target] = state_dict[
                input_spec.target].detach().cpu().numpy()

    output_data_dir = os.path.join(output_dir, constants.MLIR_DATA_DIR)
    os.makedirs(output_data_dir, exist_ok=True)
    for k, v in params_to_freeze.items():
        tensor = TensorProto(
            tensor_info=TensorInfoProto(
                name=k,
                dtype=numpy_dtype_to_enum_dtype[v.dtype.type],
                shape=v.shape,
            ),
            data=v.tobytes(),
        )
        tensor_path = os.path.join(output_data_dir, f"{k}.npy")
        np.save(tensor_path, v)
        glog.info(f"Saved tensor to {tensor_path}")

    stablehlo_program = exported_program_to_stablehlo(exported)

    # save the mapping from position arguments to the names of the arguments
    func = stablehlo_program._name_to_stablehlo[
        constants.MLIR_DEFAULT_FUNCTION_NAME]
    meta = func.meta
    arg_position_to_name_mapping = {}
    for idx, loc in enumerate(meta.input_locations):
        if loc.type_ in POSITIONAL_PARAM_TYPES_TO_FREEZE:
            arg_position_to_name_mapping[idx] = loc.name

    json_file = os.path.join(output_data_dir,
                     constants.MLIR_POSITION_TO_ARG_NAME_MAP_FILENAME)
    json_str = json.dumps(arg_position_to_name_mapping)
    json_file.write_text(json_str)
    glog.info(f"Saved arg position to name mapping to {json_file}")

    mlir_binary_path = os.path.join(output_dir, constants.MLIR_BINARY_FILENAME)
    with open(mlir_binary_path, "wb+") as f:
        f.write(stablehlo_program.get_stablehlo_bytecode('forward'))
    glog.info(f"Saved MLIR binary to {mlir_binary_path}")
    mlir_text_path = os.path.join(output_dir, constants.MLIR_TEXT_FILENAME)
    with open(mlir_text_path, "w+") as f:
        f.write(stablehlo_program.get_stablehlo_text('forward'))
    glog.info(f"Saved MLIR debug text to {mlir_text_path}")
    return stablehlo_program


class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
        self.linear = nn.Linear(16 * 32 * 32, 2**12)
        self.linear2 = nn.Linear(2**12, 2**12)
        self.linear3 = nn.Linear(2**12, 2**10)
        self.linear4 = nn.Linear(2**10, 2**10)
        self.linear5 = nn.Linear(2**10, 10)
        
    def forward(self, x):
        x = self.conv(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.relu(x)
        x = self.linear3(x)
        x = F.relu(x)
        x = self.linear4(x)
        x = F.relu(x)
        x = self.linear5(x)
        return x

# Initialize model
model = SimpleModel().eval()

# Create dummy input (batch_size=1, channels=3, height=32, width=32)
dummy_input = torch.randn(1, 3, 32, 32, dtype=torch.float32)

# Run inference
with torch.no_grad():
    output = model(dummy_input)

args = (dummy_input,)
with torch.no_grad():
    with torch.autocast(device_type="cpu", dtype=torch.float16, enabled=False):
        torch2hlo(model, args, output_dir=workdir, input_names=["input"], output_names=["output"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant