From da4a136c83b35ad681bb3715d0ace830b0de5d2e Mon Sep 17 00:00:00 2001 From: pelikhan Date: Thu, 23 Jan 2025 21:35:00 -0800 Subject: [PATCH] towards audio --- packages/core/src/chat.ts | 34 +++++++++++++++---- packages/core/src/chattypes.ts | 3 ++ packages/core/src/expander.ts | 12 +++++-- packages/core/src/promptdom.ts | 48 +++++++++++++++++++++++++++ packages/core/src/runpromptcontext.ts | 2 ++ packages/vscode/src/lmaccess.ts | 5 +++ 6 files changed, 95 insertions(+), 9 deletions(-) diff --git a/packages/core/src/chat.ts b/packages/core/src/chat.ts index 76c76508e9..441efa4154 100644 --- a/packages/core/src/chat.ts +++ b/packages/core/src/chat.ts @@ -1,6 +1,11 @@ // cspell: disable import { MarkdownTrace, TraceOptions } from "./trace" -import { PromptImage, PromptPrediction, renderPromptNode } from "./promptdom" +import { + PromptAudio, + PromptImage, + PromptPrediction, + renderPromptNode, +} from "./promptdom" import { host, runtimeHost } from "./host" import { GenerationOptions } from "./generation" import { dispose } from "./dispose" @@ -46,6 +51,7 @@ import { parseModelIdentifier, traceLanguageModelConnection } from "./models" import { ChatCompletionAssistantMessageParam, ChatCompletionContentPartImage, + ChatCompletionContentPartInputAudio, ChatCompletionMessageParam, ChatCompletionResponse, ChatCompletionsOptions, @@ -95,9 +101,11 @@ import { deleteUndefinedValues } from "./cleaners" export function toChatCompletionUserMessage( expanded: string, - images?: PromptImage[] + images?: PromptImage[], + audios?: PromptAudio[] ): ChatCompletionUserMessageParam { const imgs = images?.filter(({ url }) => url) || [] + const auds = audios?.filter(({ data }) => data) || [] if (imgs.length) return { role: "user", @@ -108,13 +116,23 @@ export function toChatCompletionUserMessage( }, ...imgs.map( ({ url, detail }) => - { + ({ type: "image_url", image_url: { url, detail, }, - } + }) satisfies ChatCompletionContentPartImage + ), + ...auds.map( + ({ data, format }) => + ({ + type: "input_audio", + input_audio: { + data, + format, + }, + }) satisfies ChatCompletionContentPartInputAudio ), ], } @@ -135,9 +153,11 @@ export type ChatCompletionHandler = ( export type ListModelsFunction = ( cfg: LanguageModelConfiguration, options: TraceOptions & CancellationOptions -) => Promise +) => Promise< + ResponseStatus & { + models?: LanguageModelInfo[] + } +> export type PullModelFunction = ( cfg: LanguageModelConfiguration, diff --git a/packages/core/src/chattypes.ts b/packages/core/src/chattypes.ts index 2b652b4efe..a3fefc2606 100644 --- a/packages/core/src/chattypes.ts +++ b/packages/core/src/chattypes.ts @@ -100,6 +100,9 @@ export type ChatCompletionUserMessageParam = export type ChatCompletionContentPartImage = OpenAI.Chat.Completions.ChatCompletionContentPartImage +export type ChatCompletionContentPartInputAudio = + OpenAI.Chat.Completions.ChatCompletionContentPartInputAudio + // Parameters for creating embeddings export type EmbeddingCreateParams = OpenAI.Embeddings.EmbeddingCreateParams diff --git a/packages/core/src/expander.ts b/packages/core/src/expander.ts index 9b863cc2a1..685d7355fe 100644 --- a/packages/core/src/expander.ts +++ b/packages/core/src/expander.ts @@ -10,6 +10,7 @@ import { } from "./constants" import { finalizeMessages, + PromptAudio, PromptImage, PromptPrediction, renderPromptNode, @@ -50,6 +51,7 @@ export async function callExpander( let logs = "" let messages: ChatCompletionMessageParam[] = [] let images: PromptImage[] = [] + let audios: PromptAudio[] = [] let schemas: Record = {} let functions: ToolCallback[] = [] let fileMerges: FileMergeHandler[] = [] @@ -82,6 +84,7 @@ export async function callExpander( const { messages: msgs, images: imgs, + audios: auds, errors, schemas: schs, functions: fns, @@ -98,6 +101,7 @@ export async function callExpander( }) messages = msgs images = imgs + audios = auds schemas = schs functions = fns fileMerges = fms @@ -136,6 +140,7 @@ export async function callExpander( statusText, messages, images, + audios, schemas, functions: Object.freeze(functions), fileMerges, @@ -247,6 +252,7 @@ export async function expandTemplate( const { status, statusText, messages } = prompt const images = prompt.images.slice(0) + const audios = prompt.audios.slice(0) const schemas = structuredClone(prompt.schemas) const tools = prompt.functions.slice(0) const fileMerges = prompt.fileMerges.slice(0) @@ -279,8 +285,8 @@ export async function expandTemplate( } } - if (prompt.images?.length) - messages.push(toChatCompletionUserMessage("", prompt.images)) + if (images?.length || audios?.length) + messages.push(toChatCompletionUserMessage("", images, audios)) if (prompt.aici) messages.push(prompt.aici) const addSystemMessage = (content: string) => { @@ -314,6 +320,7 @@ export async function expandTemplate( const sysr = await callExpander(prj, system, env, trace, options) if (sysr.images) images.push(...sysr.images) + if (sysr.audios) audios.push(...sysr.audios) if (sysr.schemas) Object.assign(schemas, sysr.schemas) if (sysr.functions) tools.push(...sysr.functions) if (sysr.fileMerges) fileMerges.push(...sysr.fileMerges) @@ -394,6 +401,7 @@ ${schemaTs} cache, messages, images, + audios, schemas, tools, status: status, diff --git a/packages/core/src/promptdom.ts b/packages/core/src/promptdom.ts index 6fe89040e4..bd7e645d3f 100644 --- a/packages/core/src/promptdom.ts +++ b/packages/core/src/promptdom.ts @@ -45,6 +45,7 @@ export interface PromptNode extends ContextExpansionOptions { type?: | "text" | "image" + | "audio" | "schema" | "tool" | "fileMerge" @@ -150,6 +151,18 @@ export interface PromptImageNode extends PromptNode { resolved?: PromptImage // Resolved image information } +export interface PromptAudio { + filename?: string + data: string + format: "mp3" | "wav" +} + +export interface PromptAudioNode extends PromptNode { + type: "audio" + value: Awaitable // Image information + resolved?: PromptAudio // Resolved image information +} + // Interface for a schema node. export interface PromptSchemaNode extends PromptNode { type: "schema" @@ -418,6 +431,15 @@ export function createImageNode( return { type: "image", value, ...(options || {}) } } +// Function to create an image node. +export function createAudioNode( + value: Awaitable, + options?: ContextExpansionOptions +): PromptAudioNode { + assert(value !== undefined) + return { type: "audio", value, ...(options || {}) } +} + // Function to create a schema node. export function createSchemaNode( name: string, @@ -556,6 +578,7 @@ export interface PromptNodeVisitor { def?: (node: PromptDefNode) => Awaitable // Definition node visitor defData?: (node: PromptDefDataNode) => Awaitable // Definition data node visitor image?: (node: PromptImageNode) => Awaitable // Image node visitor + audio?: (node: PromptAudioNode) => Awaitable // Audio node visitor schema?: (node: PromptSchemaNode) => Awaitable // Schema node visitor tool?: (node: PromptToolNode) => Awaitable // Function node visitor fileMerge?: (node: PromptFileMergeNode) => Awaitable // File merge node visitor @@ -585,6 +608,9 @@ export async function visitNode(node: PromptNode, visitor: PromptNodeVisitor) { case "image": await visitor.image?.(node as PromptImageNode) break + case "audio": + await visitor.audio?.(node as PromptAudioNode) + break case "schema": await visitor.schema?.(node as PromptSchemaNode) break @@ -632,6 +658,7 @@ export async function visitNode(node: PromptNode, visitor: PromptNodeVisitor) { // Interface for representing a rendered prompt node. export interface PromptNodeRender { images: PromptImage[] // Images included in the prompt + audios: PromptAudio[] errors: unknown[] // Errors encountered during rendering schemas: Record // Schemas included in the prompt functions: ToolCallback[] // Functions included in the prompt @@ -847,6 +874,15 @@ async function resolvePromptNode( n.error = e } }, + audio: async (n) => { + try { + const v = await n.value + n.resolved = v + n.preview = n.resolved ? `