Skip to content

Commit dfdbee8

Browse files
Initial commit
fbshipit-source-id: 8f9686235729bb0aa9e03e3dbf73f74e75932b3f
0 parents  commit dfdbee8

File tree

128 files changed

+28284
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

128 files changed

+28284
-0
lines changed

.github/workflows/pre-commit.yaml

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
name: pre-commit
2+
3+
on:
4+
push:
5+
branches: [master]
6+
pull_request:
7+
8+
jobs:
9+
pre-commit:
10+
runs-on: ubuntu-latest
11+
steps:
12+
- name: Setup Python
13+
uses: actions/setup-python@v2
14+
with:
15+
python-version: 3.8
16+
architecture: x64
17+
- name: Checkout Torchrec
18+
uses: actions/checkout@v2
19+
- name: Run pre-commit
20+
uses: pre-commit/[email protected]

.pre-commit-config.yaml

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v4.0.1
4+
hooks:
5+
- id: check-toml
6+
- id: check-yaml
7+
exclude: packaging/.*
8+
- id: end-of-file-fixer
9+
10+
- repo: https://github.com/omnilib/ufmt
11+
rev: v1.3.0
12+
hooks:
13+
- id: ufmt
14+
additional_dependencies:
15+
- black == 21.9b0
16+
- usort == 0.6.4

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[tool.usort]
2+
3+
first_party_detection = false

setup.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/usr/bin/env python3
2+
3+
from setuptools import setup, find_packages
4+
5+
# Minimal setup configuration.
6+
setup(
7+
name="torchrec",
8+
packages=find_packages(exclude=("*tests",)),
9+
)

test_installation.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/usr/bin/env python3
2+
3+
import os
4+
5+
import torchx.specs as specs
6+
from torchx.components.base import torch_dist_role
7+
from torchx.specs.api import Resource
8+
9+
10+
def test_installation() -> specs.AppDef:
11+
cwd = os.getcwd()
12+
entrypoint = os.path.join(cwd, "test_installation_main.py")
13+
14+
user = os.environ.get("USER")
15+
image = f"/data/home/{user}"
16+
17+
return specs.AppDef(
18+
name="test_installation",
19+
roles=[
20+
torch_dist_role(
21+
name="trainer",
22+
image=image,
23+
# AWS p4d instance (https://aws.amazon.com/ec2/instance-types/p4/).
24+
resource=Resource(
25+
cpu=96,
26+
gpu=8,
27+
memMB=-1,
28+
),
29+
num_replicas=1,
30+
entrypoint=entrypoint,
31+
nproc_per_node="1",
32+
rdzv_backend="c10d",
33+
args=[],
34+
),
35+
],
36+
)

