You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Unifying TBE API using List (Frontend) - reland (#3821)
Summary:
X-link: facebookresearch/FBGEMM#904
Pull Request resolved: #3821
**TLDR;**
- D68055168 was reverted due to S498612 (backout diff D70996903)
- This diff re-lands D68055168 to enable the [new TBE front-end API](https://fb.workplace.com/groups/fbgemmusers/permalink/9662778130469548/)
- The main changes on this new unified API (API v2).
1) Pack some arguments to list due to limitation of number of arguments
2) Enable `learning_rate_tensor` to avoid PT2 recompilation
(see D65511904)
3) Enable constant `info_B_num_bits` and `info_B_mask` to address
insufficent bit issue (see D69387123
See detail in D68055168
- This diff addresses the error that caused S498612
```
Issue:
- keep learning_rate as float will cause PT2 recompilation
(see D65511904)
- introduce learning_rate_tensor will cause torch JIT script error
(which was the issue with previous land that caused S498612).
Solution:
- remove learning_rate from optimizer_args
- replace with learning_rate_tensor as module attribute
(i.e., self.learning_rate_tensor)
Hence, this solution satisfy both PT2 and torch JIT script.
```
See more detail on the root cause of S498612 and the proposed workaround below.
------
__**Context:**__
Previous landing of D68055168 causes S498612 with error
```
RuntimeError: refined_slots[i]->isSubtypeOf(*attributes_[i].getType()) INTERNAL ASSERT FAILED at "fbcode/caffe2/aten/src/ATen/core/class_type.cpp":415, please report a bug to PyTorch.
```
Assertion failure point: https://fburl.com/code/q1xg0w8m
__**Root cause:**__
This is a bug in PyTorch in which JIT type cannot infer the type correctly. This error occurs in the combination of the following:
- Only in JIT script
- Only fails for module attributes
**Explanation**:
1. Module is scripted
e.g., [module_factory/model_materializer_full_sync.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/model_materializer_full_sync.py?lines=1445-1446%2C1622-1623), [module_factory/sync_sgd/train_module.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/sync_sgd/train_module.py?lines=1615)
2. JIT refines class type of `type __torch__.fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen` where module attributes are added to `refined_slots`
3. For each attribute, it asserts if the `refined_slot` attribute is a JIT subtype of the attribute ([here](https://www.internalfb.com/code/fbsource/[b5505c1f8e2a7945af8860fd29f89e220ffec919]/fbcode/caffe2/aten/src/ATen/core/class_type.cpp?lines=415) and [here](https://www.internalfb.com/code/fbsource/[d693ef764b7d8eb2e4205dd85a61a33d6fd7c977]/fbcode/caffe2/aten/src/ATen/core/jit_type_base.h?lines=396-409))
4. `SplitTableBatchedEmbeddingBagsCodegen` module has an attribute called `optimizer_args`. To avoid re-compilation in PT2, we need to change `learning_rate` to a Tensor (D65511904). We hence change the type of `optimizer_args` from `OptimizerArgs` (with `learning_rate` as float) to `OptimizerArgsPT2` (with `learning_rate_tensor`) as defined [here](https://www.internalfb.com/code/fbsource/[cec693b485d5fb800984943f4ce2bc6a1ca1c52a]/fbcode/deeplearning/fbgemm/fbgemm_gpu/codegen/training/python/lookup_args.template?lines=111).
5. JIT subtype sees the attribute as type `OptimizerArgsPT2` but `refined_slots` assumes the type as `Tuple` (which is the actual schema of `OptimizerArgsPT2`) and concludes that they are not the same.
{F1976090817}
See [log of refined_slot debug ](https://www.internalfb.com/phabricator/paste/view/P1756112218?lines=730-726)
6. Since JIT type sees that they are of different types, the assertion is triggered.
**Note**:
The JIT type is unable to figure out only when *Tensor* is added to the Class.
- Previously prior to landing D68055168, `optimizer_args` was of type `OptimizerArgs` defined here. The class contains variables of types float, int and bool. JIT Type sees that they are both of type `OptimizerArgs` and does not fail.
{F1976090829}
See [log of refined_slot debug](https://www.internalfb.com/intern/everpaste/?handle=GO2d1Bz063DvSSAEADUn74T4WyZPbsIXAAAB&phabricator_paste_number=1756112738)
- Adding any tensors to the class `OptimizerArgs` would cause errors as above.
- Adding floats, ints or bools to the class `OptimizerArgs`is fine.
-------
__**Solution**__
We discussed and tested several workarounds, but went with the last one which satisfied both torch JIT script and PT2 compile.
1) keep the `optimizer_args` as `OptimizerArgs` the same, i.e., `learning_rate` remains float and create a `learning_rate_tensor` in invokers before passing it to `lookup_function`.
❌ this doesn't work with PT2 because it will cause re-compilation.
2) keep v1 and v2 interfaces
❌ this doesn't work with torch jit script.
3) remove `learning_rate` from `optimizer_args`. Make `learning_rate_tensor` a module attribute.
✅ this works with torch.jit.script and PT2 torch.compile.
Notice:
Many users access `learning_rate` directly through `optimizer_args.learning_rate`, changes to `optimizer_args` may cause backward compatibility for their cases.
We already landed D71444136 to address any future compatibility issue.
Note: `learning_rate_tensor` is always created and stays on CPU, so there should be no host-device synchronization.
Reviewed By: q10, nautsimon
Differential Revision: D71010630
{%- if"learning_rate" in args.split_kernel_arg_names %}
605
605
// convert `learning rate` to float since `learning rate` is float in kernels
606
+
TORCH_CHECK(learning_rate_tensor.is_cpu(), "learning_rate_tensor tensor needs to be on CPU. Ensure learning_rate_tensor is on CPU or contact FBGEMM team if you get this error.")
0 commit comments