@@ -84,8 +84,13 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
84
84
break # this message means "done sending"
85
85
86
86
def step (
87
- self , inputs : torch .Tensor , prompts : torch .Tensor , hypo_ids : torch .LongTensor , * ,
88
- step_id : str , start_from_position : int
87
+ self ,
88
+ inputs : torch .Tensor ,
89
+ prompts : torch .Tensor ,
90
+ hypo_ids : torch .LongTensor ,
91
+ * ,
92
+ step_id : str ,
93
+ start_from_position : int ,
89
94
) -> torch .Tensor :
90
95
"""
91
96
Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -266,8 +271,11 @@ def __enter__(self) -> "InferenceSession":
266
271
return self
267
272
268
273
def step (
269
- self , inputs : torch .Tensor , prompts : Optional [torch .Tensor ] = None ,
270
- hypo_ids : Optional [torch .Tensor ] = None , start_from_position : Optional [int ] = None
274
+ self ,
275
+ inputs : torch .Tensor ,
276
+ prompts : Optional [torch .Tensor ] = None ,
277
+ hypo_ids : Optional [torch .Tensor ] = None ,
278
+ start_from_position : Optional [int ] = None ,
271
279
) -> torch .Tensor :
272
280
273
281
if start_from_position is not None :
@@ -317,8 +325,11 @@ def step(
317
325
318
326
server_session = self ._server_sessions [server_idx ]
319
327
inputs = server_session .step (
320
- inputs , prompts [server_session .span .start : server_session .span .end ], hypo_ids ,
321
- step_id = step_id , start_from_position = start_from_position
328
+ inputs ,
329
+ prompts [server_session .span .start : server_session .span .end ],
330
+ hypo_ids ,
331
+ step_id = step_id ,
332
+ start_from_position = start_from_position ,
322
333
)
323
334
324
335
server_idx += 1
0 commit comments