@@ -4,20 +4,26 @@ use crate::net::NTSocket;
4
4
use async_std:: net:: TcpListener ;
5
5
use async_std:: sync:: { channel, Arc , Mutex , Sender } ;
6
6
use async_std:: task;
7
+ use async_std:: io;
7
8
use async_tungstenite:: tungstenite:: handshake:: server:: { Request , Response } ;
8
9
use async_tungstenite:: tungstenite:: http:: { HeaderValue , StatusCode } ;
9
10
use itertools:: Itertools ;
10
11
use proto:: prelude:: { DataType , MessageBody , NTBinaryMessage , NTMessage , NTTextMessage } ;
11
12
use std:: collections:: HashMap ;
12
13
use std:: ops:: DerefMut ;
14
+ use crate :: persist:: restore_persistent;
15
+ use proto:: prelude:: publish:: SetFlags ;
13
16
14
17
mod broadcast;
15
18
use broadcast:: * ;
16
19
17
20
mod loop_;
18
- use crate :: persist:: restore_persistent;
19
21
use loop_:: * ;
20
- use proto:: prelude:: publish:: SetFlags ;
22
+
23
+ mod tls;
24
+ use tls:: * ;
25
+ use async_tungstenite:: stream:: Stream ;
26
+ use std:: time:: Duration ;
21
27
22
28
pub static MAX_BATCHING_SIZE : usize = 5 ;
23
29
@@ -39,9 +45,46 @@ async fn tcp_loop(state: Arc<Mutex<NTServer>>, tx: Sender<ServerMessage>) -> any
39
45
let listener = TcpListener :: bind ( "0.0.0.0:5810" ) . await ?;
40
46
41
47
while let Ok ( ( sock, addr) ) = listener. accept ( ) . await {
42
- log:: info!( "TCP connection at {}" , addr) ;
48
+ log:: info!( "Unsecure TCP connection at {}" , addr) ;
49
+ let cid = rand:: random :: < u32 > ( ) ;
50
+ let sock = async_tungstenite:: accept_hdr_async ( Stream :: Plain ( sock) , |req : & Request , mut res : Response | {
51
+ let ws_proto = req. headers ( ) . iter ( ) . find ( |( hdr, _) | * * hdr == "Sec-WebSocket-Protocol" ) ;
52
+
53
+ match ws_proto. map ( |( _, s) | s. to_str ( ) . unwrap ( ) ) {
54
+ Some ( "networktables.first.wpi.edu" ) => {
55
+ res. headers_mut ( ) . insert ( "Sec-WebSocket-Protocol" , HeaderValue :: from_static ( "networktables.first.wpi.edu" ) ) ;
56
+ Ok ( res)
57
+ }
58
+ _ => {
59
+ log:: error!( "Rejecting client that did not specify correct subprotocol" ) ;
60
+ Err ( Response :: builder ( )
61
+ . status ( StatusCode :: BAD_REQUEST )
62
+ . body ( Some ( "Protocol 'networktables.first.wpi.edu' required to communicate with this server" . to_string ( ) ) )
63
+ . unwrap ( ) )
64
+ }
65
+ }
66
+ } ) . await ;
67
+
68
+ if let Ok ( sock) = sock {
69
+ log:: info!( "Client assigned CID {}" , cid) ;
70
+ let client = ConnectedClient :: new ( NTSocket :: new ( sock) , tx. clone ( ) , cid) ;
71
+ state. lock ( ) . await . clients . insert ( cid, client) ;
72
+ task:: spawn ( update_new_client ( cid, state. clone ( ) ) ) ;
73
+ }
74
+ }
75
+ Ok ( ( ) )
76
+ }
77
+
78
+ async fn tls_loop ( state : Arc < Mutex < NTServer > > , tx : Sender < ServerMessage > ) -> anyhow:: Result < ( ) > {
79
+ let listener = TcpListener :: bind ( "0.0.0.0:5811" ) . await ?;
80
+ let acceptor = generate_acceptor ( ) ;
81
+
82
+ while let Ok ( ( sock, addr) ) = listener. accept ( ) . await {
83
+ log:: info!( "Secure TCP connection at {}" , addr) ;
43
84
let cid = rand:: random :: < u32 > ( ) ;
44
- let sock = async_tungstenite:: accept_hdr_async ( sock, |req : & Request , mut res : Response | {
85
+ let sock = acceptor. accept ( sock) . await ?;
86
+ log:: info!( "TLS handshake completed" ) ;
87
+ let sock = async_tungstenite:: accept_hdr_async ( Stream :: Tls ( sock) , |req : & Request , mut res : Response | {
45
88
let ws_proto = req. headers ( ) . iter ( ) . find ( |( hdr, _) | * * hdr == "Sec-WebSocket-Protocol" ) ;
46
89
47
90
match ws_proto. map ( |( _, s) | s. to_str ( ) . unwrap ( ) ) {
@@ -121,7 +164,8 @@ impl NTServer {
121
164
122
165
let ( tx, rx) = channel ( 32 ) ;
123
166
124
- task:: spawn ( tcp_loop ( _self. clone ( ) , tx) ) ;
167
+ task:: spawn ( tcp_loop ( _self. clone ( ) , tx. clone ( ) ) ) ;
168
+ task:: spawn ( tls_loop ( _self. clone ( ) , tx) ) ;
125
169
task:: spawn ( channel_loop ( _self. clone ( ) , rx) ) ;
126
170
task:: spawn ( broadcast_loop ( _self. clone ( ) ) ) ;
127
171
task:: spawn ( flush_persistent_loop ( _self. clone ( ) ) ) ;
0 commit comments