mistralrs_core/diffusion_models/
processor.rs

1use std::{any::Any, num::NonZeroUsize, sync::Arc};
2
3use anyhow::{Context, Result};
4use candle_core::Device;
5use indexmap::IndexMap;
6use tokenizers::Tokenizer;
7
8use crate::{
9    device_map::DeviceMapper,
10    pipeline::{
11        text_models_inputs_processor::PagedAttentionMeta, InputProcessorOutput, InputsProcessor,
12        InputsProcessorType, MessagesAction, Processor,
13    },
14    sequence::Sequence,
15    MessageContent, Pipeline,
16};
17
18use super::DiffusionGenerationParams;
19
20pub struct DiffusionProcessor;
21
22impl Processor for DiffusionProcessor {
23    fn process(
24        &self,
25        _pipeline: &dyn Pipeline,
26        _messages: Vec<IndexMap<String, MessageContent>>,
27        _add_generation_prompt: bool,
28        _add_special_tokens: bool,
29        _enable_thinking: Option<bool>,
30        _tools: Vec<crate::Tool>,
31    ) -> Result<(Vec<u32>, String)> {
32        anyhow::bail!(
33            "DiffusionProcessor::process should not be used. It does not expect chat messages."
34        )
35    }
36    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
37        Arc::new(DiffusionInputsProcessor)
38    }
39    fn get_special_tokens(&self) -> &[&'static str] {
40        &[]
41    }
42    fn template_action(&self) -> MessagesAction {
43        // Just a default
44        MessagesAction::FlattenOnlyText
45    }
46}
47
48pub struct DiffusionInputsProcessor;
49
50#[derive(Clone)]
51pub struct ModelInputs {
52    pub(crate) prompts: Vec<String>,
53    pub(crate) params: DiffusionGenerationParams,
54}
55
56impl InputsProcessor for DiffusionInputsProcessor {
57    fn get_type(&self) -> InputsProcessorType {
58        InputsProcessorType::Text
59    }
60
61    fn process_inputs(
62        &self,
63        _tokenizer: Option<Arc<Tokenizer>>,
64        input_seqs: &mut [&mut Sequence],
65        _is_prompt: bool,
66        _is_xlora: bool,
67        _device: &Device,
68        _no_kv_cache: bool,
69        _last_n_context_len: Option<(usize, usize)>,
70        _return_raw_logits: bool,
71        _other_config: Option<Arc<dyn Any>>,
72        _paged_attn_metadata: Option<PagedAttentionMeta>,
73        prompt_chunksize: Option<NonZeroUsize>,
74        _mapper: Option<&dyn DeviceMapper>,
75    ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
76        let mut make_value = if prompt_chunksize.is_some() {
77            return Box::new(std::iter::once(Err(anyhow::Error::msg(
78                "Prompt batching is unsupported for diffusion models",
79            ))));
80        } else {
81            || {
82                let inputs = ModelInputs {
83                    prompts: input_seqs
84                        .iter_mut()
85                        .map(|seq| seq.get_initial_prompt().to_string())
86                        .collect::<Vec<_>>(),
87                    params: input_seqs[0]
88                        .get_diffusion_diffusion_params()
89                        .context("Diffusion model params must be present")?,
90                };
91                Ok(InputProcessorOutput {
92                    inputs: Box::new(inputs),
93                    seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
94                })
95            }
96        };
97        Box::new(std::iter::once(make_value()))
98    }
99}