Skip to content

Commit 4c2c8f0

Browse files
committed
take the first steps towards an end-to-end solution
1 parent 93c5431 commit 4c2c8f0

File tree

4 files changed

+193
-11
lines changed

4 files changed

+193
-11
lines changed

Diff for: README.md

+64-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,65 @@ $ pip install toolformer-pytorch
2020

2121
## Usage
2222

23+
Example usage with giving language models awareness of current date and time.
24+
25+
```python
26+
import torch
27+
from toolformer_pytorch import Toolformer, PaLM
28+
29+
# simple calendar api call - function that returns a string
30+
31+
def Calendar():
32+
import datetime
33+
from calendar import day_name, month_name
34+
now = datetime.datetime.now()
35+
return f'Today is {day_name[now.weekday()]}, {month_name[now.month]} {now.day}, {now.year}.'
36+
37+
# prompt for teaching it to use the Calendar function from above
38+
39+
prompt = f"""
40+
Your task is to add calls to a Calendar API to a piece of text.
41+
The API calls should help you get information required to complete the text.
42+
You can call the API by writing "[Calendar()]"
43+
Here are some examples of API calls:
44+
Input: Today is the first Friday of the year.
45+
Output: Today is the first [Calendar()] Friday of the year.
46+
Input: The president of the United States is Joe Biden.
47+
Output: The president of the United States is [Calendar()] Joe Biden.
48+
Input: [input]
49+
Output:
50+
"""
51+
52+
data = [
53+
"The store is never open on the weekend, so today it is closed.",
54+
"The number of days from now until Christmas is 30",
55+
"The current day of the week is Wednesday."
56+
]
57+
58+
# model - here using PaLM, but any nn.Module that returns logits in the shape (batch, seq, num_tokens) is fine
59+
60+
model = PaLM(
61+
dim = 512,
62+
depth = 2,
63+
heads = 8,
64+
dim_head = 64
65+
).cuda()
66+
67+
# toolformer
68+
69+
toolformer = Toolformer(
70+
model = model,
71+
model_seq_len = 256,
72+
teach_tool_prompt = prompt,
73+
tool_id = 'Calendar',
74+
tool = Calendar
75+
)
76+
77+
data_with_api_calls = toolformer.generate_data_with_api_calls(data)
78+
79+
# complete the filtering and fine tuning step
80+
```
81+
2382
The main novelty of the paper is defining a fitness score for the outputs from a transformer instructed to insert API calls. The score is used to filter the sampled outputs for finetuning the transformer to make API calls that decreases perplexity of the text that follows it.
2483

