mistralrs_core/diffusion_models/
processor.rs

1use std::{any::Any, sync::Arc};
2
3use anyhow::{Context, Result};
4use candle_core::Device;
5use indexmap::IndexMap;
6use tokenizers::Tokenizer;
7
8use crate::{
9    device_map::DeviceMapper,
10    pipeline::{
11        text_models_inputs_processor::PagedAttentionMeta, InputProcessorOutput, InputsProcessor,
12        InputsProcessorType, MessagesAction, Processor,
13    },
14    sequence::Sequence,
15    MessageContent, Pipeline,
16};
17
18use super::DiffusionGenerationParams;
19
20pub struct DiffusionProcessor;
21
22impl Processor for DiffusionProcessor {
23    fn process(
24        &self,
25        _pipeline: &dyn Pipeline,
26        _messages: Vec<IndexMap<String, MessageContent>>,
27        _add_generation_prompt: bool,
28        _add_special_tokens: bool,
29        _enable_thinking: Option<bool>,
30        _tools: Vec<crate::Tool>,
31    ) -> Result<(Vec<u32>, String)> {
32        anyhow::bail!(
33            "DiffusionProcessor::process should not be used. It does not expect chat messages."
34        )
35    }
36    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
37        Arc::new(DiffusionInputsProcessor)
38    }
39    fn get_special_tokens(&self) -> &[&'static str] {
40        &[]
41    }
42    fn template_action(&self) -> MessagesAction {
43        // Just a default
44        MessagesAction::FlattenOnlyText
45    }
46}
47
48pub struct DiffusionInputsProcessor;
49
50#[derive(Clone)]
51pub struct ModelInputs {
52    pub(crate) prompts: Vec<String>,
53    pub(crate) params: DiffusionGenerationParams,
54}
55
56impl InputsProcessor for DiffusionInputsProcessor {
57    fn get_type(&self) -> InputsProcessorType {
58        InputsProcessorType::Text
59    }
60
61    fn process_inputs(
62        &self,
63        _tokenizer: Option<Arc<Tokenizer>>,
64        input_seqs: &mut [&mut Sequence],
65        _is_prompt: bool,
66        _is_xlora: bool,
67        _device: &Device,
68        _no_kv_cache: bool,
69        _last_n_context_len: Option<(usize, usize)>,
70        _return_raw_logits: bool,
71        _other_config: Option<Arc<dyn Any>>,
72        _paged_attn_metadata: Option<PagedAttentionMeta>,
73        _mapper: Option<&dyn DeviceMapper>,
74    ) -> Result<InputProcessorOutput> {
75        let inputs = ModelInputs {
76            prompts: input_seqs
77                .iter_mut()
78                .map(|seq| seq.get_initial_prompt().to_string())
79                .collect::<Vec<_>>(),
80            params: input_seqs[0]
81                .get_diffusion_diffusion_params()
82                .context("Diffusion model params must be present")?,
83        };
84        Ok(InputProcessorOutput {
85            inputs: Box::new(inputs),
86            seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
87        })
88    }
89}