Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pgvector support for sorting #46

Merged
merged 13 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pgdog.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ host = "127.0.0.1"
database_name = "shard_1"
shard = 1


#
# Read/write access to theses tables will be automatically
# sharded.
Expand All @@ -47,13 +46,24 @@ shard = 1
database = "pgdog_sharded"
table = "sharded"
column = "id"
primary = true

[[sharded_tables]]
database = "pgdog_sharded"
table = "users"
column = "id"
primary = true

# [[sharded_tables]]
# database = "pgdog_sharded"
# table = "vectors"
# column = "embedding"
# primary = true
# centroids = [[
# [1, 2, 3],
# [100, 200, 300],
# ]]

#
# ActiveRecord sends these queries
# at startup to figure out the schema.
Expand Down
5 changes: 3 additions & 2 deletions pgdog/src/backend/databases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
use super::{
pool::{Address, Config},
replication::ReplicationConfig,
Cluster, Error, ShardedTables,
Cluster, ClusterShardConfig, Error, ShardedTables,
};

static DATABASES: Lazy<ArcSwap<Databases>> =
Expand Down Expand Up @@ -217,7 +217,8 @@ pub fn from_config(config: &ConfigAndUsers) -> Databases {
config: Config::new(general, replica, user),
})
.collect::<Vec<_>>();
shard_configs.push((primary, replicas));

shard_configs.push(ClusterShardConfig { primary, replicas });
}

let sharded_tables = sharded_tables
Expand Down
2 changes: 1 addition & 1 deletion pgdog/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod server;
pub mod stats;

pub use error::Error;
pub use pool::{Cluster, Pool, Replicas, Shard};
pub use pool::{Cluster, ClusterShardConfig, Pool, Replicas, Shard, ShardingSchema};
pub use prepared_statements::PreparedStatements;
pub use replication::ShardedTables;
pub use schema::Schema;
Expand Down
27 changes: 25 additions & 2 deletions pgdog/src/backend/pool/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,25 @@ pub struct Cluster {
replication_sharding: Option<String>,
}

/// Sharding configuration from the cluster.
#[derive(Debug, Clone, Default)]
pub struct ShardingSchema {
/// Number of shards.
pub shards: usize,
/// Sharded tables.
pub tables: ShardedTables,
}

pub struct ClusterShardConfig {
pub primary: Option<PoolConfig>,
pub replicas: Vec<PoolConfig>,
}

