mistralrs_core/diffusion_models/
processor.rs

1use std::{any::Any, num::NonZeroUsize, sync::Arc};
2
3use anyhow::{Context, Result};
4use candle_core::Device;
5use indexmap::IndexMap;
6use tokenizers::Tokenizer;
7
8use crate::{
9    device_map::DeviceMapper,
10    pipeline::{
11        text_models_inputs_processor::PagedAttentionMeta, InputProcessorOutput, InputsProcessor,
12        InputsProcessorType, MessagesAction, Processor,
13    },
14    sequence::Sequence,
15    MessageContent, Pipeline,
16};
17
18use super::DiffusionGenerationParams;
19
20pub struct DiffusionProcessor;
21
22impl Processor for DiffusionProcessor {
23    fn process(
24        &self,
25        _pipeline: &dyn Pipeline,
26        _messages: Vec<IndexMap<String, MessageContent>>,
27        _add_generation_prompt: bool,
28        _add_special_tokens: bool,
29        _tools: Vec<crate::Tool>,
30    ) -> Result<(Vec<u32>, String)> {
31        anyhow::bail!(
32            "DiffusionProcessor::process should not be used. It does not expect chat messages."
33        )
34    }
35    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
36        Arc::new(DiffusionInputsProcessor)
37    }
38    fn get_special_tokens(&self) -> &[&'static str] {
39        &[]
40    }
41    fn template_action(&self) -> MessagesAction {
42        // Just a default
43        MessagesAction::FlattenOnlyText
44    }
45}
46
47pub struct DiffusionInputsProcessor;
48
49#[derive(Clone)]
50pub struct ModelInputs {
51    pub(crate) prompts: Vec<String>,
52    pub(crate) params: DiffusionGenerationParams,
53}
54
55impl InputsProcessor for DiffusionInputsProcessor {
56    fn get_type(&self) -> InputsProcessorType {
57        InputsProcessorType::Text
58    }
59
60    fn process_inputs(
61        &self,
62        _tokenizer: Option<Arc<Tokenizer>>,
63        input_seqs: &mut [&mut Sequence],
64        _is_prompt: bool,
65        _is_xlora: bool,
66        _device: &Device,
67        _no_kv_cache: bool,
68        _last_n_context_len: Option<(usize, usize)>,
69        _return_raw_logits: bool,
70        _other_config: Option<Arc<dyn Any>>,
71        _paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
72        prompt_chunksize: Option<NonZeroUsize>,
73        _mapper: Option<&dyn DeviceMapper>,
74    ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
75        let mut make_value = if prompt_chunksize.is_some() {
76            return Box::new(std::iter::once(Err(anyhow::Error::msg(
77                "Prompt batching is unsupported for diffusion models",
78            ))));
79        } else {
80            || {
81                let inputs = ModelInputs {
82                    prompts: input_seqs
83                        .iter_mut()
84                        .map(|seq| seq.get_initial_prompt().to_string())
85                        .collect::<Vec<_>>(),
86                    params: input_seqs[0]
87                        .get_diffusion_diffusion_params()
88                        .context("Diffusion model params must be present")?,
89                };
90                Ok(InputProcessorOutput {
91                    inputs: Box::new(inputs),
92                    seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
93                })
94            }
95        };
96        Box::new(std::iter::once(make_value()))
97    }
98}