Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unifying TBE API using List (Frontend) - reland #3821

Closed
wants to merge 1 commit into from

Conversation

spcyppt
Copy link
Contributor

@spcyppt spcyppt commented Mar 14, 2025

Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/904

Re-land D68055168

Differential Revision: D71010630

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71010630

Copy link

netlify bot commented Mar 14, 2025

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit 1c9cde3
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/67edb7c0573f390008d97d7e
😎 Deploy Preview https://deploy-preview-3821--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71010630

spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request Mar 20, 2025
Summary:
X-link: facebookresearch/FBGEMM#904


Re-land of D68055168 with fixes.

**TLDR;**
Issue:
- keep `learning_rate` as float will cause PT2 recompilation
- make `learning_rate_tensor` will cause torch JIT script error (which was the issue with previous land that caused S498612).

This diff:
- remove `learning_rate` from `optimizer_args`
- replace with `learning_rate_tensor` as module attribute (i.e., `self.learning_rate_tensor`) 

This solution addresses re-compilation issues in PT2 and is compatible with torch jit script 

**Usage**

To get learning rate:
```
emb_op = SplitTableBatchedEmbeddingBagsCodegen(....)
lr = emb_op.get_learning_rate()
```
To set learning rate:
```
emb_op.set_learning_rate(lr)
```
See more detail below.

------
**Context:**
Previous landing of D68055168 causes S498612 with error
```
RuntimeError: refined_slots[i]->isSubtypeOf(*attributes_[i].getType()) INTERNAL ASSERT FAILED at "fbcode/caffe2/aten/src/ATen/core/class_type.cpp":415, please report a bug to PyTorch.
```
Assertion failure point: https://fburl.com/code/q1xg0w8m 

**Root cause:**
This is a bug in PyTorch in which JIT type cannot infer the type correctly. This error occurs in the combination of the following:
- Only in JIT script 
- Only fails for module attributes

