Skip to content

Commit f277280

Browse files
jerryzh168jainapurva
authored andcommitted
Add example code for printing the operator and shapes in a model (#902)
Summary: This will be useful for people to do understand the ops/shapes for a model that they are interested in optimizing, also helpful for microbenchmarks with target ops/shapes Test Plan: python tutorials/developer_api_guide/print_op_and_shapes.py Reviewers: Subscribers: Tasks: Tags:
1 parent 49ddf71 commit f277280

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
3+
linear_shapes = []
4+
from torch.overrides import TorchFunctionMode
5+
class TorchFunctionLoggingMode(TorchFunctionMode):
6+
def __torch_function__(cls, func, types, args=(), kwargs=None):
7+
if kwargs is None:
8+
kwargs = {}
9+
if func is torch.nn.functional.linear:
10+
input_tensor, weight_tensor, bias = (
11+
args[0],
12+
args[1],
13+
args[2] if len(args) > 2 else None,
14+
)
15+
flattened_input_tensor = input_tensor.view(-1, input_tensor.shape[-1])
16+
M, K = flattened_input_tensor.shape[0], flattened_input_tensor.shape[1]
17+
assert K == weight_tensor.shape[1]
18+
N = weight_tensor.shape[0]
19+
print(f"TORCH_FUNC={str(func)} (M, K, N):", M, K, N)
20+
linear_shapes.append((M, K, N))
21+
else:
22+
arg_shape = args[0].shape if len(args) > 0 and isinstance(args[0], torch.Tensor) else None
23+
print(f"TORCH_FUNC={str(func)} args[0] shape:", arg_shape)
24+
return func(*args, **kwargs)
25+
26+
# NOTE: Modify this with your own model
27+
from torchvision import models
28+
m = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
29+
example_inputs = (torch.randn(1, 3, 224, 224),)
30+
31+
with TorchFunctionLoggingMode():
32+
m(*example_inputs)
33+
34+
print()
35+
print("all linear shapes (M, K, N):", linear_shapes)

0 commit comments

Comments
 (0)