@@ -238,7 +238,6 @@ def sample(
238
238
select_api_start_id_top_k = 10 ,
239
239
):
240
240
device = next (model .parameters ()).device
241
- positions = positions .clone ()
242
241
max_seq_len = seq_len + 1
243
242
244
243
# validate
@@ -258,7 +257,11 @@ def sample(
258
257
259
258
# sampling positions - different sequences have different cursors
260
259
261
- positions = default (positions , torch .zeros ((batch_size ,), device = device , dtype = torch .long ))
260
+ if exists (positions ):
261
+ positions = positions .clone ()
262
+ else :
263
+ positions = torch .zeros ((batch_size ,), device = device , dtype = torch .long )
264
+
262
265
assert (positions <= (prime_length + 1 )).all () and (positions <= max_seq_len ).all (), 'all positions must be less then initial prime length as well as the total sequence length + 1 (plus one for noop if one sequence finished sampling before the other)'
263
266
264
267
# eval model
@@ -516,7 +519,7 @@ def loss_fn(weight, probs):
516
519
selected_indices = indices [selected_mask ]
517
520
518
521
ret = FilteredResults (
519
- selected_mask .sum ().item ()
522
+ selected_mask .sum ().item (),
520
523
(~ selected_mask ).sum ().item (),
521
524
selected_indices ,
522
525
selected_mask ,
@@ -563,6 +566,22 @@ def PromptDataloader(ds: Dataset, *args, padding_value = 0, **kwargs):
563
566
collate_fn = partial (prompt_collate_fn , padding_value = padding_value )
564
567
return DataLoader (ds , * args , collate_fn = collate_fn , ** kwargs )
565
568
569
+ class FinetuneDataset (Dataset ):
570
+ def __init__ (
571
+ self ,
572
+ tokens : torch .Tensor
573
+ ):
574
+ self .tokens = tokens
575
+
576
+ def __len__ (self ):
577
+ return len (self .tokens )
578
+
579
+ def __getitem__ (self , idx ):
580
+ return self .tokens [idx ]
581
+
582
+ def FinetuneDataloader (ds : Dataset , * args , padding_value = 0 , ** kwargs ):
583
+ return DataLoader (ds , * args , collate_fn = partial (pad_sequence , padding_value = padding_value ), ** kwargs )
584
+
566
585
# classes
567
586
568
587
@beartype
@@ -585,8 +604,16 @@ def __init__(
585
604
model_seq_len = 2048 ,
586
605
tokenizer_encode : Callable = tokenizer .encode ,
587
606
tokenizer_decode : Callable = tokenizer .decode ,
607
+ post_prompt_callback : Callable = identity ,
588
608
prompt_input_tag : str = DEFAULT_PROMPT_INPUT_TAG ,
589
- exclude_filters : dict [str , Callable [[str ], bool ]] = dict ()
609
+ exclude_filters : dict [str , Callable [[str ], bool ]] = dict (),
610
+ finetune = False ,
611
+ finetune_lr = 1e-4 ,
612
+ finetune_wd = 1e-2 ,
613
+ finetune_betas = (0.9 , 0.99 ),
614
+ finetune_eps = 1e-8 ,
615
+ finetune_epochs = 3 ,
616
+ finetune_batch_size = 16
590
617
):
591
618
super ().__init__ ()
592
619
self .model = model
@@ -596,6 +623,8 @@ def __init__(
596
623
self .prompt_batch_size = prompt_batch_size
597
624
self .prompt_input_tag = prompt_input_tag
598
625
626
+ self .post_prompt_callback = post_prompt_callback # for easy mocking
627
+
599
628
self .tokenizer_encode = tokenizer_encode
600
629
self .tokenizer_decode = tokenizer_decode
601
630
self .tokenizer_encode_to_tensor = lambda s : torch .tensor (tokenizer_encode (s )).long ()
@@ -631,6 +660,22 @@ def __init__(
631
660
self .teach_tool_prompt = teach_tool_prompt
632
661
self .exclude_filters = exclude_filters
633
662
663
+ self .should_finetune = finetune
664
+
665
+ if not finetune :
666
+ return
667
+
668
+ self .finetune_batch_size = finetune_batch_size
669
+ self .finetune_epochs = finetune_epochs
670
+
671
+ self .optimizer = get_optimizer (
672
+ model .parameters (),
673
+ lr = finetune_lr ,
674
+ wd = finetune_wd ,
675
+ betas = finetune_betas ,
676
+ eps = finetune_eps
677
+ )
678
+
634
679
def generate_data_with_api_calls (
635
680
self ,
636
681
data : List [str ],
@@ -706,22 +751,46 @@ def filter_and_keep_only_first_api_call(
706
751
707
752
return included , excluded
708
753
754
+ @torch .no_grad ()
709
755
def sample_model_with_api_calls (
710
756
self ,
711
- prime : torch .Tensor ,
757
+ prime : Union [ torch .Tensor , str ] ,
712
758
occurrence = 1 ,
713
759
** kwargs
714
760
):
761
+ self .model .eval ()
762
+
763
+ prime_is_str = isinstance (prime , str )
764
+
765
+ if prime_is_str :
766
+ prime = self .tokenizer_encode (prime )
767
+ prime = torch .tensor (prime ).long ()
768
+ prime = rearrange (prime , 'n -> 1 n' )
769
+
770
+ assert prime .shape [0 ] == 1 , 'only one at a time for now'
771
+
772
+ invoke_tools_ = partial (invoke_tools , self .registry )
773
+
774
+ def call_apis (t : torch .Tensor ):
775
+ t = self .tokenizer_decode (t [0 ])
776
+ t = invoke_tools_ (t )
777
+ t = self .tokenizer_encode_to_tensor (t )
778
+ return rearrange (t , 'n -> 1 n' )
779
+
715
780
output = sample_with_api_call (
716
781
model = self .model ,
782
+ prime = prime ,
717
783
seq_len = self .model_seq_len ,
718
- call_apis = partial ( invoke_tools , self . registry ) ,
784
+ call_apis = call_apis ,
719
785
api_end_token_id = self .api_stop_id ,
720
786
occurrence = occurrence ,
721
787
** kwargs
722
788
)
723
789
724
- return output
790
+ if not prime_is_str :
791
+ return output
792
+
793
+ return self .tokenizer_decode (output [0 ])
725
794
726
795
def make_api_calls (
727
796
self ,
@@ -764,17 +833,50 @@ def filter_by_api_responses(
764
833
765
834
return filtered_results
766
835
836
+ def finetune (
837
+ self ,
838
+ filtered_results : Union [FilteredResults , torch .Tensor ]
839
+ ):
840
+ self .model .train ()
841
+
842
+ if isinstance (filtered_results , FilteredResults ):
843
+ filtered_results = filtered_results .filtered_tokens_without_api_response
844
+
845
+ dataset = FinetuneDataset (tokens = filtered_results )
846
+ dl = FinetuneDataloader (dataset , batch_size = self .finetune_batch_size , shuffle = True )
847
+
848
+ for epoch in tqdm (range (self .finetune_epochs ), desc = 'finetune epochs' ):
849
+ for batch in dl :
850
+ inp , labels = batch [:, :- 1 ], batch [:, 1 :]
851
+
852
+ logits = self .model (inp )
853
+ logits = rearrange (logits , 'b n c -> b c n' )
854
+
855
+ loss = F .cross_entropy (logits , labels , ignore_index = self .pad_id )
856
+ loss .backward ()
857
+
858
+ print (f'loss: { loss .item ()} ' )
859
+ self .optimizer .step ()
860
+ self .optimizer .zero_grad ()
861
+
862
+ print (f'finished finetuning on { len (dataset )} filtered samples' )
863
+
767
864
def forward (
768
865
self ,
769
866
data : List [str ]
770
867
):
771
868
data_with_api_calls = self .generate_data_with_api_calls (data )
772
869
870
+ data_with_api_calls = self .post_prompt_callback (data_with_api_calls )
871
+
773
872
filtered_data , filtered_data_with_api_calls = self .filter_and_keep_only_first_api_call (data , data_with_api_calls )
774
873
775
874
assert len (filtered_data_with_api_calls ) > 0 , 'your model failed to follow instructions and make API calls. please try a better model or do some better prompt engineering'
776
875
777
876
data_with_responses = self .make_api_calls (filtered_data_with_api_calls )
778
877
filtered_results = self .filter_by_api_responses (filtered_data , filtered_data_with_api_calls , data_with_responses )
779
878
780
- return filtered_results
879
+ if not self .should_finetune :
880
+ return filtered_results
881
+
882
+ self .finetune (filtered_results )
0 commit comments