Skip to content

Commit a174400

Browse files
authored
Support for basic count(*), min, max functions (#42)
* save * save * save * save * Add more trait impl
1 parent 1a2bd0c commit a174400

File tree

17 files changed

+456
-142
lines changed

17 files changed

+456
-142
lines changed

pgdog/src/backend/databases.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ pub fn reconnect() {
4444
replace_databases(databases().duplicate());
4545
}
4646

47-
/// Iniitialize the databases for the first time.
47+
/// Initialize the databases for the first time.
4848
pub fn init() {
4949
let config = config();
5050
replace_databases(from_config(&config));
+229
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
//! Buffer messages to sort and aggregate them later.
2+
3+
use std::{cmp::Ordering, collections::VecDeque};
4+
5+
use crate::{
6+
frontend::router::parser::{Aggregate, AggregateTarget, OrderBy},
7+
net::messages::{DataRow, Datum, FromBytes, Message, Protocol, RowDescription, ToBytes},
8+
};
9+
10+
/// Sort and aggregate rows received from multiple shards.
11+
#[derive(Default, Debug)]
12+
pub(super) struct Buffer {
13+
buffer: VecDeque<DataRow>,
14+
full: bool,
15+
}
16+
17+
impl Buffer {
18+
/// Add message to buffer.
19+
pub(super) fn add(&mut self, message: Message) -> Result<(), super::Error> {
20+
let dr = DataRow::from_bytes(message.to_bytes()?)?;
21+
22+
self.buffer.push_back(dr);
23+
24+
Ok(())
25+
}
26+
27+
/// Mark the buffer as full. It will start returning messages now.
28+
/// Caller is responsible for sorting the buffer if needed.
29+
pub(super) fn full(&mut self) {
30+
self.full = true;
31+
}
32+
33+
/// Sort the buffer.
34+
pub(super) fn sort(&mut self, columns: &[OrderBy], rd: &RowDescription) {
35+
// Calculate column indecies once, since
36+
// fetching indecies by name is O(n).
37+
let mut cols = vec![];
38+
for column in columns {
39+
if let Some(index) = column.index() {
40+
cols.push(Some((index, column.asc())));
41+
} else if let Some(name) = column.name() {
42+
if let Some(index) = rd.field_index(name) {
43+
cols.push(Some((index, column.asc())));
44+
} else {
45+
cols.push(None);
46+
}
47+
} else {
48+
cols.push(None);
49+
};
50+
}
51+
52+
// Sort rows.
53+
let order_by = move |a: &DataRow, b: &DataRow| -> Ordering {
54+
for col in cols.iter().flatten() {
55+
let (index, asc) = col;
56+
let left = a.get_column(*index, rd);
57+
let right = b.get_column(*index, rd);
58+
59+
let ordering = match (left, right) {
60+
(Ok(Some(left)), Ok(Some(right))) => {
61+
if *asc {
62+
left.value.partial_cmp(&right.value)
63+
} else {
64+
right.value.partial_cmp(&left.value)
65+
}
66+
}
67+
68+
_ => Some(Ordering::Equal),
69+
};
70+
71+
if ordering != Some(Ordering::Equal) {
72+
return ordering.unwrap_or(Ordering::Equal);
73+
}
74+
}
75+
76+
Ordering::Equal
77+
};
78+
79+
self.buffer.make_contiguous().sort_by(order_by);
80+
}
81+
82+
/// Execute aggregate functions.
83+
///
84+
/// This function is the entrypoint for aggregation, so if you're reading this,
85+
/// understand that this will be a WIP for a while. Some (many) assumptions are made
86+
/// about queries and they will be tested (and adjusted) over time.
87+
///
88+
/// Some aggregates will require query rewriting. This information will need to be passed in,
89+
/// and extra columns fetched from Postgres removed from the final result.
90+
pub(super) fn aggregate(
91+
&mut self,
92+
aggregates: &[Aggregate],
93+
rd: &RowDescription,
94+
) -> Result<(), super::Error> {
95+
let buffer: VecDeque<DataRow> = self.buffer.drain(0..).collect();
96+
let mut result = DataRow::new();
97+
98+
for aggregate in aggregates {
99+
match aggregate {
100+
// COUNT(*) are summed across shards. This is the easiest of the aggregates,
101+
// yet it's probably the most common one.
102+
//
103+
// TODO: If there is a GROUP BY clause, we need to sum across specified columns.
104+
Aggregate::Count(AggregateTarget::Star(index)) => {
105+
let mut count = Datum::Bigint(0);
106+
for row in &buffer {
107+
let column = row.get_column(*index, rd)?;
108+
if let Some(column) = column {
109+
count = count + column.value;
110+
}
111+
}
112+
113+
result.insert(*index, count);
114+
}
115+
116+
Aggregate::Max(AggregateTarget::Star(index)) => {
117+
let mut max = Datum::Bigint(i64::MIN);
118+
for row in &buffer {
119+
let column = row.get_column(*index, rd)?;
120+
if let Some(column) = column {
121+
if max < column.value {
122+
max = column.value;
123+
}
124+
}
125+
}
126+
127+
result.insert(*index, max);
128+
}
129+
130+
Aggregate::Min(AggregateTarget::Star(index)) => {
131+
let mut min = Datum::Bigint(i64::MAX);
132+
for row in &buffer {
133+
let column = row.get_column(*index, rd)?;
134+
if let Some(column) = column {
135+
if min > column.value {
136+
min = column.value;
137+
}
138+
}
139+
}
140+
141+
result.insert(*index, min);
142+
}
143+
_ => (),
144+
}
145+
}
146+
147+
if !result.is_empty() {
148+
self.buffer.push_back(result);
149+
} else {
150+
self.buffer = buffer;
151+
}
152+
153+
Ok(())
154+
}
155+
156+
/// Take messages from buffer.
157+
pub(super) fn take(&mut self) -> Option<Message> {
158+
if self.full {
159+
self.buffer.pop_front().and_then(|s| s.message().ok())
160+
} else {
161+
None
162+
}
163+
}
164+
165+
pub(super) fn len(&self) -> usize {
166+
self.buffer.len()
167+
}
168+
169+
#[allow(dead_code)]
170+
pub(super) fn is_empty(&self) -> bool {
171+
self.len() == 0
172+
}
173+
}
174+
175+
#[cfg(test)]
176+
mod test {
177+
use super::*;
178+
use crate::net::messages::{Field, Format};
179+
180+
#[test]
181+
fn test_sort_buffer() {
182+
let mut buf = Buffer::default();
183+
let rd = RowDescription::new(&[Field::bigint("one"), Field::text("two")]);
184+
let columns = [OrderBy::Asc(1), OrderBy::Desc(2)];
185+
186+
for i in 0..25_i64 {
187+
let mut dr = DataRow::new();
188+
dr.add(25 - i).add((25 - i).to_string());
189+
buf.add(dr.message().unwrap()).unwrap();
190+
}
191+
192+
buf.sort(&columns, &rd);
193+
buf.full();
194+
195+
let mut i = 1;
196+
while let Some(message) = buf.take() {
197+
let dr = DataRow::from_bytes(message.to_bytes().unwrap()).unwrap();
198+
let one = dr.get::<i64>(0, Format::Text).unwrap();
199+
let two = dr.get::<String>(1, Format::Text).unwrap();
200+
assert_eq!(one, i);
201+
assert_eq!(two, i.to_string());
202+
i += 1;
203+
}
204+
205+
assert_eq!(i, 26);
206+
}
207+
208+
#[test]
209+
fn test_aggregate_buffer() {
210+
let mut buf = Buffer::default();
211+
let rd = RowDescription::new(&[Field::bigint("count")]);
212+
let agg = [Aggregate::Count(AggregateTarget::Star(0))];
213+
214+
for _ in 0..6 {
215+
let mut dr = DataRow::new();
216+
dr.add(15_i64);
217+
buf.add(dr.message().unwrap()).unwrap();
218+
}
219+
220+
buf.aggregate(&agg, &rd).unwrap();
221+
buf.full();
222+
223+
assert_eq!(buf.len(), 1);
224+
let row = buf.take().unwrap();
225+
let dr = DataRow::from_bytes(row.to_bytes().unwrap()).unwrap();
226+
let count = dr.get::<i64>(0, Format::Text).unwrap();
227+
assert_eq!(count, 15 * 6);
228+
}
229+
}

pgdog/src/backend/pool/connection/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ use super::{
2121
use std::{mem::replace, time::Duration};
2222

2323
mod binding;
24+
mod buffer;
2425
mod multi_shard;
25-
mod sort_buffer;
2626

2727
use binding::Binding;
2828
use multi_shard::MultiShard;

pgdog/src/backend/pool/connection/multi_shard.rs

+15-10
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::{
99
},
1010
};
1111

12-
use super::sort_buffer::SortBuffer;
12+
use super::buffer::Buffer;
1313

1414
/// Multi-shard state.
1515
#[derive(Default, Debug)]
@@ -32,8 +32,8 @@ pub(super) struct MultiShard {
3232
rd: Option<RowDescription>,
3333
/// Rewritten CommandComplete message.
3434
command_complete: Option<Message>,
35-
/// Sorting buffer.
36-
sort_buffer: SortBuffer,
35+
/// Sorting/aggregate buffer.
36+
buffer: Buffer,
3737
}
3838

3939
impl MultiShard {
@@ -51,7 +51,6 @@ impl MultiShard {
5151
/// or modified.
5252
pub(super) fn forward(&mut self, message: Message) -> Result<Option<Message>, super::Error> {
5353
let mut forward = None;
54-
let order_by = self.route.order_by();
5554

5655
match message.code() {
5756
'Z' => {
@@ -74,13 +73,19 @@ impl MultiShard {
7473
self.cc += 1;
7574

7675
if self.cc == self.shards {
77-
self.sort_buffer.full();
76+
self.buffer.full();
7877
if let Some(ref rd) = self.rd {
79-
self.sort_buffer.sort(order_by, rd);
78+
self.buffer.aggregate(self.route.aggregate(), rd)?;
79+
self.buffer.sort(self.route.order_by(), rd);
8080
}
8181

8282
if has_rows {
83-
self.command_complete = Some(cc.rewrite(self.rows)?.message()?);
83+
let rows = if self.route.should_buffer() {
84+
self.buffer.len()
85+
} else {
86+
self.rows
87+
};
88+
self.command_complete = Some(cc.rewrite(rows)?.message()?);
8489
} else {
8590
forward = Some(cc.message()?);
8691
}
@@ -107,10 +112,10 @@ impl MultiShard {
107112
}
108113

109114
'D' => {
110-
if order_by.is_empty() {
115+
if !self.route.should_buffer() {
111116
forward = Some(message);
112117
} else {
113-
self.sort_buffer.add(message)?;
118+
self.buffer.add(message)?;
114119
}
115120
}
116121

@@ -129,7 +134,7 @@ impl MultiShard {
129134

130135
/// Multi-shard state is ready to send messages.
131136
pub(super) fn message(&mut self) -> Option<Message> {
132-
if let Some(data_row) = self.sort_buffer.take() {
137+
if let Some(data_row) = self.buffer.take() {
133138
Some(data_row)
134139
} else {
135140
self.command_complete.take()

0 commit comments

Comments
 (0)