Skip to content

Commit

Permalink
latest
Browse files Browse the repository at this point in the history
  • Loading branch information
jambayk committed Feb 18, 2025
1 parent 0d9b7cc commit 2d6c026
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions examples/utils/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ def generate(

generated_tokens = inputs["input_ids"].numpy() if use_io_binding else inputs["input_ids"].copy()
batch_size, prompt_length = generated_tokens.shape
valid_prompt_len = int(inputs["attention_mask_2d"].sum(axis=-1).max())
has_eos = np.zeros(batch_size, dtype=bool)

# buffers to keep numpy copy of model inputs, don't want to keep going back and forth between OrtValue and numpy
Expand All @@ -249,6 +248,7 @@ def generate(
inputs.pop("attention_mask_2d").numpy() if use_io_binding else inputs.pop("attention_mask_2d")
)
}
valid_prompt_len = int(np_buffers["attention_mask"].sum(axis=-1).max())
if self.use_position_ids:
np_buffers["position_ids"] = (
np_buffers["attention_mask"]
Expand All @@ -267,6 +267,8 @@ def generate(
session = self.sessions["iterator"]
io_binding = session.io_binding() if use_io_binding else None

print(np_buffers["attention_mask"])

if use_io_binding:
if idx < 2:
# need to bind logits twice, once for prompt processing and once for token generation
Expand Down Expand Up @@ -383,6 +385,13 @@ def generate(
# GQA, or static during token generation
inputs["attention_mask"].update_inplace(attention_mask)

first_cache = outputs[1]
if use_io_binding:
first_cache = first_cache.numpy()
print(first_cache[:, 0, :, 0])
if idx == 2:
sdcd

# update cache
cache.update(outputs[1:])
if use_io_binding:
Expand Down Expand Up @@ -436,6 +445,8 @@ def get_initial_inputs(
attention_mask = encodings_dict["attention_mask"]
if not self.extended_attention_mask:
attention_mask.astype(self.input_info["attention_mask"]["dtype"])
# print(input_ids)
# print(attention_mask)

cache = self.get_fresh_cache(
batch_size,
Expand All @@ -450,7 +461,10 @@ def get_initial_inputs(
[attention_mask, np.zeros((batch_size, cache.max_cache_len - prompt_length), dtype=np.int32)], 1
)

inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "attention_mask_2d": attention_mask}
inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "attention_mask_2d": attention_mask.copy()}
if isinstance(cache, GQASharedCache) and self.prompt_len:
# prompt processing needs to attend to the whole prompt+padding
inputs["attention_mask"][:, :prompt_length] = 1
if self.extended_attention_mask:
replace_with_extended_mask(inputs, "causal", -1000)
inputs["attention_mask"] = inputs["attention_mask"].astype(self.input_info["attention_mask"]["dtype"])
Expand Down

0 comments on commit 2d6c026

Please sign in to comment.