Skip to content

Commit

Permalink
feat: move RerankerModelInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
honsunrise committed Jan 13, 2025
1 parent bf44169 commit 4497fae
Show file tree
Hide file tree
Showing 13 changed files with 205 additions and 196 deletions.
2 changes: 1 addition & 1 deletion src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ pub fn load_tokenizer(tokenizer_files: TokenizerFiles, max_length: usize) -> Res
let pad_id = config["pad_token_id"].as_u64().unwrap_or(0) as u32;
let pad_token = tokenizer_config["pad_token"]
.as_str()
.expect("Error reading pad_token from tokenier_config.json")
.expect("Error reading pad_token from tokenizer_config.json")
.into();

let mut tokenizer = tokenizer
Expand Down
6 changes: 4 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ pub use ort::execution_providers::ExecutionProviderDispatch;
pub use crate::common::{
read_file_to_bytes, Embedding, Error, SparseEmbedding, TokenizerFiles, DEFAULT_CACHE_DIR,
};
pub use crate::models::{model_info::ModelInfo, quantization::QuantizationMode};
pub use crate::models::{
model_info::ModelInfo, model_info::RerankerModelInfo, quantization::QuantizationMode,
};
pub use crate::output::{EmbeddingOutput, OutputKey, OutputPrecedence, SingleBatchOutput};
pub use crate::pooling::Pooling;

