forked from tyrchen/geektime-rust
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdiff_refactor
254 lines (242 loc) · 9.48 KB
/
diff_refactor
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
diff --git a/39/kv/src/error.rs b/39/kv/src/error.rs
index 3a5b1e7..09724c7 100644
--- a/39/kv/src/error.rs
+++ b/39/kv/src/error.rs
@@ -2,8 +2,8 @@ use thiserror::Error;
#[derive(Error, Debug)]
pub enum KvError {
- #[error("Not found for table: {0}, key: {1}")]
- NotFound(String, String),
+ #[error("Not found: {0}")]
+ NotFound(String),
#[error("Frame is larger than max size")]
FrameError,
#[error("Command is invalid: `{0}`")]
diff --git a/39/kv/src/pb/mod.rs b/39/kv/src/pb/mod.rs
index bf784aa..b625e22 100644
--- a/39/kv/src/pb/mod.rs
+++ b/39/kv/src/pb/mod.rs
@@ -231,7 +231,7 @@ impl From<KvError> for CommandResponse {
};
match e {
- KvError::NotFound(_, _) => result.status = StatusCode::NOT_FOUND.as_u16() as _,
+ KvError::NotFound(_) => result.status = StatusCode::NOT_FOUND.as_u16() as _,
KvError::InvalidCommand(_) => result.status = StatusCode::BAD_REQUEST.as_u16() as _,
_ => {}
}
diff --git a/39/kv/src/service/command_service.rs b/39/kv/src/service/command_service.rs
index e609036..dfbb2e3 100644
--- a/39/kv/src/service/command_service.rs
+++ b/39/kv/src/service/command_service.rs
@@ -4,7 +4,7 @@ impl CommandService for Hget {
fn execute(self, store: &impl Storage) -> CommandResponse {
match store.get(&self.table, &self.key) {
Ok(Some(v)) => v.into(),
- Ok(None) => KvError::NotFound(self.table, self.key).into(),
+ Ok(None) => KvError::NotFound(format!("table {}, key {}", self.table, self.key)).into(),
Err(e) => e.into(),
}
}
diff --git a/39/kv/src/service/topic.rs b/39/kv/src/service/topic.rs
index 3c5874e..80fe82f 100644
--- a/39/kv/src/service/topic.rs
+++ b/39/kv/src/service/topic.rs
@@ -6,7 +6,7 @@ use std::sync::{
use tokio::sync::mpsc;
use tracing::{debug, info, warn};
-use crate::{CommandResponse, Value};
+use crate::{CommandResponse, KvError, Value};
/// topic 里最大存放的数据
const BROADCAST_CAPACITY: usize = 128;
@@ -23,7 +23,7 @@ pub trait Topic: Send + Sync + 'static {
/// 订阅某个主题
fn subscribe(self, name: String) -> mpsc::Receiver<Arc<CommandResponse>>;
/// 取消对主题的订阅
- fn unsubscribe(self, name: String, id: u32);
+ fn unsubscribe(self, name: String, id: u32) -> Result<u32, KvError>;
/// 往主题里发布一个数据
fn publish(self, name: String, value: Arc<CommandResponse>);
}
@@ -68,49 +68,67 @@ impl Topic for Arc<Broadcaster> {
rx
}
- fn unsubscribe(self, name: String, id: u32) {
- if let Some(v) = self.topics.get_mut(&name) {
- // 在 topics 表里找到 topic 的 subscription id,删除
- v.remove(&id);
-
- // 如果这个 topic 为空,则也删除 topic
- if v.is_empty() {
- info!("Topic: {:?} is deleted", &name);
- drop(v);
- self.topics.remove(&name);
- }
+ fn unsubscribe(self, name: String, id: u32) -> Result<u32, KvError> {
+ match self.remove_subscription(name, id) {
+ Some(id) => Ok(id),
+ None => Err(KvError::NotFound(format!("subscription {}", id))),
}
-
- debug!("Subscription {} is removed!", id);
- // 在 subscription 表中同样删除
- self.subscriptions.remove(&id);
}
fn publish(self, name: String, value: Arc<CommandResponse>) {
tokio::spawn(async move {
+ let mut ids = vec![];
match self.topics.get(&name) {
- Some(chan) => {
+ Some(topic) => {
// 复制整个 topic 下所有的 subscription id
// 这里我们每个 id 是 u32,如果一个 topic 下有 10k 订阅,复制的成本
// 也就是 40k 堆内存(外加一些控制结构),所以效率不算差
// 这也是为什么我们用 NEXT_ID 来控制 subscription id 的生成
- let chan = chan.value().clone();
+
+ let subscriptions = topic.value().clone();
+ // 尽快释放锁
+ drop(topic);
// 循环发送
- for id in chan.into_iter() {
+ for id in subscriptions.into_iter() {
if let Some(tx) = self.subscriptions.get(&id) {
if let Err(e) = tx.send(value.clone()).await {
warn!("Publish to {} failed! error: {:?}", id, e);
+ // client 中断连接
+ ids.push(id);
}
}
}
}
None => {}
}
+ for id in ids {
+ self.remove_subscription(name.clone(), id);
+ }
});
}
}
+impl Broadcaster {
+ pub fn remove_subscription(&self, name: String, id: u32) -> Option<u32> {
+ if let Some(v) = self.topics.get_mut(&name) {
+ // 在 topics 表里找到 topic 的 subscription id,删除
+ v.remove(&id);
+
+ // 如果这个 topic 为空,则也删除 topic
+ if v.is_empty() {
+ info!("Topic: {:?} is deleted", &name);
+ drop(v);
+ self.topics.remove(&name);
+ }
+ }
+
+ debug!("Subscription {} is removed!", id);
+ // 在 subscription 表中同样删除
+ self.subscriptions.remove(&id).map(|(id, _)| id)
+ }
+}
+
#[cfg(test)]
mod tests {
use std::convert::TryInto;
@@ -145,7 +163,8 @@ mod tests {
assert_res_ok(&res1, &[v.clone()], &[]);
// 如果 subscriber 取消订阅,则收不到新数据
- b.clone().unsubscribe(lobby.clone(), id1 as _);
+ let result = b.clone().unsubscribe(lobby.clone(), id1 as _).unwrap();
+ assert_eq!(result, id1 as _);
// publish
let v: Value = "world".into();
diff --git a/39/kv/src/service/topic_service.rs b/39/kv/src/service/topic_service.rs
index c7f49c6..95deef4 100644
--- a/39/kv/src/service/topic_service.rs
+++ b/39/kv/src/service/topic_service.rs
@@ -20,8 +20,11 @@ impl TopicService for Subscribe {
impl TopicService for Unsubscribe {
fn execute(self, topic: impl Topic) -> StreamingResponse {
- topic.unsubscribe(self.topic, self.id);
- Box::pin(stream::once(async { Arc::new(CommandResponse::ok()) }))
+ let res = match topic.unsubscribe(self.topic, self.id) {
+ Ok(_) => CommandResponse::ok(),
+ Err(e) => e.into(),
+ };
+ Box::pin(stream::once(async { Arc::new(res) }))
}
}
@@ -31,3 +34,76 @@ impl TopicService for Publish {
Box::pin(stream::once(async { Arc::new(CommandResponse::ok()) }))
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{assert_res_error, assert_res_ok, dispatch_stream, Broadcaster, CommandRequest};
+ use futures::StreamExt;
+ use std::{convert::TryInto, time::Duration};
+ use tokio::time;
+
+ #[tokio::test]
+ async fn dispatch_publish_should_work() {
+ let topic = Arc::new(Broadcaster::default());
+ let cmd = CommandRequest::new_publish("lobby", vec!["hello".into()]);
+ let mut res = dispatch_stream(cmd, topic);
+ let data = res.next().await.unwrap();
+ assert_res_ok(&data, &[], &[]);
+ }
+
+ #[tokio::test]
+ async fn dispatch_subscribe_should_work() {
+ let topic = Arc::new(Broadcaster::default());
+ let cmd = CommandRequest::new_subscribe("lobby");
+ let mut res = dispatch_stream(cmd, topic);
+ let id: i64 = res.next().await.unwrap().as_ref().try_into().unwrap();
+ assert!(id > 0);
+ }
+
+ #[tokio::test]
+ async fn dispatch_subscribe_abnormal_quit_should_be_removed_on_next_publish() {
+ let topic = Arc::new(Broadcaster::default());
+ let id = {
+ let cmd = CommandRequest::new_subscribe("lobby");
+ let mut res = dispatch_stream(cmd, topic.clone());
+ let id: i64 = res.next().await.unwrap().as_ref().try_into().unwrap();
+ drop(res);
+ id as u32
+ };
+
+ // publish 时,这个 subscription 已经失效,所以会被删除
+ let cmd = CommandRequest::new_publish("lobby", vec!["hello".into()]);
+ dispatch_stream(cmd, topic.clone());
+ time::sleep(Duration::from_millis(10)).await;
+
+ // 如果再尝试删除,应该返回 KvError
+ let result = topic.unsubscribe("lobby".into(), id);
+ assert!(result.is_err());
+ }
+
+ #[tokio::test]
+ async fn dispatch_unsubscribe_should_work() {
+ let topic = Arc::new(Broadcaster::default());
+ let cmd = CommandRequest::new_subscribe("lobby");
+ let mut res = dispatch_stream(cmd, topic.clone());
+ let id: i64 = res.next().await.unwrap().as_ref().try_into().unwrap();
+
+ let cmd = CommandRequest::new_unsubscribe("lobby", id as _);
+ let mut res = dispatch_stream(cmd, topic);
+ let data = res.next().await.unwrap();
+
+ assert_res_ok(&data, &[], &[]);
+ }
+
+ #[tokio::test]
+ async fn dispatch_unsubscribe_random_id_should_error() {
+ let topic = Arc::new(Broadcaster::default());
+
+ let cmd = CommandRequest::new_unsubscribe("lobby", 9527);
+ let mut res = dispatch_stream(cmd, topic);
+ let data = res.next().await.unwrap();
+
+ assert_res_error(&data, 404, "Not found: subscription 9527");
+ }
+}