1
1
import logging
2
+ import pickle
2
3
from typing import (
3
4
Callable ,
4
5
Dict ,
9
10
)
10
11
11
12
import numpy as np
12
- from sklearn .manifold import locally_linear_embedding
13
13
from sklearn .mixture import GaussianMixture
14
14
import torch
15
15
from torch import nn
16
16
from tqdm import tqdm
17
17
from transformers import PreTrainedTokenizer
18
+ import umap .umap_ as umap
18
19
19
20
from fortuna .conformal import BinaryClassificationMulticalibrator
20
21
from fortuna .hallucination .grouping .clustering .base import GroupingModel
@@ -26,7 +27,7 @@ def __init__(
26
27
self ,
27
28
generative_model : nn .Module ,
28
29
tokenizer : PreTrainedTokenizer ,
29
- embedding_reduction_fn : Optional [ Callable [[ np . ndarray ], np . ndarray ]] = None ,
30
+ embedding_reduction_model : Optional = None ,
30
31
clustering_models : Optional [List ] = None ,
31
32
scoring_fn : Optional [
32
33
Callable [[torch .Tensor , torch .Tensor , int ], torch .Tensor ]
@@ -49,8 +50,8 @@ def __init__(
49
50
A generative model.
50
51
tokenizer: PreTrainedTokenizer
51
52
A tokenizer.
52
- embedding_reduction_fn : Optional[Callable[[np.ndarray], np.ndarray]]
53
- A function aimed at reducing the embedding dimensionality .
53
+ embedding_reduction_model : Optional
54
+ An embedding reduction model .
54
55
clustering_models: Optional[List]
55
56
A list of clustering models.
56
57
scoring_fn: Optional[Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]]
@@ -61,8 +62,8 @@ def __init__(
61
62
if self .tokenizer .pad_token is None :
62
63
self .tokenizer .pad_token = self .tokenizer .eos_token
63
64
logging .info ("`tokenizer.pad_token` is None. Set to `tokenizer.eos_token`." )
64
- self .embedding_reduction_fn = (
65
- embedding_reduction_fn or locally_linear_embedding_fn
65
+ self .embedding_reduction_model = embedding_reduction_model or umap . UMAP (
66
+ n_neighbors = 20
66
67
)
67
68
self .scoring_fn = scoring_fn or inv_perplexity
68
69
self .clustering_models = clustering_models or [
@@ -124,7 +125,7 @@ def fit(
124
125
else :
125
126
targets = np .array (targets )
126
127
127
- embeddings = self .embedding_reduction_fn (embeddings )
128
+ embeddings = self .embedding_reduction_model . fit_transform (embeddings )
128
129
embeddings = np .concatenate ((embeddings , scores [:, None ]), axis = 1 )
129
130
130
131
self .grouping_model = GroupingModel ()
@@ -147,7 +148,7 @@ def fit(
147
148
148
149
def predict_proba (
149
150
self ,
150
- texts : Union [ List [str ], List [ List [ str ]] ],
151
+ texts : List [str ],
151
152
contexts : List [str ],
152
153
calibrate : bool = True ,
153
154
) -> np .ndarray :
@@ -156,7 +157,7 @@ def predict_proba(
156
157
157
158
Parameters
158
159
----------
159
- texts: Union[ List[str], List[List[str]] ]
160
+ texts: List[str]
160
161
The texts to fit.
161
162
This may either be a list of strings (e.g. a list of single answers),
162
163
or a list of lists of strings (e.g. a list of multi-choice answers).
@@ -176,14 +177,14 @@ def predict_proba(
176
177
(
177
178
scores ,
178
179
embeddings ,
179
- which_choices ,
180
+ _ ,
180
181
) = self ._compute_scores_embeddings_which_choices (
181
182
texts = texts , contexts = contexts
182
183
)
183
184
if not calibrate :
184
185
return scores
185
186
186
- embeddings = self .embedding_reduction_fn (embeddings )
187
+ embeddings = self .embedding_reduction_model . transform (embeddings )
187
188
embeddings = np .concatenate ((embeddings , scores [:, None ]), axis = 1 )
188
189
189
190
group_scores = self .grouping_model .predict_proba (
@@ -195,7 +196,7 @@ def predict_proba(
195
196
196
197
def predict (
197
198
self ,
198
- texts : Union [ List [str ], List [ List [ str ]] ],
199
+ texts : List [str ],
199
200
contexts : List [str ],
200
201
calibrate : bool = True ,
201
202
probs : Optional [np .ndarray ] = None ,
@@ -206,7 +207,7 @@ def predict(
206
207
207
208
Parameters
208
209
----------
209
- texts: Union[ List[str], List[List[str]]]
210
+ texts: List[str],
210
211
The texts to fit.
211
212
This may either be a list of strings (e.g. a list of single answers),
212
213
or a list of lists of strings (e.g. a list of multi-choice answers).
@@ -253,7 +254,7 @@ def _compute_scores_embeddings_which_choices(
253
254
embeddings .append (_embeddings [which_choice , None ])
254
255
elif isinstance (text , str ):
255
256
embeddings .append (_embeddings )
256
- scores .append (_scores )
257
+ scores .append (_scores [ 0 ] )
257
258
258
259
return (
259
260
np .array (scores ),
@@ -278,16 +279,26 @@ def _get_logits_scores(
278
279
with torch .no_grad ():
279
280
_logits = self .generative_model (** inputs ).logits
280
281
281
- _scores = self .scoring_fn (
282
- logits = _logits ,
283
- labels = inputs ["input_ids" ],
284
- init_pos = len (context_inputs ),
285
- )
282
+ _scores = self .scoring_fn (
283
+ logits = _logits ,
284
+ labels = inputs ["input_ids" ],
285
+ init_pos = len (context_inputs ),
286
+ )
286
287
287
288
return _logits .cpu ().numpy (), _scores .cpu ().numpy ()
288
289
290
+ def save (self , path ):
291
+ state = dict (
292
+ embedding_reduction_model = self .embedding_reduction_model ,
293
+ grouping_model = self .grouping_model ,
294
+ multicalibrator = self .multicalibrator ,
295
+ _quantiles = self ._quantiles ,
296
+ )
297
+
298
+ with open (path , "wb" ) as filehandler :
299
+ pickle .dump (state , filehandler , - 1 )
289
300
290
- def locally_linear_embedding_fn ( x : np . ndarray ) -> np . ndarray :
291
- return locally_linear_embedding (
292
- x , n_neighbors = 300 , n_components = 200 , method = "modified"
293
- )[ 0 ]
301
+ def load ( self , path ) :
302
+ state = pickle . load ( open ( path , "rb" ))
303
+ for k , v in state . items ():
304
+ setattr ( self , k , v )
0 commit comments