test_installation_main.py

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#!/usr/bin/env python3
2+
3+
import os
4+
import sys
5+
from typing import List, Iterator
6+
7+
import torch
8+
import torch.distributed as dist
9+
from torchrec import EmbeddingBagCollection
10+
from torchrec import KeyedJaggedTensor
11+
from torchrec.distributed.model_parallel import DistributedModelParallel
12+
from torchrec.models.dlrm import DLRM
13+
from torchrec.modules.embedding_configs import EmbeddingBagConfig
14+
from torchrec.optim.keyed import KeyedOptimizerWrapper
15+
16+
17+
class RandomIterator(Iterator):
18+
def __init__(
19+
self, batch_size: int, num_dense: int, num_sparse: int, num_embeddings: int
20+
) -> None:
21+
self.batch_size = batch_size
22+
self.num_dense = num_dense
23+
self.num_sparse = num_sparse
24+
self.sparse_keys = [f"feature{id}" for id in range(self.num_sparse)]
25+
self.num_embeddings = num_embeddings
26+
self.num_ids_per_feature = 3
27+
self.num_ids_to_generate = (
28+
self.num_sparse * self.num_ids_per_feature * self.batch_size
29+
)
30+
31+
def __next__(self) -> (torch.Tensor, KeyedJaggedTensor, torch.Tensor):
32+
float_features = torch.randn(
33+
self.batch_size,
34+
self.num_dense,
35+
)
36+
labels = torch.randint(
37+
low=0,
38+
high=2,
39+
size=(self.batch_size,),
40+
)
41+
sparse_ids = torch.randint(
42+
high=self.num_sparse,
43+
size=(self.num_ids_to_generate,),
44+
)
45+
sparse_features = KeyedJaggedTensor.from_offsets_sync(
46+
keys=self.sparse_keys,
47+
values=sparse_ids,
48+
offsets=torch.tensor(
49+
list(range(0, self.num_ids_to_generate + 1, self.num_ids_per_feature)),
50+
dtype=torch.int32,
51+
),
52+
)
53+
return (float_features, sparse_features, labels)
54+
55+
56+
def main(argv: List[str]) -> None:
57+
batch_size = 1024
58+
num_dense = 1000
59+
num_sparse = 20
60+
num_embeddings = 1000000
61+
62+
configs = [
63+
EmbeddingBagConfig(
64+
name=f"table{id}",
65+
embedding_dim=64,
66+
num_embeddings=num_embeddings,
67+
feature_names=[f"feature{id}"],
68+
)
69+
for id in range(num_sparse)
70+
]
71+
72+
rank = int(os.environ["LOCAL_RANK"])
73+
if torch.cuda.is_available():
74+
device = torch.device(f"cuda:{rank}")
75+
backend = "nccl"
76+
torch.cuda.set_device(device)
77+
else:
78+
raise Exception("Cuda not available")
79+
80+
if not torch.distributed.is_initialized():
81+
dist.init_process_group(backend=backend)
82+
83+
model = DLRM(
84+
embedding_bag_collection=EmbeddingBagCollection(
85+
tables=configs, device=torch.device("meta")
86+
),
87+
dense_in_features=num_dense,
88+
dense_arch_layer_sizes=[500, 64],
89+
over_arch_layer_sizes=[32, 16, 1],
90+
dense_device=device,
91+
)
92+
model = DistributedModelParallel(
93+
module=model,
94+
device=device,
95+
)
96+
optimizer = KeyedOptimizerWrapper(
97+
dict(model.named_parameters()),
98+
lambda params: torch.optim.SGD(params, lr=0.01),
99+
)
100+
101+
random_iterator = RandomIterator(batch_size, num_dense, num_sparse, num_embeddings)
102+
loss_fn = torch.nn.BCEWithLogitsLoss()
103+
for _ in range(10):
104+
(dense_features, sparse_features, labels) = next(random_iterator)
105+
dense_features = dense_features.to(device)
106+
sparse_features = sparse_features.to(device)
107+
output = model(dense_features, sparse_features)
108+
labels = labels.to(device)
109+
loss = loss_fn(output.squeeze(), labels.float())
110+
torch.sum(loss, dim=0).backward()
111+
optimizer.zero_grad()
112+
optimizer.step()
113+
114+
print(
115+
"\033[92m" + "Successfully ran a few epochs for DLRM. Installation looks good!"
116+
)
117+
118+
119+
if __name__ == "__main__":
120+
main(sys.argv[1:])

torchrec/__init__.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#!/usr/bin/env python3
2+
3+
import torchrec.distributed # noqa
4+
import torchrec.quant # noqa
5+
from torchrec.fx import tracer # noqa
6+
from torchrec.modules.embedding_configs import ( # noqa
7+
EmbeddingBagConfig,
8+
EmbeddingConfig,
9+
DataType,
10+
PoolingType,
11+
)
12+
from torchrec.modules.embedding_modules import ( # noqa
13+
EmbeddingBagCollection,
14+
EmbeddingCollection,
15+
EmbeddingBagCollectionInterface,
16+
) # noqa
17+
from torchrec.modules.score_learning import PositionWeightsAttacher # noqa
18+
from torchrec.sparse.jagged_tensor import ( # noqa
19+
JaggedTensor,
20+
KeyedJaggedTensor,
21+
KeyedTensor,
22+
)
23+
from torchrec.types import Pipelineable, Multistreamable # noqa

torchrec/datasets/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/usr/bin/env python3
2+
3+
import torchrec.datasets.criteo # noqa
4+
import torchrec.datasets.movielens # noqa
5+
import torchrec.datasets.random # noqa
6+
import torchrec.datasets.utils # noqa