**Explanation**:
1. Module is scripted
e.g., [module_factory/model_materializer_full_sync.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/model_materializer_full_sync.py?lines=1445-1446%2C1622-1623), [module_factory/sync_sgd/train_module.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/sync_sgd/train_module.py?lines=1615)
2. JIT refines class type of `type __torch__.fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen` where module attributes are added to `refined_slots`
3. For each attribute, it asserts if the `refined_slot` attribute is a JIT subtype of the attribute ([here](https://www.internalfb.com/code/fbsource/[b5505c1f8e2a7945af8860fd29f89e220ffec919]/fbcode/caffe2/aten/src/ATen/core/class_type.cpp?lines=415) and [here](https://www.internalfb.com/code/fbsource/[d693ef764b7d8eb2e4205dd85a61a33d6fd7c977]/fbcode/caffe2/aten/src/ATen/core/jit_type_base.h?lines=396-409))
4. `SplitTableBatchedEmbeddingBagsCodegen` module has an attribute called `optimizer_args`. To avoid re-compilation in PT2, we need to change `learning_rate` to a Tensor (D65511904). We hence change the type of `optimizer_args` from `OptimizerArgs` (with `learning_rate` as float) to `OptimizerArgsPT2` (with `learning_rate_tensor`) as defined [here](https://www.internalfb.com/code/fbsource/[cec693b485d5fb800984943f4ce2bc6a1ca1c52a]/fbcode/deeplearning/fbgemm/fbgemm_gpu/codegen/training/python/lookup_args.template?lines=111).
5. JIT subtype sees the attribute as type `OptimizerArgsPT2` but `refined_slots` assumes the type as `Tuple` (which is the actual schema of `OptimizerArgsPT2`) and concludes that they are not the same.
 {F1976090817} 
See [log of refined_slot debug ](https://www.internalfb.com/phabricator/paste/view/P1756112218?lines=730-726)
6. Since JIT type sees that they are of different types, the assertion is triggered.

**Note**:
The JIT type is unable to figure out only when *Tensor* is added to the Class. 
- Previously prior to landing D68055168, `optimizer_args` was of type `OptimizerArgs` defined here. The class contains variables of types float, int and bool. JIT Type sees that they are both of type  `OptimizerArgs` and does not fail.
 {F1976090829} 
See [log of refined_slot debug](https://www.internalfb.com/intern/everpaste/?handle=GO2d1Bz063DvSSAEADUn74T4WyZPbsIXAAAB&phabricator_paste_number=1756112738)
- Adding any tensors to the class `OptimizerArgs` would cause errors as above. 
- Adding floats, ints or bools to the class `OptimizerArgs`is fine.
-------

__**Solution**__
We discussed and tested several workarounds, but went with the last one which satisfied both torch JIT script and PT2 compile.

1) keep the `optimizer_args` as `OptimizerArgs` the same, i.e., `learning_rate` remains float and create a `learning_rate_tensor` in invokers before passing it to `lookup_function`.
❌ this doesn't work with PT2 because it will cause re-compilation. 
2) keep v1 and v2 interfaces
❌ this doesn't work with torch jit script.
3) remove `learning_rate` from `optimizer_args`. Make `learning_rate_tensor` a module attribute. 
✅ this works with torch.jit.script and PT2 torch.compile.

Notice:
Many users access `learning_rate` directly through `optimizer_args.learning_rate`, changes to `optimizer_args` may cause backward compatibility for their cases.

We will land D71444136 which addresses this or any compatibility issue first.

Note: `learning_rate_tensor` is always created and stays on CPU, so there should be no host-device synchronization.

Differential Revision: D71010630
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71010630

spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request Mar 20, 2025
Summary:
X-link: facebookresearch/FBGEMM#904


Re-land of D68055168 with fixes.

**TLDR;**
Issue:
- keep `learning_rate` as float will cause PT2 recompilation
- make `learning_rate_tensor` will cause torch JIT script error (which was the issue with previous land that caused S498612).

This diff:
- remove `learning_rate` from `optimizer_args`
- replace with `learning_rate_tensor` as module attribute (i.e., `self.learning_rate_tensor`) 

This solution addresses re-compilation issues in PT2 and is compatible with torch jit script 

**Usage**

To get learning rate:
```
emb_op = SplitTableBatchedEmbeddingBagsCodegen(....)
lr = emb_op.get_learning_rate()
```
To set learning rate:
```
emb_op.set_learning_rate(lr)
```
See more detail below.

------
**Context:**
Previous landing of D68055168 causes S498612 with error
```
RuntimeError: refined_slots[i]->isSubtypeOf(*attributes_[i].getType()) INTERNAL ASSERT FAILED at "fbcode/caffe2/aten/src/ATen/core/class_type.cpp":415, please report a bug to PyTorch.
```
Assertion failure point: https://fburl.com/code/q1xg0w8m 

**Root cause:**
This is a bug in PyTorch in which JIT type cannot infer the type correctly. This error occurs in the combination of the following:
- Only in JIT script 
- Only fails for module attributes

**Explanation**:
1. Module is scripted
e.g., [module_factory/model_materializer_full_sync.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/model_materializer_full_sync.py?lines=1445-1446%2C1622-1623), [module_factory/sync_sgd/train_module.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/sync_sgd/train_module.py?lines=1615)
2. JIT refines class type of `type __torch__.fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen` where module attributes are added to `refined_slots`
3. For each attribute, it asserts if the `refined_slot` attribute is a JIT subtype of the attribute ([here](https://www.internalfb.com/code/fbsource/[b5505c1f8e2a7945af8860fd29f89e220ffec919]/fbcode/caffe2/aten/src/ATen/core/class_type.cpp?lines=415) and [here](https://www.internalfb.com/code/fbsource/[d693ef764b7d8eb2e4205dd85a61a33d6fd7c977]/fbcode/caffe2/aten/src/ATen/core/jit_type_base.h?lines=396-409))
4. `SplitTableBatchedEmbeddingBagsCodegen` module has an attribute called `optimizer_args`. To avoid re-compilation in PT2, we need to change `learning_rate` to a Tensor (D65511904). We hence change the type of `optimizer_args` from `OptimizerArgs` (with `learning_rate` as float) to `OptimizerArgsPT2` (with `learning_rate_tensor`) as defined [here](https://www.internalfb.com/code/fbsource/[cec693b485d5fb800984943f4ce2bc6a1ca1c52a]/fbcode/deeplearning/fbgemm/fbgemm_gpu/codegen/training/python/lookup_args.template?lines=111).
5. JIT subtype sees the attribute as type `OptimizerArgsPT2` but `refined_slots` assumes the type as `Tuple` (which is the actual schema of `OptimizerArgsPT2`) and concludes that they are not the same.
 {F1976090817} 
See [log of refined_slot debug ](https://www.internalfb.com/phabricator/paste/view/P1756112218?lines=730-726)
6. Since JIT type sees that they are of different types, the assertion is triggered.

**Note**:
The JIT type is unable to figure out only when *Tensor* is added to the Class. 
- Previously prior to landing D68055168, `optimizer_args` was of type `OptimizerArgs` defined here. The class contains variables of types float, int and bool. JIT Type sees that they are both of type  `OptimizerArgs` and does not fail.
 {F1976090829} 
See [log of refined_slot debug](https://www.internalfb.com/intern/everpaste/?handle=GO2d1Bz063DvSSAEADUn74T4WyZPbsIXAAAB&phabricator_paste_number=1756112738)
- Adding any tensors to the class `OptimizerArgs` would cause errors as above. 
- Adding floats, ints or bools to the class `OptimizerArgs`is fine.
-------

__**Solution**__
We discussed and tested several workarounds, but went with the last one which satisfied both torch JIT script and PT2 compile.

1) keep the `optimizer_args` as `OptimizerArgs` the same, i.e., `learning_rate` remains float and create a `learning_rate_tensor` in invokers before passing it to `lookup_function`.
❌ this doesn't work with PT2 because it will cause re-compilation. 
2) keep v1 and v2 interfaces
❌ this doesn't work with torch jit script.
3) remove `learning_rate` from `optimizer_args`. Make `learning_rate_tensor` a module attribute. 
✅ this works with torch.jit.script and PT2 torch.compile.

Notice:
Many users access `learning_rate` directly through `optimizer_args.learning_rate`, changes to `optimizer_args` may cause backward compatibility for their cases.

We will land D71444136 which addresses this or any compatibility issue first.

Note: `learning_rate_tensor` is always created and stays on CPU, so there should be no host-device synchronization.

Reviewed By: q10

Differential Revision: D71010630
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71010630

spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request Mar 28, 2025
Summary:
X-link: facebookresearch/FBGEMM#904

**TLDR;**
- D68055168 was reverted due to S498612 (backout diff  D70996903)
- This diff re-lands D68055168 to enable the [new TBE front-end API](https://fb.workplace.com/groups/fbgemmusers/permalink/9662778130469548/)
- The main changes on this new unified API (API v2). 
      1) Pack some arguments to list due to limitation of number of arguments
      2) Enable `learning_rate_tensor` to avoid PT2 recompilation 
         (see D65511904)
      3) Enable constant `info_B_num_bits` and `info_B_mask` to address
         insufficent bit issue (see D69387123
     See detail in D68055168
- This diff addresses the error that caused S498612
```
  Issue:
   - keep learning_rate as float will cause PT2 recompilation
     (see D65511904)
   - introduce learning_rate_tensor will cause torch JIT script error 
     (which was the issue with previous land that caused S498612).
  Solution:
   - remove learning_rate from optimizer_args
   - replace with learning_rate_tensor as module attribute 
     (i.e., self.learning_rate_tensor)

Hence, this solution satisfy both PT2 and torch JIT script.
```
See more detail on the root cause of S498612 and the proposed workaround below.

------
__**Context:**__
Previous landing of D68055168 causes S498612 with error
```
RuntimeError: refined_slots[i]->isSubtypeOf(*attributes_[i].getType()) INTERNAL ASSERT FAILED at "fbcode/caffe2/aten/src/ATen/core/class_type.cpp":415, please report a bug to PyTorch.
```
Assertion failure point: https://fburl.com/code/q1xg0w8m 

__**Root cause:**__
This is a bug in PyTorch in which JIT type cannot infer the type correctly. This error occurs in the combination of the following:
- Only in JIT script 
- Only fails for module attributes

**Explanation**:
1. Module is scripted
e.g., [module_factory/model_materializer_full_sync.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/model_materializer_full_sync.py?lines=1445-1446%2C1622-1623), [module_factory/sync_sgd/train_module.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/sync_sgd/train_module.py?lines=1615)
2. JIT refines class type of `type __torch__.fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen` where module attributes are added to `refined_slots`
3. For each attribute, it asserts if the `refined_slot` attribute is a JIT subtype of the attribute ([here](https://www.internalfb.com/code/fbsource/[b5505c1f8e2a7945af8860fd29f89e220ffec919]/fbcode/caffe2/aten/src/ATen/core/class_type.cpp?lines=415) and [here](https://www.internalfb.com/code/fbsource/[d693ef764b7d8eb2e4205dd85a61a33d6fd7c977]/fbcode/caffe2/aten/src/ATen/core/jit_type_base.h?lines=396-409))
4. `SplitTableBatchedEmbeddingBagsCodegen` module has an attribute called `optimizer_args`. To avoid re-compilation in PT2, we need to change `learning_rate` to a Tensor (D65511904). We hence change the type of `optimizer_args` from `OptimizerArgs` (with `learning_rate` as float) to `OptimizerArgsPT2` (with `learning_rate_tensor`) as defined [here](https://www.internalfb.com/code/fbsource/[cec693b485d5fb800984943f4ce2bc6a1ca1c52a]/fbcode/deeplearning/fbgemm/fbgemm_gpu/codegen/training/python/lookup_args.template?lines=111).
5. JIT subtype sees the attribute as type `OptimizerArgsPT2` but `refined_slots` assumes the type as `Tuple` (which is the actual schema of `OptimizerArgsPT2`) and concludes that they are not the same.
 {F1976090817} 
See [log of refined_slot debug ](https://www.internalfb.com/phabricator/paste/view/P1756112218?lines=730-726)
6. Since JIT type sees that they are of different types, the assertion is triggered.

**Note**:
The JIT type is unable to figure out only when *Tensor* is added to the Class. 
- Previously prior to landing D68055168, `optimizer_args` was of type `OptimizerArgs` defined here. The class contains variables of types float, int and bool. JIT Type sees that they are both of type  `OptimizerArgs` and does not fail.
 {F1976090829} 
See [log of refined_slot debug](https://www.internalfb.com/intern/everpaste/?handle=GO2d1Bz063DvSSAEADUn74T4WyZPbsIXAAAB&phabricator_paste_number=1756112738)
- Adding any tensors to the class `OptimizerArgs` would cause errors as above. 
- Adding floats, ints or bools to the class `OptimizerArgs`is fine.
-------

__**Solution**__
We discussed and tested several workarounds, but went with the last one which satisfied both torch JIT script and PT2 compile.

1) keep the `optimizer_args` as `OptimizerArgs` the same, i.e., `learning_rate` remains float and create a `learning_rate_tensor` in invokers before passing it to `lookup_function`.
❌ this doesn't work with PT2 because it will cause re-compilation. 
2) keep v1 and v2 interfaces
❌ this doesn't work with torch jit script.
3) remove `learning_rate` from `optimizer_args`. Make `learning_rate_tensor` a module attribute. 
✅ this works with torch.jit.script and PT2 torch.compile.

Notice:
Many users access `learning_rate` directly through `optimizer_args.learning_rate`, changes to `optimizer_args` may cause backward compatibility for their cases.

We already landed D71444136 to address any future compatibility issue. 

Note: `learning_rate_tensor` is always created and stays on CPU, so there should be no host-device synchronization.

Reviewed By: q10, nautsimon

Differential Revision: D71010630
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71010630

@spcyppt spcyppt force-pushed the export-D71010630 branch from 5f39f55 to 0d3c0dd Compare April 2, 2025 22:04
spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request Apr 2, 2025
Summary:
X-link: facebookresearch/FBGEMM#904

**TLDR;**
- D68055168 was reverted due to S498612 (backout diff  D70996903)
- This diff re-lands D68055168 to enable the [new TBE front-end API](https://fb.workplace.com/groups/fbgemmusers/permalink/9662778130469548/)
- The main changes on this new unified API (API v2). 
      1) Pack some arguments to list due to limitation of number of arguments
      2) Enable `learning_rate_tensor` to avoid PT2 recompilation 
         (see D65511904)
      3) Enable constant `info_B_num_bits` and `info_B_mask` to address
         insufficent bit issue (see D69387123
     See detail in D68055168
- This diff addresses the error that caused S498612
```
  Issue:
   - keep learning_rate as float will cause PT2 recompilation
     (see D65511904)
   - introduce learning_rate_tensor will cause torch JIT script error 
     (which was the issue with previous land that caused S498612).
  Solution:
   - remove learning_rate from optimizer_args
   - replace with learning_rate_tensor as module attribute 
     (i.e., self.learning_rate_tensor)

Hence, this solution satisfy both PT2 and torch JIT script.
```
See more detail on the root cause of S498612 and the proposed workaround below.

------
__**Context:**__
Previous landing of D68055168 causes S498612 with error
```
RuntimeError: refined_slots[i]->isSubtypeOf(*attributes_[i].getType()) INTERNAL ASSERT FAILED at "fbcode/caffe2/aten/src/ATen/core/class_type.cpp":415, please report a bug to PyTorch.
```
Assertion failure point: https://fburl.com/code/q1xg0w8m 

__**Root cause:**__
This is a bug in PyTorch in which JIT type cannot infer the type correctly. This error occurs in the combination of the following:
- Only in JIT script 
- Only fails for module attributes

**Explanation**:
1. Module is scripted
e.g., [module_factory/model_materializer_full_sync.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/model_materializer_full_sync.py?lines=1445-1446%2C1622-1623), [module_factory/sync_sgd/train_module.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/sync_sgd/train_module.py?lines=1615)
2. JIT refines class type of `type __torch__.fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen` where module attributes are added to `refined_slots`
3. For each attribute, it asserts if the `refined_slot` attribute is a JIT subtype of the attribute ([here](https://www.internalfb.com/code/fbsource/[b5505c1f8e2a7945af8860fd29f89e220ffec919]/fbcode/caffe2/aten/src/ATen/core/class_type.cpp?lines=415) and [here](https://www.internalfb.com/code/fbsource/[d693ef764b7d8eb2e4205dd85a61a33d6fd7c977]/fbcode/caffe2/aten/src/ATen/core/jit_type_base.h?lines=396-409))
4. `SplitTableBatchedEmbeddingBagsCodegen` module has an attribute called `optimizer_args`. To avoid re-compilation in PT2, we need to change `learning_rate` to a Tensor (D65511904). We hence change the type of `optimizer_args` from `OptimizerArgs` (with `learning_rate` as float) to `OptimizerArgsPT2` (with `learning_rate_tensor`) as defined [here](https://www.internalfb.com/code/fbsource/[cec693b485d5fb800984943f4ce2bc6a1ca1c52a]/fbcode/deeplearning/fbgemm/fbgemm_gpu/codegen/training/python/lookup_args.template?lines=111).
5. JIT subtype sees the attribute as type `OptimizerArgsPT2` but `refined_slots` assumes the type as `Tuple` (which is the actual schema of `OptimizerArgsPT2`) and concludes that they are not the same.
 {F1976090817} 
See [log of refined_slot debug ](https://www.internalfb.com/phabricator/paste/view/P1756112218?lines=730-726)
6. Since JIT type sees that they are of different types, the assertion is triggered.

**Note**:
The JIT type is unable to figure out only when *Tensor* is added to the Class. 
- Previously prior to landing D68055168, `optimizer_args` was of type `OptimizerArgs` defined here. The class contains variables of types float, int and bool. JIT Type sees that they are both of type  `OptimizerArgs` and does not fail.
 {F1976090829} 
See [log of refined_slot debug](https://www.internalfb.com/intern/everpaste/?handle=GO2d1Bz063DvSSAEADUn74T4WyZPbsIXAAAB&phabricator_paste_number=1756112738)
- Adding any tensors to the class `OptimizerArgs` would cause errors as above. 
- Adding floats, ints or bools to the class `OptimizerArgs`is fine.
-------

__**Solution**__
We discussed and tested several workarounds, but went with the last one which satisfied both torch JIT script and PT2 compile.

1) keep the `optimizer_args` as `OptimizerArgs` the same, i.e., `learning_rate` remains float and create a `learning_rate_tensor` in invokers before passing it to `lookup_function`.
❌ this doesn't work with PT2 because it will cause re-compilation. 
2) keep v1 and v2 interfaces
❌ this doesn't work with torch jit script.
3) remove `learning_rate` from `optimizer_args`. Make `learning_rate_tensor` a module attribute. 
✅ this works with torch.jit.script and PT2 torch.compile.

Notice:
Many users access `learning_rate` directly through `optimizer_args.learning_rate`, changes to `optimizer_args` may cause backward compatibility for their cases.

We already landed D71444136 to address any future compatibility issue. 

Note: `learning_rate_tensor` is always created and stays on CPU, so there should be no host-device synchronization.

Reviewed By: q10, nautsimon

Differential Revision: D71010630
spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request Apr 2, 2025
Summary:
X-link: facebookresearch/FBGEMM#904

**TLDR;**
- D68055168 was reverted due to S498612 (backout diff  D70996903)
- This diff re-lands D68055168 to enable the [new TBE front-end API](https://fb.workplace.com/groups/fbgemmusers/permalink/9662778130469548/)
- The main changes on this new unified API (API v2). 
      1) Pack some arguments to list due to limitation of number of arguments
      2) Enable `learning_rate_tensor` to avoid PT2 recompilation 
         (see D65511904)
      3) Enable constant `info_B_num_bits` and `info_B_mask` to address
         insufficent bit issue (see D69387123
     See detail in D68055168
- This diff addresses the error that caused S498612
```
  Issue:
   - keep learning_rate as float will cause PT2 recompilation
     (see D65511904)
   - introduce learning_rate_tensor will cause torch JIT script error 
     (which was the issue with previous land that caused S498612).
  Solution:
   - remove learning_rate from optimizer_args
   - replace with learning_rate_tensor as module attribute 
     (i.e., self.learning_rate_tensor)

Hence, this solution satisfy both PT2 and torch JIT script.
```
See more detail on the root cause of S498612 and the proposed workaround below.

------
__**Context:**__
Previous landing of D68055168 causes S498612 with error
```
RuntimeError: refined_slots[i]->isSubtypeOf(*attributes_[i].getType()) INTERNAL ASSERT FAILED at "fbcode/caffe2/aten/src/ATen/core/class_type.cpp":415, please report a bug to PyTorch.
```
Assertion failure point: https://fburl.com/code/q1xg0w8m 

__**Root cause:**__
This is a bug in PyTorch in which JIT type cannot infer the type correctly. This error occurs in the combination of the following:
- Only in JIT script 
- Only fails for module attributes

**Explanation**:
1. Module is scripted
e.g., [module_factory/model_materializer_full_sync.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/model_materializer_full_sync.py?lines=1445-1446%2C1622-1623), [module_factory/sync_sgd/train_module.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/sync_sgd/train_module.py?lines=1615)
2. JIT refines class type of `type __torch__.fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen` where module attributes are added to `refined_slots`
3. For each attribute, it asserts if the `refined_slot` attribute is a JIT subtype of the attribute ([here](https://www.internalfb.com/code/fbsource/[b5505c1f8e2a7945af8860fd29f89e220ffec919]/fbcode/caffe2/aten/src/ATen/core/class_type.cpp?lines=415) and [here](https://www.internalfb.com/code/fbsource/[d693ef764b7d8eb2e4205dd85a61a33d6fd7c977]/fbcode/caffe2/aten/src/ATen/core/jit_type_base.h?lines=396-409))
4. `SplitTableBatchedEmbeddingBagsCodegen` module has an attribute called `optimizer_args`. To avoid re-compilation in PT2, we need to change `learning_rate` to a Tensor (D65511904). We hence change the type of `optimizer_args` from `OptimizerArgs` (with `learning_rate` as float) to `OptimizerArgsPT2` (with `learning_rate_tensor`) as defined [here](https://www.internalfb.com/code/fbsource/[cec693b485d5fb800984943f4ce2bc6a1ca1c52a]/fbcode/deeplearning/fbgemm/fbgemm_gpu/codegen/training/python/lookup_args.template?lines=111).
5. JIT subtype sees the attribute as type `OptimizerArgsPT2` but `refined_slots` assumes the type as `Tuple` (which is the actual schema of `OptimizerArgsPT2`) and concludes that they are not the same.
 {F1976090817} 
See [log of refined_slot debug ](https://www.internalfb.com/phabricator/paste/view/P1756112218?lines=730-726)
6. Since JIT type sees that they are of different types, the assertion is triggered.

**Note**:
The JIT type is unable to figure out only when *Tensor* is added to the Class. 
- Previously prior to landing D68055168, `optimizer_args` was of type `OptimizerArgs` defined here. The class contains variables of types float, int and bool. JIT Type sees that they are both of type  `OptimizerArgs` and does not fail.
 {F1976090829} 
See [log of refined_slot debug](https://www.internalfb.com/intern/everpaste/?handle=GO2d1Bz063DvSSAEADUn74T4WyZPbsIXAAAB&phabricator_paste_number=1756112738)
- Adding any tensors to the class `OptimizerArgs` would cause errors as above. 
- Adding floats, ints or bools to the class `OptimizerArgs`is fine.
-------

__**Solution**__
We discussed and tested several workarounds, but went with the last one which satisfied both torch JIT script and PT2 compile.

1) keep the `optimizer_args` as `OptimizerArgs` the same, i.e., `learning_rate` remains float and create a `learning_rate_tensor` in invokers before passing it to `lookup_function`.
❌ this doesn't work with PT2 because it will cause re-compilation. 
2) keep v1 and v2 interfaces
❌ this doesn't work with torch jit script.
3) remove `learning_rate` from `optimizer_args`. Make `learning_rate_tensor` a module attribute. 
✅ this works with torch.jit.script and PT2 torch.compile.

Notice:
Many users access `learning_rate` directly through `optimizer_args.learning_rate`, changes to `optimizer_args` may cause backward compatibility for their cases.

We already landed D71444136 to address any future compatibility issue. 

Note: `learning_rate_tensor` is always created and stays on CPU, so there should be no host-device synchronization.

Reviewed By: q10, nautsimon

Differential Revision: D71010630
@spcyppt spcyppt force-pushed the export-D71010630 branch from 0d3c0dd to a79e7e6 Compare April 2, 2025 22:05
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71010630

spcyppt added a commit to spcyppt/FBGEMM that referenced this pull request Apr 2, 2025
Summary:
X-link: facebookresearch/FBGEMM#904
Pull Request resolved: pytorch#3821

**TLDR;**
- D68055168 was reverted due to S498612 (backout diff  D70996903)
- This diff re-lands D68055168 to enable the [new TBE front-end API](https://fb.workplace.com/groups/fbgemmusers/permalink/9662778130469548/)
- The main changes on this new unified API (API v2).
      1) Pack some arguments to list due to limitation of number of arguments
      2) Enable `learning_rate_tensor` to avoid PT2 recompilation
         (see D65511904)
      3) Enable constant `info_B_num_bits` and `info_B_mask` to address
         insufficent bit issue (see D69387123
     See detail in D68055168
- This diff addresses the error that caused S498612
```
  Issue:
   - keep learning_rate as float will cause PT2 recompilation
     (see D65511904)
   - introduce learning_rate_tensor will cause torch JIT script error
     (which was the issue with previous land that caused S498612).
  Solution:
   - remove learning_rate from optimizer_args
   - replace with learning_rate_tensor as module attribute
     (i.e., self.learning_rate_tensor)

Hence, this solution satisfy both PT2 and torch JIT script.
```
See more detail on the root cause of S498612 and the proposed workaround below.

------
__**Context:**__
Previous landing of D68055168 causes S498612 with error
```
RuntimeError: refined_slots[i]->isSubtypeOf(*attributes_[i].getType()) INTERNAL ASSERT FAILED at "fbcode/caffe2/aten/src/ATen/core/class_type.cpp":415, please report a bug to PyTorch.
```
Assertion failure point: https://fburl.com/code/q1xg0w8m

__**Root cause:**__
This is a bug in PyTorch in which JIT type cannot infer the type correctly. This error occurs in the combination of the following:
- Only in JIT script
- Only fails for module attributes

**Explanation**:
1. Module is scripted
e.g., [module_factory/model_materializer_full_sync.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/model_materializer_full_sync.py?lines=1445-1446%2C1622-1623), [module_factory/sync_sgd/train_module.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/sync_sgd/train_module.py?lines=1615)
2. JIT refines class type of `type __torch__.fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen` where module attributes are added to `refined_slots`
3. For each attribute, it asserts if the `refined_slot` attribute is a JIT subtype of the attribute ([here](https://www.internalfb.com/code/fbsource/[b5505c1f8e2a7945af8860fd29f89e220ffec919]/fbcode/caffe2/aten/src/ATen/core/class_type.cpp?lines=415) and [here](https://www.internalfb.com/code/fbsource/[d693ef764b7d8eb2e4205dd85a61a33d6fd7c977]/fbcode/caffe2/aten/src/ATen/core/jit_type_base.h?lines=396-409))
4. `SplitTableBatchedEmbeddingBagsCodegen` module has an attribute called `optimizer_args`. To avoid re-compilation in PT2, we need to change `learning_rate` to a Tensor (D65511904). We hence change the type of `optimizer_args` from `OptimizerArgs` (with `learning_rate` as float) to `OptimizerArgsPT2` (with `learning_rate_tensor`) as defined [here](https://www.internalfb.com/code/fbsource/[cec693b485d5fb800984943f4ce2bc6a1ca1c52a]/fbcode/deeplearning/fbgemm/fbgemm_gpu/codegen/training/python/lookup_args.template?lines=111).
5. JIT subtype sees the attribute as type `OptimizerArgsPT2` but `refined_slots` assumes the type as `Tuple` (which is the actual schema of `OptimizerArgsPT2`) and concludes that they are not the same.
 {F1976090817}
See [log of refined_slot debug ](https://www.internalfb.com/phabricator/paste/view/P1756112218?lines=730-726)
6. Since JIT type sees that they are of different types, the assertion is triggered.

**Note**:
The JIT type is unable to figure out only when *Tensor* is added to the Class.
- Previously prior to landing D68055168, `optimizer_args` was of type `OptimizerArgs` defined here. The class contains variables of types float, int and bool. JIT Type sees that they are both of type  `OptimizerArgs` and does not fail.
 {F1976090829}
See [log of refined_slot debug](https://www.internalfb.com/intern/everpaste/?handle=GO2d1Bz063DvSSAEADUn74T4WyZPbsIXAAAB&phabricator_paste_number=1756112738)
- Adding any tensors to the class `OptimizerArgs` would cause errors as above.
- Adding floats, ints or bools to the class `OptimizerArgs`is fine.
-------

__**Solution**__
We discussed and tested several workarounds, but went with the last one which satisfied both torch JIT script and PT2 compile.

1) keep the `optimizer_args` as `OptimizerArgs` the same, i.e., `learning_rate` remains float and create a `learning_rate_tensor` in invokers before passing it to `lookup_function`.
❌ this doesn't work with PT2 because it will cause re-compilation.
2) keep v1 and v2 interfaces
❌ this doesn't work with torch jit script.
3) remove `learning_rate` from `optimizer_args`. Make `learning_rate_tensor` a module attribute.
✅ this works with torch.jit.script and PT2 torch.compile.

Notice:
Many users access `learning_rate` directly through `optimizer_args.learning_rate`, changes to `optimizer_args` may cause backward compatibility for their cases.

We already landed D71444136 to address any future compatibility issue.

Note: `learning_rate_tensor` is always created and stays on CPU, so there should be no host-device synchronization.

Reviewed By: q10, nautsimon

Differential Revision: D71010630
@spcyppt spcyppt force-pushed the export-D71010630 branch from a79e7e6 to af842d3 Compare April 2, 2025 22:09
Summary:
X-link: facebookresearch/FBGEMM#904
Pull Request resolved: pytorch#3821

**TLDR;**
- D68055168 was reverted due to S498612 (backout diff  D70996903)
- This diff re-lands D68055168 to enable the [new TBE front-end API](https://fb.workplace.com/groups/fbgemmusers/permalink/9662778130469548/)
- The main changes on this new unified API (API v2).
      1) Pack some arguments to list due to limitation of number of arguments
      2) Enable `learning_rate_tensor` to avoid PT2 recompilation
         (see D65511904)
      3) Enable constant `info_B_num_bits` and `info_B_mask` to address
         insufficent bit issue (see D69387123
     See detail in D68055168
- This diff addresses the error that caused S498612
```
  Issue:
   - keep learning_rate as float will cause PT2 recompilation
     (see D65511904)
   - introduce learning_rate_tensor will cause torch JIT script error
     (which was the issue with previous land that caused S498612).
  Solution:
   - remove learning_rate from optimizer_args
   - replace with learning_rate_tensor as module attribute
     (i.e., self.learning_rate_tensor)

Hence, this solution satisfy both PT2 and torch JIT script.
```
See more detail on the root cause of S498612 and the proposed workaround below.

------
__**Context:**__
Previous landing of D68055168 causes S498612 with error
```
RuntimeError: refined_slots[i]->isSubtypeOf(*attributes_[i].getType()) INTERNAL ASSERT FAILED at "fbcode/caffe2/aten/src/ATen/core/class_type.cpp":415, please report a bug to PyTorch.
```
Assertion failure point: https://fburl.com/code/q1xg0w8m

__**Root cause:**__
This is a bug in PyTorch in which JIT type cannot infer the type correctly. This error occurs in the combination of the following:
- Only in JIT script
- Only fails for module attributes

**Explanation**:
1. Module is scripted
e.g., [module_factory/model_materializer_full_sync.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/model_materializer_full_sync.py?lines=1445-1446%2C1622-1623), [module_factory/sync_sgd/train_module.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/sync_sgd/train_module.py?lines=1615)
2. JIT refines class type of `type __torch__.fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen` where module attributes are added to `refined_slots`
3. For each attribute, it asserts if the `refined_slot` attribute is a JIT subtype of the attribute ([here](https://www.internalfb.com/code/fbsource/[b5505c1f8e2a7945af8860fd29f89e220ffec919]/fbcode/caffe2/aten/src/ATen/core/class_type.cpp?lines=415) and [here](https://www.internalfb.com/code/fbsource/[d693ef764b7d8eb2e4205dd85a61a33d6fd7c977]/fbcode/caffe2/aten/src/ATen/core/jit_type_base.h?lines=396-409))
4. `SplitTableBatchedEmbeddingBagsCodegen` module has an attribute called `optimizer_args`. To avoid re-compilation in PT2, we need to change `learning_rate` to a Tensor (D65511904). We hence change the type of `optimizer_args` from `OptimizerArgs` (with `learning_rate` as float) to `OptimizerArgsPT2` (with `learning_rate_tensor`) as defined [here](https://www.internalfb.com/code/fbsource/[cec693b485d5fb800984943f4ce2bc6a1ca1c52a]/fbcode/deeplearning/fbgemm/fbgemm_gpu/codegen/training/python/lookup_args.template?lines=111).
5. JIT subtype sees the attribute as type `OptimizerArgsPT2` but `refined_slots` assumes the type as `Tuple` (which is the actual schema of `OptimizerArgsPT2`) and concludes that they are not the same.
 {F1976090817}
See [log of refined_slot debug ](https://www.internalfb.com/phabricator/paste/view/P1756112218?lines=730-726)
6. Since JIT type sees that they are of different types, the assertion is triggered.

**Note**:
The JIT type is unable to figure out only when *Tensor* is added to the Class.
- Previously prior to landing D68055168, `optimizer_args` was of type `OptimizerArgs` defined here. The class contains variables of types float, int and bool. JIT Type sees that they are both of type  `OptimizerArgs` and does not fail.
 {F1976090829}
See [log of refined_slot debug](https://www.internalfb.com/intern/everpaste/?handle=GO2d1Bz063DvSSAEADUn74T4WyZPbsIXAAAB&phabricator_paste_number=1756112738)
- Adding any tensors to the class `OptimizerArgs` would cause errors as above.
- Adding floats, ints or bools to the class `OptimizerArgs`is fine.
-------

__**Solution**__
We discussed and tested several workarounds, but went with the last one which satisfied both torch JIT script and PT2 compile.

1) keep the `optimizer_args` as `OptimizerArgs` the same, i.e., `learning_rate` remains float and create a `learning_rate_tensor` in invokers before passing it to `lookup_function`.
❌ this doesn't work with PT2 because it will cause re-compilation.
2) keep v1 and v2 interfaces
❌ this doesn't work with torch jit script.
3) remove `learning_rate` from `optimizer_args`. Make `learning_rate_tensor` a module attribute.
✅ this works with torch.jit.script and PT2 torch.compile.

Notice:
Many users access `learning_rate` directly through `optimizer_args.learning_rate`, changes to `optimizer_args` may cause backward compatibility for their cases.

We already landed D71444136 to address any future compatibility issue.

Note: `learning_rate_tensor` is always created and stays on CPU, so there should be no host-device synchronization.

Reviewed By: q10, nautsimon

Differential Revision: D71010630
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71010630

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 1a9c70a.

q10 pushed a commit to q10/FBGEMM that referenced this pull request Apr 10, 2025
Summary:
Pull Request resolved: facebookresearch/FBGEMM#904
X-link: pytorch#3821

**TLDR;**
- D68055168 was reverted due to S498612 (backout diff  D70996903)
- This diff re-lands D68055168 to enable the [new TBE front-end API](https://fb.workplace.com/groups/fbgemmusers/permalink/9662778130469548/)
- The main changes on this new unified API (API v2).
      1) Pack some arguments to list due to limitation of number of arguments
      2) Enable `learning_rate_tensor` to avoid PT2 recompilation
         (see D65511904)
      3) Enable constant `info_B_num_bits` and `info_B_mask` to address
         insufficent bit issue (see D69387123)
     See detail in D68055168
- This diff addresses the error that caused S498612
```
  Issue:
   - keep learning_rate as float will cause PT2 recompilation
     (see D65511904)
   - introduce learning_rate_tensor will cause torch JIT script error
     (which was the issue with previous land that caused S498612).
  Solution:
   - remove learning_rate from optimizer_args
   - replace with learning_rate_tensor as module attribute
     (i.e., self.learning_rate_tensor)

Hence, this solution satisfy both PT2 and torch JIT script.
```
- `torch.ops.fbgemm.get_infos_metadata` now returns `info_B_num_bits` and  `info_B_mask` that are calculated based on **number of features `T`** (see detail in D69387123)

See more detail on the root cause of S498612 and the proposed workaround below.

------
__**Context:**__
Previous landing of D68055168 causes S498612 with error
```
RuntimeError: refined_slots[i]->isSubtypeOf(*attributes_[i].getType()) INTERNAL ASSERT FAILED at "fbcode/caffe2/aten/src/ATen/core/class_type.cpp":415, please report a bug to PyTorch.
```
Assertion failure point: https://fburl.com/code/q1xg0w8m

__**Root cause:**__
This is a bug in PyTorch in which JIT type cannot infer the type correctly. This error occurs in the combination of the following:
- Only in JIT script
- Only fails for module attributes

**Explanation**:
1. Module is scripted
e.g., [module_factory/model_materializer_full_sync.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/model_materializer_full_sync.py?lines=1445-1446%2C1622-1623), [module_factory/sync_sgd/train_module.py](https://www.internalfb.com/code/fbsource/[ed9d26f3fda9f40054948e7c492bbbbf5dab987e]/fbcode/caffe2/torch/fb/module_factory/sync_sgd/train_module.py?lines=1615)
2. JIT refines class type of `type __torch__.fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen` where module attributes are added to `refined_slots`
3. For each attribute, it asserts if the `refined_slot` attribute is a JIT subtype of the attribute ([here](https://www.internalfb.com/code/fbsource/[b5505c1f8e2a7945af8860fd29f89e220ffec919]/fbcode/caffe2/aten/src/ATen/core/class_type.cpp?lines=415) and [here](https://www.internalfb.com/code/fbsource/[d693ef764b7d8eb2e4205dd85a61a33d6fd7c977]/fbcode/caffe2/aten/src/ATen/core/jit_type_base.h?lines=396-409))
4. `SplitTableBatchedEmbeddingBagsCodegen` module has an attribute called `optimizer_args`. To avoid re-compilation in PT2, we need to change `learning_rate` to a Tensor (D65511904). We hence change the type of `optimizer_args` from `OptimizerArgs` (with `learning_rate` as float) to `OptimizerArgsPT2` (with `learning_rate_tensor`) as defined [here](https://www.internalfb.com/code/fbsource/[cec693b485d5fb800984943f4ce2bc6a1ca1c52a]/fbcode/deeplearning/fbgemm/fbgemm_gpu/codegen/training/python/lookup_args.template?lines=111).
5. JIT subtype sees the attribute as type `OptimizerArgsPT2` but `refined_slots` assumes the type as `Tuple` (which is the actual schema of `OptimizerArgsPT2`) and concludes that they are not the same.
 {F1976090817}
See [log of refined_slot debug ](https://www.internalfb.com/phabricator/paste/view/P1756112218?lines=730-726)
6. Since JIT type sees that they are of different types, the assertion is triggered.

**Note**:
The JIT type is unable to figure out only when *Tensor* is added to the Class.
- Previously prior to landing D68055168, `optimizer_args` was of type `OptimizerArgs` defined here. The class contains variables of types float, int and bool. JIT Type sees that they are both of type  `OptimizerArgs` and does not fail.
 {F1976090829}
See [log of refined_slot debug](https://www.internalfb.com/intern/everpaste/?handle=GO2d1Bz063DvSSAEADUn74T4WyZPbsIXAAAB&phabricator_paste_number=1756112738)
- Adding any tensors to the class `OptimizerArgs` would cause errors as above.
- Adding floats, ints or bools to the class `OptimizerArgs`is fine.
-------

__**Solution**__
We discussed and tested several workarounds, but went with the last one which satisfied both torch JIT script and PT2 compile.

1) keep the `optimizer_args` as `OptimizerArgs` the same, i.e., `learning_rate` remains float and create a `learning_rate_tensor` in invokers before passing it to `lookup_function`.
❌ this doesn't work with PT2 because it will cause re-compilation.
2) keep v1 and v2 interfaces
❌ this doesn't work with torch jit script.
3) remove `learning_rate` from `optimizer_args`. Make `learning_rate_tensor` a module attribute.
✅ this works with torch.jit.script and PT2 torch.compile.

Notice:
Many users access `learning_rate` directly through `optimizer_args.learning_rate`, changes to `optimizer_args` may cause backward compatibility for their cases.

We already landed D71444136 to address any future compatibility issue.

Note: `learning_rate_tensor` is always created and stays on CPU, so there should be no host-device synchronization.

Reviewed By: q10, sryap, nautsimon

Differential Revision: D71010630

fbshipit-source-id: 961275d56a5b48183c331b7ecb7b3d80a70b0b25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants