mistralrs_core/vision_models/llava/llava_llm/
mod.rs1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::{DType, Device, Result, Tensor};
3
4use crate::pipeline::{
5 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
6 IsqModel, NormalModel,
7};
8
9pub(crate) trait LLaVALLM: IsqModel + NormalModel + Sync + Send {
10 fn embed(&self, input_ids: &Tensor) -> Result<Tensor>;
12 #[allow(clippy::too_many_arguments)]
13 fn forward_input_embed(
14 &self,
15 input_ids: &Tensor, input_embed: Tensor, seqlen_offsets: &[usize],
18 context_lens: Vec<(usize, usize)>,
19 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
20 flash_params: &FlashParams,
21 ) -> Result<Tensor>;
22}
23
24#[derive(Debug)]
25pub(crate) struct OrdinaryRoPE;
26
27impl OrdinaryRoPE {
28 fn create_parameters(
29 n_elem: usize,
30 max_seq_len: usize,
31 rope_theta: f32,
32 dtype: DType,
33 device: &Device,
34 ) -> Result<(Tensor, Tensor)> {
35 let theta: Vec<_> = (0..n_elem)
36 .step_by(2)
37 .map(|i| 1f32 / rope_theta.powf(i as f32 / n_elem as f32))
38 .collect();
39 let theta = Tensor::new(theta.as_slice(), device)?;
40 let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
41 .to_dtype(DType::F32)?
42 .reshape((max_seq_len, 1))?
43 .matmul(&theta.reshape((1, theta.elem_count()))?)?;
44 let cos = idx_theta.cos()?.to_dtype(dtype)?;
45 let sin = idx_theta.sin()?.to_dtype(dtype)?;
46 Result::Ok((cos, sin))
47 }
48 fn forward(x: &Tensor, index_pos: usize, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
49 let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?;
50 let cos = cos.narrow(0, index_pos, seq_len)?;
51 let sin = sin.narrow(0, index_pos, seq_len)?;
52 candle_nn::rotary_emb::rope(x, &cos, &sin)
53 }
54}
55pub(crate) mod llama;
56pub(crate) mod mistral;
57
58pub use llama::Llama;
59pub use mistral::Model as Mistral;