-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbackground.ts
100 lines (82 loc) · 3.69 KB
/
background.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
export { }
import { getDB, getUrlId, lockAndRunPglite, deletePagesOlderThan, storeEmbeddings, urlIsPresentOrInDatetimeRange } from "~db";
import { pipeline, env, type PipelineType } from "@xenova/transformers";
import { PGliteWorker } from "~dist/electric-sql/worker";
import type { Chunk } from "~lib/chunk";
import { MODEL_TYPE } from "~lib/chunk";
// IMPORTANT: see this issue https://github.com/microsoft/onnxruntime/issues/14445#issuecomment-1625861446
env.backends.onnx.wasm.numThreads = 1;
class PipelineSingleton {
// TODO: is it possible to ensure there's only a single instance of transformers Pipeline, and each worker reuses the same Pipeline?
static task = "feature-extraction" as PipelineType;
static model = MODEL_TYPE;
static instance = null;
static async getInstance(progress_callback = null) {
if (this.instance === null) {
this.instance = pipeline(this.task, this.model, {
progress_callback,
// dtype: "fp32",
// device: !!navigator.gpu ? "webgpu" : "wasm"
})
}
return this.instance;
}
}
const getLLMPipeline = async () => {
return await PipelineSingleton.getInstance((x) => {
console.log("Progress update", x)
});
}
const runInference = async (pipeline, chunk: string) => {
const output = await pipeline(chunk, {
pooling: "mean",
normalize: true,
});
const embedding = Array.from(output.data) as number[];
return embedding;
}
const deleteIfUrlIsExpired = async ({ db }) => {
return await deletePagesOlderThan(db);
}
chrome.runtime.onMessage.addListener((message, sender, sendResponse) => {
// check the URL - if it exists, check if age of the URL. if older than a day, delete the old entries, chunk new and get new embs
if (message.type === "process_page") {
const { type, url, textChunks }: { type: string, url: string, textChunks: Chunk[] } = message;
(async () => {
const urlIsPresent = await lockAndRunPglite(urlIsPresentOrInDatetimeRange, { url });
if (urlIsPresent) {
// URL exists OR is still valid. Do not process, and return.
sendResponse({ ok: true, error: null, msg: `URL ${url} has been processed before. Skipping...` })
} else {
const urlId = await lockAndRunPglite(getUrlId, { url });
const pipeline = await getLLMPipeline();
if (textChunks.length <= 0) {
sendResponse({ ok: false, error: null, msg: `No chunks sent.` })
} else {
navigator.locks.request("pglite", async (lock) => {
let pg = await getDB()
for (let chunk of textChunks) {
let embedding = await runInference(pipeline, chunk.content);
await storeEmbeddings(pg, urlId, chunk, embedding);
}
sendResponse({ ok: true, error: null, msg: `Done processing ${textChunks.length} chunks.` })
})
}
}
})()
} else if (message.type === "get_embedding") {
(async () => {
const { chunk } = message;
const pipeline = await getLLMPipeline();
let embedding = await runInference(pipeline, chunk);
sendResponse({ ok: true, error: null, embedding })
})()
} else if (message.type === "clean_up") {
(async () => {
await lockAndRunPglite(deleteIfUrlIsExpired, {});
sendResponse({ ok: true, error: null })
})()
}
// IMPORTANT: Keeps the message channel open for async response
return true;
})