Skip to content

Commit 316a360

Browse files
committed
make it run
1 parent 2a87eba commit 316a360

File tree

3 files changed

+120
-14
lines changed

3 files changed

+120
-14
lines changed

Diff for: README.md

+9-5
Original file line numberDiff line numberDiff line change
@@ -71,20 +71,24 @@ toolformer = Toolformer(
7171
model_seq_len = 256,
7272
teach_tool_prompt = prompt,
7373
tool_id = 'Calendar',
74-
tool = Calendar
74+
tool = Calendar,
75+
finetune = True
7576
)
7677

7778
# invoking this will
7879
# (1) prompt the model with your inputs (data), inserted into [input] tag
7980
# (2) with the sampled outputs, filter out the ones that made proper API calls
8081
# (3) execute the API calls with the `tool` given
8182
# (4) filter with the specialized filter function (which can be used independently as shown in the next section)
83+
# (5) fine-tune on the filtered results
8284

83-
filtered_results = toolformer(data)
85+
toolformer(data)
8486

85-
# then finetune with token ids at
86-
# -> filtered_results.filtered_tokens_without_api_response
87-
# (5) complete this with toolformer.finetune(filtered_results) - and return all statistics
87+
# then, once you see the 'finetune complete' message
88+
89+
response = toolformer.sample_model_with_api_calls("How many days until the next new years?")
90+
91+
# hopefully you see it invoke the calendar and utilize the response of the api call...
8892

