Skip to content

Commit 9aecb3f

Browse files
committed
style
1 parent 269028d commit 9aecb3f

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

src/petals/client/inference_session.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,13 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
8484
break # this message means "done sending"
8585

8686
def step(
87-
self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *,
88-
step_id: str, start_from_position: int
87+
self,
88+
inputs: torch.Tensor,
89+
prompts: torch.Tensor,
90+
hypo_ids: torch.LongTensor,
91+
*,
92+
step_id: str,
93+
start_from_position: int,
8994
) -> torch.Tensor:
9095
"""
9196
Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -266,8 +271,11 @@ def __enter__(self) -> "InferenceSession":
266271
return self
267272

268273
def step(
269-
self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None,
270-
hypo_ids: Optional[torch.Tensor] = None, start_from_position: Optional[int] = None
274+
self,
275+
inputs: torch.Tensor,
276+
prompts: Optional[torch.Tensor] = None,
277+
hypo_ids: Optional[torch.Tensor] = None,
278+
start_from_position: Optional[int] = None,
271279
) -> torch.Tensor:
272280

273281
if start_from_position is not None:
@@ -317,8 +325,11 @@ def step(
317325

318326
server_session = self._server_sessions[server_idx]
319327
inputs = server_session.step(
320-
inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids,
321-
step_id=step_id, start_from_position=start_from_position
328+
inputs,
329+
prompts[server_session.span.start : server_session.span.end],
330+
hypo_ids,
331+
step_id=step_id,
332+
start_from_position=start_from_position,
322333
)
323334

324335
server_idx += 1

src/petals/server/block_functions.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ async def iterate_rpc_inference(
162162
async for request, step_metadata in input_iterator:
163163
if "start_from_position" in step_metadata:
164164
start_from_position = step_metadata["start_from_position"]
165-
assert prefix_length >= start_from_position, f"prefix_length={prefix_length}, start_from_position={start_from_position}"
165+
assert (
166+
prefix_length >= start_from_position,
167+
), f"prefix_length={prefix_length}, start_from_position={start_from_position}"
166168
prefix_length = start_from_position
167169

168170
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)

0 commit comments

Comments
 (0)