Skip to content

Commit 17009c6

Browse files
committed
research embeddings
1 parent 1727728 commit 17009c6

File tree

5 files changed

+233
-15
lines changed

5 files changed

+233
-15
lines changed

Diff for: package-lock.json

+59-14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -216,4 +216,4 @@
216216
"@xenova/transformers": "^2.17.1",
217217
"langchain": "^0.1.17"
218218
}
219-
}
219+
}

Diff for: src/extension.ts

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { tokenizer } from "./common/prompt/tokenizer";
1111
import { login } from "./common/auth";
1212
import { secretsStorage } from "./common/utils/secretStore";
1313
import { getSuppabaseClient } from "./common/auth/supabaseClient";
14+
import { startTest } from "./test";
1415

1516
export async function activate(context: vscode.ExtensionContext) {
1617
FirecoderTelemetrySenderInstance.init(context);
@@ -40,6 +41,7 @@ export async function activate(context: vscode.ExtensionContext) {
4041

4142
statusBar.init(context);
4243
await tokenizer.init();
44+
startTest();
4345

4446
context.subscriptions.push(
4547
vscode.commands.registerCommand(

Diff for: src/hft.ts

+122
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
2+
import { chunkArray } from "@langchain/core/utils/chunk_array";
3+
import { getSaveFolder } from "./common/download/utils";
4+
5+
export interface HuggingFaceTransformersEmbeddingsParams
6+
extends EmbeddingsParams {
7+
/** Model name to use */
8+
modelName: string;
9+
10+
/**
11+
* Timeout to use when making requests to OpenAI.
12+
*/
13+
timeout?: number;
14+
15+
/**
16+
* The maximum number of documents to embed in a single request.
17+
*/
18+
batchSize?: number;
19+
20+
/**
21+
* Whether to strip new lines from the input text. This is recommended by
22+
* OpenAI, but may not be suitable for all use cases.
23+
*/
24+
stripNewLines?: boolean;
25+
}
26+
27+
/**
28+
* @example
29+
* ```typescript
30+
* const model = new HuggingFaceTransformersEmbeddings({
31+
* modelName: "Xenova/all-MiniLM-L6-v2",
32+
* });
33+
*
34+
* // Embed a single query
35+
* const res = await model.embedQuery(
36+
* "What would be a good company name for a company that makes colorful socks?"
37+
* );
38+
* console.log({ res });
39+
*
40+
* // Embed multiple documents
41+
* const documentRes = await model.embedDocuments(["Hello world", "Bye bye"]);
42+
* console.log({ documentRes });
43+
* ```
44+
*/
45+
export class HuggingFaceTransformersEmbeddingsLocal
46+
extends Embeddings
47+
implements HuggingFaceTransformersEmbeddingsParams
48+
{
49+
modelName = "Xenova/all-MiniLM-L6-v2";
50+
51+
batchSize = 1;
52+
53+
stripNewLines = true;
54+
55+
timeout?: number;
56+
57+
private pipelinePromise?: Promise<any>;
58+
59+
constructor(fields?: Partial<HuggingFaceTransformersEmbeddingsParams>) {
60+
super(fields ?? {});
61+
62+
this.modelName = fields?.modelName ?? this.modelName;
63+
this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines;
64+
this.timeout = fields?.timeout;
65+
}
66+
67+
async embedDocuments(texts: string[]): Promise<number[][]> {
68+
const batches = chunkArray(
69+
this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts,
70+
this.batchSize
71+
);
72+
73+
const batchRequests = batches.map((batch) => this.runEmbedding(batch));
74+
const batchResponses = await Promise.all(batchRequests);
75+
const embeddings: number[][] = [];
76+
77+
for (let i = 0; i < batchResponses.length; i += 1) {
78+
const batchResponse = batchResponses[i];
79+
for (let j = 0; j < batchResponse.length; j += 1) {
80+
embeddings.push(batchResponse[j]);
81+
}
82+
}
83+
84+
return embeddings;
85+
}
86+
87+
async embedQuery(text: string): Promise<number[]> {
88+
const data = await this.runEmbedding([
89+
this.stripNewLines ? text.replace(/\n/g, " ") : text,
90+
]);
91+
return data[0];
92+
}
93+
94+
private async runEmbedding(texts: string[]) {
95+
return this.caller.call(async () => {
96+
try {
97+
const { pipeline } = await import("@xenova/transformers");
98+
99+
const pipe = await (this.pipelinePromise ??= pipeline(
100+
"feature-extraction",
101+
this.modelName,
102+
{
103+
cache_dir: await getSaveFolder(),
104+
quantized: true,
105+
}
106+
));
107+
108+
// cls 0.537 place 1
109+
// mean 0.601 place 13
110+
// none shit
111+
const output = await pipe(texts, {
112+
pooling: "cls",
113+
normalize: true,
114+
});
115+
return output.tolist();
116+
} catch (error) {
117+
console.log(error);
118+
debugger;
119+
}
120+
});
121+
}
122+
}

Diff for: src/test.ts

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import { HuggingFaceTransformersEmbeddingsLocal } from "./hft";
2+
import { DirectoryLoader } from "langchain/document_loaders/fs/directory";
3+
import { TextLoader } from "langchain/document_loaders/fs/text";
4+
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
5+
import { MemoryVectorStore } from "langchain/vectorstores/memory";
6+
7+
export const startTest = async () => {
8+
const loader = new DirectoryLoader(
9+
"/home/gespispace/helper/helper-coder/src",
10+
{
11+
".ts": (path) => new TextLoader(path),
12+
}
13+
);
14+
15+
const docs = await loader.load();
16+
17+
const splitter = RecursiveCharacterTextSplitter.fromLanguage("js", {
18+
chunkSize: 4000,
19+
chunkOverlap: 0,
20+
});
21+
const jsOutput = await splitter.splitDocuments(docs);
22+
const vectorStore = await MemoryVectorStore.fromDocuments(
23+
jsOutput,
24+
new HuggingFaceTransformersEmbeddingsLocal({
25+
// modelName: "jinaai/jina-embeddings-v2-base-code",
26+
modelName: "Xenova/bge-m3",
27+
maxConcurrency: 1,
28+
})
29+
);
30+
31+
const resultOne = await vectorStore.similaritySearchWithScore(
32+
"what properties do we send with each events to telemetry?",
33+
20
34+
);
35+
console.log(resultOne);
36+
// const model = new HuggingFaceTransformersEmbeddingsLocal({
37+
// // modelName: "jinaai/jina-embeddings-v2-base-code",
38+
// modelName: "Xenova/bge-m3",
39+
// });
40+
41+
// /* Embed queries */
42+
// const res = await model.embedQuery(
43+
// "What would be a good company name for a company that makes colorful socks?"
44+
// );
45+
// console.log({ res });
46+
// /* Embed documents */
47+
// const documentRes = await model.embedDocuments(["Hello world", "Bye bye"]);
48+
// console.log({ documentRes });
49+
};

0 commit comments

Comments
 (0)