Skip to content

Commit 38dec03

Browse files
authored
Add pgvector support for sorting (#46)
* save * shard by order by * save * save * save * save * save * save * wip * save * save * fix decoding oid * remove diff
1 parent 63e2c4e commit 38dec03

34 files changed

+819
-109
lines changed

pgdog.toml

+11-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ host = "127.0.0.1"
3838
database_name = "shard_1"
3939
shard = 1
4040

41-
4241
#
4342
# Read/write access to theses tables will be automatically
4443
# sharded.
@@ -47,13 +46,24 @@ shard = 1
4746
database = "pgdog_sharded"
4847
table = "sharded"
4948
column = "id"
49+
primary = true
5050

5151
[[sharded_tables]]
5252
database = "pgdog_sharded"
5353
table = "users"
5454
column = "id"
5555
primary = true
5656

57+
# [[sharded_tables]]
58+
# database = "pgdog_sharded"
59+
# table = "vectors"
60+
# column = "embedding"
61+
# primary = true
62+
# centroids = [[
63+
# [1, 2, 3],
64+
# [100, 200, 300],
65+
# ]]
66+
5767
#
5868
# ActiveRecord sends these queries
5969
# at startup to figure out the schema.

pgdog/src/backend/databases.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use crate::{
1515
use super::{
1616
pool::{Address, Config},
1717
replication::ReplicationConfig,
18-
Cluster, Error, ShardedTables,
18+
Cluster, ClusterShardConfig, Error, ShardedTables,
1919
};
2020

2121
static DATABASES: Lazy<ArcSwap<Databases>> =
@@ -217,7 +217,8 @@ pub fn from_config(config: &ConfigAndUsers) -> Databases {
217217
config: Config::new(general, replica, user),
218218
})
219219
.collect::<Vec<_>>();
220-
shard_configs.push((primary, replicas));
220+
221+
shard_configs.push(ClusterShardConfig { primary, replicas });
221222
}
222223

223224
let sharded_tables = sharded_tables

pgdog/src/backend/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub mod server;
1010
pub mod stats;
1111

1212
pub use error::Error;
13-
pub use pool::{Cluster, Pool, Replicas, Shard};
13+
pub use pool::{Cluster, ClusterShardConfig, Pool, Replicas, Shard, ShardingSchema};
1414
pub use prepared_statements::PreparedStatements;
1515
pub use replication::ShardedTables;
1616
pub use schema::Schema;

pgdog/src/backend/pool/cluster.rs

+25-2
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,25 @@ pub struct Cluster {
3232
replication_sharding: Option<String>,
3333
}
3434