Expand All @@ -90,7 +92,7 @@ pub use crate::image_embedding::{
pub use crate::models::image_embedding::ImageEmbeddingModel;

// For Reranking
pub use crate::models::reranking::{RerankerModel, RerankerModelInfo};
pub use crate::models::reranking::RerankerModel;
pub use crate::reranking::{
OnnxSource, RerankInitOptions, RerankInitOptionsUserDefined, RerankResult, TextRerank,
UserDefinedRerankingModel,
Expand Down
12 changes: 12 additions & 0 deletions src/models/model_info.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::RerankerModel;

/// Data struct about the available models
#[derive(Debug, Clone)]
pub struct ModelInfo<T> {
Expand All @@ -8,3 +10,13 @@ pub struct ModelInfo<T> {
pub model_file: String,
pub additional_files: Vec<String>,
}

/// Data struct about the available reanker models
#[derive(Debug, Clone)]
pub struct RerankerModelInfo {
pub model: RerankerModel,
pub description: String,
pub model_code: String,
pub model_file: String,
pub additional_files: Vec<String>,
}
12 changes: 2 additions & 10 deletions src/models/reranking.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::fmt::Display;

use crate::RerankerModelInfo;

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RerankerModel {
/// BAAI/bge-reranker-base
Expand Down Expand Up @@ -46,16 +48,6 @@ pub fn reranker_model_list() -> Vec<RerankerModelInfo> {
reranker_model_list
}

/// Data struct about the available reanker models
#[derive(Debug, Clone)]
pub struct RerankerModelInfo {
pub model: RerankerModel,
pub description: String,
pub model_code: String,
pub model_file: String,
pub additional_files: Vec<String>,
}

impl Display for RerankerModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let model_info = reranker_model_list()
Expand Down
45 changes: 1 addition & 44 deletions src/models/sparse.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::fmt::Display;

use crate::{common::SparseEmbedding, ModelInfo};
use ndarray::{ArrayViewD, Axis, CowArray, Dim};
use crate::ModelInfo;

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SparseModel {
Expand All @@ -20,48 +19,6 @@ pub fn models_list() -> Vec<ModelInfo<SparseModel>> {
}]
}

impl SparseModel {
pub fn post_process(
&self,
model_output: &ArrayViewD<f32>,
attention_mask: &CowArray<i64, Dim<[usize; 2]>>,
) -> Vec<SparseEmbedding> {
match self {
SparseModel::SPLADEPPV1 => {
// Apply ReLU and logarithm transformation
let relu_log = model_output.mapv(|x| (1.0 + x.max(0.0)).ln());

// Convert to f32 and expand the dimensions
let attention_mask = attention_mask.mapv(|x| x as f32).insert_axis(Axis(2));

// Weight the transformed values by the attention mask
let weighted_log = relu_log * attention_mask;

// Get the max scores
let scores = weighted_log.fold_axis(Axis(1), f32::NEG_INFINITY, |r, &v| r.max(v));

scores
.rows()
.into_iter()
.map(|row_scores| {
let mut values: Vec<f32> = Vec::with_capacity(scores.len());
let mut indices: Vec<usize> = Vec::with_capacity(scores.len());

row_scores.into_iter().enumerate().for_each(|(idx, f)| {
if *f > 0.0 {
values.push(*f);
indices.push(idx);
}
});

SparseEmbedding { values, indices }
})
.collect()
}
}
}
}

impl Display for SparseModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let model_info = models_list()
Expand Down
78 changes: 1 addition & 77 deletions src/models/text_embedding.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
use crate::pooling::Pooling;
use std::{collections::HashMap, fmt::Display, sync::OnceLock};

use super::model_info::ModelInfo;

use super::quantization::QuantizationMode;

use std::{collections::HashMap, fmt::Display, sync::OnceLock};

/// Lazy static list of all available models.
static MODEL_MAP: OnceLock<HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>>> = OnceLock::new();

Expand Down Expand Up @@ -338,78 +334,6 @@ pub fn models_list() -> Vec<ModelInfo<EmbeddingModel>> {
models_map().values().cloned().collect()
}

impl EmbeddingModel {
pub fn get_default_pooling_method(&self) -> Option<Pooling> {
match self {
EmbeddingModel::AllMiniLML6V2 => Some(Pooling::Mean),
EmbeddingModel::AllMiniLML6V2Q => Some(Pooling::Mean),
EmbeddingModel::AllMiniLML12V2 => Some(Pooling::Mean),
EmbeddingModel::AllMiniLML12V2Q => Some(Pooling::Mean),

EmbeddingModel::BGEBaseENV15 => Some(Pooling::Cls),
EmbeddingModel::BGEBaseENV15Q => Some(Pooling::Cls),
EmbeddingModel::BGELargeENV15 => Some(Pooling::Cls),
EmbeddingModel::BGELargeENV15Q => Some(Pooling::Cls),
EmbeddingModel::BGESmallENV15 => Some(Pooling::Cls),
EmbeddingModel::BGESmallENV15Q => Some(Pooling::Cls),
EmbeddingModel::BGESmallZHV15 => Some(Pooling::Cls),

EmbeddingModel::NomicEmbedTextV1 => Some(Pooling::Mean),
EmbeddingModel::NomicEmbedTextV15 => Some(Pooling::Mean),
EmbeddingModel::NomicEmbedTextV15Q => Some(Pooling::Mean),

EmbeddingModel::ParaphraseMLMiniLML12V2 => Some(Pooling::Mean),
EmbeddingModel::ParaphraseMLMiniLML12V2Q => Some(Pooling::Mean),
EmbeddingModel::ParaphraseMLMpnetBaseV2 => Some(Pooling::Mean),

EmbeddingModel::MultilingualE5Base => Some(Pooling::Mean),
EmbeddingModel::MultilingualE5Small => Some(Pooling::Mean),
EmbeddingModel::MultilingualE5Large => Some(Pooling::Mean),

EmbeddingModel::MxbaiEmbedLargeV1 => Some(Pooling::Cls),
EmbeddingModel::MxbaiEmbedLargeV1Q => Some(Pooling::Cls),

EmbeddingModel::GTEBaseENV15 => Some(Pooling::Cls),
EmbeddingModel::GTEBaseENV15Q => Some(Pooling::Cls),
EmbeddingModel::GTELargeENV15 => Some(Pooling::Cls),
EmbeddingModel::GTELargeENV15Q => Some(Pooling::Cls),

EmbeddingModel::ClipVitB32 => Some(Pooling::Mean),

EmbeddingModel::JinaEmbeddingsV2BaseCode => Some(Pooling::Mean),
}
}

/// Get the quantization mode of the model.
///
/// Any models with a `Q` suffix in their name are quantized models.
///
/// Currently only 6 supported models have dynamic quantization:
/// - Alibaba-NLP/gte-base-en-v1.5
/// - Alibaba-NLP/gte-large-en-v1.5
/// - mixedbread-ai/mxbai-embed-large-v1
/// - nomic-ai/nomic-embed-text-v1.5
/// - Xenova/all-MiniLM-L12-v2
/// - Xenova/all-MiniLM-L6-v2
///
// TODO: Update this list when more models are added
pub fn get_quantization_mode(&self) -> QuantizationMode {
match self {
EmbeddingModel::AllMiniLML6V2Q => QuantizationMode::Dynamic,
EmbeddingModel::AllMiniLML12V2Q => QuantizationMode::Dynamic,
EmbeddingModel::BGEBaseENV15Q => QuantizationMode::Static,
EmbeddingModel::BGELargeENV15Q => QuantizationMode::Static,
EmbeddingModel::BGESmallENV15Q => QuantizationMode::Static,
EmbeddingModel::NomicEmbedTextV15Q => QuantizationMode::Dynamic,
EmbeddingModel::ParaphraseMLMiniLML12V2Q => QuantizationMode::Static,
EmbeddingModel::MxbaiEmbedLargeV1Q => QuantizationMode::Dynamic,
EmbeddingModel::GTEBaseENV15Q => QuantizationMode::Dynamic,
EmbeddingModel::GTELargeENV15Q => QuantizationMode::Dynamic,
_ => QuantizationMode::None,
}
}
}

impl Display for EmbeddingModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let model_info = get_model_info(self).expect("Model not found.");
Expand Down
2 changes: 1 addition & 1 deletion src/output/embedding_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl SingleBatchOutput<'_, '_> {

// If there is none pooling, default to cls so as not to break the existing implementations
// TODO: Consider return output as is to support custom model that has built-in pooling layer:
// - [] Add model with built-in pooling to the list of supported model in ``models::text_embdding::models_list``
// - [] Add model with built-in pooling to the list of supported model in ``models::text_embedding::models_list``
// - [] Write unit test for new model
// - [] Update ``pooling::Pooling`` to include None type
// - [] Change the line below to return output as is
Expand Down
10 changes: 5 additions & 5 deletions src/reranking/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl TextRerank {
Ok(Self::new(tokenizer, session))
}

/// Reranks documents using the reranker model and returns the results sorted by score in descending order.
/// Rerank documents using the reranker model and returns the results sorted by score in descending order.
pub fn rerank<S: AsRef<str> + Send + Sync>(
&self,
query: S,
Expand Down Expand Up @@ -151,16 +151,16 @@ impl TextRerank {

let mut ids_array = Vec::with_capacity(max_size);
let mut mask_array = Vec::with_capacity(max_size);
let mut typeids_array = Vec::with_capacity(max_size);
let mut type_ids_array = Vec::with_capacity(max_size);

encodings.iter().for_each(|encoding| {
let ids = encoding.get_ids();
let mask = encoding.get_attention_mask();
let typeids = encoding.get_type_ids();
let type_ids = encoding.get_type_ids();

ids_array.extend(ids.iter().map(|x| *x as i64));
mask_array.extend(mask.iter().map(|x| *x as i64));
typeids_array.extend(typeids.iter().map(|x| *x as i64));
type_ids_array.extend(type_ids.iter().map(|x| *x as i64));
});

let inputs_ids_array =
Expand All @@ -170,7 +170,7 @@ impl TextRerank {
Array::from_shape_vec((batch_size, encoding_length), mask_array)?;

let token_type_ids_array =
Array::from_shape_vec((batch_size, encoding_length), typeids_array)?;
Array::from_shape_vec((batch_size, encoding_length), type_ids_array)?;

let mut session_inputs = ort::inputs![
"input_ids" => Value::from_array(inputs_ids_array)?,
Expand Down
56 changes: 50 additions & 6 deletions src/sparse_text_embedding/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use hf_hub::{
api::sync::{ApiBuilder, ApiRepo},
Cache,
};
use ndarray::{Array, CowArray};
use ndarray::{Array, ArrayViewD, Axis, CowArray, Dim};
use ort::{session::Session, value::Value};
#[cfg_attr(not(feature = "online"), allow(unused_imports))]
use rayon::{iter::ParallelIterator, slice::ParallelSlice};
Expand Down Expand Up @@ -138,19 +138,19 @@ impl SparseTextEmbedding {
// Preallocate arrays with the maximum size
let mut ids_array = Vec::with_capacity(max_size);
let mut mask_array = Vec::with_capacity(max_size);
let mut typeids_array = Vec::with_capacity(max_size);
let mut type_ids_array = Vec::with_capacity(max_size);

// Not using par_iter because the closure needs to be FnMut
encodings.iter().for_each(|encoding| {
let ids = encoding.get_ids();
let mask = encoding.get_attention_mask();
let typeids = encoding.get_type_ids();
let type_ids = encoding.get_type_ids();

// Extend the preallocated arrays with the current encoding
// Requires the closure to be FnMut
ids_array.extend(ids.iter().map(|x| *x as i64));
mask_array.extend(mask.iter().map(|x| *x as i64));
typeids_array.extend(typeids.iter().map(|x| *x as i64));
type_ids_array.extend(type_ids.iter().map(|x| *x as i64));
});

// Create CowArrays from vectors
Expand All @@ -161,7 +161,7 @@ impl SparseTextEmbedding {
let attention_mask_array = CowArray::from(&owned_attention_mask);

let token_type_ids_array =
Array::from_shape_vec((batch_size, encoding_length), typeids_array)?;
Array::from_shape_vec((batch_size, encoding_length), type_ids_array)?;

let mut session_inputs = ort::inputs![
"input_ids" => Value::from_array(inputs_ids_array)?,
Expand All @@ -186,7 +186,11 @@ impl SparseTextEmbedding {

let output_data = outputs[last_hidden_state_key].try_extract_tensor::<f32>()?;

let embeddings = self.model.post_process(&output_data, &attention_mask_array);
let embeddings = SparseTextEmbedding::post_process(
&self.model,
&output_data,
&attention_mask_array,
);

Ok(embeddings)
})
Expand All @@ -197,4 +201,44 @@ impl SparseTextEmbedding {

Ok(output)
}

fn post_process(
model_name: &SparseModel,
model_output: &ArrayViewD<f32>,
attention_mask: &CowArray<i64, Dim<[usize; 2]>>,
) -> Vec<SparseEmbedding> {
match model_name {
SparseModel::SPLADEPPV1 => {
// Apply ReLU and logarithm transformation
let relu_log = model_output.mapv(|x| (1.0 + x.max(0.0)).ln());

// Convert to f32 and expand the dimensions
let attention_mask = attention_mask.mapv(|x| x as f32).insert_axis(Axis(2));

// Weight the transformed values by the attention mask
let weighted_log = relu_log * attention_mask;

// Get the max scores
let scores = weighted_log.fold_axis(Axis(1), f32::NEG_INFINITY, |r, &v| r.max(v));

scores
.rows()
.into_iter()
.map(|row_scores| {
let mut values: Vec<f32> = Vec::with_capacity(scores.len());
let mut indices: Vec<usize> = Vec::with_capacity(scores.len());

row_scores.into_iter().enumerate().for_each(|(idx, f)| {
if *f > 0.0 {
values.push(*f);
indices.push(idx);
}
});

SparseEmbedding { values, indices }
})
.collect()
}
}
}
}
Loading

0 comments on commit 4497fae

Please sign in to comment.