diff --git a/async-openai/src/config.rs b/async-openai/src/config.rs index 91b3699a..230d2b10 100644 --- a/async-openai/src/config.rs +++ b/async-openai/src/config.rs @@ -1,5 +1,7 @@ //! Client configurations: [OpenAIConfig] for OpenAI, [AzureConfig] for Azure OpenAI Service. -use reqwest::header::{HeaderMap, AUTHORIZATION}; +use std::collections::HashMap; + +use reqwest::header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION}; use secrecy::{ExposeSecret, Secret}; use serde::Deserialize; @@ -31,10 +33,20 @@ pub trait Config: Clone { pub struct OpenAIConfig { api_base: String, api_key: Secret<String>, + #[serde(deserialize_with = "deserialize_header_map")] + headers: HashMap<String, String>, org_id: String, project_id: String, } +fn deserialize_header_map<'de, D>(deserializer: D) -> Result<HashMap<String, String>, D::Error> +where + D: serde::Deserializer<'de>, +{ + let header_map: HashMap<String, String> = HashMap::deserialize(deserializer)?; + Ok(header_map) +} + impl Default for OpenAIConfig { fn default() -> Self { Self { @@ -42,6 +54,7 @@ impl Default for OpenAIConfig { api_key: std::env::var("OPENAI_API_KEY") .unwrap_or_else(|_| "".to_string()) .into(), + headers: HashMap::new(), org_id: Default::default(), project_id: Default::default(), } @@ -78,6 +91,12 @@ impl OpenAIConfig { self } + /// Add custom headers to the existing headers + pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self { + self.headers.extend(headers); + self + } + pub fn org_id(&self) -> &str { &self.org_id } @@ -112,6 +131,13 @@ impl Config for OpenAIConfig { // Calls to the Assistants API require that you pass a Beta header headers.insert(OPENAI_BETA_HEADER, "assistants=v2".parse().unwrap()); + headers.extend(self.headers.iter().map(|(k, v)| { + ( + HeaderName::from_bytes(k.as_bytes()).unwrap(), + HeaderValue::from_str(v).unwrap(), + ) + })); + headers }