8993
```
9094

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'toolformer-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.22',
6+
version = '0.0.24',
77
license='MIT',
88
description = 'Toolformer - Pytorch',
99
author = 'Phil Wang',

Diff for: toolformer_pytorch/toolformer_pytorch.py

+110-8
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ def sample(
238238
select_api_start_id_top_k = 10,
239239
):
240240
device = next(model.parameters()).device
241-
positions = positions.clone()
242241
max_seq_len = seq_len + 1
243242

244243
# validate
@@ -258,7 +257,11 @@ def sample(
258257

259258
# sampling positions - different sequences have different cursors
260259

261-
positions = default(positions, torch.zeros((batch_size,), device = device, dtype = torch.long))
260+
if exists(positions):
261+
positions = positions.clone()
262+
else:
263+
positions = torch.zeros((batch_size,), device = device, dtype = torch.long)
264+
262265
assert (positions <= (prime_length + 1)).all() and (positions <= max_seq_len).all(), 'all positions must be less then initial prime length as well as the total sequence length + 1 (plus one for noop if one sequence finished sampling before the other)'
263266

264267
# eval model
@@ -516,7 +519,7 @@ def loss_fn(weight, probs):
516519
selected_indices = indices[selected_mask]
517520

518521
ret = FilteredResults(
519-
selected_mask.sum().item()
522+
selected_mask.sum().item(),
520523
(~selected_mask).sum().item(),
521524
selected_indices,
522525
selected_mask,
@@ -563,6 +566,22 @@ def PromptDataloader(ds: Dataset, *args, padding_value = 0, **kwargs):
563566
collate_fn = partial(prompt_collate_fn, padding_value = padding_value)
564567
return DataLoader(ds, *args, collate_fn = collate_fn, **kwargs)
565568

569+
class FinetuneDataset(Dataset):
570+
def __init__(
571+
self,
572+
tokens: torch.Tensor
573+
):
574+
self.tokens = tokens
575+
576+
def __len__(self):
577+
return len(self.tokens)
578+
579+
def __getitem__(self, idx):
580+
return self.tokens[idx]
581+
582+
def FinetuneDataloader(ds: Dataset, *args, padding_value = 0, **kwargs):
583+
return DataLoader(ds, *args, collate_fn = partial(pad_sequence, padding_value = padding_value), **kwargs)
584+
566585
# classes
567586

568587
@beartype
@@ -585,8 +604,16 @@ def __init__(
585604
model_seq_len = 2048,
586605
tokenizer_encode: Callable = tokenizer.encode,
587606
tokenizer_decode: Callable = tokenizer.decode,
607+
post_prompt_callback: Callable = identity,
588608
prompt_input_tag: str = DEFAULT_PROMPT_INPUT_TAG,
589-
exclude_filters: dict[str, Callable[[str], bool]] = dict()
609+
exclude_filters: dict[str, Callable[[str], bool]] = dict(),
610+
finetune = False,
611+
finetune_lr = 1e-4,
612+
finetune_wd = 1e-2,
613+
finetune_betas = (0.9, 0.99),
614+
finetune_eps = 1e-8,
615+
finetune_epochs = 3,
616+
finetune_batch_size = 16
590617
):
591618
super().__init__()
592619
self.model = model
@@ -596,6 +623,8 @@ def __init__(
596623
self.prompt_batch_size = prompt_batch_size
597624
self.prompt_input_tag = prompt_input_tag
598625

626+
self.post_prompt_callback = post_prompt_callback # for easy mocking
627+
599628
self.tokenizer_encode = tokenizer_encode
600629
self.tokenizer_decode = tokenizer_decode
601630
self.tokenizer_encode_to_tensor = lambda s: torch.tensor(tokenizer_encode(s)).long()
@@ -631,6 +660,22 @@ def __init__(
631660
self.teach_tool_prompt = teach_tool_prompt
632661
self.exclude_filters = exclude_filters
633662

663+
self.should_finetune = finetune
664+
665+
if not finetune:
666+
return
667+
668+
self.finetune_batch_size = finetune_batch_size
669+
self.finetune_epochs = finetune_epochs
670+
671+
self.optimizer = get_optimizer(
672+
model.parameters(),
673+
lr = finetune_lr,
674+
wd = finetune_wd,
675+
betas = finetune_betas,
676+
eps = finetune_eps
677+
)
678+
634679
def generate_data_with_api_calls(
635680
self,
636681
data: List[str],
@@ -706,22 +751,46 @@ def filter_and_keep_only_first_api_call(
706751

707752
return included, excluded
708753

754+
@torch.no_grad()
709755
def sample_model_with_api_calls(
710756
self,
711-
prime: torch.Tensor,
757+
prime: Union[torch.Tensor, str],
712758
occurrence = 1,
713759
**kwargs
714760
):
761+
self.model.eval()
762+
763+
prime_is_str = isinstance(prime, str)
764+
765+
if prime_is_str:
766+
prime = self.tokenizer_encode(prime)
767+
prime = torch.tensor(prime).long()
768+
prime = rearrange(prime, 'n -> 1 n')
769+
770+
assert prime.shape[0] == 1, 'only one at a time for now'
771+
772+
invoke_tools_ = partial(invoke_tools, self.registry)
773+
774+
def call_apis(t: torch.Tensor):
775+
t = self.tokenizer_decode(t[0])
776+
t = invoke_tools_(t)
777+
t = self.tokenizer_encode_to_tensor(t)
778+
return rearrange(t, 'n -> 1 n')
779+
715780
output = sample_with_api_call(
716781
model = self.model,
782+
prime = prime,
717783
seq_len = self.model_seq_len,
718-
call_apis = partial(invoke_tools, self.registry),
784+
call_apis = call_apis,
719785
api_end_token_id = self.api_stop_id,
720786
occurrence = occurrence,
721787
**kwargs
722788
)
723789

724-
return output
790+
if not prime_is_str:
791+
return output
792+
793+
return self.tokenizer_decode(output[0])
725794

726795
def make_api_calls(
727796
self,
@@ -764,17 +833,50 @@ def filter_by_api_responses(
764833

765834
return filtered_results
766835

836+
def finetune(
837+
self,
838+
filtered_results: Union[FilteredResults, torch.Tensor]
839+
):
840+
self.model.train()
841+
842+
if isinstance(filtered_results, FilteredResults):
843+
filtered_results = filtered_results.filtered_tokens_without_api_response
844+
845+
dataset = FinetuneDataset(tokens = filtered_results)
846+
dl = FinetuneDataloader(dataset, batch_size = self.finetune_batch_size, shuffle = True)
847+
848+
for epoch in tqdm(range(self.finetune_epochs), desc = 'finetune epochs'):
849+
for batch in dl:
850+
inp, labels = batch[:, :-1], batch[:, 1:]
851+
852+
logits = self.model(inp)
853+
logits = rearrange(logits, 'b n c -> b c n')
854+
855+
loss = F.cross_entropy(logits, labels, ignore_index = self.pad_id)
856+
loss.backward()
857+
858+
print(f'loss: {loss.item()}')
859+
self.optimizer.step()
860+
self.optimizer.zero_grad()
861+
862+
print(f'finished finetuning on {len(dataset)} filtered samples')
863+
767864
def forward(
768865
self,
769866
data: List[str]
770867
):
771868
data_with_api_calls = self.generate_data_with_api_calls(data)
772869

870+
data_with_api_calls = self.post_prompt_callback(data_with_api_calls)
871+
773872
filtered_data, filtered_data_with_api_calls = self.filter_and_keep_only_first_api_call(data, data_with_api_calls)
774873

775874
assert len(filtered_data_with_api_calls) > 0, 'your model failed to follow instructions and make API calls. please try a better model or do some better prompt engineering'
776875

777876
data_with_responses = self.make_api_calls(filtered_data_with_api_calls)
778877
filtered_results = self.filter_by_api_responses(filtered_data, filtered_data_with_api_calls, data_with_responses)
779878

780-
return filtered_results
879+
if not self.should_finetune:
880+
return filtered_results
881+
882+
self.finetune(filtered_results)

0 commit comments

Comments
 (0)