1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
4use candle_nn::{layer_norm::RmsNormNonQuantized, LayerNorm, Linear, RmsNorm};
5use mistralrs_quant::ShardedVarBuilder;
6use serde::Deserialize;
7
8use crate::layers::{self, MatMul};
9
10const MLP_RATIO: f64 = 4.;
11const HIDDEN_SIZE: usize = 3072;
12const AXES_DIM: &[usize] = &[16, 56, 56];
13const THETA: usize = 10000;
14
15#[derive(Debug, Clone, Deserialize)]
16pub struct Config {
17 pub in_channels: usize,
18 pub pooled_projection_dim: usize,
19 pub joint_attention_dim: usize,
20 pub num_attention_heads: usize,
21 pub num_layers: usize,
22 pub num_single_layers: usize,
23 pub guidance_embeds: bool,
24}
25
26fn layer_norm(dim: usize, vb: ShardedVarBuilder) -> Result<LayerNorm> {
27 let ws = Tensor::ones(dim, vb.dtype(), vb.device())?;
28 Ok(LayerNorm::new_no_bias(ws, 1e-6))
29}
30
31fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
32 let dim = q.dim(D::Minus1)?;
33 let scale_factor = 1.0 / (dim as f64).sqrt();
34 let mut batch_dims = q.dims().to_vec();
35 batch_dims.pop();
36 batch_dims.pop();
37 let q = q.flatten_to(batch_dims.len() - 1)?;
38 let k = k.flatten_to(batch_dims.len() - 1)?;
39 let v = v.flatten_to(batch_dims.len() - 1)?;
40 let attn_weights = (MatMul.matmul(&q, &k.t()?)? * scale_factor)?;
41 let attn_scores = MatMul.matmul(&candle_nn::ops::softmax_last_dim(&attn_weights)?, &v)?;
42 batch_dims.push(attn_scores.dim(D::Minus2)?);
43 batch_dims.push(attn_scores.dim(D::Minus1)?);
44 attn_scores.reshape(batch_dims)
45}
46
47fn rope(pos: &Tensor, dim: usize, theta: usize) -> Result<Tensor> {
48 if dim % 2 == 1 {
49 candle_core::bail!("dim {dim} is odd")
50 }
51 let dev = pos.device();
52 let theta = theta as f64;
53 let inv_freq: Vec<_> = (0..dim)
54 .step_by(2)
55 .map(|i| 1f32 / theta.powf(i as f64 / dim as f64) as f32)
56 .collect();
57 let inv_freq_len = inv_freq.len();
58 let inv_freq = Tensor::from_vec(inv_freq, (1, 1, inv_freq_len), dev)?;
59 let inv_freq = inv_freq.to_dtype(pos.dtype())?;
60 let freqs = pos.unsqueeze(2)?.broadcast_mul(&inv_freq)?;
61 let cos = freqs.cos()?;
62 let sin = freqs.sin()?;
63 let out = Tensor::stack(&[&cos, &sin.neg()?, &sin, &cos], 3)?;
64 let (b, n, d, _ij) = out.dims4()?;
65 out.reshape((b, n, d, 2, 2))
66}
67
68fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
69 let dims = x.dims();
70 let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
71 let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
72 let x0 = x.narrow(D::Minus1, 0, 1)?;
73 let x1 = x.narrow(D::Minus1, 1, 1)?;
74 let fr0 = freq_cis.get_on_dim(D::Minus1, 0)?;
75 let fr1 = freq_cis.get_on_dim(D::Minus1, 1)?;
76 (fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
77}
78
79fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
80 let q = apply_rope(q, pe)?.contiguous()?;
81 let k = apply_rope(k, pe)?.contiguous()?;
82 let x = scaled_dot_product_attention(&q, &k, v)?;
83 x.transpose(1, 2)?.flatten_from(2)
84}
85
86fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
87 const TIME_FACTOR: f64 = 1000.;
88 const MAX_PERIOD: f64 = 10000.;
89 if dim % 2 == 1 {
90 candle_core::bail!("{dim} is odd")
91 }
92 let dev = t.device();
93 let half = dim / 2;
94 let t = (t * TIME_FACTOR)?;
95 let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle_core::DType::F32)?;
96 let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?;
97 let args = t
98 .unsqueeze(1)?
99 .to_dtype(candle_core::DType::F32)?
100 .broadcast_mul(&freqs.unsqueeze(0)?)?;
101 let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?;
102 Ok(emb)
103}
104
105#[derive(Debug, Clone)]
106pub struct EmbedNd {
107 #[allow(unused)]
108 dim: usize,
109 theta: usize,
110 axes_dim: Vec<usize>,
111}
112
113impl EmbedNd {
114 fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
115 Self {
116 dim,
117 theta,
118 axes_dim,
119 }
120 }
121}
122
123impl candle_core::Module for EmbedNd {
124 fn forward(&self, ids: &Tensor) -> Result<Tensor> {
125 let n_axes = ids.dim(D::Minus1)?;
126 let mut emb = Vec::with_capacity(n_axes);
127 for idx in 0..n_axes {
128 let r = rope(
129 &ids.get_on_dim(D::Minus1, idx)?,
130 self.axes_dim[idx],
131 self.theta,
132 )?;
133 emb.push(r)
134 }
135 let emb = Tensor::cat(&emb, 2)?;
136 emb.unsqueeze(1)
137 }
138}
139
140#[derive(Debug, Clone)]
141pub struct MlpEmbedder {
142 in_layer: Linear,
143 out_layer: Linear,
144}
145
146impl MlpEmbedder {
147 fn new(in_sz: usize, h_sz: usize, vb: ShardedVarBuilder) -> Result<Self> {
148 let in_layer = layers::linear(in_sz, h_sz, vb.pp("in_layer"))?;
149 let out_layer = layers::linear(h_sz, h_sz, vb.pp("out_layer"))?;
150 Ok(Self {
151 in_layer,
152 out_layer,
153 })
154 }
155}
156
157impl candle_core::Module for MlpEmbedder {
158 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
159 xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer)
160 }
161}
162
163#[derive(Debug, Clone)]
164pub struct QkNorm {
165 query_norm: RmsNorm<RmsNormNonQuantized>,
166 key_norm: RmsNorm<RmsNormNonQuantized>,
167}
168
169impl QkNorm {
170 fn new(dim: usize, vb: ShardedVarBuilder) -> Result<Self> {
171 let query_norm = vb.get(dim, "query_norm.scale")?;
172 let query_norm = RmsNorm::<RmsNormNonQuantized>::new(query_norm, 1e-6);
173 let key_norm = vb.get(dim, "key_norm.scale")?;
174 let key_norm = RmsNorm::<RmsNormNonQuantized>::new(key_norm, 1e-6);
175 Ok(Self {
176 query_norm,
177 key_norm,
178 })
179 }
180}
181
182struct ModulationOut {
183 shift: Tensor,
184 scale: Tensor,
185 gate: Tensor,
186}
187
188impl ModulationOut {
189 fn scale_shift(&self, xs: &Tensor) -> Result<Tensor> {
190 xs.broadcast_mul(&(&self.scale + 1.)?)?
191 .broadcast_add(&self.shift)
192 }
193
194 fn gate(&self, xs: &Tensor) -> Result<Tensor> {
195 self.gate.broadcast_mul(xs)
196 }
197}
198
199#[derive(Debug, Clone)]
200struct Modulation1 {
201 lin: Linear,
202}
203
204impl Modulation1 {
205 fn new(dim: usize, vb: ShardedVarBuilder) -> Result<Self> {
206 let lin = layers::linear(dim, 3 * dim, vb.pp("lin"))?;
207 Ok(Self { lin })
208 }
209
210 fn forward(&self, vec_: &Tensor) -> Result<ModulationOut> {
211 let ys = vec_
212 .silu()?
213 .apply(&self.lin)?
214 .unsqueeze(1)?
215 .chunk(3, D::Minus1)?;
216 if ys.len() != 3 {
217 candle_core::bail!("unexpected len from chunk {ys:?}")
218 }
219 Ok(ModulationOut {
220 shift: ys[0].clone(),
221 scale: ys[1].clone(),
222 gate: ys[2].clone(),
223 })
224 }
225}
226
227#[derive(Debug, Clone)]
228struct Modulation2 {
229 lin: Linear,
230}
231
232impl Modulation2 {
233 fn new(dim: usize, vb: ShardedVarBuilder) -> Result<Self> {
234 let lin = layers::linear(dim, 6 * dim, vb.pp("lin"))?;
235 Ok(Self { lin })
236 }
237
238 fn forward(&self, vec_: &Tensor) -> Result<(ModulationOut, ModulationOut)> {
239 let ys = vec_
240 .silu()?
241 .apply(&self.lin)?
242 .unsqueeze(1)?
243 .chunk(6, D::Minus1)?;
244 if ys.len() != 6 {
245 candle_core::bail!("unexpected len from chunk {ys:?}")
246 }
247 let mod1 = ModulationOut {
248 shift: ys[0].clone(),
249 scale: ys[1].clone(),
250 gate: ys[2].clone(),
251 };
252 let mod2 = ModulationOut {
253 shift: ys[3].clone(),
254 scale: ys[4].clone(),
255 gate: ys[5].clone(),
256 };
257 Ok((mod1, mod2))
258 }
259}
260
261#[derive(Debug, Clone)]
262pub struct SelfAttention {
263 qkv: Linear,
264 norm: QkNorm,
265 proj: Linear,
266 num_attention_heads: usize,
267}
268
269impl SelfAttention {
270 fn new(
271 dim: usize,
272 num_attention_heads: usize,
273 qkv_bias: bool,
274 vb: ShardedVarBuilder,
275 ) -> Result<Self> {
276 let head_dim = dim / num_attention_heads;
277 let qkv = layers::linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?;
278 let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
279 let proj = layers::linear(dim, dim, vb.pp("proj"))?;
280 Ok(Self {
281 qkv,
282 norm,
283 proj,
284 num_attention_heads,
285 })
286 }
287
288 fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
289 let qkv = xs.apply(&self.qkv)?;
290 let (b, l, _khd) = qkv.dims3()?;
291 let qkv = qkv.reshape((b, l, 3, self.num_attention_heads, ()))?;
292 let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
293 let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
294 let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
295 let q = q.apply(&self.norm.query_norm)?;
296 let k = k.apply(&self.norm.key_norm)?;
297 Ok((q, k, v))
298 }
299
300 #[allow(unused)]
301 fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result<Tensor> {
302 let (q, k, v) = self.qkv(xs)?;
303 attention(&q, &k, &v, pe)?.apply(&self.proj)
304 }
305
306 fn cast_to(&mut self, device: &Device) -> Result<()> {
307 self.qkv = Linear::new(
308 self.qkv.weight().to_device(device)?,
309 self.qkv.bias().map(|x| x.to_device(device).unwrap()),
310 );
311 self.proj = Linear::new(
312 self.proj.weight().to_device(device)?,
313 self.proj.bias().map(|x| x.to_device(device).unwrap()),
314 );
315 self.norm = QkNorm {
316 query_norm: RmsNorm::<RmsNormNonQuantized>::new(
317 self.norm.query_norm.inner().weight().to_device(device)?,
318 1e-6,
319 ),
320 key_norm: RmsNorm::<RmsNormNonQuantized>::new(
321 self.norm.key_norm.inner().weight().to_device(device)?,
322 1e-6,
323 ),
324 };
325 Ok(())
326 }
327}
328
329#[derive(Debug, Clone)]
330struct Mlp {
331 lin1: Linear,
332 lin2: Linear,
333}
334
335impl Mlp {
336 fn new(in_sz: usize, mlp_sz: usize, vb: ShardedVarBuilder) -> Result<Self> {
337 let lin1 = layers::linear(in_sz, mlp_sz, vb.pp("0"))?;
338 let lin2 = layers::linear(mlp_sz, in_sz, vb.pp("2"))?;
339 Ok(Self { lin1, lin2 })
340 }
341
342 fn cast_to(&mut self, device: &Device) -> Result<()> {
343 self.lin1 = Linear::new(
344 self.lin1.weight().to_device(device)?,
345 self.lin1.bias().map(|x| x.to_device(device).unwrap()),
346 );
347 self.lin2 = Linear::new(
348 self.lin2.weight().to_device(device)?,
349 self.lin2.bias().map(|x| x.to_device(device).unwrap()),
350 );
351 Ok(())
352 }
353}
354
355impl candle_core::Module for Mlp {
356 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
357 xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2)
358 }
359}
360
361#[derive(Debug, Clone)]
362pub struct DoubleStreamBlock {
363 img_mod: Modulation2,
364 img_norm1: LayerNorm,
365 img_attn: SelfAttention,
366 img_norm2: LayerNorm,
367 img_mlp: Mlp,
368 txt_mod: Modulation2,
369 txt_norm1: LayerNorm,
370 txt_attn: SelfAttention,
371 txt_norm2: LayerNorm,
372 txt_mlp: Mlp,
373}
374
375impl DoubleStreamBlock {
376 fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
377 let h_sz = HIDDEN_SIZE;
378 let mlp_sz = (h_sz as f64 * MLP_RATIO) as usize;
379 let img_mod = Modulation2::new(h_sz, vb.pp("img_mod"))?;
380 let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?;
381 let img_attn = SelfAttention::new(h_sz, cfg.num_attention_heads, true, vb.pp("img_attn"))?;
382 let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?;
383 let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?;
384 let txt_mod = Modulation2::new(h_sz, vb.pp("txt_mod"))?;
385 let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?;
386 let txt_attn = SelfAttention::new(h_sz, cfg.num_attention_heads, true, vb.pp("txt_attn"))?;
387 let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?;
388 let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?;
389 Ok(Self {
390 img_mod,
391 img_norm1,
392 img_attn,
393 img_norm2,
394 img_mlp,
395 txt_mod,
396 txt_norm1,
397 txt_attn,
398 txt_norm2,
399 txt_mlp,
400 })
401 }
402
403 fn forward(
404 &self,
405 img: &Tensor,
406 txt: &Tensor,
407 vec_: &Tensor,
408 pe: &Tensor,
409 ) -> Result<(Tensor, Tensor)> {
410 let (img_mod1, img_mod2) = self.img_mod.forward(vec_)?; let (txt_mod1, txt_mod2) = self.txt_mod.forward(vec_)?; let img_modulated = img.apply(&self.img_norm1)?;
413 let img_modulated = img_mod1.scale_shift(&img_modulated)?;
414 let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
415
416 let txt_modulated = txt.apply(&self.txt_norm1)?;
417 let txt_modulated = txt_mod1.scale_shift(&txt_modulated)?;
418 let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
419
420 let q = Tensor::cat(&[txt_q, img_q], 2)?;
421 let k = Tensor::cat(&[txt_k, img_k], 2)?;
422 let v = Tensor::cat(&[txt_v, img_v], 2)?;
423
424 let attn = attention(&q, &k, &v, pe)?;
425 let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?;
426 let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?;
427
428 let img = (img + img_mod1.gate(&img_attn.apply(&self.img_attn.proj)?))?;
429 let img = (&img
430 + img_mod2.gate(
431 &img_mod2
432 .scale_shift(&img.apply(&self.img_norm2)?)?
433 .apply(&self.img_mlp)?,
434 )?)?;
435
436 let txt = (txt + txt_mod1.gate(&txt_attn.apply(&self.txt_attn.proj)?))?;
437 let txt = (&txt
438 + txt_mod2.gate(
439 &txt_mod2
440 .scale_shift(&txt.apply(&self.txt_norm2)?)?
441 .apply(&self.txt_mlp)?,
442 )?)?;
443
444 Ok((img, txt))
445 }
446
447 fn cast_to(&mut self, device: &Device) -> Result<()> {
448 self.img_mod.lin = Linear::new(
449 self.img_mod.lin.weight().to_device(device)?,
450 self.img_mod
451 .lin
452 .bias()
453 .map(|x| x.to_device(device).unwrap()),
454 );
455 self.img_norm1 = LayerNorm::new_no_bias(self.img_norm1.weight().to_device(device)?, 1e-6);
456 self.img_attn.cast_to(device)?;
457 self.img_norm2 = LayerNorm::new_no_bias(self.img_norm2.weight().to_device(device)?, 1e-6);
458 self.img_mlp.cast_to(device)?;
459
460 self.txt_mod.lin = Linear::new(
461 self.txt_mod.lin.weight().to_device(device)?,
462 self.txt_mod
463 .lin
464 .bias()
465 .map(|x| x.to_device(device).unwrap()),
466 );
467 self.txt_norm1 = LayerNorm::new_no_bias(self.txt_norm1.weight().to_device(device)?, 1e-6);
468 self.txt_attn.cast_to(device)?;
469 self.txt_norm2 = LayerNorm::new_no_bias(self.txt_norm2.weight().to_device(device)?, 1e-6);
470 self.txt_mlp.cast_to(device)?;
471
472 Ok(())
473 }
474}
475
476#[derive(Debug, Clone)]
477pub struct SingleStreamBlock {
478 linear1: Linear,
479 linear2: Linear,
480 norm: QkNorm,
481 pre_norm: LayerNorm,
482 modulation: Modulation1,
483 h_sz: usize,
484 mlp_sz: usize,
485 num_attention_heads: usize,
486}
487
488impl SingleStreamBlock {
489 fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
490 let h_sz = HIDDEN_SIZE;
491 let mlp_sz = (h_sz as f64 * MLP_RATIO) as usize;
492 let head_dim = h_sz / cfg.num_attention_heads;
493 let linear1 = layers::linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?;
494 let linear2 = layers::linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?;
495 let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
496 let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?;
497 let modulation = Modulation1::new(h_sz, vb.pp("modulation"))?;
498 Ok(Self {
499 linear1,
500 linear2,
501 norm,
502 pre_norm,
503 modulation,
504 h_sz,
505 mlp_sz,
506 num_attention_heads: cfg.num_attention_heads,
507 })
508 }
509
510 fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result<Tensor> {
511 let mod_ = self.modulation.forward(vec_)?;
512 let x_mod = mod_.scale_shift(&xs.apply(&self.pre_norm)?)?;
513 let x_mod = x_mod.apply(&self.linear1)?;
514 let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?;
515 let (b, l, _khd) = qkv.dims3()?;
516 let qkv = qkv.reshape((b, l, 3, self.num_attention_heads, ()))?;
517 let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
518 let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
519 let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
520 let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?;
521 let q = q.apply(&self.norm.query_norm)?;
522 let k = k.apply(&self.norm.key_norm)?;
523 let attn = attention(&q, &k, &v, pe)?;
524 let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?;
525 xs + mod_.gate(&output)
526 }
527
528 fn cast_to(&mut self, device: &Device) -> Result<()> {
529 self.linear1 = Linear::new(
530 self.linear1.weight().to_device(device)?,
531 self.linear1.bias().map(|x| x.to_device(device).unwrap()),
532 );
533 self.linear2 = Linear::new(
534 self.linear2.weight().to_device(device)?,
535 self.linear2.bias().map(|x| x.to_device(device).unwrap()),
536 );
537 self.norm = QkNorm {
538 query_norm: RmsNorm::<RmsNormNonQuantized>::new(
539 self.norm.query_norm.inner().weight().to_device(device)?,
540 1e-6,
541 ),
542 key_norm: RmsNorm::<RmsNormNonQuantized>::new(
543 self.norm.key_norm.inner().weight().to_device(device)?,
544 1e-6,
545 ),
546 };
547 self.pre_norm = LayerNorm::new_no_bias(self.pre_norm.weight().to_device(device)?, 1e-6);
548 self.modulation.lin = Linear::new(
549 self.modulation.lin.weight().to_device(device)?,
550 self.modulation
551 .lin
552 .bias()
553 .map(|x| x.to_device(device).unwrap()),
554 );
555 Ok(())
556 }
557}
558
559#[derive(Debug, Clone)]
560pub struct LastLayer {
561 norm_final: LayerNorm,
562 linear: Linear,
563 ada_ln_modulation: Linear,
564}
565
566impl LastLayer {
567 fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: ShardedVarBuilder) -> Result<Self> {
568 let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?;
569 let linear = layers::linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?;
570 let ada_ln_modulation = layers::linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?;
571 Ok(Self {
572 norm_final,
573 linear,
574 ada_ln_modulation,
575 })
576 }
577
578 fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result<Tensor> {
579 let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?;
580 let (shift, scale) = (&chunks[0], &chunks[1]);
581 let xs = xs
582 .apply(&self.norm_final)?
583 .broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)?
584 .broadcast_add(&shift.unsqueeze(1)?)?;
585 xs.apply(&self.linear)
586 }
587}
588
589#[derive(Debug, Clone)]
590pub struct Flux {
591 img_in: Linear,
592 txt_in: Linear,
593 time_in: MlpEmbedder,
594 vector_in: MlpEmbedder,
595 guidance_in: Option<MlpEmbedder>,
596 pe_embedder: EmbedNd,
597 double_blocks: Vec<DoubleStreamBlock>,
598 single_blocks: Vec<SingleStreamBlock>,
599 final_layer: LastLayer,
600 device: Device,
601 offloaded: bool,
602}
603
604impl Flux {
605 pub fn new(
606 cfg: &Config,
607 vb: ShardedVarBuilder,
608 device: Device,
609 offloaded: bool,
610 ) -> Result<Self> {
611 let img_in = layers::linear(
612 cfg.in_channels,
613 HIDDEN_SIZE,
614 vb.pp("img_in").set_device(device.clone()),
615 )?;
616 let txt_in = layers::linear(
617 cfg.joint_attention_dim,
618 HIDDEN_SIZE,
619 vb.pp("txt_in").set_device(device.clone()),
620 )?;
621 let mut double_blocks = Vec::with_capacity(cfg.num_layers);
622 let vb_d = vb.pp("double_blocks");
623 for idx in 0..cfg.num_layers {
624 let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?;
625 double_blocks.push(db)
626 }
627 let mut single_blocks = Vec::with_capacity(cfg.num_single_layers);
628 let vb_s = vb.pp("single_blocks");
629 for idx in 0..cfg.num_single_layers {
630 let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?;
631 single_blocks.push(sb)
632 }
633 let time_in = MlpEmbedder::new(
634 256,
635 HIDDEN_SIZE,
636 vb.pp("time_in").set_device(device.clone()),
637 )?;
638 let vector_in = MlpEmbedder::new(
639 cfg.pooled_projection_dim,
640 HIDDEN_SIZE,
641 vb.pp("vector_in").set_device(device.clone()),
642 )?;
643 let guidance_in = if cfg.guidance_embeds {
644 let mlp = MlpEmbedder::new(
645 256,
646 HIDDEN_SIZE,
647 vb.pp("guidance_in").set_device(device.clone()),
648 )?;
649 Some(mlp)
650 } else {
651 None
652 };
653 let final_layer = LastLayer::new(
654 HIDDEN_SIZE,
655 1,
656 cfg.in_channels,
657 vb.pp("final_layer").set_device(device.clone()),
658 )?;
659 let pe_dim = HIDDEN_SIZE / cfg.num_attention_heads;
660 let pe_embedder = EmbedNd::new(pe_dim, THETA, AXES_DIM.to_vec());
661 Ok(Self {
662 img_in,
663 txt_in,
664 time_in,
665 vector_in,
666 guidance_in,
667 pe_embedder,
668 double_blocks,
669 single_blocks,
670 final_layer,
671 device: device.clone(),
672 offloaded,
673 })
674 }
675
676 #[allow(clippy::too_many_arguments)]
677 pub fn forward(
678 &mut self,
679 img: &Tensor,
680 img_ids: &Tensor,
681 txt: &Tensor,
682 txt_ids: &Tensor,
683 timesteps: &Tensor,
684 y: &Tensor,
685 guidance: Option<&Tensor>,
686 ) -> Result<Tensor> {
687 if txt.rank() != 3 {
688 candle_core::bail!("unexpected shape for txt {:?}", txt.shape())
689 }
690 if img.rank() != 3 {
691 candle_core::bail!("unexpected shape for img {:?}", img.shape())
692 }
693 let dtype = img.dtype();
694 let pe = {
695 let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
696 ids.apply(&self.pe_embedder)?
697 };
698 let mut txt = txt.apply(&self.txt_in)?;
699 let mut img = img.apply(&self.img_in)?;
700 let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
701 let vec_ = match (self.guidance_in.as_ref(), guidance) {
702 (Some(g_in), Some(guidance)) => {
703 (vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))?
704 }
705 _ => vec_,
706 };
707 let vec_ = (vec_ + y.apply(&self.vector_in))?;
708
709 for block in self.double_blocks.iter_mut() {
711 if self.offloaded {
712 block.cast_to(&self.device)?;
713 }
714 (img, txt) = block.forward(&img, &txt, &vec_, &pe)?;
715 if self.offloaded {
716 block.cast_to(&Device::Cpu)?;
717 }
718 }
719 let mut img = Tensor::cat(&[&txt, &img], 1)?;
721 for block in self.single_blocks.iter_mut() {
722 if self.offloaded {
723 block.cast_to(&self.device)?;
724 }
725 img = block.forward(&img, &vec_, &pe)?;
726 if self.offloaded {
727 block.cast_to(&Device::Cpu)?;
728 }
729 }
730 let img = img.i((.., txt.dim(1)?..))?;
731 self.final_layer.forward(&img, &vec_)
732 }
733}