Skip to content

Commit

Permalink
Formatting issue
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Sep 12, 2024
1 parent 5c33dec commit 248bc3a
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
6 changes: 5 additions & 1 deletion torchbenchmark/operators/addmm/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
from typing import Tuple

with add_path(str(SUBMODULE_PATH)):
triton_addmm = importlib.import_module("generative-recommenders.ops.triton.triton_addmm")
triton_addmm = importlib.import_module(
"generative-recommenders.ops.triton.triton_addmm"
)
_addmm_fwd = triton_addmm._addmm_fwd


class _AddMmFunction(torch.autograd.Function):
@staticmethod
# pyre-ignore[14]
Expand Down Expand Up @@ -77,6 +80,7 @@ def backward(

return dx, dw, dy


@torch.fx.wrap
def triton_addmm(
input: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion torchbenchmark/operators/ragged_attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .operator import Operator
from .operator import Operator
2 changes: 1 addition & 1 deletion torchbenchmark/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ def get_x_val(self, example_inputs):
def get_input_iter(self):
for _input_id in range(self._num_inputs):
inputs = get_test_inputs(self.batch_size, self.num_heads, self.max_seq_len)
yield inputs
yield inputs

0 comments on commit 248bc3a

Please sign in to comment.