Skip to content

Commit ce323c0

Browse files
zhuzilinfacebook-github-bot
authored andcommitted
Migrate id transformer binding (#835)
Summary: This PR integrated the migrated id transformer and lxu strategy to create a python binding. There are several point that may need to be mentioned: - We use `IDTransformerVariant` and `LXUStrategyVariant` type to support adding more kinds of id transformers and lxu strategies. - The origin design used [nlohmann::json](https://github.com/nlohmann/json) to parse the configs. To remove this dependency, we list all configs of all kinds of id transformers and lxu strategies as the parameter of the constructor of the variant. - We choose to use pytorch native binding (`CustomClassHolder`) instead of pybind to be coherent with fbgemm_gpu. Thank you for your time on reviewing this PR :) gently ping divchenko colin2328 reyoung Pull Request resolved: #835 Reviewed By: s4ayub Differential Revision: D41507250 Pulled By: colin2328 fbshipit-source-id: f6e8bc6234f287dfba333e1fe7bbe7de64df88b6
1 parent d981c50 commit ce323c0

17 files changed

+476
-221
lines changed

benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_benchmark.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
namespace torchrec {
1313
void BM_MixedLFULRUStrategy(benchmark::State& state) {
1414
size_t num_ext_values = state.range(0);
15-
std::vector<MixedLFULRUStrategy::lxu_record_t> ext_values(num_ext_values);
15+
std::vector<lxu_record_t> ext_values(num_ext_values);
1616

1717
MixedLFULRUStrategy strategy;
1818
for (auto& v : ext_values) {

benchmarks/cpp/dynamic_embedding/mixed_lfu_lru_strategy_evict_benchmark.cpp

+8-7
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ class RecordIterator {
1818
public:
1919
RecordIterator(Container::const_iterator begin, Container::const_iterator end)
2020
: begin_(begin), end_(end) {}
21-
std::optional<TransformerRecord<uint32_t>> operator()() {
21+
std::optional<record_t> operator()() {
2222
if (begin_ == end_) {
2323
return std::nullopt;
2424
}
25-
TransformerRecord<uint32_t> record{};
26-
record.global_id_ = next_global_id_++;
27-
record.lxu_record_ = *reinterpret_cast<const uint32_t*>(&(*begin_++));
25+
record_t record{};
26+
record.global_id = next_global_id_++;
27+
record.lxu_record = *reinterpret_cast<const uint32_t*>(&(*begin_++));
2828
return record;
2929
}
3030

@@ -52,8 +52,8 @@ class RandomizeMixedLXUSet {
5252
std::uniform_int_distribution<uint32_t> time_dist(0, max_time - 1);
5353
for (size_t i = 0; i < n; ++i) {
5454
MixedLFULRUStrategy::Record record{};
55-
record.freq_power_ = freq_dist(engine) + min_freq;
56-
record.time_ = time_dist(engine);
55+
record.freq_power = freq_dist(engine) + min_freq;
56+
record.time = time_dist(engine);
5757
records_.emplace_back(record);
5858
}
5959
}
@@ -68,8 +68,9 @@ class RandomizeMixedLXUSet {
6868

6969
void BM_MixedLFULRUStrategyEvict(benchmark::State& state) {
7070
RandomizeMixedLXUSet lxuSet(state.range(0), state.range(1), state.range(2));
71+
MixedLFULRUStrategy strategy;
7172
for (auto _ : state) {
72-
MixedLFULRUStrategy::evict(lxuSet.Iterator(), state.range(3));
73+
strategy.evict(lxuSet.Iterator(), state.range(3));
7374
}
7475
}
7576

benchmarks/cpp/dynamic_embedding/naive_id_transformer_benchmark.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
namespace torchrec {
1414

1515
static void BM_NaiveIDTransformer(benchmark::State& state) {
16-
using Tag = int32_t;
17-
NaiveIDTransformer<Tag> transformer(2e8);
16+
NaiveIDTransformer transformer(2e8);
1817
torch::Tensor global_ids = torch::empty({1024, 1024}, torch::kLong);
1918
torch::Tensor cache_ids = torch::empty_like(global_ids);
2019
for (auto _ : state) {

test/cpp/dynamic_embedding/mixed_lfu_lru_strategy_test.cpp

+26-27
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
namespace torchrec {
1313
TEST(TDE, order) {
1414
MixedLFULRUStrategy::Record a;
15-
a.time_ = 1;
16-
a.freq_power_ = 31;
15+
a.time = 1;
16+
a.freq_power = 31;
1717
uint32_t i32 = a.ToUint32();
1818
ASSERT_EQ(0xF8000001, i32);
1919
}
@@ -23,42 +23,41 @@ TEST(TDE, MixedLFULRUStrategy_Evict) {
2323
{
2424
records.emplace_back();
2525
records.back().first = 1;
26-
records.back().second.time_ = 100;
27-
records.back().second.freq_power_ = 2;
26+
records.back().second.time = 100;
27+
records.back().second.freq_power = 2;
2828
}
2929
{
3030
records.emplace_back();
3131
records.back().first = 2;
32-
records.back().second.time_ = 10;
33-
records.back().second.freq_power_ = 2;
32+
records.back().second.time = 10;
33+
records.back().second.freq_power = 2;
3434
}
3535
{
3636
records.emplace_back();
3737
records.back().first = 3;
38-
records.back().second.time_ = 100;
39-
records.back().second.freq_power_ = 1;
38+
records.back().second.time = 100;
39+
records.back().second.freq_power = 1;
4040
}
4141
{
4242
records.emplace_back();
4343
records.back().first = 4;
44-
records.back().second.time_ = 150;
45-
records.back().second.freq_power_ = 2;
44+
records.back().second.time = 150;
45+
records.back().second.freq_power = 2;
4646
}
4747
size_t offset_{0};
48-
auto ids = MixedLFULRUStrategy::evict(
49-
[&offset_,
50-
&records]() -> std::optional<MixedLFULRUStrategy::transformer_record_t> {
48+
MixedLFULRUStrategy strategy;
49+
auto ids = strategy.evict(
50+
[&offset_, &records]() -> std::optional<record_t> {
5151
if (offset_ == records.size()) {
5252
return std::nullopt;
5353
}
5454
auto record = records[offset_++];
55-
MixedLFULRUStrategy::lxu_record_t ext_type =
56-
*reinterpret_cast<MixedLFULRUStrategy::lxu_record_t*>(
57-
&record.second);
58-
return MixedLFULRUStrategy::transformer_record_t{
59-
.global_id_ = record.first,
60-
.cache_id_ = 0,
61-
.lxu_record_ = ext_type,
55+
lxu_record_t ext_type =
56+
*reinterpret_cast<lxu_record_t*>(&record.second);
57+
return record_t{
58+
.global_id = record.first,
59+
.cache_id = 0,
60+
.lxu_record = ext_type,
6261
};
6362
},
6463
3);
@@ -73,12 +72,12 @@ TEST(TDE, MixedLFULRUStrategy_Transform) {
7372
constexpr static size_t n_iter = 1000000;
7473
MixedLFULRUStrategy strategy;
7574
strategy.update_time(10);
76-
MixedLFULRUStrategy::lxu_record_t val;
75+
lxu_record_t val;
7776
{
7877
val = strategy.update(0, 0, std::nullopt);
7978
auto record = reinterpret_cast<MixedLFULRUStrategy::Record*>(&val);
80-
ASSERT_EQ(record->freq_power_, 5);
81-
ASSERT_EQ(record->time_, 10);
79+
ASSERT_EQ(record->freq_power, 5);
80+
ASSERT_EQ(record->time, 10);
8281
}
8382

8483
uint32_t freq_power_5_cnt = 0;
@@ -87,13 +86,13 @@ TEST(TDE, MixedLFULRUStrategy_Transform) {
8786
for (size_t i = 0; i < n_iter; ++i) {
8887
auto tmp = strategy.update(0, 0, val);
8988
auto record = reinterpret_cast<MixedLFULRUStrategy::Record*>(&tmp);
90-
ASSERT_EQ(record->time_, 10);
91-
if (record->freq_power_ == 5) {
89+
ASSERT_EQ(record->time, 10);
90+
if (record->freq_power == 5) {
9291
++freq_power_5_cnt;
93-
} else if (record->freq_power_ == 6) {
92+
} else if (record->freq_power == 6) {
9493
++freq_power_6_cnt;
9594
} else {
96-
ASSERT_TRUE(record->freq_power_ == 5 || record->freq_power_ == 6);
95+
ASSERT_TRUE(record->freq_power == 5 || record->freq_power == 6);
9796
}
9897
}
9998

test/cpp/dynamic_embedding/naive_id_transformer_test.cpp

+4-8
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
namespace torchrec {
1313

1414
TEST(tde, NaiveThreadedIDTransformer_NoFilter) {
15-
using Tag = int32_t;
16-
NaiveIDTransformer<Tag, Bitmap<uint8_t>> transformer(16);
15+
NaiveIDTransformer<Bitmap<uint8_t>> transformer(16);
1716
const int64_t global_ids[5] = {100, 101, 100, 102, 101};
1817
int64_t cache_ids[5];
1918
int64_t expected_cache_ids[5] = {0, 1, 0, 2, 1};
@@ -24,8 +23,7 @@ TEST(tde, NaiveThreadedIDTransformer_NoFilter) {
2423
}
2524

2625
TEST(tde, NaiveThreadedIDTransformer_Full) {
27-
using Tag = int32_t;
28-
NaiveIDTransformer<Tag, Bitmap<uint8_t>> transformer(4);
26+
NaiveIDTransformer<Bitmap<uint8_t>> transformer(4);
2927
const int64_t global_ids[5] = {100, 101, 102, 103, 104};
3028
int64_t cache_ids[5];
3129
int64_t expected_cache_ids[5] = {0, 1, 2, 3, -1};
@@ -37,8 +35,7 @@ TEST(tde, NaiveThreadedIDTransformer_Full) {
3735
}
3836

3937
TEST(tde, NaiveThreadedIDTransformer_Evict) {
40-
using Tag = int32_t;
41-
NaiveIDTransformer<Tag, Bitmap<uint8_t>> transformer(4);
38+
NaiveIDTransformer<Bitmap<uint8_t>> transformer(4);
4239
const int64_t global_ids[5] = {100, 101, 102, 103, 104};
4340
int64_t cache_ids[5];
4441

@@ -60,8 +57,7 @@ TEST(tde, NaiveThreadedIDTransformer_Evict) {
6057
}
6158

6259
TEST(tde, NaiveThreadedIDTransformer_Iterator) {
63-
using Tag = int32_t;
64-
NaiveIDTransformer<Tag, Bitmap<uint8_t>> transformer(16);
60+
NaiveIDTransformer<Bitmap<uint8_t>> transformer(16);
6561
const int64_t global_ids[5] = {100, 101, 100, 102, 101};
6662
int64_t cache_ids[5];
6763
int64_t expected_cache_ids[5] = {3, 4, 3, 5, 4};

torchrec/csrc/dynamic_embedding/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
add_library(tde_cpp_objs
88
OBJECT
9+
bind.cpp
10+
id_transformer_wrapper.cpp
911
details/clz_impl.cpp
1012
details/ctz_impl.cpp
1113
details/random_bits_generator.cpp
12-
details/mixed_lfu_lru_strategy.cpp
1314
details/io_registry.cpp
1415
details/io.cpp
1516
details/notification.cpp)
+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <torch/torch.h>
10+
11+
#include <torchrec/csrc/dynamic_embedding/details/io_registry.h>
12+
#include <torchrec/csrc/dynamic_embedding/id_transformer_wrapper.h>
13+
14+
namespace torchrec {
15+
TORCH_LIBRARY(tde, m) {
16+
m.def("register_io", [](const std::string& name) {
17+
IORegistry::Instance().register_plugin(name.c_str());
18+
});
19+
20+
m.class_<TransformResult>("TransformResult")
21+
.def_readonly("success", &TransformResult::success)
22+
.def_readonly("ids_to_fetch", &TransformResult::ids_to_fetch);
23+
24+
m.class_<IDTransformerWrapper>("IDTransformer")
25+
.def(torch::init<int64_t, std::string, std::string, int64_t>())
26+
.def("transform", &IDTransformerWrapper::transform)
27+
.def("evict", &IDTransformerWrapper::evict)
28+
.def("save", &IDTransformerWrapper::save);
29+
}
30+
} // namespace torchrec
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
#include <torchrec/csrc/dynamic_embedding/details/types.h>
11+
#include <functional>
12+
#include <span>
13+
#include <type_traits>
14+
15+
namespace torchrec {
16+
17+
namespace transform_default {
18+
inline lxu_record_t no_update(
19+
int64_t global_id,
20+
int64_t cache_id,
21+
std::optional<lxu_record_t> record) {
22+
return record.value_or(lxu_record_t{});
23+
};
24+
25+
inline void no_fetch(int64_t global_id, int64_t cache_id) {}
26+
} // namespace transform_default
27+
28+
class IDTransformer {
29+
public:
30+
/**
31+
* Transform global ids to cache ids
32+
*
33+
* @tparam Update Update the eviction strategy tag type. Update LXU Record
34+
* @tparam Fetch Fetch the not existing global-id/cache-id pair. It is used
35+
* by dynamic embedding parameter server.
36+
*
37+
* @param global_ids Global ID vector
38+
* @param cache_ids [out] Cache ID vector
39+
* @param update update lambda. See `Update` doc.
40+
* @param fetch fetch lambda. See `Fetch` doc.
41+
* @return true if all transformed, otherwise need eviction.
42+
*/
43+
virtual bool transform(
44+
std::span<const int64_t> global_ids,
45+
std::span<int64_t> cache_ids,
46+
update_t update = transform_default::no_update,
47+
fetch_t fetch = transform_default::no_fetch) = 0;
48+
49+
/**
50+
* Evict global ids from the transformer
51+
*
52+
* @param global_ids Global IDs to evict.
53+
*/
54+
virtual void evict(std::span<const int64_t> global_ids) = 0;
55+
56+
/**
57+
* Create an iterator of the id transformer, a possible usecase is:
58+
*
59+
* auto iterator = transformer.iterator();
60+
* auto record = iterator();
61+
* while (record.has_value()) {
62+
* // do sth with the record
63+
* // ...
64+
* // get next record
65+
* auto record = iterator();
66+
* }
67+
*
68+
* @return the iterator created.
69+
*/
70+
virtual iterator_t iterator() const = 0;
71+
};
72+
73+
} // namespace torchrec
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
#include <torchrec/csrc/dynamic_embedding/details/types.h>
11+
#include <optional>
12+
13+
namespace torchrec {
14+
15+
class LXUStrategy {
16+
public:
17+
LXUStrategy() = default;
18+
LXUStrategy(const LXUStrategy&) = delete;
19+
LXUStrategy(LXUStrategy&& o) noexcept = default;
20+
21+
virtual void update_time(uint32_t time) = 0;
22+
virtual int64_t time(lxu_record_t record) = 0;
23+
24+
virtual lxu_record_t update(
25+
int64_t global_id,
26+
int64_t cache_id,
27+
std::optional<lxu_record_t> val) = 0;
28+
29+
/**
30+
* Analysis all ids and returns the num_elems that are most need to evict.
31+
* @param iterator Returns each global_id to ExtValue pair. Returns nullopt
32+
* when at ends.
33+
* @param num_to_evict
34+
* @return
35+
*/
36+
virtual std::vector<int64_t> evict(
37+
iterator_t iterator,
38+
uint64_t num_to_evict) = 0;
39+
};
40+
41+
} // namespace torchrec

0 commit comments

Comments
 (0)