mistralrs_core/vision_models/llava/llava_llm/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::{DType, Device, Result, Tensor};
3
4use crate::pipeline::{
5    text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
6    IsqModel, NormalModel,
7};
8
9pub(crate) trait LLaVALLM: IsqModel + NormalModel + Sync + Send {
10    //Normal model without anymoe, but add embed and forward_input_embed. This is only a temporary solution. Finally when the rope problem solved for normal LLM models, we should refactor this.
11    fn embed(&self, input_ids: &Tensor) -> Result<Tensor>;
12    #[allow(clippy::too_many_arguments)]
13    fn forward_input_embed(
14        &self,
15        input_ids: &Tensor,  // only for masking
16        input_embed: Tensor, // we don't want to clone, so we pass it in
17        seqlen_offsets: &[usize],
18        context_lens: Vec<(usize, usize)>,
19        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
20        flash_params: &FlashParams,
21    ) -> Result<Tensor>;
22}
23
24#[derive(Debug)]
25pub(crate) struct OrdinaryRoPE;
26
27impl OrdinaryRoPE {
28    fn create_parameters(
29        n_elem: usize,
30        max_seq_len: usize,
31        rope_theta: f32,
32        dtype: DType,
33        device: &Device,
34    ) -> Result<(Tensor, Tensor)> {
35        let theta: Vec<_> = (0..n_elem)
36            .step_by(2)
37            .map(|i| 1f32 / rope_theta.powf(i as f32 / n_elem as f32))
38            .collect();
39        let theta = Tensor::new(theta.as_slice(), device)?;
40        let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
41            .to_dtype(DType::F32)?
42            .reshape((max_seq_len, 1))?
43            .matmul(&theta.reshape((1, theta.elem_count()))?)?;
44        let cos = idx_theta.cos()?.to_dtype(dtype)?;
45        let sin = idx_theta.sin()?.to_dtype(dtype)?;
46        Result::Ok((cos, sin))
47    }
48    fn forward(x: &Tensor, index_pos: usize, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
49        let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?;
50        let cos = cos.narrow(0, index_pos, seq_len)?;
51        let sin = sin.narrow(0, index_pos, seq_len)?;
52        candle_nn::rotary_emb::rope(x, &cos, &sin)
53    }
54}
55pub(crate) mod llama;
56pub(crate) mod mistral;
57
58pub use llama::Llama;
59pub use mistral::Model as Mistral;