-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathbase_connection.py
158 lines (133 loc) · 5.59 KB
/
base_connection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import json
import socket
from typing import Optional
from .exceptions import IncompatibleVersionException, InternalServerException, TaskCanceledException
from .init_messages import client_init_messages, server_init_message
from ..commands import responses
class BaseConnection:
"""
Base class for connections that access the control server via the Duet API
using a UNIX socket
"""
def __init__(self, debug: bool = False, timeout: int = 3):
self.debug = debug
self.timeout = timeout
self.socket: Optional[socket.socket] = None
self.id = None
self.input = ""
def connect(self, init_message: client_init_messages.ClientInitMessage, socket_file: str, timeout: int = 0):
"""Establishes a connection to the given UNIX socket file"""
self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.socket.connect(socket_file)
if timeout == 0:
self.socket.setblocking(True)
else:
self.socket.setblocking(False)
self.socket.settimeout(timeout)
server_init_msg = server_init_message.ServerInitMessage.from_json(
json.loads(self.socket.recv(50).decode("utf8"))
)
if not server_init_msg.is_compatible():
raise IncompatibleVersionException(
f"Incompatible API version (need {server_init_msg.PROTOCOL_VERSION}, got {server_init_msg.version})"
)
self.id = server_init_msg.id
self.send(init_message)
response = self.receive_response()
if not response.success:
raise Exception(
f"Could not set connection type {init_message.mode} ({response.error_type}: {response.error_message})"
)
def close(self):
"""Closes the current connection and disposes it"""
if self.socket is not None:
self.socket.close()
self.socket = None
def perform_command(self, command, cls=None):
"""Perform an arbitrary command"""
self.send(command)
response = self.receive_response()
if response.success:
if cls is not None and response.result is not None:
response.result = cls.from_json(response.result)
return response
if response.error_type == "TaskCanceledException":
raise TaskCanceledException(response.error_message)
raise InternalServerException(
command, response.error_type, response.error_message
)
def send(self, msg):
"""Serialize an arbitrary object into JSON and send it to the server plus NL"""
json_string = json.dumps(msg, separators=(",", ":"), default=lambda o: o.__dict__)
if self.debug:
print(f"send: {json_string}")
self.socket.sendall(json_string.encode("utf8"))
def receive(self, cls):
"""Receive a deserialized object from the server"""
json_string = self.receive_json()
try:
return cls.from_json(json.loads(json_string))
except Exception as e:
return None
def receive_response(self):
"""Receive a base response from the server"""
json_string = self.receive_json()
return responses.decode_response(json.loads(json_string))
def receive_json(self) -> str:
"""Receive the JSON response from the server"""
if not self.socket:
raise RuntimeError("socket is closed or missing")
json_string = self.input
# There might be a full object waiting in the buffer
end_index = self.get_json_object_end_index(json_string)
if end_index > 1:
# Save whatever is left in the buffer
self.input = json_string[end_index:]
# Limit to the first full JSON object
json_string = json_string[:end_index]
else:
found = False
while not found:
# Refill the buffer and check again
BUFF_SIZE = 4096 # 4 KiB
data = b""
part = b""
while True:
try:
part = self.socket.recv(BUFF_SIZE)
data += part
except socket.timeout:
return None
except Exception as e:
raise e
# either 0 or end of data
if len(part) < BUFF_SIZE:
break
json_string += data.decode("utf8")
end_index = self.get_json_object_end_index(json_string)
if end_index > 1:
# Save whatever is left in the buffer
self.input = json_string[end_index:]
# Limit to the first full JSON object
json_string = json_string[:end_index]
found = True
if self.debug:
print("recv:", json_string)
return json_string
@staticmethod
def get_json_object_end_index(json_string: str):
"""Return the end index of the next full JSON object in the string"""
count = 0
index = 0
while index < len(json_string):
token = json_string[index]
if token == "{": # Found opening curly brace
count += 1
elif token == "}": # Found closing curly brace
count -= 1
if count < 0: # Unbalanced curly braces - incomplete input?
return -1
if count == 0: # Found a complete object
return index + 1
index += 1
return -1 # Nothing here