Skip to content

Commit 5e7cc0e

Browse files
committed
test
1 parent 9aecb3f commit 5e7cc0e

File tree

1 file changed

+101
-18
lines changed

1 file changed

+101
-18
lines changed

tests/test_speculative_generation.py

+101-18
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,116 @@
33
import pytest
44
import torch
55

6+
import transformers
7+
8+
from petals import AutoDistributedModelForCausalLM
69
from petals import AutoDistributedConfig, RemoteSequential
710
from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
811
from petals.server.from_pretrained import load_pretrained_block
912
from test_utils import *
1013

1114

15+
@pytest.fixture
16+
def tokenizer():
17+
# We set use_fast=False since LlamaTokenizerFast is slow on load
18+
return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
19+
20+
21+
@pytest.fixture
22+
def model():
23+
return AutoDistributedModelForCausalLM.from_pretrained(
24+
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
25+
)
26+
27+
@pytest.fixture
28+
def model2():
29+
return transformers.AutoModelForCausalLM.from_pretrained(
30+
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
31+
)
32+
33+
@pytest.fixture
34+
def ref_model():
35+
return transformers.AutoModelForCausalLM.from_pretrained(
36+
MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
37+
)
38+
39+
# @pytest.mark.forked
40+
# def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, atol_inference=1e-3):
41+
# config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
42+
# remote_sequential = RemoteSequential(config)
43+
44+
# block_index = random.randint(0, config.num_hidden_layers - 1)
45+
# remote_block = remote_sequential[block_index]
46+
47+
# inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
48+
# short_inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
49+
# short_inputs[:, :2, :] = inputs[:, :2, :]
50+
51+
# initial_outputs_inference = None
52+
# secondary_outputs_inference = None
53+
# with torch.inference_mode():
54+
# with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
55+
# initial_outputs_inference = sess.step(inputs)
56+
# secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
57+
# result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
58+
59+
# ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
60+
# (outputs_local,) = ref_block(short_inputs)
61+
62+
# assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)
63+
64+
# @pytest.mark.forked
65+
# def test_speculative_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4):
66+
# inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
67+
68+
# options = dict(max_new_tokens=max_new_tokens, do_sample=False)
69+
# outputs = model.generate(inputs, **options)
70+
# print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@", outputs.shape, outputs)
71+
# ref_outputs = ref_model.generate(inputs, **options)
72+
# assert torch.allclose(
73+
# outputs, ref_outputs
74+
# ), f"Greedy generation is not identical to HF with {multiple_calls=}, {inputs.shape=}"
75+
1276
@pytest.mark.forked
13-
def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, atol_inference=1e-3):
14-
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
15-
remote_sequential = RemoteSequential(config)
77+
def test_speculative_greedy_generation(tokenizer, model, model2, ref_model, max_new_tokens=50, batch_size=10):
78+
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
79+
generated_ids = inputs
80+
81+
with torch.no_grad():
82+
while generated_ids.shape[1] < max_new_tokens + inputs.shape[1]:
83+
outputs2 = model2.generate(generated_ids, max_new_tokens=batch_size, do_sample=False)
84+
new_tokens = outputs2[:, -batch_size:]
85+
86+
random_pos = random.randrange(1, batch_size)
87+
new_tokens[:, random_pos] = random.randrange(1, 100)
1688

17-
block_index = random.randint(0, config.num_hidden_layers - 1)
18-
remote_block = remote_sequential[block_index]
89+
combined_ids = torch.cat((generated_ids, new_tokens), dim=1)
90+
logits = model(combined_ids, start_from_position=1).logits
1991

20-
inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
21-
short_inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
22-
short_inputs[:, :2, :] = inputs[:, :2, :]
92+
# Найти первую позицию, где токены совпали
93+
match_length = 0
94+
for i in range(batch_size):
95+
top_predicted_id_model2 = new_tokens[:, i]
96+
top_predicted_id_model = torch.argmax(logits[:, generated_ids.shape[1] + i - 1, :], dim=-1)
97+
98+
if top_predicted_id_model2 == top_predicted_id_model:
99+
match_length += 1
100+
else:
101+
break
102+
print(f"Принято {match_length} из {batch_size}")
23103

24-
initial_outputs_inference = None
25-
secondary_outputs_inference = None
26-
with torch.inference_mode():
27-
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
28-
initial_outputs_inference = sess.step(inputs)
29-
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
30-
result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
104+
if match_length > 0:
105+
generated_ids = torch.cat((generated_ids, new_tokens[:, :match_length]), dim=1)
106+
print(f"Всего {generated_ids.shape[1]}")
107+
else:
108+
break
109+
110+
ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, do_sample=False)
111+
112+
gen_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
113+
ref_text = tokenizer.decode(ref_outputs[0], skip_special_tokens=True)
31114

32-
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
33-
(outputs_local,) = ref_block(short_inputs)
115+
print(f"Generated by speculative decoding: {gen_text}")
116+
print(f"Reference generation: {ref_text}")
34117

35-
assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)
118+
assert gen_text == ref_text, "The outputs do not match!"

0 commit comments

Comments
 (0)