|
3 | 3 | import pytest
|
4 | 4 | import torch
|
5 | 5 |
|
| 6 | +import transformers |
| 7 | + |
| 8 | +from petals import AutoDistributedModelForCausalLM |
6 | 9 | from petals import AutoDistributedConfig, RemoteSequential
|
7 | 10 | from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
|
8 | 11 | from petals.server.from_pretrained import load_pretrained_block
|
9 | 12 | from test_utils import *
|
10 | 13 |
|
11 | 14 |
|
| 15 | +@pytest.fixture |
| 16 | +def tokenizer(): |
| 17 | + # We set use_fast=False since LlamaTokenizerFast is slow on load |
| 18 | + return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) |
| 19 | + |
| 20 | + |
| 21 | +@pytest.fixture |
| 22 | +def model(): |
| 23 | + return AutoDistributedModelForCausalLM.from_pretrained( |
| 24 | + MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32 |
| 25 | + ) |
| 26 | + |
| 27 | +@pytest.fixture |
| 28 | +def model2(): |
| 29 | + return transformers.AutoModelForCausalLM.from_pretrained( |
| 30 | + REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32 |
| 31 | + ) |
| 32 | + |
| 33 | +@pytest.fixture |
| 34 | +def ref_model(): |
| 35 | + return transformers.AutoModelForCausalLM.from_pretrained( |
| 36 | + MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32 |
| 37 | + ) |
| 38 | + |
| 39 | +# @pytest.mark.forked |
| 40 | +# def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, atol_inference=1e-3): |
| 41 | +# config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) |
| 42 | +# remote_sequential = RemoteSequential(config) |
| 43 | + |
| 44 | +# block_index = random.randint(0, config.num_hidden_layers - 1) |
| 45 | +# remote_block = remote_sequential[block_index] |
| 46 | + |
| 47 | +# inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size) |
| 48 | +# short_inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size) |
| 49 | +# short_inputs[:, :2, :] = inputs[:, :2, :] |
| 50 | + |
| 51 | +# initial_outputs_inference = None |
| 52 | +# secondary_outputs_inference = None |
| 53 | +# with torch.inference_mode(): |
| 54 | +# with remote_block.inference_session(max_length=inputs.shape[1]) as sess: |
| 55 | +# initial_outputs_inference = sess.step(inputs) |
| 56 | +# secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2) |
| 57 | +# result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1) |
| 58 | + |
| 59 | +# ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) |
| 60 | +# (outputs_local,) = ref_block(short_inputs) |
| 61 | + |
| 62 | +# assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference) |
| 63 | + |
| 64 | +# @pytest.mark.forked |
| 65 | +# def test_speculative_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4): |
| 66 | +# inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] |
| 67 | + |
| 68 | +# options = dict(max_new_tokens=max_new_tokens, do_sample=False) |
| 69 | +# outputs = model.generate(inputs, **options) |
| 70 | +# print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@", outputs.shape, outputs) |
| 71 | +# ref_outputs = ref_model.generate(inputs, **options) |
| 72 | +# assert torch.allclose( |
| 73 | +# outputs, ref_outputs |
| 74 | +# ), f"Greedy generation is not identical to HF with {multiple_calls=}, {inputs.shape=}" |
| 75 | + |
12 | 76 | @pytest.mark.forked
|
13 |
| -def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, atol_inference=1e-3): |
14 |
| - config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS) |
15 |
| - remote_sequential = RemoteSequential(config) |
| 77 | +def test_speculative_greedy_generation(tokenizer, model, model2, ref_model, max_new_tokens=50, batch_size=10): |
| 78 | + inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"] |
| 79 | + generated_ids = inputs |
| 80 | + |
| 81 | + with torch.no_grad(): |
| 82 | + while generated_ids.shape[1] < max_new_tokens + inputs.shape[1]: |
| 83 | + outputs2 = model2.generate(generated_ids, max_new_tokens=batch_size, do_sample=False) |
| 84 | + new_tokens = outputs2[:, -batch_size:] |
| 85 | + |
| 86 | + random_pos = random.randrange(1, batch_size) |
| 87 | + new_tokens[:, random_pos] = random.randrange(1, 100) |
16 | 88 |
|
17 |
| - block_index = random.randint(0, config.num_hidden_layers - 1) |
18 |
| - remote_block = remote_sequential[block_index] |
| 89 | + combined_ids = torch.cat((generated_ids, new_tokens), dim=1) |
| 90 | + logits = model(combined_ids, start_from_position=1).logits |
19 | 91 |
|
20 |
| - inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size) |
21 |
| - short_inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size) |
22 |
| - short_inputs[:, :2, :] = inputs[:, :2, :] |
| 92 | + # Найти первую позицию, где токены совпали |
| 93 | + match_length = 0 |
| 94 | + for i in range(batch_size): |
| 95 | + top_predicted_id_model2 = new_tokens[:, i] |
| 96 | + top_predicted_id_model = torch.argmax(logits[:, generated_ids.shape[1] + i - 1, :], dim=-1) |
| 97 | + |
| 98 | + if top_predicted_id_model2 == top_predicted_id_model: |
| 99 | + match_length += 1 |
| 100 | + else: |
| 101 | + break |
| 102 | + print(f"Принято {match_length} из {batch_size}") |
23 | 103 |
|
24 |
| - initial_outputs_inference = None |
25 |
| - secondary_outputs_inference = None |
26 |
| - with torch.inference_mode(): |
27 |
| - with remote_block.inference_session(max_length=inputs.shape[1]) as sess: |
28 |
| - initial_outputs_inference = sess.step(inputs) |
29 |
| - secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2) |
30 |
| - result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1) |
| 104 | + if match_length > 0: |
| 105 | + generated_ids = torch.cat((generated_ids, new_tokens[:, :match_length]), dim=1) |
| 106 | + print(f"Всего {generated_ids.shape[1]}") |
| 107 | + else: |
| 108 | + break |
| 109 | + |
| 110 | + ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, do_sample=False) |
| 111 | + |
| 112 | + gen_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| 113 | + ref_text = tokenizer.decode(ref_outputs[0], skip_special_tokens=True) |
31 | 114 |
|
32 |
| - ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32) |
33 |
| - (outputs_local,) = ref_block(short_inputs) |
| 115 | + print(f"Generated by speculative decoding: {gen_text}") |
| 116 | + print(f"Reference generation: {ref_text}") |
34 | 117 |
|
35 |
| - assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference) |
| 118 | + assert gen_text == ref_text, "The outputs do not match!" |
0 commit comments