Skip to content

Commit af842d3

Browse files
spcypptfacebook-github-bot
authored andcommitted
Unifying TBE API using List (Frontend) - reland (#3821)
Summary: X-link: facebookresearch/FBGEMM#904 Pull Request resolved: #3821 **TLDR;** - D68055168 was reverted due to S498612 (backout diff D70996903) - This diff re-lands D68055168 to enable the [new TBE front-end API](https://fb.workplace.com/groups/fbgemmusers/permalink/9662778130469548/) - The main changes on this new unified API (API v2). 1) Pack some arguments to list due to limitation of number of arguments 2) Enable `learning_rate_tensor` to avoid PT2 recompilation (see D65511904) 3) Enable constant `info_B_num_bits` and `info_B_mask` to address insufficent bit issue (see D69387123 See detail in D68055168 - This diff addresses the error that caused S498612 ``` Issue: - keep learning_rate as float will cause PT2 recompilation (see D65511904) - introduce learning_rate_tensor will cause torch JIT script error (which was the issue with previous land that caused S498612). Solution: - remove learning_rate from optimizer_args - replace with learning_rate_tensor as module attribute (i.e., self.learning_rate_tensor) Hence, this solution satisfy both PT2 and torch JIT script. ``` See more detail on the root cause of S498612 and the proposed workaround below. ------ __**Context:**__ Previous landing of D68055168 causes S498612 with error ``` RuntimeError: refined_slots[i]->isSubtypeOf(*attributes_[i].getType()) INTERNAL ASSERT FAILED at "fbcode/caffe2/aten/src/ATen/core/class_type.cpp":415, please report a bug to PyTorch. ``` Assertion failure point: https://fburl.com/code/q1xg0w8m __**Root cause:**__ This is a bug in PyTorch in which JIT type cannot infer the type correctly. This error occurs in the combination of the following: - Only in JIT script - Only fails for module attributes **Explanation**: 1. Module is scripted e.g., [module_factory/model_materializer_full_sync.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/model_materializer_full_sync.py?lines=1445-1446%2C1622-1623), [module_factory/sync_sgd/train_module.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/sync_sgd/train_module.py?lines=1615) 2. JIT refines class type of `type __torch__.fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen` where module attributes are added to `refined_slots` 3. For each attribute, it asserts if the `refined_slot` attribute is a JIT subtype of the attribute ([here](https://www.internalfb.com/code/fbsource/[b5505c1f8e2a7945af8860fd29f89e220ffec919]/fbcode/caffe2/aten/src/ATen/core/class_type.cpp?lines=415) and [here](https://www.internalfb.com/code/fbsource/[d693ef764b7d8eb2e4205dd85a61a33d6fd7c977]/fbcode/caffe2/aten/src/ATen/core/jit_type_base.h?lines=396-409)) 4. `SplitTableBatchedEmbeddingBagsCodegen` module has an attribute called `optimizer_args`. To avoid re-compilation in PT2, we need to change `learning_rate` to a Tensor (D65511904). We hence change the type of `optimizer_args` from `OptimizerArgs` (with `learning_rate` as float) to `OptimizerArgsPT2` (with `learning_rate_tensor`) as defined [here](https://www.internalfb.com/code/fbsource/[cec693b485d5fb800984943f4ce2bc6a1ca1c52a]/fbcode/deeplearning/fbgemm/fbgemm_gpu/codegen/training/python/lookup_args.template?lines=111). 5. JIT subtype sees the attribute as type `OptimizerArgsPT2` but `refined_slots` assumes the type as `Tuple` (which is the actual schema of `OptimizerArgsPT2`) and concludes that they are not the same. {F1976090817} See [log of refined_slot debug ](https://www.internalfb.com/phabricator/paste/view/P1756112218?lines=730-726) 6. Since JIT type sees that they are of different types, the assertion is triggered. **Note**: The JIT type is unable to figure out only when *Tensor* is added to the Class. - Previously prior to landing D68055168, `optimizer_args` was of type `OptimizerArgs` defined here. The class contains variables of types float, int and bool. JIT Type sees that they are both of type `OptimizerArgs` and does not fail. {F1976090829} See [log of refined_slot debug](https://www.internalfb.com/intern/everpaste/?handle=GO2d1Bz063DvSSAEADUn74T4WyZPbsIXAAAB&phabricator_paste_number=1756112738) - Adding any tensors to the class `OptimizerArgs` would cause errors as above. - Adding floats, ints or bools to the class `OptimizerArgs`is fine. ------- __**Solution**__ We discussed and tested several workarounds, but went with the last one which satisfied both torch JIT script and PT2 compile. 1) keep the `optimizer_args` as `OptimizerArgs` the same, i.e., `learning_rate` remains float and create a `learning_rate_tensor` in invokers before passing it to `lookup_function`. ❌ this doesn't work with PT2 because it will cause re-compilation. 2) keep v1 and v2 interfaces ❌ this doesn't work with torch jit script. 3) remove `learning_rate` from `optimizer_args`. Make `learning_rate_tensor` a module attribute. ✅ this works with torch.jit.script and PT2 torch.compile. Notice: Many users access `learning_rate` directly through `optimizer_args.learning_rate`, changes to `optimizer_args` may cause backward compatibility for their cases. We already landed D71444136 to address any future compatibility issue. Note: `learning_rate_tensor` is always created and stays on CPU, so there should be no host-device synchronization. Reviewed By: q10, nautsimon Differential Revision: D71010630
1 parent def7bbe commit af842d3

File tree

6 files changed

+351
-384
lines changed

6 files changed

+351
-384
lines changed

Diff for: fbgemm_gpu/codegen/genscript/generate_backward_split.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def generate() -> None:
447447
ssd_optimizers.append(optim)
448448

449449
BackwardSplitGenerator.generate_backward_split(
450-
ssd_tensors=ssd_tensors, **optimizer
450+
ssd_tensors=ssd_tensors, aux_args=aux_args, **optimizer
451451
)
452452
BackwardSplitGenerator.generate_rocm_backward_split()
453453

Diff for: fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu

+1
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,7 @@ Tensor {{ embedding_cuda_op }}(
603603

604604
{%- if "learning_rate" in args.split_kernel_arg_names %}
605605
// convert `learning rate` to float since `learning rate` is float in kernels
606+
TORCH_CHECK(learning_rate_tensor.is_cpu(), "learning_rate_tensor tensor needs to be on CPU. Ensure learning_rate_tensor is on CPU or contact FBGEMM team if you get this error.")
606607
const float learning_rate = learning_rate_tensor.item<float>();
607608
{%- endif %}
608609

Diff for: fbgemm_gpu/codegen/training/python/lookup_args.template

+5-61
Original file line numberDiff line numberDiff line change
@@ -49,74 +49,18 @@ class CommonArgs(NamedTuple):
4949
{%- if ssd %}
5050
ssd_tensors: Dict[str, torch.Tensor]
5151
{%- endif %}
52-
53-
54-
class OptimizerArgs(NamedTuple):
55-
stochastic_rounding: bool
56-
gradient_clipping: bool
57-
max_gradient: float
58-
max_norm: float
59-
learning_rate: float
60-
eps: float
61-
beta1: float
62-
beta2: float
63-
weight_decay: float
64-
weight_decay_mode: int
65-
eta: float
66-
momentum: float
67-
counter_halflife: int
68-
adjustment_iter: int
69-
adjustment_ub: float
70-
learning_rate_mode: int
71-
grad_sum_decay: int
72-
tail_id_threshold: float
73-
is_tail_id_thresh_ratio: int
74-
total_hash_size: int # Required for OptimType.NONE
75-
weight_norm_coefficient: float
76-
lower_bound: float
77-
regularization_mode: int
78-
use_rowwise_bias_correction: bool # Used for OptimType.ADAM
79-
80-
class CommonArgsPT2(NamedTuple):
81-
placeholder_autograd_tensor: torch.Tensor
82-
dev_weights: torch.Tensor
83-
host_weights: torch.Tensor
84-
uvm_weights: torch.Tensor
85-
lxu_cache_weights: torch.Tensor
86-
weights_placements: torch.Tensor
87-
weights_offsets: torch.Tensor
88-
D_offsets: torch.Tensor
89-
total_D: int
90-
max_D: int
91-
hash_size_cumsum: torch.Tensor
92-
total_hash_size_bits: int
93-
indices: torch.Tensor
94-
offsets: torch.Tensor
95-
pooling_mode: int
96-
indice_weights: Optional[torch.Tensor]
97-
feature_requires_grad: Optional[torch.Tensor]
98-
lxu_cache_locations: torch.Tensor
99-
uvm_cache_stats: Optional[torch.Tensor]
100-
output_dtype: int
101-
vbe_metadata: VBEMetadata
102-
is_experimental: bool
103-
use_uniq_cache_locations_bwd: bool
104-
use_homogeneous_placements: bool
52+
learning_rate_tensor: torch.Tensor
10553
info_B_num_bits: int
10654
info_B_mask: int
107-
{%- if ssd %}
108-
ssd_tensors: Dict[str, torch.Tensor]
109-
{%- endif %}
11055

111-
class OptimizerArgsPT2(NamedTuple):
112-
"""
113-
Optimizer arguments for PT2 interface
114-
"""
56+
57+
# Do not add a parameter of Type tensor here. It will cause JIT script error due to a bug in PyTorch.
58+
# See more detail in D71010630.
59+
class OptimizerArgs(NamedTuple):
11560
stochastic_rounding: bool
11661
gradient_clipping: bool
11762
max_gradient: float
11863
max_norm: float
119-
learning_rate_tensor: torch.Tensor
12064
eps: float
12165
beta1: float
12266
beta2: float

0 commit comments

Comments
 (0)