Skip to content

Commit d3310a6

Browse files
authored
Client md5 auth and clean up scram (#77)
* client md5 auth and clean up scram * add pw * add user * add user * log
1 parent d412238 commit d3310a6

File tree

8 files changed

+213
-95
lines changed

8 files changed

+213
-95
lines changed

.circleci/run_tests.sh

+22-19
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ function start_pgcat() {
1313

1414
# Setup the database with shards and user
1515
PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 5432 -U postgres -f tests/sharding/query_routing_setup.sql
16+
1617
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard0 -i
1718
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard1 -i
1819
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard2 -i
@@ -30,26 +31,28 @@ toxiproxy-cli create -l 127.0.0.1:5433 -u 127.0.0.1:5432 postgres_replica
3031

3132
start_pgcat "info"
3233

34+
export PGPASSWORD=sharding_user
35+
3336
# pgbench test
34-
pgbench -i -h 127.0.0.1 -p 6432
35-
pgbench -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol simple -f tests/pgbench/simple.sql
36-
pgbench -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol extended
37+
pgbench -U sharding_user -i -h 127.0.0.1 -p 6432
38+
pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol simple -f tests/pgbench/simple.sql
39+
pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol extended
3740

3841
# COPY TO STDOUT test
39-
psql -h 127.0.0.1 -p 6432 -c 'COPY (SELECT * FROM pgbench_accounts LIMIT 15) TO STDOUT;' > /dev/null
42+
psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'COPY (SELECT * FROM pgbench_accounts LIMIT 15) TO STDOUT;' > /dev/null
4043

4144
# Query cancellation test
42-
(psql -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(5)' || true) &
45+
(psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(5)' || true) &
4346
killall psql -s SIGINT
4447

4548
# Sharding insert
46-
psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_insert.sql
49+
psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_insert.sql
4750

4851
# Sharding select
49-
psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_select.sql > /dev/null
52+
psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_select.sql > /dev/null
5053

5154
# Replica/primary selection & more sharding tests
52-
psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_primary_replica.sql > /dev/null
55+
psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_primary_replica.sql > /dev/null
5356

5457
#
5558
# ActiveRecord tests
@@ -61,15 +64,15 @@ cd tests/ruby && \
6164
cd ../..
6265

6366
# Admin tests
64-
psql -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
65-
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null
66-
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null
67-
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null
68-
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW LISTS' > /dev/null
69-
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW POOLS' > /dev/null
70-
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW VERSION' > /dev/null
71-
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev/null # will ignore
72-
(! psql -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null)
67+
psql -U sharding_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
68+
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null
69+
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null
70+
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null
71+
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW LISTS' > /dev/null
72+
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW POOLS' > /dev/null
73+
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW VERSION' > /dev/null
74+
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev/null # will ignore
75+
(! psql -U sharding_user -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null)
7376

7477
# Start PgCat in debug to demonstrate failover better
7578
start_pgcat "trace"
@@ -79,7 +82,7 @@ toxiproxy-cli toxic add -t latency -a latency=300 postgres_replica
7982
sleep 1
8083

8184
# Note the failover in the logs
82-
timeout 5 psql -e -h 127.0.0.1 -p 6432 <<-EOF
85+
timeout 5 psql -U sharding_user -e -h 127.0.0.1 -p 6432 <<-EOF
8386
SELECT 1;
8487
SELECT 1;
8588
SELECT 1;
@@ -97,7 +100,7 @@ sed -i 's/pool_mode = "transaction"/pool_mode = "session"/' pgcat.toml
97100
kill -SIGHUP $(pgrep pgcat)
98101

99102
# Prepared statements that will only work in session mode
100-
pgbench -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol prepared
103+
pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol prepared
101104

102105
# Attempt clean shut down
103106
killall pgcat -s SIGINT

Cargo.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgcat"
3-
version = "0.2.0-beta1"
3+
version = "0.2.1-beta1"
44
edition = "2021"
55

66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

src/client.rs

+41-4
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ impl Client {
7272
server_info: BytesMut,
7373
stats: Reporter,
7474
) -> Result<Client, Error> {
75-
let config = get_config();
75+
let config = get_config().clone();
7676
let transaction_mode = config.general.pool_mode.starts_with("t");
77-
drop(config);
77+
// drop(config);
7878
loop {
7979
trace!("Waiting for StartupMessage");
8080

@@ -108,14 +108,51 @@ impl Client {
108108
// Regular startup message.
109109
PROTOCOL_VERSION_NUMBER => {
110110
trace!("Got StartupMessage");
111-
112-
// TODO: perform actual auth.
113111
let parameters = parse_startup(bytes.clone())?;
114112

115113
// Generate random backend ID and secret key
116114
let process_id: i32 = rand::random();
117115
let secret_key: i32 = rand::random();
118116

117+
// Perform MD5 authentication.
118+
// TODO: Add SASL support.
119+
let salt = md5_challenge(&mut stream).await?;
120+
121+
let code = match stream.read_u8().await {
122+
Ok(p) => p,
123+
Err(_) => return Err(Error::SocketError),
124+
};
125+
126+
// PasswordMessage
127+
if code as char != 'p' {
128+
debug!("Expected p, got {}", code as char);
129+
return Err(Error::ProtocolSyncError);
130+
}
131+
132+
let len = match stream.read_i32().await {
133+
Ok(len) => len,
134+
Err(_) => return Err(Error::SocketError),
135+
};
136+
137+
let mut password_response = vec![0u8; (len - 4) as usize];
138+
139+
match stream.read_exact(&mut password_response).await {
140+
Ok(_) => (),
141+
Err(_) => return Err(Error::SocketError),
142+
};
143+
144+
// Compare server and client hashes.
145+
let password_hash =
146+
md5_hash_password(&config.user.name, &config.user.password, &salt);
147+
148+
if password_hash != password_response {
149+
debug!("Password authentication failed");
150+
wrong_password(&mut stream, &config.user.name).await?;
151+
return Err(Error::ClientError);
152+
}
153+
154+
debug!("Password authentication successful");
155+
119156
auth_ok(&mut stream).await?;
120157
write_all(&mut stream, server_info).await?;
121158
backend_key_data(&mut stream, process_id, secret_key).await?;

src/errors.rs

+1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ pub enum Error {
99
ServerError,
1010
BadConfig,
1111
AllServersDown,
12+
ClientError,
1213
}

src/messages.rs

+68-8
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,26 @@ pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
4040
Ok(write_all(stream, auth_ok).await?)
4141
}
4242

43+
/// Generate md5 password challenge.
44+
pub async fn md5_challenge(stream: &mut TcpStream) -> Result<[u8; 4], Error> {
45+
// let mut rng = rand::thread_rng();
46+
let salt: [u8; 4] = [
47+
rand::random(),
48+
rand::random(),
49+
rand::random(),
50+
rand::random(),
51+
];
52+
53+
let mut res = BytesMut::new();
54+
res.put_u8(b'R');
55+
res.put_i32(12);
56+
res.put_i32(5); // MD5
57+
res.put_slice(&salt[..]);
58+
59+
write_all(stream, res).await?;
60+
Ok(salt)
61+
}
62+
4363
/// Give the client the process_id and secret we generated
4464
/// used in query cancellation.
4565
pub async fn backend_key_data(
@@ -160,14 +180,8 @@ pub fn parse_startup(bytes: BytesMut) -> Result<HashMap<String, String>, Error>
160180
Ok(result)
161181
}
162182

163-
/// Send password challenge response to the server.
164-
/// This is the MD5 challenge.
165-
pub async fn md5_password(
166-
stream: &mut TcpStream,
167-
user: &str,
168-
password: &str,
169-
salt: &[u8],
170-
) -> Result<(), Error> {
183+
/// Create md5 password hash given a salt.
184+
pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
171185
let mut md5 = Md5::new();
172186

173187
// First pass
@@ -186,6 +200,19 @@ pub async fn md5_password(
186200
.collect::<Vec<u8>>();
187201
password.push(0);
188202

203+
password
204+
}
205+
206+
/// Send password challenge response to the server.
207+
/// This is the MD5 challenge.
208+
pub async fn md5_password(
209+
stream: &mut TcpStream,
210+
user: &str,
211+
password: &str,
212+
salt: &[u8],
213+
) -> Result<(), Error> {
214+
let password = md5_hash_password(user, password, salt);
215+
189216
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
190217

191218
message.put_u8(b'p');
@@ -264,6 +291,39 @@ pub async fn error_response(stream: &mut OwnedWriteHalf, message: &str) -> Resul
264291
Ok(write_all_half(stream, res).await?)
265292
}
266293

294+
pub async fn wrong_password(stream: &mut TcpStream, user: &str) -> Result<(), Error> {
295+
let mut error = BytesMut::new();
296+
297+
// Error level
298+
error.put_u8(b'S');
299+
error.put_slice(&b"FATAL\0"[..]);
300+
301+
// Error level (non-translatable)
302+
error.put_u8(b'V');
303+
error.put_slice(&b"FATAL\0"[..]);
304+
305+
// Error code: not sure how much this matters.
306+
error.put_u8(b'C');
307+
error.put_slice(&b"28P01\0"[..]); // system_error, see Appendix A.
308+
309+
// The short error message.
310+
error.put_u8(b'M');
311+
error.put_slice(&format!("password authentication failed for user \"{}\"\0", user).as_bytes());
312+
313+
// No more fields follow.
314+
error.put_u8(0);
315+
316+
// Compose the two message reply.
317+
let mut res = BytesMut::new();
318+
319+
res.put_u8(b'E');
320+
res.put_i32(error.len() as i32 + 4);
321+
322+
res.put(error);
323+
324+
write_all(stream, res).await
325+
}
326+
267327
/// Respond to a SHOW SHARD command.
268328
pub async fn show_response(
269329
stream: &mut OwnedWriteHalf,

0 commit comments

Comments
 (0)