mistralrs_core/diffusion_models/flux/
sampling.rs
1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Device, Result, Tensor};
4
5pub fn get_noise(
6 num_samples: usize,
7 height: usize,
8 width: usize,
9 device: &Device,
10) -> Result<Tensor> {
11 let height = height.div_ceil(16) * 2;
12 let width = width.div_ceil(16) * 2;
13 Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
14}
15
16#[derive(Debug, Clone)]
17pub struct State {
18 pub img: Tensor,
19 pub img_ids: Tensor,
20 pub txt: Tensor,
21 pub txt_ids: Tensor,
22 pub vec: Tensor,
23}
24
25impl State {
26 pub fn new(t5_emb: &Tensor, clip_emb: &Tensor, img: &Tensor) -> Result<Self> {
27 let dtype = img.dtype();
28 let (bs, c, h, w) = img.dims4()?;
29 let dev = img.device();
30 let img = img.reshape((bs, c, h / 2, 2, w / 2, 2))?; let img = img.permute((0, 2, 4, 1, 3, 5))?; let img = img.reshape((bs, h / 2 * w / 2, c * 4))?;
33 let img_ids = Tensor::stack(
34 &[
35 Tensor::full(0u32, (h / 2, w / 2), dev)?,
36 Tensor::arange(0u32, h as u32 / 2, dev)?
37 .reshape(((), 1))?
38 .broadcast_as((h / 2, w / 2))?,
39 Tensor::arange(0u32, w as u32 / 2, dev)?
40 .reshape((1, ()))?
41 .broadcast_as((h / 2, w / 2))?,
42 ],
43 2,
44 )?
45 .to_dtype(dtype)?;
46 let img_ids = img_ids.reshape((1, h / 2 * w / 2, 3))?;
47 let img_ids = img_ids.repeat((bs, 1, 1))?;
48 let txt = t5_emb.repeat(bs)?;
49 let txt_ids = Tensor::zeros((bs, txt.dim(1)?, 3), dtype, dev)?;
50 let vec = clip_emb.repeat(bs)?;
51 Ok(Self {
52 img,
53 img_ids,
54 txt,
55 txt_ids,
56 vec,
57 })
58 }
59}
60
61fn time_shift(mu: f64, sigma: f64, t: f64) -> f64 {
62 let e = mu.exp();
63 e / (e + (1. / t - 1.).powf(sigma))
64}
65
66pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec<f64> {
68 let timesteps: Vec<f64> = (0..=num_steps)
69 .map(|v| v as f64 / num_steps as f64)
70 .rev()
71 .collect();
72 match shift {
73 None => timesteps,
74 Some((image_seq_len, y1, y2)) => {
75 let (x1, x2) = (256., 4096.);
76 let m = (y2 - y1) / (x2 - x1);
77 let b = y1 - m * x1;
78 let mu = m * image_seq_len as f64 + b;
79 timesteps
80 .into_iter()
81 .map(|v| time_shift(mu, 1., v))
82 .collect()
83 }
84 }
85}
86
87pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
88 let (b, _h_w, c_ph_pw) = xs.dims3()?;
89 let height = height.div_ceil(16);
90 let width = width.div_ceil(16);
91 xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? .permute((0, 3, 1, 4, 2, 5))? .reshape((b, c_ph_pw / 4, height * 2, width * 2))
94}
95
96#[allow(clippy::too_many_arguments)]
97fn denoise_inner(
98 model: &mut super::model::Flux,
99 img: &Tensor,
100 img_ids: &Tensor,
101 txt: &Tensor,
102 txt_ids: &Tensor,
103 vec_: &Tensor,
104 timesteps: &[f64],
105 guidance: Option<f64>,
106) -> Result<Tensor> {
107 let b_sz = img.dim(0)?;
108 let dev = img.device();
109 let guidance = if let Some(guidance) = guidance {
110 Some(Tensor::full(guidance as f32, b_sz, dev)?)
111 } else {
112 None
113 };
114 let mut img = img.clone();
115 for window in timesteps.windows(2) {
116 let (t_curr, t_prev) = match window {
117 [a, b] => (a, b),
118 _ => continue,
119 };
120 let t_vec = Tensor::full(*t_curr as f32, b_sz, dev)?;
121 let pred = model.forward(&img, img_ids, txt, txt_ids, &t_vec, vec_, guidance.as_ref())?;
122 img = (img + pred * (t_prev - t_curr))?
123 }
124 Ok(img)
125}
126
127#[allow(clippy::too_many_arguments)]
128pub fn denoise(
129 model: &mut super::model::Flux,
130 img: &Tensor,
131 img_ids: &Tensor,
132 txt: &Tensor,
133 txt_ids: &Tensor,
134 vec_: &Tensor,
135 timesteps: &[f64],
136 guidance: f64,
137) -> Result<Tensor> {
138 denoise_inner(
139 model,
140 img,
141 img_ids,
142 txt,
143 txt_ids,
144 vec_,
145 timesteps,
146 Some(guidance),
147 )
148}
149
150#[allow(clippy::too_many_arguments)]
151pub fn denoise_no_guidance(
152 model: &mut super::model::Flux,
153 img: &Tensor,
154 img_ids: &Tensor,
155 txt: &Tensor,
156 txt_ids: &Tensor,
157 vec_: &Tensor,
158 timesteps: &[f64],
159) -> Result<Tensor> {
160 denoise_inner(model, img, img_ids, txt, txt_ids, vec_, timesteps, None)
161}