#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use std::any::Any;
use std::num::NonZeroUsize;
use std::sync::Arc;
use candle_core::Result;
use candle_core::{DType, Device, Tensor};
use image::GenericImageView;
use image::Rgb;
use itertools::Itertools;
use regex_automata::meta::Regex;
use tokenizers::Tokenizer;
use tracing::warn;
use super::llava15::LLaVAVisionSpecificArgs;
use super::utils::{expand2square, LLaVAImageProcessor};
use crate::device_map::DeviceMapper;
use crate::pipeline::text_models_inputs_processor::{
get_completion_input, get_prompt_input, PagedAttentionMeta,
};
use crate::pipeline::{
text_models_inputs_processor, InputProcessorOutput, InputsProcessor, InputsProcessorType,
MessagesAction, Processor,
};
use crate::sequence::Sequence;
use crate::vision_models::image_processor::{self, ImagePreProcessor, PreprocessedImages};
use crate::vision_models::llava::config::Config as LLaVAConfig;
use crate::vision_models::preprocessor_config::{PreProcessorConfig, ToFilter};
use crate::vision_models::{preprocessor_config, ModelInputs};
pub struct LLaVAProcessor {
inputs_processor: Arc<LLaVAInputProcessor>,
}
impl Processor for LLaVAProcessor {
fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
self.inputs_processor.clone()
}
fn get_special_tokens(&self) -> &[&'static str] {
&[]
}
fn template_action(&self) -> MessagesAction {
MessagesAction::FlattenOnlyText
}
}
impl LLaVAProcessor {
pub fn new(config: &str) -> Self {
let model_config =
serde_json::from_str::<LLaVAConfig>(config).expect("Failed to parse model config.");
let image_tag_splitter = Regex::new(r"<image>").expect("Failed to compile split regex.");
let inputs_processor = Arc::new(LLaVAInputProcessor {
image_tag_splitter,
model_config: model_config.clone(),
});
Self { inputs_processor }
}
}
pub struct LLaVAInputProcessor {
image_tag_splitter: Regex,
model_config: LLaVAConfig,
}
impl LLaVAInputProcessor {
fn get_num_image_tokens(&self) -> usize {
let patch_size = self.model_config.vision_config.patch_size;
let patch_per_side = self.model_config.vision_config.image_size / patch_size;
patch_per_side * patch_per_side
}
}
impl InputsProcessor for LLaVAInputProcessor {
fn get_type(&self) -> InputsProcessorType {
InputsProcessorType::Vision
}
fn process_inputs(
&self,
tokenizer: Option<Arc<Tokenizer>>,
input_seqs: &mut [&mut Sequence],
is_prompt: bool,
is_xlora: bool,
device: &Device,
no_kv_cache: bool,
last_n_context_len: Option<(usize, usize)>,
return_raw_logits: bool,
other_config: Option<Arc<dyn Any>>,
mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
prompt_batchsize: Option<NonZeroUsize>,
_mapper: Option<&dyn DeviceMapper>,
) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
if is_xlora {
return Box::new(std::iter::once(Err(anyhow::Error::msg(
"Cannot make inputs for X-LoRA vision model.",
))));
}
if no_kv_cache {
return Box::new(std::iter::once(Err(anyhow::Error::msg(
"Vision model must have kv cache.",
))));
}
if prompt_batchsize.is_some() {
warn!("`prompt_batchsize` is set. Idefics 2 does not support prompt batching.");
}
let Some(tokenizer) = tokenizer else {
return Box::new(std::iter::once(Err(anyhow::Error::msg(
"LLaVAInputProcessor requires a specified tokenizer.",
))));
};
let config = other_config
.clone()
.expect("Need a PreProcessorConfig config.");
let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
let has_images = input_seqs
.iter()
.all(|seq| seq.images().is_some_and(|images| !images.is_empty()));
let (pixel_values, num_img_tokens) = if has_images {
let mut pixel_values_accum = Vec::new();
let mut num_img_tokens_accum = Vec::new();
for seq in input_seqs.iter_mut() {
let imgs = seq
.take_images()
.expect("Need to have images by this point.");
let PreprocessedImages {
pixel_values,
pixel_attention_mask: _,
image_sizes: _,
num_img_tokens,
aspect_ratio_ids: _,
aspect_ratio_mask: _,
num_tiles: _,
image_grid_thw: _,
video_grid_thw: _,
rows: _,
cols: _,
} = self
.preprocess(
imgs.clone(),
vec![],
config,
device,
(usize::MAX, usize::MAX),
)
.expect("Preprocessor failed");
pixel_values_accum.push(pixel_values);
num_img_tokens_accum.push(num_img_tokens.unwrap());
}
(
Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
Some(num_img_tokens_accum),
)
} else {
return Box::new(
text_models_inputs_processor::TextInputsProcessor
.process_inputs(
Some(tokenizer),
input_seqs,
is_prompt,
is_xlora,
device,
no_kv_cache,
last_n_context_len,
return_raw_logits,
other_config,
paged_attn_metadata,
None, None,
)
.map(|metadata| {
let InputProcessorOutput {
inputs,
seq_indices,
} = metadata?;
let text_models_inputs_processor::ModelInputs {
input_ids,
input_ids_full: _,
seqlen_offsets,
seqlen_offsets_full: _,
seqlen_offsets_kernel,
seqlen_offsets_kernel_full: _,
context_lens,
position_ids,
paged_attn_meta,
flash_meta,
flash_meta_full: _,
} = *inputs
.downcast::<text_models_inputs_processor::ModelInputs>()
.expect("Downcast failed.");
let inputs: Box<dyn Any> = Box::new(ModelInputs {
input_ids,
seqlen_offsets,
seqlen_offsets_kernel,
context_lens,
position_ids,
pixel_values: None,
model_specific_args: Box::new(LLaVAVisionSpecificArgs {}),
paged_attn_meta,
flash_meta,
});
Ok(InputProcessorOutput {
inputs,
seq_indices,
})
}),
);
};
let mut toks = Vec::new();
let detokenized = tokenizer
.decode_batch(
&input_seqs
.iter()
.map(|seq| seq.get_toks())
.collect::<Vec<_>>(),
false,
)
.expect("Decoding failed");
for (detokenized, (seq, num_img_tokens)) in detokenized.into_iter().zip(
input_seqs
.iter_mut()
.zip(num_img_tokens.unwrap().into_iter()),
) {
let splits = self
.image_tag_splitter
.split(&detokenized)
.map(|span| &detokenized[span.range()])
.collect::<Vec<_>>();
let prompt_chunks = splits
.iter()
.map(|s| {
tokenizer
.encode(*s, true)
.unwrap()
.get_ids()
.to_vec()
.iter()
.map(|x| *x as i64)
.collect()
})
.collect::<Vec<Vec<_>>>();
let mut image_ids_pad = Vec::new();
for (i, num_img_token) in num_img_tokens.iter().enumerate() {
let mut image_id_pad = vec![0; *num_img_token];
image_id_pad[0] = -(i as i64 + 1);
image_ids_pad.push(image_id_pad);
}
let mut input_ids: Vec<i64> = Vec::new();
for item in prompt_chunks
.iter()
.map(|x| x.to_vec())
.interleave(image_ids_pad)
{
input_ids.extend(item);
}
seq.set_toks(
input_ids
.iter()
.map(|x| if *x < 0 { 0u32 } else { *x as u32 })
.collect::<Vec<_>>(),
);
if let Some(ref mut metadata) = paged_attn_metadata {
metadata.block_engine.free_sequence(*seq.id());
metadata.block_engine.allocate(*seq);
}
toks.push(input_ids);
}
let iter = if is_prompt {
get_prompt_input(
toks,
input_seqs,
device,
last_n_context_len,
return_raw_logits,
paged_attn_metadata.as_mut(),
None, None,
)
} else {
get_completion_input(
toks,
input_seqs,
device,
no_kv_cache,
last_n_context_len,
return_raw_logits,
paged_attn_metadata.as_mut(),
None, None,
)
};
Box::new(iter.into_iter().map(move |metadata| {
let text_models_inputs_processor::InnerInputProcessorOutput {
inputs:
text_models_inputs_processor::InputMetadata {
input,
positions,
positions_kernel,
context_lens,
position_ids,
paged_attn_meta,
flash_meta,
},
seq_indices,
} = metadata?;
let inputs: Box<dyn Any> = Box::new(ModelInputs {
input_ids: input,
seqlen_offsets: positions,
seqlen_offsets_kernel: positions_kernel,
context_lens,
position_ids,
pixel_values: pixel_values.clone(),
model_specific_args: Box::new(LLaVAVisionSpecificArgs {}),
paged_attn_meta,
flash_meta,
});
Ok(InputProcessorOutput {
inputs,
seq_indices,
})
}))
}
}
impl ImagePreProcessor for LLaVAInputProcessor {
#[allow(clippy::excessive_precision)]
const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
#[allow(clippy::excessive_precision)]
const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
fn preprocess(
&self,
images: Vec<image::DynamicImage>,
videos: Vec<Vec<image::DynamicImage>>,
config: &preprocessor_config::PreProcessorConfig,
device: &candle_core::Device,
(_, _): (usize, usize),
) -> candle_core::Result<image_processor::PreprocessedImages> {
if images.len() > 1 {
candle_core::bail!("Can only process one image per batch"); };
assert!(videos.is_empty());
let resized_size = *config.size.as_ref().unwrap().get("shortest_edge").unwrap() as usize;
let original_size = images[0].dimensions();
let filter = config.resampling.to_filter()?;
let image_mean = config
.image_mean
.unwrap_or(Self::DEFAULT_MEAN)
.map(|x| x as f32);
let mean_color = image_mean
.iter()
.map(|x| ((*x) * 255.0) as u8)
.collect::<Vec<u8>>();
let mean_color = Rgb::from([mean_color[0], mean_color[1], mean_color[2]]);
let image = expand2square(&images[0], mean_color);
let image_std = config
.image_std
.unwrap_or(Self::DEFAULT_STD)
.map(|x| x as f32);
let pixel_values = [image]
.iter()
.map(|x| {
LLaVAImageProcessor::process_one_image(
x,
config,
resized_size as u32,
filter,
DType::BF16,
device,
&image_mean,
&image_std,
)
})
.collect::<Result<Vec<Tensor>>>()?;
let pixel_values = Tensor::stack(&pixel_values, 0)?;
Ok(image_processor::PreprocessedImages {
pixel_values,
pixel_attention_mask: None,
image_sizes: Some((original_size.0 as usize, original_size.1 as usize)),
num_img_tokens: Some(vec![self.get_num_image_tokens()]),
aspect_ratio_ids: None,
aspect_ratio_mask: None,
num_tiles: None,
image_grid_thw: None,
video_grid_thw: None,
rows: None,
cols: None,
})
}
}