35+
/// Sharding configuration from the cluster.
36+
#[derive(Debug, Clone, Default)]
37+
pub struct ShardingSchema {
38+
/// Number of shards.
39+
pub shards: usize,
40+
/// Sharded tables.
41+
pub tables: ShardedTables,
42+
}
43+
44+
pub struct ClusterShardConfig {
45+
pub primary: Option<PoolConfig>,
46+
pub replicas: Vec<PoolConfig>,
47+
}
48+
3549
impl Cluster {
3650
/// Create new cluster of shards.
3751
pub fn new(
3852
name: &str,
39-
shards: &[(Option<PoolConfig>, Vec<PoolConfig>)],
53+
shards: &[ClusterShardConfig],
4054
lb_strategy: LoadBalancingStrategy,
4155
password: &str,
4256
pooler_mode: PoolerMode,
@@ -46,7 +60,7 @@ impl Cluster {
4660
Self {
4761
shards: shards
4862
.iter()
49-
.map(|addr| Shard::new(addr.0.clone(), &addr.1, lb_strategy))
63+
.map(|config| Shard::new(&config.primary, &config.replicas, lb_strategy))
5064
.collect(),
5165
name: name.to_owned(),
5266
password: password.to_owned(),
@@ -191,6 +205,14 @@ impl Cluster {
191205
.as_ref()
192206
.and_then(|database| databases().replication(database))
193207
}
208+
209+
/// Get all data required for sharding.
210+
pub fn sharding_schema(&self) -> ShardingSchema {
211+
ShardingSchema {
212+
shards: self.shards.len(),
213+
tables: self.sharded_tables.clone(),
214+
}
215+
}
194216
}
195217

196218
#[cfg(test)]
@@ -210,6 +232,7 @@ mod test {
210232
name: Some("sharded".into()),
211233
column: "id".into(),
212234
primary: true,
235+
centroids: vec![],
213236
}]),
214237
shards: vec![Shard::default(), Shard::default()],
215238
..Default::default()

pgdog/src/backend/pool/connection/aggregate.rs

+20
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ impl Grouping {
3030
}
3131

3232
/// The aggregate accumulator.
33+
///
34+
/// This transfors disttributed aggregate functions
35+
/// into a single value.
3336
#[derive(Debug)]
3437
struct Accumulator<'a> {
3538
target: &'a AggregateTarget,
@@ -54,6 +57,7 @@ impl<'a> Accumulator<'a> {
5457
.collect()
5558
}
5659

60+
/// Transform COUNT(*), MIN, MAX, etc., from multiple shards into a single value.
5761
fn accumulate(&mut self, row: &DataRow, rd: &RowDescription) -> Result<(), Error> {
5862
let column = row.get_column(self.target.column(), rd)?;
5963
if let Some(column) = column {
@@ -77,6 +81,13 @@ impl<'a> Accumulator<'a> {
7781
self.datum = column.value;
7882
}
7983
}
84+
AggregateFunction::Sum => {
85+
if !self.datum.is_null() {
86+
self.datum = self.datum.clone() + column.value;
87+
} else {
88+
self.datum = column.value;
89+
}
90+
}
8091
_ => (),
8192
}
8293
}
@@ -122,6 +133,15 @@ impl<'a> Aggregates<'a> {
122133

123134
let mut rows = VecDeque::new();
124135
for (grouping, accumulator) in self.mappings {
136+
//
137+
// Aggregate rules in Postgres dictate that the only
138+
// columns present in the row are either:
139+
//
140+
// 1. part of the GROUP BY, which means they are
141+
// stored in the grouping
142+
// 2. are aggregate functions, which means they
143+
// are stored in the accunmulator
144+
//
125145
let mut row = DataRow::new();
126146
for (idx, datum) in grouping.columns {
127147
row.insert(idx, datum);

pgdog/src/backend/pool/connection/buffer.rs

+45-17
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::{cmp::Ordering, collections::VecDeque};
44

55
use crate::{
66
frontend::router::parser::{Aggregate, OrderBy},
7-
net::messages::{DataRow, FromBytes, Message, Protocol, RowDescription, ToBytes},
7+
net::messages::{DataRow, FromBytes, Message, Protocol, RowDescription, ToBytes, Vector},
88
};
99

1010
use super::Aggregates;
@@ -34,33 +34,61 @@ impl Buffer {
3434

3535
/// Sort the buffer.
3636
pub(super) fn sort(&mut self, columns: &[OrderBy], rd: &RowDescription) {
37-
// Calculate column indecies once, since
38-
// fetching indecies by name is O(n).
37+
// Calculate column indices once, since
38+
// fetching indices by name is O(number of columns).
3939
let mut cols = vec![];
4040
for column in columns {
41-
if let Some(index) = column.index() {
42-
cols.push(Some((index, column.asc())));
43-
} else if let Some(name) = column.name() {
44-
if let Some(index) = rd.field_index(name) {
45-
cols.push(Some((index, column.asc())));
46-
} else {
47-
cols.push(None);
41+
match column {
42+
OrderBy::Asc(_) => cols.push(column.clone()),
43+
OrderBy::AscColumn(name) => {
44+
if let Some(index) = rd.field_index(name) {
45+
cols.push(OrderBy::Asc(index + 1));
46+
}
47+
}
48+
OrderBy::Desc(_) => cols.push(column.clone()),
49+
OrderBy::DescColumn(name) => {
50+
if let Some(index) = rd.field_index(name) {
51+
cols.push(OrderBy::Desc(index + 1));
52+
}
53+
}
54+
OrderBy::AscVectorL2(_, _) => cols.push(column.clone()),
55+
OrderBy::AscVectorL2Column(name, vector) => {
56+
if let Some(index) = rd.field_index(name) {
57+
cols.push(OrderBy::AscVectorL2(index + 1, vector.clone()));
58+
}
4859
}
49-
} else {
50-
cols.push(None);
5160
};
5261
}
5362

5463
// Sort rows.
5564
let order_by = move |a: &DataRow, b: &DataRow| -> Ordering {
56-
for col in cols.iter().flatten() {
57-
let (index, asc) = col;
58-
let left = a.get_column(*index, rd);
59-
let right = b.get_column(*index, rd);
65+
for col in cols.iter() {
66+
let index = col.index();
67+
let asc = col.asc();
68+
let index = if let Some(index) = index {
69+
index
70+
} else {
71+
continue;
72+
};
73+
let left = a.get_column(index, rd);
74+
let right = b.get_column(index, rd);
6075

6176
let ordering = match (left, right) {
6277
(Ok(Some(left)), Ok(Some(right))) => {
63-
if *asc {
78+
// Handle the special vector case.
79+
if let OrderBy::AscVectorL2(_, vector) = col {
80+
let left: Option<Vector> = left.value.try_into().ok();
81+
let right: Option<Vector> = right.value.try_into().ok();
82+
83+
if let (Some(left), Some(right)) = (left, right) {
84+
let left = left.distance_l2(vector);
85+
let right = right.distance_l2(vector);
86+
87+
left.partial_cmp(&right)
88+
} else {
89+
Some(Ordering::Equal)
90+
}
91+
} else if asc {
6492
left.value.partial_cmp(&right.value)
6593
} else {
6694
right.value.partial_cmp(&left.value)

pgdog/src/backend/pool/connection/mod.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use crate::{
1515

1616
use super::{
1717
super::{pool::Guard, Error},
18-
Address, Cluster, Request,
18+
Address, Cluster, Request, ShardingSchema,
1919
};
2020

2121
use std::{mem::replace, time::Duration};
@@ -91,8 +91,12 @@ impl Connection {
9191
&mut self,
9292
shard: Option<usize>,
9393
replication_config: &ReplicationConfig,
94+
sharding_schema: &ShardingSchema,
9495
) -> Result<(), Error> {
95-
self.binding = Binding::Replication(None, Buffer::new(shard, replication_config));
96+
self.binding = Binding::Replication(
97+
None,
98+
Buffer::new(shard, replication_config, sharding_schema),
99+
);
96100
Ok(())
97101
}
98102

pgdog/src/backend/pool/inner.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::{cmp::max, time::Instant};
66
use crate::backend::Server;
77
use crate::net::messages::BackendKeyData;
88

9-
use super::{Ban, Config, Error, Mapping, Stats};
9+
use super::{Ban, Config, Error, Mapping, Oids, Stats};
1010

1111
/// Pool internals protected by a mutex.
1212
#[derive(Default)]
@@ -21,7 +21,7 @@ pub(super) struct Inner {
2121
pub(super) waiting: usize,
2222
/// Pool ban status.
2323
pub(super) ban: Option<Ban>,
24-
/// Pool is online and availble to clients.
24+
/// Pool is online and available to clients.
2525
pub(super) online: bool,
2626
/// Pool is paused.
2727
pub(super) paused: bool,
@@ -33,6 +33,8 @@ pub(super) struct Inner {
3333
pub(super) errors: usize,
3434
/// Stats
3535
pub(super) stats: Stats,
36+
/// OIDs.
37+
pub(super) oids: Option<Oids>,
3638
}
3739

3840
impl std::fmt::Debug for Inner {
@@ -63,6 +65,7 @@ impl Inner {
6365
out_of_sync: 0,
6466
errors: 0,
6567
stats: Stats::default(),
68+
oids: None,
6669
}
6770
}
6871
/// Total number of connections managed by the pool.

pgdog/src/backend/pool/mod.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub mod healthcheck;
1313
pub mod inner;
1414
pub mod mapping;
1515
pub mod monitor;
16+
pub mod oids;
1617
pub mod pool_impl;
1718
pub mod replicas;
1819
pub mod request;
@@ -22,13 +23,14 @@ pub mod stats;
2223
pub mod waiting;
2324

2425
pub use address::Address;
25-
pub use cluster::{Cluster, PoolConfig};
26+
pub use cluster::{Cluster, ClusterShardConfig, PoolConfig, ShardingSchema};
2627
pub use config::Config;
2728
pub use connection::Connection;
2829
pub use error::Error;
2930
pub use guard::Guard;
3031
pub use healthcheck::Healtcheck;
3132
use monitor::Monitor;
33+
pub use oids::Oids;
3234
pub use pool_impl::Pool;
3335
pub use replicas::Replicas;
3436
pub use request::Request;

pgdog/src/backend/pool/monitor.rs

+16-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
3535
use std::time::{Duration, Instant};
3636

37-
use super::{Error, Guard, Healtcheck, Pool};
37+
use super::{Error, Guard, Healtcheck, Oids, Pool, Request};
3838
use crate::backend::Server;
3939
use crate::net::messages::BackendKeyData;
4040

@@ -95,7 +95,7 @@ impl Monitor {
9595

9696
select! {
9797
// A client is requesting a connection and no idle
98-
// connections are availble.
98+
// connections are available.
9999
_ = comms.request.notified() => {
100100
let (
101101
idle,
@@ -263,6 +263,20 @@ impl Monitor {
263263
ok
264264
}
265265

266+
#[allow(dead_code)]
267+
async fn fetch_oids(pool: &Pool) -> Result<(), Error> {
268+
if pool.lock().oids.is_none() {
269+
let oids = Oids::load(&mut pool.get(&Request::default()).await?)
270+
.await
271+
.ok();
272+
if let Some(oids) = oids {
273+
pool.lock().oids = Some(oids);
274+
}
275+
}
276+
277+
Ok(())
278+
}
279+
266280
/// Perform a periodic healthcheck on the pool.
267281
async fn healthcheck(pool: &Pool) -> Result<(), Error> {
268282
let (conn, healthcheck_timeout) = {

0 commit comments

Comments
 (0)