mistralrs_core/diffusion_models/
processor.rs1use std::{any::Any, sync::Arc};
2
3use anyhow::{Context, Result};
4use candle_core::Device;
5use indexmap::IndexMap;
6use tokenizers::Tokenizer;
7
8use crate::{
9 device_map::DeviceMapper,
10 pipeline::{
11 text_models_inputs_processor::PagedAttentionMeta, InputProcessorOutput, InputsProcessor,
12 InputsProcessorType, MessagesAction, Processor,
13 },
14 sequence::Sequence,
15 MessageContent, Pipeline,
16};
17
18use super::DiffusionGenerationParams;
19
20pub struct DiffusionProcessor;
21
22impl Processor for DiffusionProcessor {
23 fn process(
24 &self,
25 _pipeline: &dyn Pipeline,
26 _messages: Vec<IndexMap<String, MessageContent>>,
27 _add_generation_prompt: bool,
28 _add_special_tokens: bool,
29 _enable_thinking: Option<bool>,
30 _tools: Vec<crate::Tool>,
31 ) -> Result<(Vec<u32>, String)> {
32 anyhow::bail!(
33 "DiffusionProcessor::process should not be used. It does not expect chat messages."
34 )
35 }
36 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
37 Arc::new(DiffusionInputsProcessor)
38 }
39 fn get_special_tokens(&self) -> &[&'static str] {
40 &[]
41 }
42 fn template_action(&self) -> MessagesAction {
43 MessagesAction::FlattenOnlyText
45 }
46}
47
48pub struct DiffusionInputsProcessor;
49
50#[derive(Clone)]
51pub struct ModelInputs {
52 pub(crate) prompts: Vec<String>,
53 pub(crate) params: DiffusionGenerationParams,
54}
55
56impl InputsProcessor for DiffusionInputsProcessor {
57 fn get_type(&self) -> InputsProcessorType {
58 InputsProcessorType::Text
59 }
60
61 fn process_inputs(
62 &self,
63 _tokenizer: Option<Arc<Tokenizer>>,
64 input_seqs: &mut [&mut Sequence],
65 _is_prompt: bool,
66 _is_xlora: bool,
67 _device: &Device,
68 _no_kv_cache: bool,
69 _last_n_context_len: Option<(usize, usize)>,
70 _return_raw_logits: bool,
71 _other_config: Option<Arc<dyn Any>>,
72 _paged_attn_metadata: Option<PagedAttentionMeta>,
73 _mapper: Option<&dyn DeviceMapper>,
74 ) -> Result<InputProcessorOutput> {
75 let inputs = ModelInputs {
76 prompts: input_seqs
77 .iter_mut()
78 .map(|seq| seq.get_initial_prompt().to_string())
79 .collect::<Vec<_>>(),
80 params: input_seqs[0]
81 .get_diffusion_diffusion_params()
82 .context("Diffusion model params must be present")?,
83 };
84 Ok(InputProcessorOutput {
85 inputs: Box::new(inputs),
86 seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
87 })
88 }
89}