mistralrs_core/pipeline/
response.rs1use std::io::Cursor;
2
3use base64::{engine::general_purpose::STANDARD, Engine};
4use candle_core::Tensor;
5use image::DynamicImage;
6use uuid::Uuid;
7
8use crate::{
9 sequence::{Sequence, SequenceState, StopReason},
10 ImageChoice, ImageGenerationResponse, ImageGenerationResponseFormat,
11};
12
13pub async fn send_image_responses(
14 input_seqs: &mut [&mut Sequence],
15 images: Vec<DynamicImage>,
16) -> candle_core::Result<()> {
17 if input_seqs.len() != images.len() {
18 candle_core::bail!(
19 "Input seqs len ({}) does not match images generated len ({})",
20 input_seqs.len(),
21 images.len()
22 );
23 }
24
25 for (seq, image) in input_seqs.iter_mut().zip(images) {
26 let choice = match seq
27 .image_gen_response_format()
28 .unwrap_or(ImageGenerationResponseFormat::Url)
29 {
30 ImageGenerationResponseFormat::Url => {
31 let saved_path = format!("image-generation-{}.png", Uuid::new_v4());
32 image
33 .save_with_format(&saved_path, image::ImageFormat::Png)
34 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
35 ImageChoice {
36 url: Some(saved_path),
37 b64_json: None,
38 }
39 }
40 ImageGenerationResponseFormat::B64Json => {
41 let mut buffer = Vec::new();
42 image
43 .write_to(&mut Cursor::new(&mut buffer), image::ImageFormat::Png)
44 .expect("Failed to encode image");
45 let encoded = STANDARD.encode(&buffer);
46 let serialized_b64 = format!("data:image/png;base64,{encoded}");
47 ImageChoice {
48 url: None,
49 b64_json: Some(serialized_b64),
50 }
51 }
52 };
53 seq.add_image_choice_to_group(choice);
54
55 let group = seq.get_mut_group();
56 group
57 .maybe_send_image_gen_response(
58 ImageGenerationResponse {
59 created: seq.creation_time() as u128,
60 data: group.get_image_choices().to_vec(),
61 },
62 seq.responder(),
63 )
64 .await
65 .map_err(candle_core::Error::msg)?;
66
67 seq.set_state(SequenceState::Done(StopReason::GeneratedImage));
68 }
69
70 Ok(())
71}
72
73pub async fn send_raw_responses(
74 input_seqs: &mut [&mut Sequence],
75 logits_chunks: Vec<Vec<Tensor>>,
76) -> candle_core::Result<()> {
77 let logits_chunks = if logits_chunks.len() == 1 {
78 logits_chunks[0].clone()
79 } else {
80 candle_core::bail!("Raw response only supports batch size of 1.");
81 };
82 assert_eq!(input_seqs.len(), 1);
83
84 let seq = &mut *input_seqs[0];
85
86 seq.add_raw_choice_to_group(logits_chunks);
87
88 let group = seq.get_mut_group();
89 group
90 .maybe_send_raw_done_response(seq.responder())
91 .await
92 .map_err(candle_core::Error::msg)?;
93
94 seq.set_state(SequenceState::Done(StopReason::Length(0)));
95
96 Ok(())
97}