Skip to content

Commit 723fde7

Browse files
committed
just need filter stats and fine-tuning logic
1 parent 4c2c8f0 commit 723fde7

File tree

3 files changed

+121
-7
lines changed

3 files changed

+121
-7
lines changed

README.md

+15-2
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,20 @@ toolformer = Toolformer(
7676

7777
data_with_api_calls = toolformer.generate_data_with_api_calls(data)
7878

79-
# complete the filtering and fine tuning step
79+
filtered_data, filtered_data_with_api_calls = toolformer.filter_and_keep_only_first_api_call(data, data_with_api_calls)
80+
81+
data_with_api_responses = toolformer.make_api_calls(filtered_data_with_api_calls)
82+
83+
filtered_results = toolformer.filter_by_api_responses(
84+
filtered_data,
85+
filtered_data_with_api_calls,
86+
data_with_api_responses
87+
)
88+
89+
# then finetune with token ids at
90+
# -> filtered_results.filtered_tokens_without_api_response
91+
# complete this with toolformer.finetune(filtered_results)
92+
8093
```
8194

8295
The main novelty of the paper is defining a fitness score for the outputs from a transformer instructed to insert API calls. The score is used to filter the sampled outputs for finetuning the transformer to make API calls that decreases perplexity of the text that follows it.
@@ -162,7 +175,7 @@ invoke_tools(function_registry, text)
162175
- [ ] Toolformer should eventually calculate all statistics (how many properly sampled, filtered out by different criterias, the distribution of scores as well as how many were rejected) before the final fine-tuning
163176
- [ ] do end-to-end training in `Toolformer`
164177
- [x] doing the prompting and bootstrapping the data
165-
- [ ] prefiltering of bootstrapped data followed by api calls and then another round of filtering
178+
- [x] prefiltering of bootstrapped data followed by api calls and then another round of filtering
166179
- [ ] keep track of all stats
167180
- [ ] take care of fine-tuning, with the interleaving of datasets + optimizer hyperparams
168181
- [ ] hook up gpt-j

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'toolformer-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.19',
6+
version = '0.0.20',
77
license='MIT',
88
description = 'Toolformer - Pytorch',
99
author = 'Phil Wang',

toolformer_pytorch/toolformer_pytorch.py

+105-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from toolformer_pytorch.prompts import DEFAULT_PROMPT_INPUT_TAG
1717

1818
from beartype import beartype
19-
from beartype.typing import Callable, Optional, Union, List
19+
from beartype.typing import Callable, Optional, Union, List, Tuple
2020

2121
from tqdm import tqdm
2222
from x_clip.tokenizer import tokenizer
@@ -142,7 +142,7 @@ def replace_fn(
142142

143143
# return original text with the output delimiter and the stringified output
144144

145-
return f'{text_without_end_api_token} {delimiter} {str(out)}{end_api_token}'
145+
return f'{text_without_end_api_token} {delimiter} {str(out)} {end_api_token}'
146146

147147
# main function, which takes a registry of functions, the text in question, and makes all the appropriate api calls and append the output
148148

@@ -451,11 +451,16 @@ def filter_tokens_with_api_response(
451451
assert all_contains_id(tokens_with_api_response, api_start_token_id)
452452
assert all_contains_id(tokens_with_api_response, api_end_token_id)
453453

454+
# auto set devices
455+
456+
device = next(model.parameters()).device
457+
tokens, tokens_without_api_response, tokens_with_api_response = map(lambda t: t.to(device), (tokens, tokens_without_api_response, tokens_with_api_response))
458+
454459
# get all the logits
455460

456461
with torch.no_grad():
457462
model.eval()
458-
logits, logits_without_api_response, logits_with_api_response = map(model, (tokens, tokens_with_api_response, tokens_without_api_response))
463+
logits, logits_without_api_response, logits_with_api_response = map(model, (tokens, tokens_without_api_response, tokens_with_api_response))
459464

460465
# derive all predicted prob of the actual next token id in sequence
461466

@@ -472,8 +477,10 @@ def filter_tokens_with_api_response(
472477

473478
# deriving the weighting for the original passage is more tricky
474479
# would need to start counting up from <api> start token location
480+
# this would also assume that the language model perfectly copied the passage over and that both token ids are aligned except for the inserted API call - but this can be done with the custom filtering functions eventually
475481

476482
weight = weight_and_mask_fn(tokens_without_api_response[:, 1:], api_start_token_id) # shift to the left by one since <api> does not exist in the original sequence
483+
weight = weight[:, :probs.shape[-1]]
477484

478485
# get the loss L for all three types of sequences
479486

@@ -561,9 +568,11 @@ def __init__(
561568
tool: Callable,
562569
api_start_str = ' [',
563570
api_stop_str = ']',
571+
api_response_delimiter = '→',
564572
api_start_id = None,
565573
api_stop_id = None,
566574
teach_tool_prompt: str,
575+
filter_threshold = 1.,
567576
pad_id = 0,
568577
prompt_batch_size = 4,
569578
model_seq_len = 2048,
@@ -582,12 +591,28 @@ def __init__(
582591

583592
self.tokenizer_encode = tokenizer_encode
584593
self.tokenizer_decode = tokenizer_decode
594+
self.tokenizer_encode_to_tensor = lambda s: torch.tensor(tokenizer_encode(s)).long()
595+
596+
self.filter_threshold = filter_threshold
585597

586598
self.api_start_str = api_start_str
587599
self.api_stop_str = api_stop_str
600+
self.api_response_delimiter = api_response_delimiter
601+
602+
if not exists(api_start_id):
603+
api_start_id = tokenizer_encode(api_start_str)
604+
assert len(api_start_id) == 1
605+
api_start_id = api_start_id[0]
588606

589607
self.api_start_id = api_start_id
608+
609+
if not exists(api_stop_id):
610+
api_stop_id = tokenizer_encode(api_stop_str)
611+
assert len(api_stop_id) == 1
612+
api_stop_id = api_stop_id[0]
613+
590614
self.api_stop_id = api_stop_id
615+
591616
self.pad_id = pad_id
592617

593618
self.tool_id = tool_id
@@ -638,9 +663,85 @@ def generate_data_with_api_calls(
638663

639664
return prompted_outputs
640665

666+
def filter_and_keep_only_first_api_call(
667+
self,
668+
data,
669+
data_with_api_calls: List[str],
670+
return_excluded = False
671+
):
672+
included = []
673+
excluded = []
674+
675+
api_start_stop_kwargs = dict(api_start = self.api_start_str, api_stop = self.api_stop_str)
676+
677+
has_api_calls_ = partial(has_api_calls, **api_start_stop_kwargs)
678+
replace_all_but_first_ = partial(replace_all_but_first, **api_start_stop_kwargs)
679+
680+
for datum, data_with_api_call in zip(data, data_with_api_calls):
681+
if has_api_calls_(data_with_api_call):
682+
data_with_api_call = replace_all_but_first_(data_with_api_call)
683+
included.append((datum, data_with_api_call))
684+
else:
685+
excluded.append((datum, data_with_api_call))
686+
687+
included = tuple(map(list, zip(*included)))
688+
689+
if not return_excluded:
690+
return included
691+
692+
excluded = tuple(map(list, zip(*excluded)))
693+
return included, excluded
694+
695+
def make_api_calls(
696+
self,
697+
filtered_data_with_api_calls: List[str]
698+
):
699+
invoke_tools_ = partial(
700+
invoke_tools,
701+
self.registry,
702+
api_start = self.api_start_str,
703+
api_stop = self.api_stop_str, delimiter = self.api_response_delimiter
704+
)
705+
706+
data_with_api_responses = []
707+
for data in filtered_data_with_api_calls:
708+
output = invoke_tools_(data)
709+
data_with_api_responses.append(output)
710+
711+
return data_with_api_responses
712+
713+
def filter_by_api_responses(
714+
self,
715+
data: List[str],
716+
data_with_api_calls: List[str],
717+
data_with_api_responses: List[str]
718+
) -> FilteredResults:
719+
720+
to_token_ids = lambda l: pad_sequence([*map(self.tokenizer_encode_to_tensor, l)], batch_first = True, padding_value = self.pad_id)
721+
722+
tokens, tokens_without_api_response, tokens_with_api_response = map(to_token_ids, (data, data_with_api_calls, data_with_api_responses))
723+
724+
filtered_results = filter_tokens_with_api_response(
725+
model = self.model,
726+
tokens = tokens,
727+
tokens_with_api_response = tokens_with_api_response,
728+
tokens_without_api_response = tokens_without_api_response,
729+
filter_threshold = self.filter_threshold,
730+
api_start_token_id = self.api_start_id,
731+
api_end_token_id = self.api_stop_id
732+
)
733+
734+
return filtered_results
735+
641736
def forward(
642737
self,
643738
data: List[str]
644739
):
645740
data_with_api_calls = self.generate_data_with_api_calls(data)
646-
return data_with_api_calls
741+
filtered_data_with_api_calls = self.filter_and_keep_only_first_api_call(data_with_api_calls)
742+
743+
assert len(filtered_data_with_api_calls) > 0, 'your model failed to follow instructions and make API calls. please try a better model or do some better prompt engineering'
744+
745+
data_with_responses = self.make_api_calls(filtered_data_with_api_calls)
746+
747+
return data_with_responses

0 commit comments

Comments
 (0)