Skip to content

Commit 7fe735e

Browse files
authored
feat: New FlagEmbedding models (#6)
* feat: new embedding models * test: canonical value checks update
1 parent 62e65cc commit 7fe735e

File tree

3 files changed

+31
-91
lines changed

3 files changed

+31
-91
lines changed

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@ The default embedding supports "query" and "passage" prefixes for the input text
2525

2626
## 🤖 Models
2727

28+
- [**BAAI/bge-base-en**](https://huggingface.co/BAAI/bge-base-en)
2829
- [**BAAI/bge-base-en-v1.5**](https://huggingface.co/BAAI/bge-base-en-v1.5)
30+
- [**BAAI/bge-small-en**](https://huggingface.co/BAAI/bge-small-en)
2931
- [**BAAI/bge-small-en-v1.5**](https://huggingface.co/BAAI/bge-small-en-v1.5) - Default
32+
- [**BAAI/bge-base-zh-v1.5**](https://huggingface.co/BAAI/bge-base-zh-v1.5)
3033
- [**sentence-transformers/all-MiniLM-L6-v2**](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
3134

3235
## 🚀 Installation

fastembed.go

+20-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ type EmbeddingModel string
2626
const (
2727
AllMiniLML6V2 EmbeddingModel = "fast-all-MiniLM-L6-v2"
2828
BGEBaseEN EmbeddingModel = "fast-bge-base-en"
29+
BGEBaseENV15 EmbeddingModel = "fast-bge-base-en-v1.5"
2930
BGESmallEN EmbeddingModel = "fast-bge-small-en"
31+
BGESmallENV15 EmbeddingModel = "fast-bge-small-en-v1.5"
32+
BGESmallZH EmbeddingModel = "fast-bge-small-zh-v1.5"
3033

3134
// A model with type "Unigram" is not yet supported by the tokenizer
3235
// Ref: https://github.com/sugarme/tokenizer/blob/448e79b1ed65947b8c6343bf9aa39e78364f45c8/pretrained/model.go#L152
@@ -79,7 +82,7 @@ func NewFlagEmbedding(options *InitOptions) (*FlagEmbedding, error) {
7982
}
8083

8184
if options.Model == "" {
82-
options.Model = BGESmallEN
85+
options.Model = BGESmallENV15
8386
}
8487

8588
if options.MaxLength == 0 {
@@ -281,10 +284,25 @@ func ListSupportedModels() []ModelInfo {
281284
Dim: 768,
282285
Description: "Base English model",
283286
},
287+
{
288+
Model: BGEBaseENV15,
289+
Dim: 768,
290+
Description: "v1.5 release of the base English model",
291+
},
284292
{
285293
Model: BGESmallEN,
286294
Dim: 384,
287-
Description: "Fast and Default English model",
295+
Description: "Fast English model",
296+
},
297+
{
298+
Model: BGESmallENV15,
299+
Dim: 384,
300+
Description: "Fast, default English model",
301+
},
302+
{
303+
Model: BGESmallZH,
304+
Dim: 512,
305+
Description: "Fast Chinese model",
288306
},
289307
// {
290308
// Model: MLE5Large,

fastembed_test.go

+8-89
Original file line numberDiff line numberDiff line change
@@ -5,95 +5,14 @@ import (
55
"testing"
66
)
77

8-
func TestEmbedBGEBaseEN(t *testing.T) {
9-
// Test with a single input
10-
fe, err := NewFlagEmbedding(&InitOptions{
11-
Model: BGEBaseEN,
12-
})
13-
defer fe.Destroy()
14-
if err != nil {
15-
t.Fatalf("Expected no error, got %v", err)
16-
}
17-
input := []string{"hello world"}
18-
result, err := fe.Embed(input, 1)
19-
if err != nil {
20-
t.Fatalf("Expected no error, got %v", err)
21-
}
22-
23-
if len(result) != len(input) {
24-
t.Errorf("Expected result length %v, got %v", len(input), len(result))
25-
}
26-
}
27-
28-
func TestEmbedAllMiniLML6V2(t *testing.T) {
29-
// Test with a single input
30-
fe, err := NewFlagEmbedding(&InitOptions{
31-
Model: AllMiniLML6V2,
32-
})
33-
defer fe.Destroy()
34-
if err != nil {
35-
t.Fatalf("Expected no error, got %v", err)
36-
}
37-
input := []string{"hello world"}
38-
result, err := fe.Embed(input, 1)
39-
if err != nil {
40-
t.Fatalf("Expected no error, got %v", err)
41-
}
42-
43-
if len(result) != len(input) {
44-
t.Errorf("Expected result length %v, got %v", len(input), len(result))
45-
}
46-
}
47-
48-
func TestEmbedBGESmallEN(t *testing.T) {
49-
// Test with a single input
50-
fe, err := NewFlagEmbedding(&InitOptions{
51-
Model: BGESmallEN,
52-
})
53-
defer fe.Destroy()
54-
if err != nil {
55-
t.Fatalf("Expected no error, got %v", err)
56-
}
57-
input := []string{"hello world"}
58-
result, err := fe.Embed(input, 1)
59-
if err != nil {
60-
t.Fatalf("Expected no error, got %v", err)
61-
}
62-
63-
if len(result) != len(input) {
64-
t.Errorf("Expected result length %v, got %v", len(input), len(result))
65-
}
66-
}
67-
68-
// A model type "Unigram" is not yet supported by the tokenizer
69-
// Ref: https://github.com/sugarme/tokenizer/blob/448e79b1ed65947b8c6343bf9aa39e78364f45c8/pretrained/model.go#L152
70-
// func TestEmbedMLE5Large(t *testing.T) {
71-
// // Test with a single input
72-
// show := false
73-
// fe, err := NewFlagEmbedding(&InitOptions{
74-
// Model: MLE5Large,
75-
// ShowDownloadProgress: &show,
76-
// })
77-
// defer fe.Destroy()
78-
// if err != nil {
79-
// t.Fatalf("Expected no error, got %v", err)
80-
// }
81-
// input := []string{"hello world"}
82-
// result, err := fe.Embed(input, 1)
83-
// if err != nil {
84-
// t.Fatalf("Expected no error, got %v", err)
85-
// }
86-
87-
// if len(result) != len(input) {
88-
// t.Errorf("Expected result length %v, got %v", len(input), len(result))
89-
// }
90-
// }
91-
928
func TestCanonicalValues(T *testing.T) {
939
canonicalValues := map[EmbeddingModel]([]float32){
94-
AllMiniLML6V2: []float32{0.02591, 0.00573, 0.01147, 0.03796, -0.02328, -0.05493, 0.014040, -0.01079, -0.02440, -0.01822},
95-
BGESmallEN: []float32{-0.02313, -0.02552, 0.017357, -0.06393, -0.00061, 0.02212, -0.01472, 0.03925, 0.03444, 0.00459},
96-
BGEBaseEN: []float32{0.01140, 0.03722, 0.02941, 0.01230, 0.03451, 0.00876, 0.02356, 0.05414, -0.02945, -0.05472},
10+
AllMiniLML6V2: []float32{0.02591, 0.00573, 0.01147, 0.03796, -0.02328},
11+
BGESmallEN: []float32{-0.02313, -0.02552, 0.017357, -0.06393, -0.00061},
12+
BGEBaseEN: []float32{0.01140, 0.03722, 0.02941, 0.01230, 0.03451},
13+
BGEBaseENV15: []float32{0.01129394, 0.05493144, 0.02615099, 0.00328772, 0.02996045},
14+
BGESmallENV15: []float32{0.01522374, -0.02271799, 0.00860278, -0.07424029, 0.00386434},
15+
BGESmallZH: []float32{-0.01023294, 0.07634465, 0.0691722, -0.04458365, -0.03160762},
9716
}
9817

9918
for model, expected := range canonicalValues {
@@ -114,10 +33,10 @@ func TestCanonicalValues(T *testing.T) {
11433
T.Errorf("Expected result length %v, got %v", len(input), len(result))
11534
}
11635

117-
epsilon := float64(1e-5)
36+
epsilon := float64(1e-4)
11837
for i, v := range expected {
11938
if math.Abs(float64(result[0][i]-v)) > float64(epsilon) {
120-
T.Errorf("Element %d mismatch: expected %.6f, got %.6f", i, v, result[0][i])
39+
T.Errorf("Element %d mismatch for %s: expected %.6f, got %.6f", i, model, v, result[0][i])
12140
}
12241
}
12342
}

0 commit comments

Comments
 (0)