impl Cluster {
/// Create new cluster of shards.
pub fn new(
name: &str,
shards: &[(Option<PoolConfig>, Vec<PoolConfig>)],
shards: &[ClusterShardConfig],
lb_strategy: LoadBalancingStrategy,
password: &str,
pooler_mode: PoolerMode,
Expand All @@ -46,7 +60,7 @@ impl Cluster {
Self {
shards: shards
.iter()
.map(|addr| Shard::new(addr.0.clone(), &addr.1, lb_strategy))
.map(|config| Shard::new(&config.primary, &config.replicas, lb_strategy))
.collect(),
name: name.to_owned(),
password: password.to_owned(),
Expand Down Expand Up @@ -191,6 +205,14 @@ impl Cluster {
.as_ref()
.and_then(|database| databases().replication(database))
}

/// Get all data required for sharding.
pub fn sharding_schema(&self) -> ShardingSchema {
ShardingSchema {
shards: self.shards.len(),
tables: self.sharded_tables.clone(),
}
}
}

#[cfg(test)]
Expand All @@ -210,6 +232,7 @@ mod test {
name: Some("sharded".into()),
column: "id".into(),
primary: true,
centroids: vec![],
}]),
shards: vec![Shard::default(), Shard::default()],
..Default::default()
Expand Down
20 changes: 20 additions & 0 deletions pgdog/src/backend/pool/connection/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ impl Grouping {
}

/// The aggregate accumulator.
///
/// This transfors disttributed aggregate functions
/// into a single value.
#[derive(Debug)]
struct Accumulator<'a> {
target: &'a AggregateTarget,
Expand All @@ -54,6 +57,7 @@ impl<'a> Accumulator<'a> {
.collect()
}

/// Transform COUNT(*), MIN, MAX, etc., from multiple shards into a single value.
fn accumulate(&mut self, row: &DataRow, rd: &RowDescription) -> Result<(), Error> {
let column = row.get_column(self.target.column(), rd)?;
if let Some(column) = column {
Expand All @@ -77,6 +81,13 @@ impl<'a> Accumulator<'a> {
self.datum = column.value;
}
}
AggregateFunction::Sum => {
if !self.datum.is_null() {
self.datum = self.datum.clone() + column.value;
} else {
self.datum = column.value;
}
}
_ => (),
}
}
Expand Down Expand Up @@ -122,6 +133,15 @@ impl<'a> Aggregates<'a> {

let mut rows = VecDeque::new();
for (grouping, accumulator) in self.mappings {
//
// Aggregate rules in Postgres dictate that the only
// columns present in the row are either:
//
// 1. part of the GROUP BY, which means they are
// stored in the grouping
// 2. are aggregate functions, which means they
// are stored in the accunmulator
//
let mut row = DataRow::new();
for (idx, datum) in grouping.columns {
row.insert(idx, datum);
Expand Down
62 changes: 45 additions & 17 deletions pgdog/src/backend/pool/connection/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{cmp::Ordering, collections::VecDeque};

use crate::{
frontend::router::parser::{Aggregate, OrderBy},
net::messages::{DataRow, FromBytes, Message, Protocol, RowDescription, ToBytes},
net::messages::{DataRow, FromBytes, Message, Protocol, RowDescription, ToBytes, Vector},
};

use super::Aggregates;
Expand Down Expand Up @@ -34,33 +34,61 @@ impl Buffer {

/// Sort the buffer.
pub(super) fn sort(&mut self, columns: &[OrderBy], rd: &RowDescription) {
// Calculate column indecies once, since
// fetching indecies by name is O(n).
// Calculate column indices once, since
// fetching indices by name is O(number of columns).
let mut cols = vec![];
for column in columns {
if let Some(index) = column.index() {
cols.push(Some((index, column.asc())));
} else if let Some(name) = column.name() {
if let Some(index) = rd.field_index(name) {
cols.push(Some((index, column.asc())));
} else {
cols.push(None);
match column {
OrderBy::Asc(_) => cols.push(column.clone()),
OrderBy::AscColumn(name) => {
if let Some(index) = rd.field_index(name) {
cols.push(OrderBy::Asc(index + 1));
}
}
OrderBy::Desc(_) => cols.push(column.clone()),
OrderBy::DescColumn(name) => {
if let Some(index) = rd.field_index(name) {
cols.push(OrderBy::Desc(index + 1));
}
}
OrderBy::AscVectorL2(_, _) => cols.push(column.clone()),
OrderBy::AscVectorL2Column(name, vector) => {
if let Some(index) = rd.field_index(name) {
cols.push(OrderBy::AscVectorL2(index + 1, vector.clone()));
}
}
} else {
cols.push(None);
};
}

// Sort rows.
let order_by = move |a: &DataRow, b: &DataRow| -> Ordering {
for col in cols.iter().flatten() {
let (index, asc) = col;
let left = a.get_column(*index, rd);
let right = b.get_column(*index, rd);
for col in cols.iter() {
let index = col.index();
let asc = col.asc();
let index = if let Some(index) = index {
index
} else {
continue;
};
let left = a.get_column(index, rd);
let right = b.get_column(index, rd);

let ordering = match (left, right) {
(Ok(Some(left)), Ok(Some(right))) => {
if *asc {
// Handle the special vector case.
if let OrderBy::AscVectorL2(_, vector) = col {
let left: Option<Vector> = left.value.try_into().ok();
let right: Option<Vector> = right.value.try_into().ok();

if let (Some(left), Some(right)) = (left, right) {
let left = left.distance_l2(vector);
let right = right.distance_l2(vector);

left.partial_cmp(&right)
} else {
Some(Ordering::Equal)
}
} else if asc {
left.value.partial_cmp(&right.value)
} else {
right.value.partial_cmp(&left.value)
Expand Down
8 changes: 6 additions & 2 deletions pgdog/src/backend/pool/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{

use super::{
super::{pool::Guard, Error},
Address, Cluster, Request,
Address, Cluster, Request, ShardingSchema,
};

use std::{mem::replace, time::Duration};
Expand Down Expand Up @@ -91,8 +91,12 @@ impl Connection {
&mut self,
shard: Option<usize>,
replication_config: &ReplicationConfig,
sharding_schema: &ShardingSchema,
) -> Result<(), Error> {
self.binding = Binding::Replication(None, Buffer::new(shard, replication_config));
self.binding = Binding::Replication(
None,
Buffer::new(shard, replication_config, sharding_schema),
);
Ok(())
}

Expand Down
7 changes: 5 additions & 2 deletions pgdog/src/backend/pool/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{cmp::max, time::Instant};
use crate::backend::Server;
use crate::net::messages::BackendKeyData;

use super::{Ban, Config, Error, Mapping, Stats};
use super::{Ban, Config, Error, Mapping, Oids, Stats};

/// Pool internals protected by a mutex.
#[derive(Default)]
Expand All @@ -21,7 +21,7 @@ pub(super) struct Inner {
pub(super) waiting: usize,
/// Pool ban status.
pub(super) ban: Option<Ban>,
/// Pool is online and availble to clients.
/// Pool is online and available to clients.
pub(super) online: bool,
/// Pool is paused.
pub(super) paused: bool,
Expand All @@ -33,6 +33,8 @@ pub(super) struct Inner {
pub(super) errors: usize,
/// Stats
pub(super) stats: Stats,
/// OIDs.
pub(super) oids: Option<Oids>,
}

impl std::fmt::Debug for Inner {
Expand Down Expand Up @@ -63,6 +65,7 @@ impl Inner {
out_of_sync: 0,
errors: 0,
stats: Stats::default(),
oids: None,
}
}
/// Total number of connections managed by the pool.
Expand Down
4 changes: 3 additions & 1 deletion pgdog/src/backend/pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub mod healthcheck;
pub mod inner;
pub mod mapping;
pub mod monitor;
pub mod oids;
pub mod pool_impl;
pub mod replicas;
pub mod request;
Expand All @@ -22,13 +23,14 @@ pub mod stats;
pub mod waiting;

pub use address::Address;
pub use cluster::{Cluster, PoolConfig};
pub use cluster::{Cluster, ClusterShardConfig, PoolConfig, ShardingSchema};
pub use config::Config;
pub use connection::Connection;
pub use error::Error;
pub use guard::Guard;
pub use healthcheck::Healtcheck;
use monitor::Monitor;
pub use oids::Oids;
pub use pool_impl::Pool;
pub use replicas::Replicas;
pub use request::Request;
Expand Down
18 changes: 16 additions & 2 deletions pgdog/src/backend/pool/monitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

use std::time::{Duration, Instant};

use super::{Error, Guard, Healtcheck, Pool};
use super::{Error, Guard, Healtcheck, Oids, Pool, Request};
use crate::backend::Server;
use crate::net::messages::BackendKeyData;

Expand Down Expand Up @@ -95,7 +95,7 @@ impl Monitor {

select! {
// A client is requesting a connection and no idle
// connections are availble.
// connections are available.
_ = comms.request.notified() => {
let (
idle,
Expand Down Expand Up @@ -263,6 +263,20 @@ impl Monitor {
ok
}

#[allow(dead_code)]
async fn fetch_oids(pool: &Pool) -> Result<(), Error> {
if pool.lock().oids.is_none() {
let oids = Oids::load(&mut pool.get(&Request::default()).await?)
.await
.ok();
if let Some(oids) = oids {
pool.lock().oids = Some(oids);
}
}

Ok(())
}

/// Perform a periodic healthcheck on the pool.
async fn healthcheck(pool: &Pool) -> Result<(), Error> {
let (conn, healthcheck_timeout) = {
Expand Down
Loading
Loading