torchrec/datasets/criteo.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#!/usr/bin/env python3
2+
3+
from typing import (
4+
Iterator,
5+
Any,
6+
Callable,
7+
Dict,
8+
Iterable,
9+
List,
10+
Optional,
11+
Union,
12+
)
13+
14+
import torch
15+
import torch.utils.data.datapipes as dp
16+
from torch.utils.data import IterDataPipe
17+
from torchrec.datasets.utils import LoadFiles, ReadLinesFromCSV, safe_cast
18+
19+
20+
INT_FEATURE_COUNT = 13
21+
CAT_FEATURE_COUNT = 26
22+
DEFAULT_LABEL_NAME = "label"
23+
DEFAULT_INT_NAMES: List[str] = [f"int_{idx}" for idx in range(INT_FEATURE_COUNT)]
24+
DEFAULT_CAT_NAMES: List[str] = [f"cat_{idx}" for idx in range(CAT_FEATURE_COUNT)]
25+
DEFAULT_COLUMN_NAMES: List[str] = [
26+
DEFAULT_LABEL_NAME,
27+
*DEFAULT_INT_NAMES,
28+
*DEFAULT_CAT_NAMES,
29+
]
30+
31+
COLUMN_TYPE_CASTERS: List[Callable[[Union[int, str]], Union[int, str]]] = [
32+
lambda val: safe_cast(val, int, 0),
33+
*(lambda val: safe_cast(val, int, 0) for _ in range(INT_FEATURE_COUNT)),
34+
*(lambda val: safe_cast(val, str, "") for _ in range(CAT_FEATURE_COUNT)),
35+
]
36+
37+
38+
def _default_row_mapper(example: List[str]) -> Dict[str, Union[int, str]]:
39+
column_names = reversed(DEFAULT_COLUMN_NAMES)
40+
column_type_casters = reversed(COLUMN_TYPE_CASTERS)
41+
return {
42+
next(column_names): next(column_type_casters)(val) for val in reversed(example)
43+
}
44+
45+
46+
class CriteoIterDataPipe(IterDataPipe):
47+
def __init__(
48+
self,
49+
paths: Iterable[str],
50+
*,
51+
# pyre-ignore[2]
52+
row_mapper: Optional[Callable[[List[str]], Any]] = _default_row_mapper,
53+
# pyre-ignore[2]
54+
**open_kw,
55+
) -> None:
56+
self.paths = paths
57+
self.row_mapper = row_mapper
58+
self.open_kw: Any = open_kw # pyre-ignore[4]
59+
60+
# pyre-ignore[3]
61+
def __iter__(self) -> Iterator[Any]:
62+
worker_info = torch.utils.data.get_worker_info()
63+
paths = self.paths
64+
if worker_info is not None:
65+
paths = (
66+
path
67+
for (idx, path) in enumerate(paths)
68+
if idx % worker_info.num_workers == worker_info.id
69+
)
70+
datapipe = LoadFiles(paths, mode="r", **self.open_kw)
71+
datapipe = ReadLinesFromCSV(datapipe, delimiter="\t")
72+
if self.row_mapper:
73+
datapipe = dp.iter.Mapper(datapipe, self.row_mapper)
74+
yield from datapipe
75+
76+
77+
def criteo_terabyte(
78+
paths: Iterable[str],
79+
*,
80+
# pyre-ignore[2]
81+
row_mapper: Optional[Callable[[List[str]], Any]] = _default_row_mapper,
82+
# pyre-ignore[2]
83+
**open_kw,
84+
) -> IterDataPipe:
85+
"""`Criteo 1TB Click Logs <https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/>`_ Dataset
86+
Args:
87+
paths (str): local paths to TSV files that constitute the Criteo 1TB dataset.
88+
row_mapper (Optional[Callable[[List[str]], Any]]): function to apply to each split TSV line.
89+
open_kw: options to pass to underlying invocation of iopath.common.file_io.PathManager.open.
90+
91+
Example:
92+
>>> datapipe = criteo_terabyte(
93+
>>> ("/home/datasets/criteo/day_0.tsv", "/home/datasets/criteo/day_1.tsv")
94+
>>> )
95+
>>> datapipe = dp.iter.Batcher(datapipe, 100)
96+
>>> datapipe = dp.iter.Collator(datapipe)
97+
>>> batch = next(iter(datapipe))
98+
"""
99+
return CriteoIterDataPipe(paths, row_mapper=row_mapper, **open_kw)
100+
101+
102+
def criteo_kaggle(
103+
path: str,
104+
*,
105+
# pyre-ignore[2]
106+
row_mapper: Optional[Callable[[List[str]], Any]] = _default_row_mapper,
107+
# pyre-ignore[2]
108+
**open_kw,
109+
) -> IterDataPipe:
110+
"""`Kaggle/Criteo Display Advertising <https://www.kaggle.com/c/criteo-display-ad-challenge/>`_ Dataset
111+
Args:
112+
root (str): local path to train or test dataset file.
113+
row_mapper (Optional[Callable[[List[str]], Any]]): function to apply to each split TSV line.
114+
open_kw: options to pass to underlying invocation of iopath.common.file_io.PathManager.open.
115+
116+
Example:
117+
>>> train_datapipe = criteo_kaggle(
118+
>>> "/home/datasets/criteo_kaggle/train.txt",
119+
>>> )
120+
>>> example = next(iter(train_datapipe))
121+
>>> test_datapipe = criteo_kaggle(
122+
>>> "/home/datasets/criteo_kaggle/test.txt",
123+
>>> )
124+
>>> example = next(iter(test_datapipe))
125+
"""
126+
return CriteoIterDataPipe((path,), row_mapper=row_mapper, **open_kw)

0 commit comments

Comments
 (0)