We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 96b34ae commit 222b676Copy full SHA for 222b676
finetune_moss.py
@@ -118,7 +118,7 @@ def collate_fn(self, batch):
118
batch_labels.append(label)
119
120
batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=self.tokenizer.eos_token_id)
121
- batch_attn_mask = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=0).to(torch.bool)
+ batch_attn_mask = torch.nn.utils.rnn.pad_sequence(batch_attn_mask, batch_first=True, padding_value=0).to(torch.bool)
122
batch_labels = torch.nn.utils.rnn.pad_sequence(batch_labels, batch_first=True, padding_value=-100)
123
124
return batch_input_ids, batch_attn_mask, batch_labels
0 commit comments