2584
```python
@@ -98,10 +157,14 @@ invoke_tools(function_registry, text)
98157

99158
- [x] create custom generate function for palm that can do external API calls
100159
- [x] allow for generating tokens at different cursor indices
160+
- [x] api token (which was left and right brackets in paper) needs to be customizable
101161
- [ ] allow for customizing how to fine handling errors in function name, parameters, or execution and output
102-
- [ ] api token (which was left and right brackets in paper) needs to be customizable
103162
- [ ] Toolformer should eventually calculate all statistics (how many properly sampled, filtered out by different criterias, the distribution of scores as well as how many were rejected) before the final fine-tuning
104163
- [ ] do end-to-end training in `Toolformer`
164+
- [x] doing the prompting and bootstrapping the data
165+
- [ ] prefiltering of bootstrapped data followed by api calls and then another round of filtering
166+
- [ ] keep track of all stats
167+
- [ ] take care of fine-tuning, with the interleaving of datasets + optimizer hyperparams
105168
- [ ] hook up gpt-j
106169
- [ ] test for a simple calculator eval dataset
107170

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'toolformer-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.17',
6+
version = '0.0.19',
77
license='MIT',
88
description = 'Toolformer - Pytorch',
99
author = 'Phil Wang',

Diff for: toolformer_pytorch/palm.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from torch import nn, einsum
33
from einops import rearrange
44

5+
from x_clip.tokenizer import tokenizer
6+
57
# helpers
68

79
def exists(val):
@@ -162,7 +164,7 @@ def __init__(
162164
depth,
163165
heads,
164166
dim_head,
165-
ff_mult=4,
167+
ff_mult = 4,
166168
):
167169
super().__init__()
168170
self.layers = nn.ModuleList([])
@@ -184,8 +186,8 @@ class PaLM(nn.Module):
184186
def __init__(
185187
self,
186188
dim,
187-
num_tokens,
188189
depth,
190+
num_tokens=tokenizer.vocab_size,
189191
dim_head=64,
190192
heads=8,
191193
ff_mult=4,

Diff for: toolformer_pytorch/toolformer_pytorch.py

+124-7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch
77
import torch.nn.functional as F
88
from torch import nn, einsum
9+
from torch.utils.data import Dataset, DataLoader
10+
from torch.nn.utils.rnn import pad_sequence
911

1012
from einops import rearrange, reduce
1113

@@ -17,6 +19,7 @@
1719
from beartype.typing import Callable, Optional, Union, List
1820

1921
from tqdm import tqdm
22+
from x_clip.tokenizer import tokenizer
2023

2124
# helpers
2225

@@ -26,6 +29,23 @@ def exists(val):
2629
def default(val, d):
2730
return val if exists(val) else d
2831

32+
def identity(t):
33+
return t
34+
35+
def always(val):
36+
def inner(*args, **kwargs):
37+
return val
38+
return inner
39+
40+
def try_except(fn, callback = identity):
41+
@wraps(fn)
42+
def inner(*args):
43+
try:
44+
return fn(*args)
45+
except Exception as e:
46+
return callback(e)
47+
return inner
48+
2949
# tensor helpers
3050

3151
def log(t, eps = 1e-20):
@@ -113,10 +133,7 @@ def replace_fn(
113133

114134
# just return original text if there is some error with the function
115135

116-
try:
117-
out = fn(*params)
118-
except:
119-
return orig_text
136+
out = try_except(fn, always(None))(*params)
120137

121138
# the api calling function can also arrest the process, by returning None
122139

@@ -218,6 +235,7 @@ def sample(
218235
select_api_start_id_top_k = 10,
219236
):
220237
device = next(model.parameters()).device
238+
positions = positions.clone()
221239
max_seq_len = seq_len + 1
222240

223241
# validate
@@ -247,7 +265,7 @@ def sample(
247265
# lengthen the prime to the entire sequence length
248266

249267
remain_iterations = seq_len - prime_length
250-
output = F.pad(prime, (max_seq_len - prime_length, 0), value = 0.)
268+
output = F.pad(prime, (0, max_seq_len - prime_length), value = 0.)
251269

252270
batch_indices = torch.arange(batch_size, device = device)
253271
batch_indices = rearrange(batch_indices, 'b -> b 1')
@@ -322,7 +340,6 @@ def create_api_token_mask(num_tokens, api_start_token_id):
322340
# remove the last token in output (use as noop placeholder)
323341

324342
output = output[:, :-1]
325-
326343
return output
327344

328345
@beartype
@@ -496,6 +513,42 @@ def loss_fn(weight, probs):
496513

497514
return ret
498515

516+
# datasets and dataloaders
517+
518+
# for bootstrapping the initial datasets with api calls
519+
# as well as for the final finetuning
520+
521+
@beartype
522+
class PromptDataset(Dataset):
523+
def __init__(
524+
self,
525+
prompt: str,
526+
prompt_input_tag: str,
527+
data: List[str],
528+
tokenizer_encode: Callable
529+
):
530+
self.data = data
531+
self.prompt = prompt
532+
self.prompt_input_tag_regex = re.escape(prompt_input_tag)
533+
534+
def __len__(self):
535+
return len(self.data)
536+
537+
def __getitem__(self, idx):
538+
data_string = self.data[idx]
539+
data_with_prompt = re.sub(self.prompt_input_tag_regex, data_string, self.prompt)
540+
token_ids = tokenizer.encode(data_with_prompt)
541+
return torch.tensor(token_ids).long(), torch.tensor(len(token_ids)).long()
542+
543+
def prompt_collate_fn(data, padding_value = 0):
544+
prompts, prompt_lengths = zip(*data)
545+
prompts = pad_sequence(prompts, batch_first = True, padding_value = padding_value)
546+
return prompts, torch.stack(prompt_lengths)
547+
548+
def PromptDataloader(ds: Dataset, *args, padding_value = 0, **kwargs):
549+
collate_fn = partial(prompt_collate_fn, padding_value = padding_value)
550+
return DataLoader(ds, *args, collate_fn = collate_fn, **kwargs)
551+
499552
# classes
500553

501554
@beartype
@@ -506,12 +559,36 @@ def __init__(
506559
*,
507560
tool_id: str,
508561
tool: Callable,
562+
api_start_str = ' [',
563+
api_stop_str = ']',
564+
api_start_id = None,
565+
api_stop_id = None,
509566
teach_tool_prompt: str,
567+
pad_id = 0,
568+
prompt_batch_size = 4,
569+
model_seq_len = 2048,
570+
tokenizer_encode: Callable = tokenizer.encode,
571+
tokenizer_decode: Callable = tokenizer.decode,
510572
prompt_input_tag: str = DEFAULT_PROMPT_INPUT_TAG,
511573
exclude_filters: dict[str, Callable[[str], bool]] = dict()
512574
):
513575
super().__init__()
514576
self.model = model
577+
self.model_seq_len = model_seq_len
578+
579+
self.teach_tool_prompt = teach_tool_prompt
580+
self.prompt_batch_size = prompt_batch_size
581+
self.prompt_input_tag = prompt_input_tag
582+
583+
self.tokenizer_encode = tokenizer_encode
584+
self.tokenizer_decode = tokenizer_decode
585+
586+
self.api_start_str = api_start_str
587+
self.api_stop_str = api_stop_str
588+
589+
self.api_start_id = api_start_id
590+
self.api_stop_id = api_stop_id
591+
self.pad_id = pad_id
515592

516593
self.tool_id = tool_id
517594
self.tool = tool
@@ -522,8 +599,48 @@ def __init__(
522599
self.teach_tool_prompt = teach_tool_prompt
523600
self.exclude_filters = exclude_filters
524601

602+
def generate_data_with_api_calls(
603+
self,
604+
data: List[str],
605+
temperature: float = 0.9
606+
) -> List[str]:
607+
608+
dataset = PromptDataset(
609+
data = data,
610+
prompt_input_tag = self.prompt_input_tag,
611+
prompt = self.teach_tool_prompt,
612+
tokenizer_encode = self.tokenizer_encode
613+
)
614+
615+
dl = PromptDataloader(
616+
dataset,
617+
batch_size = self.prompt_batch_size
618+
)
619+
620+
prompted_outputs = []
621+
622+
for prime, positions in dl:
623+
624+
sampled_outputs = sample(
625+
model = self.model,
626+
prime = prime,
627+
positions = positions,
628+
seq_len = self.model_seq_len,
629+
pad_id = self.pad_id,
630+
temperature = temperature
631+
)
632+
633+
for sample_output, position in zip(sampled_outputs, positions):
634+
start_position = position.item()
635+
636+
prompted_output = self.tokenizer_decode(sample_output[start_position:])
637+
prompted_outputs.append(prompted_output)
638+
639+
return prompted_outputs
640+
525641
def forward(
526642
self,
527643
data: List[str]
528644
):
529-
raise NotImplementedError
645+
data_with_api_calls = self.generate_data_with_api_calls(data)
646+
return data_with_api_calls

0 commit comments

Comments
 (0)