Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 7e1d71d

Browse files
committedMar 19, 2025·
save
1 parent 3d1498b commit 7e1d71d

File tree

22 files changed

+450
-60
lines changed

22 files changed

+450
-60
lines changed
 

‎pgdog.toml

+9-3
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@ name = "pgdog_sharded"
3131
host = "127.0.0.1"
3232
database_name = "shard_0"
3333
shard = 0
34+
centroid = [0.0, 0.0, 0.0]
3435

3536
[[databases]]
3637
name = "pgdog_sharded"
3738
host = "127.0.0.1"
3839
database_name = "shard_1"
3940
shard = 1
40-
41+
centroid = [100.0, 100.0, 100.0]
4142

4243
#
4344
# Read/write access to theses tables will be automatically
@@ -47,11 +48,16 @@ shard = 1
4748
database = "pgdog_sharded"
4849
table = "sharded"
4950
column = "id"
51+
primary = true
5052

5153
[[sharded_tables]]
5254
database = "pgdog_sharded"
53-
table = "users"
54-
column = "id"
55+
column = "user_id"
56+
57+
[[sharded_tables]]
58+
database = "pgdog_sharded"
59+
table = "vectors"
60+
column = "embedding"
5561
primary = true
5662

5763
#

‎pgdog/src/backend/databases.rs

+15-6
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ use once_cell::sync::Lazy;
99
use crate::{
1010
backend::pool::PoolConfig,
1111
config::{config, load, ConfigAndUsers, ManualQuery, Role},
12-
net::messages::BackendKeyData,
12+
net::messages::{BackendKeyData, Vector},
1313
};
1414

1515
use super::{
1616
pool::{Address, Config},
1717
replication::ReplicationConfig,
18-
Cluster, Error, ShardedTables,
18+
Cluster, ClusterShardConfig, Error, ShardedTables,
1919
};
2020

