@@ -5,95 +5,14 @@ import (
5
5
"testing"
6
6
)
7
7
8
- func TestEmbedBGEBaseEN (t * testing.T ) {
9
- // Test with a single input
10
- fe , err := NewFlagEmbedding (& InitOptions {
11
- Model : BGEBaseEN ,
12
- })
13
- defer fe .Destroy ()
14
- if err != nil {
15
- t .Fatalf ("Expected no error, got %v" , err )
16
- }
17
- input := []string {"hello world" }
18
- result , err := fe .Embed (input , 1 )
19
- if err != nil {
20
- t .Fatalf ("Expected no error, got %v" , err )
21
- }
22
-
23
- if len (result ) != len (input ) {
24
- t .Errorf ("Expected result length %v, got %v" , len (input ), len (result ))
25
- }
26
- }
27
-
28
- func TestEmbedAllMiniLML6V2 (t * testing.T ) {
29
- // Test with a single input
30
- fe , err := NewFlagEmbedding (& InitOptions {
31
- Model : AllMiniLML6V2 ,
32
- })
33
- defer fe .Destroy ()
34
- if err != nil {
35
- t .Fatalf ("Expected no error, got %v" , err )
36
- }
37
- input := []string {"hello world" }
38
- result , err := fe .Embed (input , 1 )
39
- if err != nil {
40
- t .Fatalf ("Expected no error, got %v" , err )
41
- }
42
-
43
- if len (result ) != len (input ) {
44
- t .Errorf ("Expected result length %v, got %v" , len (input ), len (result ))
45
- }
46
- }
47
-
48
- func TestEmbedBGESmallEN (t * testing.T ) {
49
- // Test with a single input
50
- fe , err := NewFlagEmbedding (& InitOptions {
51
- Model : BGESmallEN ,
52
- })
53
- defer fe .Destroy ()
54
- if err != nil {
55
- t .Fatalf ("Expected no error, got %v" , err )
56
- }
57
- input := []string {"hello world" }
58
- result , err := fe .Embed (input , 1 )
59
- if err != nil {
60
- t .Fatalf ("Expected no error, got %v" , err )
61
- }
62
-
63
- if len (result ) != len (input ) {
64
- t .Errorf ("Expected result length %v, got %v" , len (input ), len (result ))
65
- }
66
- }
67
-
68
- // A model type "Unigram" is not yet supported by the tokenizer
69
- // Ref: https://github.com/sugarme/tokenizer/blob/448e79b1ed65947b8c6343bf9aa39e78364f45c8/pretrained/model.go#L152
70
- // func TestEmbedMLE5Large(t *testing.T) {
71
- // // Test with a single input
72
- // show := false
73
- // fe, err := NewFlagEmbedding(&InitOptions{
74
- // Model: MLE5Large,
75
- // ShowDownloadProgress: &show,
76
- // })
77
- // defer fe.Destroy()
78
- // if err != nil {
79
- // t.Fatalf("Expected no error, got %v", err)
80
- // }
81
- // input := []string{"hello world"}
82
- // result, err := fe.Embed(input, 1)
83
- // if err != nil {
84
- // t.Fatalf("Expected no error, got %v", err)
85
- // }
86
-
87
- // if len(result) != len(input) {
88
- // t.Errorf("Expected result length %v, got %v", len(input), len(result))
89
- // }
90
- // }
91
-
92
8
func TestCanonicalValues (T * testing.T ) {
93
9
canonicalValues := map [EmbeddingModel ]([]float32 ){
94
- AllMiniLML6V2 : []float32 {0.02591 , 0.00573 , 0.01147 , 0.03796 , - 0.02328 , - 0.05493 , 0.014040 , - 0.01079 , - 0.02440 , - 0.01822 },
95
- BGESmallEN : []float32 {- 0.02313 , - 0.02552 , 0.017357 , - 0.06393 , - 0.00061 , 0.02212 , - 0.01472 , 0.03925 , 0.03444 , 0.00459 },
96
- BGEBaseEN : []float32 {0.01140 , 0.03722 , 0.02941 , 0.01230 , 0.03451 , 0.00876 , 0.02356 , 0.05414 , - 0.02945 , - 0.05472 },
10
+ AllMiniLML6V2 : []float32 {0.02591 , 0.00573 , 0.01147 , 0.03796 , - 0.02328 },
11
+ BGESmallEN : []float32 {- 0.02313 , - 0.02552 , 0.017357 , - 0.06393 , - 0.00061 },
12
+ BGEBaseEN : []float32 {0.01140 , 0.03722 , 0.02941 , 0.01230 , 0.03451 },
13
+ BGEBaseENV15 : []float32 {0.01129394 , 0.05493144 , 0.02615099 , 0.00328772 , 0.02996045 },
14
+ BGESmallENV15 : []float32 {0.01522374 , - 0.02271799 , 0.00860278 , - 0.07424029 , 0.00386434 },
15
+ BGESmallZH : []float32 {- 0.01023294 , 0.07634465 , 0.0691722 , - 0.04458365 , - 0.03160762 },
97
16
}
98
17
99
18
for model , expected := range canonicalValues {
@@ -114,10 +33,10 @@ func TestCanonicalValues(T *testing.T) {
114
33
T .Errorf ("Expected result length %v, got %v" , len (input ), len (result ))
115
34
}
116
35
117
- epsilon := float64 (1e-5 )
36
+ epsilon := float64 (1e-4 )
118
37
for i , v := range expected {
119
38
if math .Abs (float64 (result [0 ][i ]- v )) > float64 (epsilon ) {
120
- T .Errorf ("Element %d mismatch: expected %.6f, got %.6f" , i , v , result [0 ][i ])
39
+ T .Errorf ("Element %d mismatch for %s : expected %.6f, got %.6f" , i , model , v , result [0 ][i ])
121
40
}
122
41
}
123
42
}
0 commit comments