mistralrs_core/diffusion_models/
processor.rs1use std::{any::Any, num::NonZeroUsize, sync::Arc};
2
3use anyhow::{Context, Result};
4use candle_core::Device;
5use indexmap::IndexMap;
6use tokenizers::Tokenizer;
7
8use crate::{
9 device_map::DeviceMapper,
10 pipeline::{
11 text_models_inputs_processor::PagedAttentionMeta, InputProcessorOutput, InputsProcessor,
12 InputsProcessorType, MessagesAction, Processor,
13 },
14 sequence::Sequence,
15 MessageContent, Pipeline,
16};
17
18use super::DiffusionGenerationParams;
19
20pub struct DiffusionProcessor;
21
22impl Processor for DiffusionProcessor {
23 fn process(
24 &self,
25 _pipeline: &dyn Pipeline,
26 _messages: Vec<IndexMap<String, MessageContent>>,
27 _add_generation_prompt: bool,
28 _add_special_tokens: bool,
29 _tools: Vec<crate::Tool>,
30 ) -> Result<(Vec<u32>, String)> {
31 anyhow::bail!(
32 "DiffusionProcessor::process should not be used. It does not expect chat messages."
33 )
34 }
35 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
36 Arc::new(DiffusionInputsProcessor)
37 }
38 fn get_special_tokens(&self) -> &[&'static str] {
39 &[]
40 }
41 fn template_action(&self) -> MessagesAction {
42 MessagesAction::FlattenOnlyText
44 }
45}
46
47pub struct DiffusionInputsProcessor;
48
49#[derive(Clone)]
50pub struct ModelInputs {
51 pub(crate) prompts: Vec<String>,
52 pub(crate) params: DiffusionGenerationParams,
53}
54
55impl InputsProcessor for DiffusionInputsProcessor {
56 fn get_type(&self) -> InputsProcessorType {
57 InputsProcessorType::Text
58 }
59
60 fn process_inputs(
61 &self,
62 _tokenizer: Option<Arc<Tokenizer>>,
63 input_seqs: &mut [&mut Sequence],
64 _is_prompt: bool,
65 _is_xlora: bool,
66 _device: &Device,
67 _no_kv_cache: bool,
68 _last_n_context_len: Option<(usize, usize)>,
69 _return_raw_logits: bool,
70 _other_config: Option<Arc<dyn Any>>,
71 _paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
72 prompt_chunksize: Option<NonZeroUsize>,
73 _mapper: Option<&dyn DeviceMapper>,
74 ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
75 let mut make_value = if prompt_chunksize.is_some() {
76 return Box::new(std::iter::once(Err(anyhow::Error::msg(
77 "Prompt batching is unsupported for diffusion models",
78 ))));
79 } else {
80 || {
81 let inputs = ModelInputs {
82 prompts: input_seqs
83 .iter_mut()
84 .map(|seq| seq.get_initial_prompt().to_string())
85 .collect::<Vec<_>>(),
86 params: input_seqs[0]
87 .get_diffusion_diffusion_params()
88 .context("Diffusion model params must be present")?,
89 };
90 Ok(InputProcessorOutput {
91 inputs: Box::new(inputs),
92 seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
93 })
94 }
95 };
96 Box::new(std::iter::once(make_value()))
97 }
98}