Skip to content

Commit 221c35c

Browse files
Split out model and dataset creation into conftest
1 parent d7c0779 commit 221c35c

File tree

3 files changed

+302
-270
lines changed

3 files changed

+302
-270
lines changed
+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import os
2+
3+
import hugectr
4+
import numpy as np
5+
import pandas as pd
6+
import pytest
7+
8+
from merlin.io import Dataset
9+
from merlin.schema import Tags
10+
from merlin.transforms import Workflow
11+
from merlin.transforms.ops import AddTags, Categorify
12+
13+
14+
@pytest.fixture
15+
def hugectr_example_dataset(tmpdir):
16+
num_rows = 64
17+
18+
df = pd.DataFrame(
19+
{
20+
"a": np.arange(num_rows).astype(np.int64),
21+
"b": np.arange(num_rows).astype(np.int64),
22+
"c": np.arange(num_rows).astype(np.int64),
23+
"d": np.random.rand(num_rows).astype(np.float32),
24+
"label": np.array([0] * num_rows).astype(np.float32),
25+
},
26+
)
27+
categorical_columns = ["a", "b", "c"]
28+
dense_columns = ["d"]
29+
target_columns = ["label"]
30+
31+
workflow = Workflow(
32+
(categorical_columns >> Categorify())
33+
+ (dense_columns >> AddTags(Tags.CONTINUOUS))
34+
+ (target_columns >> AddTags(Tags.TARGET))
35+
)
36+
37+
dataset = workflow.fit_transform(Dataset(df))
38+
39+
return dataset
40+
41+
42+
@pytest.fixture
43+
def hugectr_example_model(hugectr_example_dataset, tmpdir):
44+
dataset = hugectr_example_dataset
45+
46+
train_path = os.path.join(tmpdir, "hugectr_example_data/")
47+
os.mkdir(train_path)
48+
49+
dataset.to_parquet(
50+
output_path=tmpdir,
51+
cats=dataset.schema.select_by_tag(Tags.CATEGORICAL).column_names,
52+
conts=dataset.schema.select_by_tag(Tags.CONTINUOUS).column_names,
53+
labels=dataset.schema.select_by_tag(Tags.TARGET).column_names,
54+
)
55+
56+
# slot_sizes = list of caridinalities per column
57+
slot_sizes = [
58+
col.properties["embedding_sizes"]["cardinality"]
59+
for col in dataset.schema.select_by_tag(Tags.CATEGORICAL)
60+
]
61+
62+
# dense_dim = num of dense inputs
63+
dense_dim = len(dataset.schema.select_by_tag(Tags.CONTINUOUS))
64+
65+
solver = hugectr.CreateSolver(
66+
vvgpu=[[0]],
67+
batchsize=10,
68+
batchsize_eval=10,
69+
max_eval_batches=50,
70+
i64_input_key=True,
71+
use_mixed_precision=False,
72+
repeat_dataset=True,
73+
)
74+
# https://github.com/NVIDIA-Merlin/HugeCTR/blob/9e648f879166fc93931c676a5594718f70178a92/docs/source/api/python_interface.md#datareaderparams
75+
reader = hugectr.DataReaderParams(
76+
data_reader_type=hugectr.DataReaderType_t.Parquet,
77+
source=[os.path.join(train_path, "_file_list.txt")],
78+
eval_source=os.path.join(train_path, "_file_list.txt"),
79+
check_type=hugectr.Check_t.Non,
80+
)
81+
82+
optimizer = hugectr.CreateOptimizer(optimizer_type=hugectr.Optimizer_t.Adam)
83+
model = hugectr.Model(solver, reader, optimizer)
84+
85+
model.add(
86+
hugectr.Input(
87+
label_dim=1,
88+
label_name="label",
89+
dense_dim=dense_dim,
90+
dense_name="dense",
91+
data_reader_sparse_param_array=[
92+
hugectr.DataReaderSparseParam("data1", len(slot_sizes) + 1, True, len(slot_sizes))
93+
],
94+
)
95+
)
96+
model.add(
97+
hugectr.SparseEmbedding(
98+
embedding_type=hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash,
99+
workspace_size_per_gpu_in_mb=107,
100+
embedding_vec_size=16,
101+
combiner="sum",
102+
sparse_embedding_name="sparse_embedding1",
103+
bottom_name="data1",
104+
slot_size_array=slot_sizes,
105+
optimizer=optimizer,
106+
)
107+
)
108+
model.add(
109+
hugectr.DenseLayer(
110+
layer_type=hugectr.Layer_t.InnerProduct,
111+
bottom_names=["dense"],
112+
top_names=["fc1"],
113+
num_output=512,
114+
)
115+
)
116+
model.add(
117+
hugectr.DenseLayer(
118+
layer_type=hugectr.Layer_t.Reshape,
119+
bottom_names=["sparse_embedding1"],
120+
top_names=["reshape1"],
121+
leading_dim=48,
122+
)
123+
)
124+
model.add(
125+
hugectr.DenseLayer(
126+
layer_type=hugectr.Layer_t.InnerProduct,
127+
bottom_names=["reshape1", "fc1"],
128+
top_names=["fc2"],
129+
num_output=1,
130+
)
131+
)
132+
model.add(
133+
hugectr.DenseLayer(
134+
layer_type=hugectr.Layer_t.BinaryCrossEntropyLoss,
135+
bottom_names=["fc2", "label"],
136+
top_names=["loss"],
137+
)
138+
)
139+
model.compile()
140+
model.summary()
141+
model.fit(max_iter=20, display=100, eval_interval=200, snapshot=10)
142+
143+
return model

0 commit comments

Comments
 (0)