2121
static DATABASES: Lazy<ArcSwap<Databases>> =
@@ -201,13 +201,17 @@ pub fn from_config(config: &ConfigAndUsers) -> Databases {
201201
if let Some(shards) = config_databases.get(&user.database) {
202202
let mut shard_configs = vec![];
203203
for user_databases in shards {
204+
let mut centroid: Option<Vector> = None;
204205
let primary =
205206
user_databases
206207
.iter()
207208
.find(|d| d.role == Role::Primary)
208-
.map(|primary| PoolConfig {
209-
address: Address::new(primary, user),
210-
config: Config::new(general, primary, user),
209+
.map(|primary| {
210+
centroid = primary.centroid.clone();
211+
PoolConfig {
212+
address: Address::new(primary, user),
213+
config: Config::new(general, primary, user),
214+
}
211215
});
212216
let replicas = user_databases
213217
.iter()
@@ -217,7 +221,12 @@ pub fn from_config(config: &ConfigAndUsers) -> Databases {
217221
config: Config::new(general, replica, user),
218222
})
219223
.collect::<Vec<_>>();
220-
shard_configs.push((primary, replicas));
224+
225+
shard_configs.push(ClusterShardConfig {
226+
primary,
227+
replicas,
228+
centroid,
229+
});
221230
}
222231

223232
let sharded_tables = sharded_tables

‎pgdog/src/backend/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub mod server;
1010
pub mod stats;
1111

1212
pub use error::Error;
13-
pub use pool::{Cluster, Pool, Replicas, Shard};
13+
pub use pool::{Cluster, ClusterShardConfig, Pool, Replicas, Shard, ShardingSchema};
1414
pub use prepared_statements::PreparedStatements;
1515
pub use replication::ShardedTables;
1616
pub use schema::Schema;

‎pgdog/src/backend/pool/cluster.rs

+55-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use crate::{
44
backend::{databases::databases, replication::ReplicationConfig, ShardedTables},
55
config::{PoolerMode, ShardedTable},
6-
net::messages::BackendKeyData,
6+
net::messages::{BackendKeyData, Vector},
77
};
88

99
use super::{Address, Config, Error, Guard, Request, Shard};
@@ -32,11 +32,47 @@ pub struct Cluster {
3232
replication_sharding: Option<String>,
3333
}
3434

35+
/// Sharding configuration from the cluster.
36+
#[derive(Debug, Clone, Default)]
37+
pub struct ShardingSchema {
38+
/// Number of shards.
39+
pub shards: usize,
40+
/// Vector centroids.
41+
pub centroids: Vec<Option<Vector>>,
42+
/// Sharded tables.
43+
pub tables: ShardedTables,
44+
}
45+
46+
impl ShardingSchema {
47+
pub fn shard_by_distance_l2(&self, vector: &Vector) -> Option<usize> {
48+
let mut shard = None;
49+
let mut min_distance = f64::MAX;
50+
51+
for (i, c) in self.centroids.iter().enumerate() {
52+
if let Some(c) = c {
53+
let distance = vector.distance_l2(c);
54+
if distance < min_distance {
55+
min_distance = distance;
56+
shard = Some(i);
57+
}
58+
}
59+
}
60+
61+
shard
62+
}
63+
}
64+
65+
pub struct ClusterShardConfig {
66+
pub primary: Option<PoolConfig>,
67+
pub replicas: Vec<PoolConfig>,
68+
pub centroid: Option<Vector>,
69+
}
70+
3571
impl Cluster {
3672
/// Create new cluster of shards.
3773
pub fn new(
3874
name: &str,
39-
shards: &[(Option<PoolConfig>, Vec<PoolConfig>)],
75+
shards: &[ClusterShardConfig],
4076
lb_strategy: LoadBalancingStrategy,
4177
password: &str,
4278
pooler_mode: PoolerMode,
@@ -46,7 +82,14 @@ impl Cluster {
4682
Self {
4783
shards: shards
4884
.iter()
49-
.map(|addr| Shard::new(addr.0.clone(), &addr.1, lb_strategy))
85+
.map(|config| {
86+
Shard::new(
87+
&config.primary,
88+
&config.replicas,
89+
lb_strategy,
90+
config.centroid.clone(),
91+
)
92+
})
5093
.collect(),
5194
name: name.to_owned(),
5295
password: password.to_owned(),
@@ -191,6 +234,15 @@ impl Cluster {
191234
.as_ref()
192235
.and_then(|database| databases().replication(database))
193236
}
237+
238+
/// Get all data required for sharding.
239+
pub fn sharding_schema(&self) -> ShardingSchema {
240+
ShardingSchema {
241+
shards: self.shards.len(),
242+
centroids: self.shards().iter().map(|s| s.centroid().clone()).collect(),
243+
tables: self.sharded_tables.clone(),
244+
}
245+
}
194246
}
195247

196248
#[cfg(test)]

‎pgdog/src/backend/pool/connection/mod.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use crate::{
1515

1616
use super::{
1717
super::{pool::Guard, Error},
18-
Address, Cluster, Request,
18+
Address, Cluster, Request, ShardingSchema,
1919
};
2020

2121
use std::{mem::replace, time::Duration};
@@ -91,8 +91,12 @@ impl Connection {
9191
&mut self,
9292
shard: Option<usize>,
9393
replication_config: &ReplicationConfig,
94+
sharding_schema: &ShardingSchema,
9495
) -> Result<(), Error> {
95-
self.binding = Binding::Replication(None, Buffer::new(shard, replication_config));
96+
self.binding = Binding::Replication(
97+
None,
98+
Buffer::new(shard, replication_config, sharding_schema),
99+
);
96100
Ok(())
97101
}
98102

‎pgdog/src/backend/pool/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pub mod stats;
2222
pub mod waiting;
2323

2424
pub use address::Address;
25-
pub use cluster::{Cluster, PoolConfig};
25+
pub use cluster::{Cluster, ClusterShardConfig, PoolConfig, ShardingSchema};
2626
pub use config::Config;
2727
pub use connection::Connection;
2828
pub use error::Error;

‎pgdog/src/backend/pool/pool_impl.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ impl Clone for Pool {
4242

4343
impl Pool {
4444
/// Create new connection pool.
45-
pub fn new(config: PoolConfig) -> Self {
45+
pub fn new(config: &PoolConfig) -> Self {
4646
Self {
4747
inner: Arc::new(Mutex::new(Inner::new(config.config))),
4848
comms: Arc::new(Comms::new()),
49-
addr: config.address,
49+
addr: config.address.clone(),
5050
}
5151
}
5252

@@ -138,7 +138,7 @@ impl Pool {
138138

139139
/// Create new identical connection pool.
140140
pub fn duplicate(&self) -> Pool {
141-
Pool::new(PoolConfig {
141+
Pool::new(&PoolConfig {
142142
address: self.addr().clone(),
143143
config: *self.lock().config(),
144144
})

‎pgdog/src/backend/pool/replicas.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ impl Replicas {
3434
/// Create new replicas pools.
3535
pub fn new(addrs: &[PoolConfig], lb_strategy: LoadBalancingStrategy) -> Replicas {
3636
Self {
37-
pools: addrs.iter().map(|p| Pool::new(p.clone())).collect(),
37+
pools: addrs.iter().map(Pool::new).collect(),
3838
checkout_timeout: Duration::from_millis(5_000),
3939
round_robin: Arc::new(AtomicUsize::new(0)),
4040
lb_strategy,

‎pgdog/src/backend/pool/shard.rs

+19-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
//! A shard is a collection of replicas and a primary.
22
3-
use crate::{config::LoadBalancingStrategy, net::messages::BackendKeyData};
3+
use crate::{
4+
config::LoadBalancingStrategy,
5+
net::messages::{BackendKeyData, Vector},
6+
};
47

58
use super::{Error, Guard, Pool, PoolConfig, Replicas, Request};
69

@@ -9,19 +12,25 @@ use super::{Error, Guard, Pool, PoolConfig, Replicas, Request};
912
pub struct Shard {
1013
pub(super) primary: Option<Pool>,
1114
pub(super) replicas: Replicas,
15+
pub(super) centroid: Option<Vector>,
1216
}
1317

1418
impl Shard {
1519
/// Create new shard connection pool.
1620
pub fn new(
17-
primary: Option<PoolConfig>,
21+
primary: &Option<PoolConfig>,
1822
replicas: &[PoolConfig],
1923
lb_strategy: LoadBalancingStrategy,
24+
centroid: Option<Vector>,
2025
) -> Self {
21-
let primary = primary.map(Pool::new);
26+
let primary = primary.as_ref().map(Pool::new);
2227
let replicas = Replicas::new(replicas, lb_strategy);
2328

24-
Self { primary, replicas }
29+
Self {
30+
primary,
31+
replicas,
32+
centroid,
33+
}
2534
}
2635

2736
/// Get a connection to the shard primary database.
@@ -51,6 +60,7 @@ impl Shard {
5160
Self {
5261
primary: self.primary.as_ref().map(|primary| primary.duplicate()),
5362
replicas: self.replicas.duplicate(),
63+
centroid: self.centroid.clone(),
5464
}
5565
}
5666

@@ -84,4 +94,9 @@ impl Shard {
8494
pub fn shutdown(&self) {
8595
self.pools().iter().for_each(|pool| pool.shutdown());
8696
}
97+
98+
/// Get the shard vector centroid.
99+
pub fn centroid(&self) -> &Option<Vector> {
100+
&self.centroid
101+
}
87102
}

‎pgdog/src/backend/pool/test/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pub fn pool() -> Pool {
2020
..Default::default()
2121
};
2222

23-
let pool = Pool::new(PoolConfig {
23+
let pool = Pool::new(&PoolConfig {
2424
address: Address {
2525
host: "127.0.0.1".into(),
2626
port: 5432,

‎pgdog/src/backend/replication/buffer.rs

+10-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use fnv::FnvHashMap as HashMap;
22
use fnv::FnvHashSet as HashSet;
33
use std::collections::VecDeque;
44

5+
use crate::backend::ShardingSchema;
56
use crate::frontend::router::sharding::shard_str;
67
use crate::net::messages::FromBytes;
78
use crate::net::messages::Protocol;
@@ -23,11 +24,16 @@ pub struct Buffer {
2324
shard: Option<usize>,
2425
oid: Option<i32>,
2526
buffer: VecDeque<Message>,
27+
sharding_schema: ShardingSchema,
2628
}
2729

2830
impl Buffer {
2931
/// New replication buffer.
30-
pub fn new(shard: Option<usize>, cluster: &ReplicationConfig) -> Self {
32+
pub fn new(
33+
shard: Option<usize>,
34+
cluster: &ReplicationConfig,
35+
sharding_schema: &ShardingSchema,
36+
) -> Self {
3137
Self {
3238
begin: None,
3339
message: None,
@@ -37,6 +43,7 @@ impl Buffer {
3743
oid: None,
3844
buffer: VecDeque::new(),
3945
replication_config: cluster.clone(),
46+
sharding_schema: sharding_schema.clone(),
4047
}
4148
}
4249

@@ -74,7 +81,7 @@ impl Buffer {
7481
.and_then(|column| update.column(column))
7582
.and_then(|column| column.as_str());
7683
if let Some(column) = column {
77-
let shard = shard_str(column, self.replication_config.shards());
84+
let shard = shard_str(column, &self.sharding_schema);
7885
if self.shard == shard {
7986
self.message = Some(xlog_data);
8087
return self.flush();
@@ -92,7 +99,7 @@ impl Buffer {
9299
.and_then(|column| insert.column(column))
93100
.and_then(|column| column.as_str());
94101
if let Some(column) = column {
95-
let shard = shard_str(column, self.replication_config.shards());
102+
let shard = shard_str(column, &self.sharding_schema);
96103
if self.shard == shard {
97104
self.message = Some(xlog_data);
98105
return self.flush();

‎pgdog/src/config/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use serde::{Deserialize, Serialize};
1919
use tracing::info;
2020
use tracing::warn;
2121

22+
use crate::net::messages::Vector;
2223
use crate::util::random_string;
2324

2425
static CONFIG: Lazy<ArcSwap<ConfigAndUsers>> =
@@ -388,6 +389,8 @@ pub struct Database {
388389
// Maximum number of connections to this database from this pooler.
389390
// #[serde(default = "Database::max_connections")]
390391
// pub max_connections: usize,
392+
/// Centroid used for a vector index.
393+
pub centroid: Option<Vector>,
391394
}
392395

393396
impl Database {

‎pgdog/src/frontend/client/inner.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ impl Inner {
4040

4141
// Configure replication mode.
4242
if client.shard.is_some() {
43-
if let Some(config) = backend.cluster()?.replication_sharding_config() {
44-
backend.replication_mode(client.shard, &config)?;
43+
let cluster = backend.cluster()?;
44+
if let Some(config) = cluster.replication_sharding_config() {
45+
backend.replication_mode(client.shard, &config, &cluster.sharding_schema())?;
4546
router.replication_mode();
4647
debug!("logical replication sharding [{}]", client.addr);
4748
}

‎pgdog/src/frontend/router/parser/comment.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ use once_cell::sync::Lazy;
22
use pg_query::{protobuf::Token, scan, Error};
33
use regex::Regex;
44

5+
use crate::backend::ShardingSchema;
6+
57
use super::super::sharding::shard_str;
68

79
static SHARD: Lazy<Regex> = Lazy::new(|| Regex::new(r#"pgdog_shard: *([0-9]+)"#).unwrap());
@@ -15,15 +17,15 @@ static SHARDING_KEY: Lazy<Regex> =
1517
///
1618
/// See [`SHARD`] and [`SHARDING_KEY`] for the style of comment we expect.
1719
///
18-
pub fn shard(query: &str, shards: usize) -> Result<Option<usize>, Error> {
20+
pub fn shard(query: &str, schema: &ShardingSchema) -> Result<Option<usize>, Error> {
1921
let tokens = scan(query)?;
2022

2123
for token in tokens.tokens.iter() {
2224
if token.token == Token::CComment as i32 {
2325
let comment = &query[token.start as usize..token.end as usize];
2426
if let Some(cap) = SHARDING_KEY.captures(comment) {
2527
if let Some(sharding_key) = cap.get(1) {
26-
return Ok(shard_str(sharding_key.as_str(), shards));
28+
return Ok(shard_str(sharding_key.as_str(), schema));
2729
}
2830
}
2931
if let Some(cap) = SHARD.captures(comment) {

‎pgdog/src/frontend/router/parser/copy.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use pg_query::{protobuf::CopyStmt, NodeEnum};
44

55
use crate::{
6-
backend::Cluster,
6+
backend::{Cluster, ShardingSchema},
77
frontend::router::{sharding::shard_str, CopyRow},
88
net::messages::CopyData,
99
};
@@ -50,6 +50,8 @@ pub struct CopyParser {
5050
pub is_from: bool,
5151
/// CSV parser that can handle incomplete records.
5252
csv_stream: CsvStream,
53+
54+
sharding_schema: ShardingSchema,
5355
}
5456

5557
impl Default for CopyParser {
@@ -62,6 +64,7 @@ impl Default for CopyParser {
6264
columns: 0,
6365
is_from: false,
6466
csv_stream: CsvStream::new(',', false),
67+
sharding_schema: ShardingSchema::default(),
6568
}
6669
}
6770
}
@@ -122,6 +125,7 @@ impl CopyParser {
122125
}
123126

124127
parser.csv_stream = CsvStream::new(parser.delimiter(), parser.headers);
128+
parser.sharding_schema = cluster.sharding_schema();
125129

126130
Ok(Some(parser))
127131
}
@@ -154,7 +158,7 @@ impl CopyParser {
154158
let shard = if let Some(sharding_column) = self.sharded_column {
155159
let key = record.get(sharding_column).ok_or(Error::NoShardingColumn)?;
156160

157-
shard_str(key, self.shards)
161+
shard_str(key, &self.sharding_schema)
158162
} else {
159163
None
160164
};

‎pgdog/src/frontend/router/parser/query.rs

+14-16
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
use std::collections::{BTreeSet, HashSet};
33

44
use crate::{
5-
backend::{databases::databases, Cluster},
5+
backend::{databases::databases, Cluster, ShardingSchema},
66
frontend::{
77
router::{parser::OrderBy, round_robin, sharding::shard_str, CopyRow},
88
Buffer,
@@ -110,8 +110,10 @@ impl QueryParser {
110110
}
111111
}
112112

113+
let sharding_schema = cluster.sharding_schema();
114+
113115
// Hardcoded shard from a comment.
114-
let shard = super::comment::shard(query, cluster.shards().len()).map_err(Error::PgQuery)?;
116+
let shard = super::comment::shard(query, &sharding_schema).map_err(Error::PgQuery)?;
115117

116118
// Cluster is read only or write only, traffic split isn't needed,
117119
// so don't parse the query further.
@@ -144,11 +146,11 @@ impl QueryParser {
144146
round_robin::next() % cluster.shards().len(),
145147
))));
146148
} else {
147-
Self::select(stmt, cluster, params)
149+
Self::select(stmt, &sharding_schema, params)
148150
}
149151
}
150152
Some(NodeEnum::CopyStmt(ref stmt)) => Self::copy(stmt, cluster),
151-
Some(NodeEnum::InsertStmt(ref stmt)) => Self::insert(stmt, cluster, &params),
153+
Some(NodeEnum::InsertStmt(ref stmt)) => Self::insert(stmt, &sharding_schema, &params),
152154
Some(NodeEnum::UpdateStmt(ref stmt)) => Self::update(stmt),
153155
Some(NodeEnum::DeleteStmt(ref stmt)) => Self::delete(stmt),
154156
Some(NodeEnum::TransactionStmt(ref stmt)) => match stmt.kind() {
@@ -193,11 +195,10 @@ impl QueryParser {
193195

194196
fn select(
195197
stmt: &SelectStmt,
196-
cluster: &Cluster,
198+
sharding_schema: &ShardingSchema,
197199
params: Option<Bind>,
198200
) -> Result<Command, Error> {
199201
let order_by = Self::select_sort(&stmt.sort_clause);
200-
let sharded_tables = cluster.sharded_tables();
201202
let mut shards = HashSet::new();
202203
let table_name = stmt
203204
.from_clause
@@ -215,13 +216,13 @@ impl QueryParser {
215216
.flatten();
216217
if let Some(where_clause) = WhereClause::new(table_name, &stmt.where_clause) {
217218
// Complexity: O(number of sharded tables * number of columns in the query)
218-
for table in sharded_tables {
219+
for table in sharding_schema.tables.tables() {
219220
let table_name = table.name.as_deref();
220221
let keys = where_clause.keys(table_name, &table.column);
221222
for key in keys {
222223
match key {
223224
Key::Constant(value) => {
224-
if let Some(shard) = shard_str(&value, cluster.shards().len()) {
225+
if let Some(shard) = shard_str(&value, &sharding_schema) {
225226
shards.insert(shard);
226227
}
227228
}
@@ -231,8 +232,7 @@ impl QueryParser {
231232
if let Some(param) = params.parameter(param)? {
232233
// TODO: Handle binary encoding.
233234
if let Some(text) = param.text() {
234-
if let Some(shard) = shard_str(text, cluster.shards().len())
235-
{
235+
if let Some(shard) = shard_str(text, &sharding_schema) {
236236
shards.insert(shard);
237237
}
238238
}
@@ -310,7 +310,7 @@ impl QueryParser {
310310

311311
fn insert(
312312
stmt: &InsertStmt,
313-
cluster: &Cluster,
313+
sharding_schema: &ShardingSchema,
314314
params: &Option<Bind>,
315315
) -> Result<Command, Error> {
316316
let insert = Insert::new(stmt);
@@ -320,17 +320,15 @@ impl QueryParser {
320320
.map(|column| column.name)
321321
.collect::<Vec<_>>();
322322
let table = insert.table().unwrap().name;
323-
let num_shards = cluster.shards().len();
324-
325-
let sharding_column = cluster.sharded_column(table, &columns);
323+
let sharding_column = sharding_schema.tables.sharded_column(table, &columns);
326324
let mut shards = BTreeSet::new();
327325
if let Some(column) = sharding_column {
328326
for tuple in insert.tuples() {
329327
if let Some(value) = tuple.get(column) {
330328
shards.insert(if let Some(bind) = params {
331-
value.shard_placeholder(bind, num_shards)
329+
value.shard_placeholder(bind, &sharding_schema)
332330
} else {
333-
value.shard(num_shards)
331+
value.shard(&sharding_schema)
334332
});
335333
}
336334
}

‎pgdog/src/frontend/router/parser/value.rs

+7-6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use pg_query::{
66
};
77

88
use crate::{
9+
backend::ShardingSchema,
910
frontend::router::sharding::{shard_int, shard_str},
1011
net::messages::Bind,
1112
};
@@ -22,23 +23,23 @@ pub enum Value<'a> {
2223

2324
impl<'a> Value<'a> {
2425
/// Extract value from a Bind (F) message and shard on it.
25-
pub fn shard_placeholder(&self, bind: &'a Bind, shards: usize) -> Option<usize> {
26+
pub fn shard_placeholder(&self, bind: &'a Bind, schema: &ShardingSchema) -> Option<usize> {
2627
match self {
2728
Value::Placeholder(placeholder) => bind
2829
.parameter(*placeholder as usize - 1)
2930
.ok()
3031
.flatten()
31-
.and_then(|value| value.text().map(|value| shard_str(value, shards)))
32+
.and_then(|value| value.text().map(|value| shard_str(value, schema)))
3233
.flatten(),
33-
_ => self.shard(shards),
34+
_ => self.shard(schema),
3435
}
3536
}
3637

3738
/// Shard the value given the number of shards in the cluster.
38-
pub fn shard(&self, shards: usize) -> Option<usize> {
39+
pub fn shard(&self, schema: &ShardingSchema) -> Option<usize> {
3940
match self {
40-
Value::String(v) => shard_str(v, shards),
41-
Value::Integer(v) => Some(shard_int(*v, shards)),
41+
Value::String(v) => shard_str(v, schema),
42+
Value::Integer(v) => Some(shard_int(*v, schema)),
4243
_ => None,
4344
}
4445
}

‎pgdog/src/frontend/router/sharding/mod.rs

+41-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
use uuid::Uuid;
22

3+
use crate::{
4+
backend::ShardingSchema,
5+
net::messages::{Format, FromDataType, Vector},
6+
};
7+
38
pub mod ffi;
9+
pub mod vector;
410

511
/// Hash `BIGINT`.
612
pub fn bigint(id: i64) -> u64 {
@@ -18,12 +24,22 @@ pub fn uuid(uuid: Uuid) -> u64 {
1824
}
1925

2026
/// Shard an integer.
21-
pub fn shard_int(value: i64, shards: usize) -> usize {
22-
bigint(value) as usize % shards
27+
pub fn shard_int(value: i64, schema: &ShardingSchema) -> usize {
28+
bigint(value) as usize % schema.shards
2329
}
2430

25-
/// Shard a string value, parsing out a BIGINT or UUID.
26-
pub fn shard_str(value: &str, shards: usize) -> Option<usize> {
31+
/// Shard a string value, parsing out a BIGINT, UUID, or vector.
32+
///
33+
/// TODO: This is really not great, we should pass in the type oid
34+
/// from RowDescription in here to avoid guessing.
35+
pub fn shard_str(value: &str, schema: &ShardingSchema) -> Option<usize> {
36+
let shards = schema.shards;
37+
if value.starts_with('[') {
38+
let vector = Vector::decode(value.as_bytes(), Format::Text).ok();
39+
if let Some(vector) = vector {
40+
return schema.shard_by_distance_l2(&vector);
41+
}
42+
}
2743
Some(match value.parse::<i64>() {
2844
Ok(value) => bigint(value) as usize % shards,
2945
Err(_) => match value.parse::<Uuid>() {
@@ -32,3 +48,24 @@ pub fn shard_str(value: &str, shards: usize) -> Option<usize> {
3248
},
3349
})
3450
}
51+
52+
#[cfg(test)]
53+
mod test {
54+
use super::*;
55+
56+
#[test]
57+
fn test_shard_str() {
58+
let schema = ShardingSchema {
59+
shards: 2,
60+
centroids: vec![
61+
Some(Vector::from(&[0.0, 1.0, 2.0][..])),
62+
Some(Vector::from(&[1.0, 2.0, 3.0][..])),
63+
],
64+
..Default::default()
65+
};
66+
let shard = shard_str("[1,2,3]", &schema);
67+
assert_eq!(shard, Some(1));
68+
let shard = shard_str("[0.0,0.5,0.1]", &schema);
69+
assert_eq!(shard, Some(0));
70+
}
71+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
use crate::net::messages::Vector;
2+
3+
pub enum Distance<'a> {
4+
Euclidean(&'a Vector, &'a Vector),
5+
}
6+
7+
impl Distance<'_> {
8+
pub fn distance(&self) -> f64 {
9+
match self {
10+
// TODO: SIMD this.
11+
Self::Euclidean(p, q) => {
12+
assert_eq!(p.len(), q.len());
13+
p.iter()
14+
.zip(q.iter())
15+
.map(|(p, q)| (**q - **p).powi(2))
16+
.sum::<f64>()
17+
.sqrt()
18+
}
19+
}
20+
}
21+
}
22+
23+
#[cfg(test)]
24+
mod test {
25+
use crate::net::messages::Vector;
26+
27+
use super::Distance;
28+
29+
#[test]
30+
fn test_euclidean() {
31+
let v1 = Vector::from(&[1.0, 2.0, 3.0][..]);
32+
let v2 = Vector::from(&[1.5, 2.0, 3.0][..]);
33+
let distance = Distance::Euclidean(&v1, &v2).distance();
34+
assert_eq!(distance, 0.5);
35+
}
36+
}

‎pgdog/src/net/messages/data_types/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@ pub mod text;
1212
pub mod timestamp;
1313
pub mod timestamptz;
1414
pub mod uuid;
15+
pub mod vector;
1516

1617
pub use interval::Interval;
1718
pub use numeric::Numeric;
1819
pub use timestamp::Timestamp;
1920
pub use timestamptz::TimestampTz;
21+
pub use vector::Vector;
2022

2123
pub trait FromDataType: Sized + PartialOrd + Ord + PartialEq {
2224
fn decode(bytes: &[u8], encoding: Format) -> Result<Self, Error>;

‎pgdog/src/net/messages/data_types/numeric.rs

+60
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ use std::{
55
};
66

77
use bytes::Buf;
8+
use serde::Deserialize;
9+
use serde::{
10+
de::{self, Visitor},
11+
Serialize,
12+
};
813
use tracing::warn;
914

1015
use crate::net::messages::data_row::Data;
@@ -13,6 +18,7 @@ use super::*;
1318

1419
/// We don't expect NaN's so we're going to implement Ord for this below.
1520
#[derive(PartialEq, Copy, Clone, Debug)]
21+
#[repr(C)]
1622
pub struct Numeric {
1723
data: f64,
1824
}
@@ -104,3 +110,57 @@ impl ToDataRowColumn for Numeric {
104110
self.encode(Format::Text).unwrap().into()
105111
}
106112
}
113+
114+
impl From<f32> for Numeric {
115+
fn from(value: f32) -> Self {
116+
Self { data: value as f64 }
117+
}
118+
}
119+
120+
impl From<f64> for Numeric {
121+
fn from(value: f64) -> Self {
122+
Self { data: value }
123+
}
124+
}
125+
126+
struct NumericVisitor;
127+
128+
impl<'de> Visitor<'de> for NumericVisitor {
129+
type Value = Numeric;
130+
131+
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
132+
formatter.write_str("a floating point (f32 or f64)")
133+
}
134+
135+
fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
136+
where
137+
E: de::Error,
138+
{
139+
Ok(Numeric { data: v })
140+
}
141+
142+
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
143+
where
144+
E: de::Error,
145+
{
146+
Ok(Numeric { data: v as f64 })
147+
}
148+
}
149+
150+
impl<'de> Deserialize<'de> for Numeric {
151+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
152+
where
153+
D: de::Deserializer<'de>,
154+
{
155+
deserializer.deserialize_f64(NumericVisitor)
156+
}
157+
}
158+
159+
impl Serialize for Numeric {
160+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
161+
where
162+
S: serde::Serializer,
163+
{
164+
serializer.serialize_f64(self.data)
165+
}
166+
}
+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
use crate::{
2+
frontend::router::sharding::vector::Distance,
3+
net::{messages::Format, Error},
4+
};
5+
use bytes::Bytes;
6+
use serde::{
7+
de::{self, Visitor},
8+
ser::SerializeSeq,
9+
Deserialize, Serialize,
10+
};
11+
use std::{ops::Deref, str::from_utf8};
12+
13+
use super::{FromDataType, Numeric};
14+
15+
#[derive(Debug, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)]
16+
#[repr(C)]
17+
pub struct Vector {
18+
values: Vec<Numeric>,
19+
}
20+
21+
impl FromDataType for Vector {
22+
fn decode(bytes: &[u8], encoding: Format) -> Result<Self, Error> {
23+
match encoding {
24+
Format::Binary => Err(Error::NotTextEncoding),
25+
Format::Text => {
26+
let no_brackets = &bytes[1..bytes.len() - 1];
27+
let floats = no_brackets
28+
.split(|n| n == &b',')
29+
.map(|b| from_utf8(b).map(|n| n.trim().parse::<f32>().ok()))
30+
.flatten()
31+
.flatten()
32+
.map(|f| Numeric::from(f))
33+
.collect();
34+
Ok(Self { values: floats })
35+
}
36+
}
37+
}
38+
39+
fn encode(&self, encoding: Format) -> Result<bytes::Bytes, Error> {
40+
match encoding {
41+
Format::Text => Ok(Bytes::from(format!(
42+
"[{}]",
43+
self.values
44+
.iter()
45+
.map(|v| v.to_string())
46+
.collect::<Vec<_>>()
47+
.join(",")
48+
))),
49+
Format::Binary => Err(Error::NotTextEncoding),
50+
}
51+
}
52+
}
53+
54+
impl Vector {
55+
/// Length of the vector.
56+
pub fn len(&self) -> usize {
57+
self.values.len()
58+
}
59+
60+
/// Is the vector empty?
61+
pub fn is_emtpy(&self) -> bool {
62+
self.len() == 0
63+
}
64+
65+
/// Compute L2 distance between the vectors.
66+
pub fn distance_l2(&self, other: &Self) -> f64 {
67+
Distance::Euclidean(self, other).distance()
68+
}
69+
}
70+
71+
impl Deref for Vector {
72+
type Target = Vec<Numeric>;
73+
74+
fn deref(&self) -> &Self::Target {
75+
&self.values
76+
}
77+
}
78+
79+
impl From<&[f64]> for Vector {
80+
fn from(value: &[f64]) -> Self {
81+
Self {
82+
values: value.iter().map(|v| Numeric::from(*v)).collect(),
83+
}
84+
}
85+
}
86+
87+
impl From<&[f32]> for Vector {
88+
fn from(value: &[f32]) -> Self {
89+
Self {
90+
values: value.iter().map(|v| Numeric::from(*v)).collect(),
91+
}
92+
}
93+
}
94+
95+
struct VectorVisitor;
96+
97+
impl<'de> Visitor<'de> for VectorVisitor {
98+
type Value = Vector;
99+
100+
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
101+
where
102+
A: de::SeqAccess<'de>,
103+
{
104+
let mut results = vec![];
105+
while let Some(n) = seq.next_element::<f64>()? {
106+
results.push(n);
107+
}
108+
109+
Ok(Vector::from(results.as_slice()))
110+
}
111+
112+
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
113+
formatter.write_str("expected a list of floating points")
114+
}
115+
}
116+
117+
impl<'de> Deserialize<'de> for Vector {
118+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
119+
where
120+
D: de::Deserializer<'de>,
121+
{
122+
deserializer.deserialize_seq(VectorVisitor)
123+
}
124+
}
125+
126+
impl Serialize for Vector {
127+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
128+
where
129+
S: serde::Serializer,
130+
{
131+
let mut seq = serializer.serialize_seq(Some(self.len()))?;
132+
for v in &self.values {
133+
seq.serialize_element(v)?;
134+
}
135+
seq.end()
136+
}
137+
}
138+
139+
#[cfg(test)]
140+
mod test {
141+
use super::*;
142+
143+
#[test]
144+
fn test_vectors() {
145+
let v = "[1,2,3]";
146+
let vector = Vector::decode(v.as_bytes(), Format::Text).unwrap();
147+
assert_eq!(vector.values[0], 1.0.into());
148+
assert_eq!(vector.values[1], 2.0.into());
149+
assert_eq!(vector.values[2], 3.0.into());
150+
let b = vector.encode(Format::Text).unwrap();
151+
assert_eq!(&b, &"[1,2,3]");
152+
}
153+
}

0 commit comments

Comments
 (0)
Please sign in to comment.