mistralrs_core/pipeline/
response.rs

1use std::io::Cursor;
2
3use base64::{engine::general_purpose::STANDARD, Engine};
4use candle_core::Tensor;
5use image::DynamicImage;
6use uuid::Uuid;
7
8use crate::{
9    sequence::{Sequence, SequenceState, StopReason},
10    ImageChoice, ImageGenerationResponse, ImageGenerationResponseFormat,
11};
12
13pub async fn send_image_responses(
14    input_seqs: &mut [&mut Sequence],
15    images: Vec<DynamicImage>,
16) -> candle_core::Result<()> {
17    if input_seqs.len() != images.len() {
18        candle_core::bail!(
19            "Input seqs len ({}) does not match images generated len ({})",
20            input_seqs.len(),
21            images.len()
22        );
23    }
24
25    for (seq, image) in input_seqs.iter_mut().zip(images) {
26        let choice = match seq
27            .image_gen_response_format()
28            .unwrap_or(ImageGenerationResponseFormat::Url)
29        {
30            ImageGenerationResponseFormat::Url => {
31                let saved_path = format!("image-generation-{}.png", Uuid::new_v4());
32                image
33                    .save_with_format(&saved_path, image::ImageFormat::Png)
34                    .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
35                ImageChoice {
36                    url: Some(saved_path),
37                    b64_json: None,
38                }
39            }
40            ImageGenerationResponseFormat::B64Json => {
41                let mut buffer = Vec::new();
42                image
43                    .write_to(&mut Cursor::new(&mut buffer), image::ImageFormat::Png)
44                    .expect("Failed to encode image");
45                let encoded = STANDARD.encode(&buffer);
46                let serialized_b64 = format!("data:image/png;base64,{encoded}");
47                ImageChoice {
48                    url: None,
49                    b64_json: Some(serialized_b64),
50                }
51            }
52        };
53        seq.add_image_choice_to_group(choice);
54
55        let group = seq.get_mut_group();
56        group
57            .maybe_send_image_gen_response(
58                ImageGenerationResponse {
59                    created: seq.creation_time() as u128,
60                    data: group.get_image_choices().to_vec(),
61                },
62                seq.responder(),
63            )
64            .await
65            .map_err(candle_core::Error::msg)?;
66
67        seq.set_state(SequenceState::Done(StopReason::GeneratedImage));
68    }
69
70    Ok(())
71}
72
73pub async fn send_raw_responses(
74    input_seqs: &mut [&mut Sequence],
75    logits_chunks: Vec<Vec<Tensor>>,
76) -> candle_core::Result<()> {
77    let logits_chunks = if logits_chunks.len() == 1 {
78        logits_chunks[0].clone()
79    } else {
80        candle_core::bail!("Raw response only supports batch size of 1.");
81    };
82    assert_eq!(input_seqs.len(), 1);
83
84    let seq = &mut *input_seqs[0];
85
86    seq.add_raw_choice_to_group(logits_chunks);
87
88    let group = seq.get_mut_group();
89    group
90        .maybe_send_raw_done_response(seq.responder())
91        .await
92        .map_err(candle_core::Error::msg)?;
93
94    seq.set_state(SequenceState::Done(StopReason::Length(0)));
95
96    Ok(())
97}