mistralrs_core/diffusion_models/
processor.rs1use std::{any::Any, num::NonZeroUsize, sync::Arc};
2
3use anyhow::{Context, Result};
4use candle_core::Device;
5use indexmap::IndexMap;
6use tokenizers::Tokenizer;
7
8use crate::{
9 device_map::DeviceMapper,
10 pipeline::{
11 text_models_inputs_processor::PagedAttentionMeta, InputProcessorOutput, InputsProcessor,
12 InputsProcessorType, MessagesAction, Processor,
13 },
14 sequence::Sequence,
15 MessageContent, Pipeline,
16};
17
18use super::DiffusionGenerationParams;
19
20pub struct DiffusionProcessor;
21
22impl Processor for DiffusionProcessor {
23 fn process(
24 &self,
25 _pipeline: &dyn Pipeline,
26 _messages: Vec<IndexMap<String, MessageContent>>,
27 _add_generation_prompt: bool,
28 _add_special_tokens: bool,
29 _enable_thinking: Option<bool>,
30 _tools: Vec<crate::Tool>,
31 ) -> Result<(Vec<u32>, String)> {
32 anyhow::bail!(
33 "DiffusionProcessor::process should not be used. It does not expect chat messages."
34 )
35 }
36 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
37 Arc::new(DiffusionInputsProcessor)
38 }
39 fn get_special_tokens(&self) -> &[&'static str] {
40 &[]
41 }
42 fn template_action(&self) -> MessagesAction {
43 MessagesAction::FlattenOnlyText
45 }
46}
47
48pub struct DiffusionInputsProcessor;
49
50#[derive(Clone)]
51pub struct ModelInputs {
52 pub(crate) prompts: Vec<String>,
53 pub(crate) params: DiffusionGenerationParams,
54}
55
56impl InputsProcessor for DiffusionInputsProcessor {
57 fn get_type(&self) -> InputsProcessorType {
58 InputsProcessorType::Text
59 }
60
61 fn process_inputs(
62 &self,
63 _tokenizer: Option<Arc<Tokenizer>>,
64 input_seqs: &mut [&mut Sequence],
65 _is_prompt: bool,
66 _is_xlora: bool,
67 _device: &Device,
68 _no_kv_cache: bool,
69 _last_n_context_len: Option<(usize, usize)>,
70 _return_raw_logits: bool,
71 _other_config: Option<Arc<dyn Any>>,
72 _paged_attn_metadata: Option<PagedAttentionMeta>,
73 prompt_chunksize: Option<NonZeroUsize>,
74 _mapper: Option<&dyn DeviceMapper>,
75 ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
76 let mut make_value = if prompt_chunksize.is_some() {
77 return Box::new(std::iter::once(Err(anyhow::Error::msg(
78 "Prompt batching is unsupported for diffusion models",
79 ))));
80 } else {
81 || {
82 let inputs = ModelInputs {
83 prompts: input_seqs
84 .iter_mut()
85 .map(|seq| seq.get_initial_prompt().to_string())
86 .collect::<Vec<_>>(),
87 params: input_seqs[0]
88 .get_diffusion_diffusion_params()
89 .context("Diffusion model params must be present")?,
90 };
91 Ok(InputProcessorOutput {
92 inputs: Box::new(inputs),
93 seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
94 })
95 }
96 };
97 Box::new(std::iter::once(make_value()))
98 }
99}