-
Notifications
You must be signed in to change notification settings - Fork 957
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add PaliGemma. * PaliGemma inference loop. * Running PaliGemma example. * Tweak the prompt.
- Loading branch information
1 parent
0ebb388
commit 2f49e1b
Showing
5 changed files
with
434 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# PaliGemma | ||
|
||
[HuggingFace Model Card](https://huggingface.co/google/paligemma-3b-pt-224) - | ||
[Model Page](https://ai.google.dev/gemma/docs/paligemma) | ||
|
||
```bash | ||
cargo run --features cuda --release --example paligemma -- \ | ||
--prompt "caption fr" --image candle-examples/examples/yolo-v8/assets/bike.jpg | ||
``` | ||
|
||
``` | ||
loaded image with shape Tensor[dims 1, 3, 224, 224; bf16, cuda:0] | ||
loaded the model in 1.267744448s | ||
caption fr. Un groupe de cyclistes qui sont dans la rue. | ||
13 tokens generated (56.52 token/s) | ||
``` | ||
|
||
```bash | ||
cargo run --features cuda --release --example paligemma -- \ | ||
--prompt "caption fr" --image candle-examples/examples/flux/assets/flux-robot.jpg | ||
``` | ||
|
||
``` | ||
loaded image with shape Tensor[dims 1, 3, 224, 224; bf16, cuda:0] | ||
loaded the model in 1.271492621s | ||
caption fr une image d' un robot sur la plage avec le mot rouillé | ||
15 tokens generated (62.78 token/s) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,276 @@ | ||
#[cfg(feature = "mkl")] | ||
extern crate intel_mkl_src; | ||
|
||
#[cfg(feature = "accelerate")] | ||
extern crate accelerate_src; | ||
|
||
use anyhow::{Error as E, Result}; | ||
use clap::Parser; | ||
|
||
use candle_transformers::models::paligemma::{Config, Model}; | ||
|
||
use candle::{DType, Device, Tensor}; | ||
use candle_examples::token_output_stream::TokenOutputStream; | ||
use candle_nn::VarBuilder; | ||
use candle_transformers::generation::LogitsProcessor; | ||
use hf_hub::{api::sync::Api, Repo, RepoType}; | ||
use tokenizers::Tokenizer; | ||
|
||
struct TextGeneration { | ||
model: Model, | ||
image: Tensor, | ||
device: Device, | ||
tokenizer: TokenOutputStream, | ||
logits_processor: LogitsProcessor, | ||
repeat_penalty: f32, | ||
repeat_last_n: usize, | ||
} | ||
|
||
impl TextGeneration { | ||
#[allow(clippy::too_many_arguments)] | ||
fn new( | ||
model: Model, | ||
image: Tensor, | ||
tokenizer: Tokenizer, | ||
seed: u64, | ||
temp: Option<f64>, | ||
top_p: Option<f64>, | ||
repeat_penalty: f32, | ||
repeat_last_n: usize, | ||
device: &Device, | ||
) -> Self { | ||
let logits_processor = LogitsProcessor::new(seed, temp, top_p); | ||
Self { | ||
model, | ||
image, | ||
tokenizer: TokenOutputStream::new(tokenizer), | ||
logits_processor, | ||
repeat_penalty, | ||
repeat_last_n, | ||
device: device.clone(), | ||
} | ||
} | ||
|
||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { | ||
use std::io::Write; | ||
self.tokenizer.clear(); | ||
let mut tokens = self | ||
.tokenizer | ||
.tokenizer() | ||
.encode(prompt, true) | ||
.map_err(E::msg)? | ||
.get_ids() | ||
.to_vec(); | ||
for &t in tokens.iter() { | ||
if let Some(t) = self.tokenizer.next_token(t)? { | ||
print!("{t}") | ||
} | ||
} | ||
std::io::stdout().flush()?; | ||
|
||
let mut generated_tokens = 0usize; | ||
let eos_token = match self.tokenizer.get_token("<eos>") { | ||
Some(token) => token, | ||
None => anyhow::bail!("cannot find the <eos> token"), | ||
}; | ||
let start_gen = std::time::Instant::now(); | ||
for index in 0..sample_len { | ||
let context_size = if index > 0 { 1 } else { tokens.len() }; | ||
let start_pos = tokens.len().saturating_sub(context_size); | ||
let ctxt = &tokens[start_pos..]; | ||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; | ||
let logits = if index > 0 { | ||
self.model.forward(&input)? | ||
} else { | ||
self.model.setup(&self.image, &input)? | ||
}; | ||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; | ||
let logits = if self.repeat_penalty == 1. { | ||
logits | ||
} else { | ||
let start_at = tokens.len().saturating_sub(self.repeat_last_n); | ||
candle_transformers::utils::apply_repeat_penalty( | ||
&logits, | ||
self.repeat_penalty, | ||
&tokens[start_at..], | ||
)? | ||
}; | ||
|
||
let next_token = self.logits_processor.sample(&logits)?; | ||
tokens.push(next_token); | ||
generated_tokens += 1; | ||
if next_token == eos_token { | ||
break; | ||
} | ||
if let Some(t) = self.tokenizer.next_token(next_token)? { | ||
print!("{t}"); | ||
std::io::stdout().flush()?; | ||
} | ||
} | ||
let dt = start_gen.elapsed(); | ||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { | ||
print!("{rest}"); | ||
} | ||
std::io::stdout().flush()?; | ||
println!( | ||
"\n{generated_tokens} tokens generated ({:.2} token/s)", | ||
generated_tokens as f64 / dt.as_secs_f64(), | ||
); | ||
Ok(()) | ||
} | ||
} | ||
|
||
#[derive(Parser, Debug)] | ||
#[command(author, version, about, long_about = None)] | ||
struct Args { | ||
/// Run on CPU rather than on GPU. | ||
#[arg(long)] | ||
cpu: bool, | ||
|
||
/// Enable tracing (generates a trace-timestamp.json file). | ||
#[arg(long)] | ||
tracing: bool, | ||
|
||
#[arg(long)] | ||
prompt: String, | ||
|
||
/// The temperature used to generate samples. | ||
#[arg(long)] | ||
temperature: Option<f64>, | ||
|
||
/// Nucleus sampling probability cutoff. | ||
#[arg(long)] | ||
top_p: Option<f64>, | ||
|
||
/// The seed to use when generating random samples. | ||
#[arg(long, default_value_t = 299792458)] | ||
seed: u64, | ||
|
||
/// The length of the sample to generate (in tokens). | ||
#[arg(long, short = 'n', default_value_t = 10000)] | ||
sample_len: usize, | ||
|
||
#[arg(long)] | ||
model_id: Option<String>, | ||
|
||
#[arg(long, default_value = "main")] | ||
revision: String, | ||
|
||
#[arg(long)] | ||
tokenizer_file: Option<String>, | ||
|
||
#[arg(long)] | ||
weight_files: Option<String>, | ||
|
||
/// Penalty to be applied for repeating tokens, 1. means no penalty. | ||
#[arg(long, default_value_t = 1.1)] | ||
repeat_penalty: f32, | ||
|
||
/// The context size to consider for the repeat penalty. | ||
#[arg(long, default_value_t = 64)] | ||
repeat_last_n: usize, | ||
|
||
#[arg(long)] | ||
image: String, | ||
} | ||
|
||
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> { | ||
let img = image::ImageReader::open(path)?.decode()?; | ||
let (height, width) = (image_size, image_size); | ||
let img = img.resize_to_fill( | ||
width as u32, | ||
height as u32, | ||
image::imageops::FilterType::Triangle, | ||
); | ||
let img = img.to_rgb8(); | ||
let img = img.into_raw(); | ||
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? | ||
.permute((2, 0, 1))? | ||
.to_dtype(DType::F32)? | ||
.affine(2. / 255., -1.)?; | ||
Ok(img) | ||
} | ||
|
||
fn main() -> Result<()> { | ||
use tracing_chrome::ChromeLayerBuilder; | ||
use tracing_subscriber::prelude::*; | ||
|
||
let args = Args::parse(); | ||
let _guard = if args.tracing { | ||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); | ||
tracing_subscriber::registry().with(chrome_layer).init(); | ||
Some(guard) | ||
} else { | ||
None | ||
}; | ||
println!( | ||
"avx: {}, neon: {}, simd128: {}, f16c: {}", | ||
candle::utils::with_avx(), | ||
candle::utils::with_neon(), | ||
candle::utils::with_simd128(), | ||
candle::utils::with_f16c() | ||
); | ||
println!( | ||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", | ||
args.temperature.unwrap_or(0.), | ||
args.repeat_penalty, | ||
args.repeat_last_n | ||
); | ||
|
||
let start = std::time::Instant::now(); | ||
let api = Api::new()?; | ||
let model_id = match &args.model_id { | ||
Some(model_id) => model_id.to_string(), | ||
None => "google/paligemma-3b-mix-224".to_string(), | ||
}; | ||
let repo = api.repo(Repo::with_revision( | ||
model_id, | ||
RepoType::Model, | ||
args.revision, | ||
)); | ||
let tokenizer_filename = match args.tokenizer_file { | ||
Some(file) => std::path::PathBuf::from(file), | ||
None => repo.get("tokenizer.json")?, | ||
}; | ||
let filenames = match args.weight_files { | ||
Some(files) => files | ||
.split(',') | ||
.map(std::path::PathBuf::from) | ||
.collect::<Vec<_>>(), | ||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, | ||
}; | ||
println!("retrieved the files in {:?}", start.elapsed()); | ||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; | ||
|
||
let device = candle_examples::device(args.cpu)?; | ||
let dtype = if device.is_cuda() { | ||
DType::BF16 | ||
} else { | ||
DType::F32 | ||
}; | ||
let config = Config::paligemma_3b_224(); | ||
let image = load_image(&args.image, config.vision_config.image_size)? | ||
.to_device(&device)? | ||
.to_dtype(dtype)? | ||
.unsqueeze(0)?; | ||
println!("loaded image with shape {:?}", image); | ||
let start = std::time::Instant::now(); | ||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; | ||
let model = Model::new(&config, vb)?; | ||
println!("loaded the model in {:?}", start.elapsed()); | ||
|
||
let mut pipeline = TextGeneration::new( | ||
model, | ||
image, | ||
tokenizer, | ||
args.seed, | ||
args.temperature, | ||
args.top_p, | ||
args.repeat_penalty, | ||
args.repeat_last_n, | ||
&device, | ||
); | ||
let prompt = format!("{}\n", args.prompt); | ||
pipeline.run(&prompt, args.sample_len)?; | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.