From 2d920f7ccc7bed1ed06cdb52e0ef50f96f8100ac Mon Sep 17 00:00:00 2001 From: lloydzhou Date: Sat, 28 Sep 2024 15:05:41 +0800 Subject: [PATCH] using stream: schema to fetch in App --- app/global.d.ts | 1 + app/utils.ts | 41 +--------------- app/utils/stream.ts | 100 ++++++++++++++++++++++++++++++++++++++++ src-tauri/Cargo.lock | 36 +-------------- src-tauri/Cargo.toml | 1 + src-tauri/src/main.rs | 51 ++------------------ src-tauri/src/stream.rs | 96 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 204 insertions(+), 122 deletions(-) create mode 100644 app/utils/stream.ts create mode 100644 src-tauri/src/stream.rs diff --git a/app/global.d.ts b/app/global.d.ts index 8ee636bcd3c..a1453dc33b4 100644 --- a/app/global.d.ts +++ b/app/global.d.ts @@ -12,6 +12,7 @@ declare module "*.svg"; declare interface Window { __TAURI__?: { + convertFileSrc(url: string, protocol?: string): string; writeText(text: string): Promise; invoke(command: string, payload?: Record): Promise; dialog: { diff --git a/app/utils.ts b/app/utils.ts index 5be7bb2d9f7..fbe77c114c5 100644 --- a/app/utils.ts +++ b/app/utils.ts @@ -3,6 +3,7 @@ import { showToast } from "./components/ui-lib"; import Locale from "./locales"; import { RequestMessage } from "./client/api"; import { ServiceProvider } from "./constant"; +import { fetch } from "./utils/stream"; export function trimTopic(topic: string) { // Fix an issue where double quotes still show in the Indonesian language @@ -286,46 +287,6 @@ export function showPlugins(provider: ServiceProvider, model: string) { return false; } -export function fetch( - url: string, - options?: Record, -): Promise { - if (window.__TAURI__) { - const tauriUri = window.__TAURI__.convertFileSrc(url, "sse"); - return window.fetch(tauriUri, options).then((r) => { - // 1. create response, - // TODO using event to get status and statusText and headers - const { status, statusText } = r; - const { readable, writable } = new TransformStream(); - const res = new Response(readable, { status, statusText }); - // 2. call fetch_read_body multi times, and write to Response.body - const writer = writable.getWriter(); - let unlisten; - window.__TAURI__.event - .listen("sse-response", (e) => { - const { id, payload } = e; - console.log("event", id, payload); - writer.ready.then(() => { - if (payload !== 0) { - writer.write(new Uint8Array(payload)); - } else { - writer.releaseLock(); - writable.close(); - unlisten && unlisten(); - } - }); - }) - .then((u) => (unlisten = u)); - return res; - }); - } - return window.fetch(url, options); -} - -if (undefined !== window) { - window.tauriFetch = fetch; -} - export function adapter(config: Record) { const { baseURL, url, params, ...rest } = config; const path = baseURL ? `${baseURL}${url}` : url; diff --git a/app/utils/stream.ts b/app/utils/stream.ts new file mode 100644 index 00000000000..8f9ccfbaa1d --- /dev/null +++ b/app/utils/stream.ts @@ -0,0 +1,100 @@ +// using tauri register_uri_scheme_protocol, register `stream:` protocol +// see src-tauri/src/stream.rs, and src-tauri/src/main.rs +// 1. window.fetch(`stream://localhost/${fetchUrl}`), get request_id +// 2. listen event: `stream-response` multi times to get response headers and body + +type ResponseEvent = { + id: number; + payload: { + request_id: number; + status?: number; + error?: string; + name?: string; + value?: string; + chunk?: number[]; + }; +}; + +export function fetch(url: string, options?: RequestInit): Promise { + if (window.__TAURI__) { + const tauriUri = window.__TAURI__.convertFileSrc(url, "stream"); + const { signal, ...rest } = options || {}; + return window + .fetch(tauriUri, rest) + .then((r) => r.text()) + .then((rid) => parseInt(rid)) + .then((request_id: number) => { + // 1. using event to get status and statusText and headers, and resolve it + let resolve: Function | undefined; + let reject: Function | undefined; + let status: number; + let writable: WritableStream | undefined; + let writer: WritableStreamDefaultWriter | undefined; + const headers = new Headers(); + let unlisten: Function | undefined; + + if (signal) { + signal.addEventListener("abort", () => { + // Reject the promise with the abort reason. + unlisten && unlisten(); + reject && reject(signal.reason); + }); + } + // @ts-ignore 2. listen response multi times, and write to Response.body + window.__TAURI__.event + .listen("stream-response", (e: ResponseEvent) => { + const { id, payload } = e; + const { + request_id: rid, + status: _status, + name, + value, + error, + chunk, + } = payload; + if (request_id != rid) { + return; + } + /** + * 1. get status code + * 2. get headers + * 3. start get body, then resolve response + * 4. get body chunk + */ + if (error) { + unlisten && unlisten(); + return reject && reject(error); + } else if (_status) { + status = _status; + } else if (name && value) { + headers.append(name, value); + } else if (chunk) { + if (resolve) { + const ts = new TransformStream(); + writable = ts.writable; + writer = writable.getWriter(); + resolve(new Response(ts.readable, { status, headers })); + resolve = undefined; + } + writer && + writer.ready.then(() => { + writer && writer.write(new Uint8Array(chunk)); + }); + } else if (_status === 0) { + // end of body + unlisten && unlisten(); + writer && + writer.ready.then(() => { + writer && writer.releaseLock(); + writable && writable.close(); + }); + } + }) + .then((u: Function) => (unlisten = u)); + return new Promise( + (_resolve, _reject) => ([resolve, reject] = [_resolve, _reject]), + ); + }); + } + return window.fetch(url, options); +} diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index fcc06d163cd..c9baffc0acc 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -1986,6 +1986,7 @@ checksum = "e4a24736216ec316047a1fc4252e27dabb04218aa4a3f37c6e7ddbf1f9782b54" name = "nextchat" version = "0.1.0" dependencies = [ + "bytes", "futures-util", "percent-encoding", "reqwest", @@ -2216,17 +2217,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "os_info" -version = "3.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae99c7fa6dd38c7cafe1ec085e804f8f555a2f8659b0dbe03f1f9963a9b51092" -dependencies = [ - "log", - "serde", - "windows-sys 0.52.0", -] - [[package]] name = "overload" version = "0.1.1" @@ -3251,19 +3241,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "sys-locale" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8a11bd9c338fdba09f7881ab41551932ad42e405f61d01e8406baea71c07aee" -dependencies = [ - "js-sys", - "libc", - "wasm-bindgen", - "web-sys", - "windows-sys 0.45.0", -] - [[package]] name = "system-configuration" version = "0.5.1" @@ -3412,7 +3389,6 @@ dependencies = [ "objc", "once_cell", "open", - "os_info", "percent-encoding", "rand 0.8.5", "raw-window-handle", @@ -3425,7 +3401,6 @@ dependencies = [ "serde_repr", "serialize-to-javascript", "state", - "sys-locale", "tar", "tauri-macros", "tauri-runtime", @@ -4345,15 +4320,6 @@ dependencies = [ "windows-targets 0.48.0", ] -[[package]] -name = "windows-sys" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" -dependencies = [ - "windows-targets 0.52.0", -] - [[package]] name = "windows-targets" version = "0.42.2" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 31ecfd83e4d..c954deb72a8 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -41,6 +41,7 @@ tauri-plugin-window-state = { git = "https://github.com/tauri-apps/plugins-works percent-encoding = "2.3.1" reqwest = "0.11.18" futures-util = "0.3.30" +bytes = "1.7.2" [features] # this feature is used for production builds or when `devPath` points to the filesystem and the built-in dev server is disabled. diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 792c656cf51..e382082572f 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -1,57 +1,14 @@ // Prevents additional console window on Windows in release, DO NOT REMOVE!! #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] -use futures_util::{StreamExt}; -use reqwest::Client; -use tauri::{ Manager}; -use tauri::http::{ResponseBuilder}; +mod stream; fn main() { tauri::Builder::default() .plugin(tauri_plugin_window_state::Builder::default().build()) - .register_uri_scheme_protocol("sse", |app_handle, request| { - let path = request.uri().strip_prefix("sse://localhost/").unwrap(); - let path = percent_encoding::percent_decode(path.as_bytes()) - .decode_utf8_lossy() - .to_string(); - // println!("path : {}", path); - let client = Client::new(); - let window = app_handle.get_window("main").unwrap(); - // send http request - let body = reqwest::Body::from(request.body().clone()); - let response_future = client.request(request.method().clone(), path) - .headers(request.headers().clone()) - .body(body).send(); - - // get response and emit to client - tauri::async_runtime::spawn(async move { - let res = response_future.await; - - match res { - Ok(res) => { - let mut stream = res.bytes_stream(); - - while let Some(chunk) = stream.next().await { - match chunk { - Ok(bytes) => { - window.emit("sse-response", bytes).unwrap(); - } - Err(err) => { - println!("Error: {:?}", err); - } - } - } - window.emit("sse-response", 0).unwrap(); - } - Err(err) => { - println!("Error: {:?}", err); - } - } - }); - ResponseBuilder::new() - .header("Access-Control-Allow-Origin", "*") - .status(200).body("OK".into()) - }) + .register_uri_scheme_protocol("stream", move |app_handle, request| { + stream::stream(app_handle, request) + }) .run(tauri::generate_context!()) .expect("error while running tauri application"); } diff --git a/src-tauri/src/stream.rs b/src-tauri/src/stream.rs new file mode 100644 index 00000000000..5e84e0f00d1 --- /dev/null +++ b/src-tauri/src/stream.rs @@ -0,0 +1,96 @@ + +use std::error::Error; +use futures_util::{StreamExt}; +use reqwest::Client; +use tauri::{ Manager, AppHandle }; +use tauri::http::{Request, ResponseBuilder}; +use tauri::http::Response; + +static mut REQUEST_COUNTER: u32 = 0; + +#[derive(Clone, serde::Serialize)] +pub struct ErrorPayload { + request_id: u32, + error: String, +} + +#[derive(Clone, serde::Serialize)] +pub struct StatusPayload { + request_id: u32, + status: u16, +} + +#[derive(Clone, serde::Serialize)] +pub struct HeaderPayload { + request_id: u32, + name: String, + value: String, +} + +#[derive(Clone, serde::Serialize)] +pub struct ChunkPayload { + request_id: u32, + chunk: bytes::Bytes, +} + +pub fn stream(app_handle: &AppHandle, request: &Request) -> Result> { + let mut request_id = 0; + let event_name = "stream-response"; + unsafe { + REQUEST_COUNTER += 1; + request_id = REQUEST_COUNTER; + } + let path = request.uri().to_string().replace("stream://localhost/", "").replace("http://stream.localhost/", ""); + let path = percent_encoding::percent_decode(path.as_bytes()) + .decode_utf8_lossy() + .to_string(); + // println!("path : {}", path); + let client = Client::new(); + let handle = app_handle.app_handle(); + // send http request + let body = reqwest::Body::from(request.body().clone()); + let response_future = client.request(request.method().clone(), path) + .headers(request.headers().clone()) + .body(body).send(); + + // get response and emit to client + tauri::async_runtime::spawn(async move { + let res = response_future.await; + + match res { + Ok(res) => { + handle.emit_all(event_name, StatusPayload{ request_id, status: res.status().as_u16() }).unwrap(); + for (name, value) in res.headers() { + handle.emit_all(event_name, HeaderPayload { + request_id, + name: name.to_string(), + value: std::str::from_utf8(value.as_bytes()).unwrap().to_string() + }).unwrap(); + } + let mut stream = res.bytes_stream(); + + while let Some(chunk) = stream.next().await { + match chunk { + Ok(bytes) => { + handle.emit_all(event_name, ChunkPayload{ request_id, chunk: bytes }).unwrap(); + } + Err(err) => { + println!("Error: {:?}", err); + } + } + } + handle.emit_all(event_name, StatusPayload { request_id, status: 0 }).unwrap(); + } + Err(err) => { + println!("Error: {:?}", err.source().expect("REASON").to_string()); + handle.emit_all(event_name, ErrorPayload { + request_id, + error: err.source().expect("REASON").to_string() + }).unwrap(); + } + } + }); + return ResponseBuilder::new() + .header("Access-Control-Allow-Origin", "*") + .status(200).body(request_id.to_string().into()) +}