diff --git a/examples/utils/generator.py b/examples/utils/generator.py index 439af507d..96e34fa78 100644 --- a/examples/utils/generator.py +++ b/examples/utils/generator.py @@ -240,7 +240,6 @@ def generate( generated_tokens = inputs["input_ids"].numpy() if use_io_binding else inputs["input_ids"].copy() batch_size, prompt_length = generated_tokens.shape - valid_prompt_len = int(inputs["attention_mask_2d"].sum(axis=-1).max()) has_eos = np.zeros(batch_size, dtype=bool) # buffers to keep numpy copy of model inputs, don't want to keep going back and forth between OrtValue and numpy @@ -249,6 +248,7 @@ def generate( inputs.pop("attention_mask_2d").numpy() if use_io_binding else inputs.pop("attention_mask_2d") ) } + valid_prompt_len = int(np_buffers["attention_mask"].sum(axis=-1).max()) if self.use_position_ids: np_buffers["position_ids"] = ( np_buffers["attention_mask"] @@ -267,6 +267,8 @@ def generate( session = self.sessions["iterator"] io_binding = session.io_binding() if use_io_binding else None + print(np_buffers["attention_mask"]) + if use_io_binding: if idx < 2: # need to bind logits twice, once for prompt processing and once for token generation @@ -383,6 +385,13 @@ def generate( # GQA, or static during token generation inputs["attention_mask"].update_inplace(attention_mask) + first_cache = outputs[1] + if use_io_binding: + first_cache = first_cache.numpy() + print(first_cache[:, 0, :, 0]) + if idx == 2: + sdcd + # update cache cache.update(outputs[1:]) if use_io_binding: @@ -436,6 +445,8 @@ def get_initial_inputs( attention_mask = encodings_dict["attention_mask"] if not self.extended_attention_mask: attention_mask.astype(self.input_info["attention_mask"]["dtype"]) + # print(input_ids) + # print(attention_mask) cache = self.get_fresh_cache( batch_size, @@ -450,7 +461,10 @@ def get_initial_inputs( [attention_mask, np.zeros((batch_size, cache.max_cache_len - prompt_length), dtype=np.int32)], 1 ) - inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "attention_mask_2d": attention_mask} + inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "attention_mask_2d": attention_mask.copy()} + if isinstance(cache, GQASharedCache) and self.prompt_len: + # prompt processing needs to attend to the whole prompt+padding + inputs["attention_mask"][:, :prompt_length] = 1 if self.extended_attention_mask: replace_with_extended_mask(inputs, "causal", -1000) inputs["attention_mask"] = inputs["attention_mask"].astype(self.input_info["attention_mask"]["dtype"])