diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 5472d68a..6151385e 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -110,6 +110,12 @@ def step( if self.closed: raise Exception("Session is closed, cannot perform step") + if start_from_position is not None: + assert start_from_position <= self._position + self._position = start_from_position + if self.history is not None and self.history.shape[1] >= start_from_position: + self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None + n_input_tokens = inputs.shape[1] if self.history is None: self.history = inputs @@ -330,7 +336,9 @@ def step( self._update_sequence(server_idx, block_idx, attempt_no) server_session = self._server_sessions[server_idx] - assert server_session.position == self.position, f"{server_session.position} and {self.position}" + assert ( + server_session.position == self.position + ), f"Position mismatch: {server_session.position} and {self.position}" inputs = server_session.step( inputs, prompts[server_session.span.start : server_session.span.end],