mistralrs_core/diffusion_models/t5/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3// T5 Text Model
4// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
5
6use candle_core::{DType, Device, Module, Result, Tensor, D};
7use candle_nn::{Activation, Embedding, Linear};
8use mistralrs_quant::ShardedVarBuilder;
9use serde::Deserialize;
10use std::sync::Arc;
11
12use crate::layers::{clamp_for_f16, embedding, linear_no_bias, MatMul};
13
14fn default_relative_attention_max_distance() -> usize {
15    128
16}
17
18fn default_is_decoder() -> bool {
19    false
20}
21
22fn default_use_cache() -> bool {
23    true
24}
25
26fn default_tie_word_embeddings() -> bool {
27    true
28}
29
30fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
31    let mask: Vec<_> = (0..size)
32        .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
33        .collect();
34    Tensor::from_slice(&mask, (size, size), device)
35}
36
37fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
38    let shape = mask.shape();
39    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
40    let m = mask.where_cond(&on_true, on_false)?;
41    Ok(m)
42}
43
44#[derive(Debug, Deserialize, Default, Clone, PartialEq)]
45pub struct ActivationWithOptionalGating {
46    pub gated: bool,
47    pub activation: candle_nn::Activation,
48}
49
50pub fn deserialize_feed_forward_proj_activation<'de, D>(
51    deserializer: D,
52) -> std::result::Result<ActivationWithOptionalGating, D::Error>
53where
54    D: serde::de::Deserializer<'de>,
55{
56    match String::deserialize(deserializer)?.as_str() {
57        "gated-gelu" => Ok(ActivationWithOptionalGating {
58            gated: true,
59            activation: candle_nn::Activation::NewGelu,
60        }),
61        "gated-silu" => Ok(ActivationWithOptionalGating {
62            gated: true,
63            activation: candle_nn::Activation::Silu,
64        }),
65        buf => {
66            let activation = serde_plain::from_str(buf).map_err(serde::de::Error::custom)?;
67            Ok(ActivationWithOptionalGating {
68                gated: false,
69                activation,
70            })
71        }
72    }
73}
74
75#[derive(Debug, Clone, PartialEq, Deserialize)]
76pub struct Config {
77    pub vocab_size: usize,
78    pub d_model: usize,
79    pub d_kv: usize,
80    pub d_ff: usize,
81    pub num_layers: usize,
82    pub num_decoder_layers: Option<usize>,
83    pub num_heads: usize,
84    pub relative_attention_num_buckets: usize,
85    #[serde(default = "default_relative_attention_max_distance")]
86    pub relative_attention_max_distance: usize,
87    pub dropout_rate: f64,
88    pub layer_norm_epsilon: f64,
89    pub initializer_factor: f64,
90    #[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
91    pub feed_forward_proj: ActivationWithOptionalGating,
92    #[serde(default = "default_tie_word_embeddings")]
93    pub tie_word_embeddings: bool,
94    #[serde(default = "default_is_decoder")]
95    pub is_decoder: bool,
96    pub is_encoder_decoder: bool,
97    #[serde(default = "default_use_cache")]
98    pub use_cache: bool,
99    pub pad_token_id: usize,
100    pub eos_token_id: usize,
101    pub decoder_start_token_id: Option<usize>,
102}
103
104impl Default for Config {
105    fn default() -> Self {
106        Self {
107            vocab_size: 32128,
108            d_model: 512,
109            d_kv: 64,
110            d_ff: 2048,
111            num_layers: 6,
112            num_decoder_layers: None,
113            num_heads: 8,
114            relative_attention_num_buckets: 32,
115            relative_attention_max_distance: 128,
116            dropout_rate: 0.1,
117            layer_norm_epsilon: 1e-6,
118            initializer_factor: 1.0,
119            feed_forward_proj: ActivationWithOptionalGating {
120                gated: false,
121                activation: Activation::Relu,
122            },
123            tie_word_embeddings: true,
124            is_decoder: false,
125            is_encoder_decoder: true,
126            use_cache: true,
127            pad_token_id: 0,
128            eos_token_id: 1,
129            decoder_start_token_id: Some(0),
130        }
131    }
132}
133
134#[derive(Debug, Clone)]
135struct T5LayerNorm {
136    weight: Tensor,
137    variance_epsilon: f64,
138}
139
140impl T5LayerNorm {
141    fn load(h: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
142        let weight = vb.get(h, "weight")?;
143        Ok(Self {
144            weight,
145            variance_epsilon: eps,
146        })
147    }
148}
149
150impl Module for T5LayerNorm {
151    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
152        let dtype = xs.dtype();
153        let xs_f32 = xs.to_dtype(DType::F32)?;
154        // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
155        let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
156        let xs = xs_f32.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
157        let xs = xs.to_dtype(dtype)?;
158        let xs = xs.broadcast_mul(&self.weight)?;
159        Ok(xs)
160    }
161}
162
163#[derive(Debug, Clone)]
164struct T5DenseActDense {
165    wi: Linear,
166    wo: Linear,
167    act: Activation,
168}
169
170impl T5DenseActDense {
171    fn load(vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
172        let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
173        let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
174        Ok(Self {
175            wi,
176            wo,
177            act: Activation::Relu,
178        })
179    }
180}
181
182impl Module for T5DenseActDense {
183    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
184        let xs = self.wi.forward(xs)?;
185        let xs = self.act.forward(&xs)?;
186        let xs = self.wo.forward(&xs)?;
187        Ok(xs)
188    }
189}
190
191#[derive(Debug, Clone)]
192struct T5DenseGatedActDense {
193    wi_0: Linear,
194    wi_1: Linear,
195    wo: Linear,
196    act: Activation,
197}
198
199impl T5DenseGatedActDense {
200    fn load(vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
201        let wi_0 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
202        let wi_1 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
203        let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
204        Ok(Self {
205            wi_0,
206            wi_1,
207            wo,
208            act: cfg.feed_forward_proj.activation,
209        })
210    }
211}
212
213impl Module for T5DenseGatedActDense {
214    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
215        let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?;
216        let hidden_linear = self.wi_1.forward(xs)?;
217        let xs = hidden_gelu.broadcast_mul(&hidden_linear)?;
218        let xs = self.wo.forward(&xs)?;
219        Ok(xs)
220    }
221}
222
223#[derive(Debug, Clone)]
224struct T5LayerFF {
225    dense_act: Option<T5DenseActDense>,
226    gated_dense_act: Option<T5DenseGatedActDense>,
227    layer_norm: T5LayerNorm,
228}
229
230impl T5LayerFF {
231    fn load(vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
232        let layer_norm =
233            T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
234        let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
235            (
236                None,
237                Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
238            )
239        } else {
240            (
241                Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?),
242                None,
243            )
244        };
245        Ok(Self {
246            dense_act,
247            gated_dense_act,
248            layer_norm,
249        })
250    }
251
252    fn cast_to(&mut self, device: &Device) -> Result<()> {
253        self.layer_norm = T5LayerNorm {
254            weight: self.layer_norm.weight.to_device(device)?,
255            variance_epsilon: self.layer_norm.variance_epsilon,
256        };
257        if let Some(dense) = &mut self.dense_act {
258            dense.wi = Linear::new(
259                dense.wi.weight().to_device(device)?,
260                dense.wi.bias().map(|x| x.to_device(device).unwrap()),
261            );
262            dense.wo = Linear::new(
263                dense.wo.weight().to_device(device)?,
264                dense.wo.bias().map(|x| x.to_device(device).unwrap()),
265            );
266        }
267        if let Some(dense) = &mut self.gated_dense_act {
268            dense.wi_0 = Linear::new(
269                dense.wi_0.weight().to_device(device)?,
270                dense.wi_0.bias().map(|x| x.to_device(device).unwrap()),
271            );
272            dense.wi_1 = Linear::new(
273                dense.wi_1.weight().to_device(device)?,
274                dense.wi_1.bias().map(|x| x.to_device(device).unwrap()),
275            );
276            dense.wo = Linear::new(
277                dense.wo.weight().to_device(device)?,
278                dense.wo.bias().map(|x| x.to_device(device).unwrap()),
279            );
280        }
281        Ok(())
282    }
283}
284
285impl Module for T5LayerFF {
286    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
287        let ys = self.layer_norm.forward(xs)?;
288        let ys = match &self.dense_act {
289            Some(dense_act) => dense_act.forward(&ys)?,
290            None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?,
291        };
292        let xs = (xs + ys)?;
293        Ok(xs)
294    }
295}
296
297#[derive(Debug, Clone)]
298struct T5Attention {
299    q: Linear,
300    k: Linear,
301    v: Linear,
302    o: Linear,
303    n_heads: usize,
304    d_kv: usize,
305    relative_attention_bias: Option<Embedding>,
306    relative_attention_num_buckets: usize,
307    relative_attention_max_distance: usize,
308    inner_dim: usize,
309    use_cache: bool,
310}
311
312impl T5Attention {
313    fn load(
314        has_relative_attention_bias: bool,
315        decoder: bool,
316        vb: ShardedVarBuilder,
317        cfg: &Config,
318    ) -> Result<Self> {
319        let inner_dim = cfg.num_heads * cfg.d_kv;
320        let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
321        let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
322        let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
323        let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
324        let relative_attention_bias = if has_relative_attention_bias {
325            let emb = embedding(
326                cfg.relative_attention_num_buckets,
327                cfg.num_heads,
328                vb.pp("relative_attention_bias"),
329                &None,
330            )?;
331            Some(emb)
332        } else {
333            None
334        };
335        Ok(Self {
336            q,
337            k,
338            v,
339            o,
340            n_heads: cfg.num_heads,
341            d_kv: cfg.d_kv,
342            relative_attention_bias,
343            relative_attention_num_buckets: cfg.relative_attention_num_buckets,
344            relative_attention_max_distance: cfg.relative_attention_max_distance,
345            inner_dim,
346            use_cache: cfg.use_cache && decoder,
347        })
348    }
349
350    fn forward(
351        &self,
352        xs: &Tensor,
353        position_bias: Option<&Tensor>,
354        key_value_states: Option<&Tensor>,
355        mask: Option<&Tensor>,
356    ) -> Result<(Tensor, Option<Tensor>)> {
357        // Performs Self-attention (if key_value_states is None) or attention
358        // over source sentence (provided by key_value_states).
359        let kv_input = match key_value_states {
360            None => xs,
361            Some(key_value_states) => key_value_states,
362        };
363        let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?);
364        let kv_len = kv_input.dim(1)?;
365        let q = self.q.forward(xs)?;
366        let k = self.k.forward(kv_input)?;
367        let v = self.v.forward(kv_input)?;
368        let q = q
369            .reshape((b_sz, q_len, self.n_heads, self.d_kv))?
370            .transpose(1, 2)?
371            .contiguous()?;
372        let k = k
373            .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
374            .transpose(1, 2)?;
375        let v = v
376            .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
377            .transpose(1, 2)?;
378
379        let k = k.contiguous()?;
380        let v = v.contiguous()?;
381        // TODO: Use flash_attn.
382        let scores = { MatMul.matmul(&q, &k.t()?)? };
383        let scores = match mask {
384            None => scores,
385            Some(mask) => masked_fill(
386                &scores,
387                &mask
388                    .unsqueeze(0)?
389                    .unsqueeze(0)?
390                    .repeat((b_sz, self.n_heads))?,
391                f32::NEG_INFINITY,
392            )?,
393        };
394
395        let (scores, position_bias) = match position_bias {
396            Some(position_bias) => (
397                scores.broadcast_add(position_bias)?,
398                Some(position_bias.clone()),
399            ),
400            None => match &self.relative_attention_bias {
401                None => (scores, None),
402                Some(relative_attention_bias) => {
403                    // This only handles the bidirectional case.
404                    let kv_len = k.dim(2)?;
405                    let (q_start, q_end) = match self.use_cache {
406                        true => ((kv_len - q_len) as u32, kv_len as u32),
407                        false => (0_u32, kv_len as u32),
408                    };
409                    let num_buckets = self.relative_attention_num_buckets as u32 / 2;
410                    let max_exact = num_buckets / 2;
411                    let relative_position = (q_start..q_end)
412                        .map(|i| {
413                            (0..kv_len as u32)
414                                .map(|j| {
415                                    if i < j {
416                                        if j - i < max_exact {
417                                            j - i + num_buckets
418                                        } else {
419                                            let b = f32::log(
420                                                (j - i) as f32 / max_exact as f32,
421                                                self.relative_attention_max_distance as f32
422                                                    / max_exact as f32,
423                                            ) * (num_buckets - max_exact) as f32;
424                                            u32::min(
425                                                max_exact + num_buckets + b as u32,
426                                                self.relative_attention_num_buckets as u32 - 1,
427                                            )
428                                        }
429                                    } else if i - j < max_exact {
430                                        i - j
431                                    } else {
432                                        let b = f32::log(
433                                            (i - j) as f32 / max_exact as f32,
434                                            self.relative_attention_max_distance as f32
435                                                / max_exact as f32,
436                                        ) * (num_buckets - max_exact) as f32;
437                                        u32::min(max_exact + b as u32, num_buckets - 1)
438                                    }
439                                })
440                                .collect::<Vec<u32>>()
441                        })
442                        .collect::<Vec<Vec<_>>>();
443                    let relative_buckets = Tensor::new(relative_position, q.device())?;
444                    let position_bias = relative_attention_bias
445                        .forward(&relative_buckets)?
446                        .permute((2, 0, 1))?
447                        .unsqueeze(0)?;
448                    (scores.broadcast_add(&position_bias)?, Some(position_bias))
449                    // TODO: position_bias_masked?
450                }
451            },
452        };
453
454        let attn_weights = { candle_nn::ops::softmax_last_dim(&scores)? };
455        let attn_output = MatMul.matmul(&attn_weights, &v)?;
456        let attn_output = attn_output
457            .transpose(1, 2)?
458            .reshape((b_sz, q_len, self.inner_dim))?;
459        let attn_output = self.o.forward(&attn_output)?;
460        Ok((attn_output, position_bias))
461    }
462}
463
464#[derive(Debug, Clone)]
465struct T5LayerSelfAttention {
466    self_attention: T5Attention,
467    layer_norm: T5LayerNorm,
468}
469
470impl T5LayerSelfAttention {
471    fn load(h: bool, d: bool, vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
472        let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?;
473        let layer_norm =
474            T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
475        Ok(Self {
476            self_attention,
477            layer_norm,
478        })
479    }
480
481    fn forward(
482        &self,
483        xs: &Tensor,
484        position_bias: Option<&Tensor>,
485        mask: Option<&Tensor>,
486    ) -> Result<(Tensor, Option<Tensor>)> {
487        let normed_xs = self.layer_norm.forward(xs)?;
488        let (ys, position_bias) =
489            self.self_attention
490                .forward(&normed_xs, position_bias, None, mask)?;
491        let ys = (xs + ys)?;
492        Ok((ys, position_bias))
493    }
494
495    fn cast_to(&mut self, device: &Device) -> Result<()> {
496        self.self_attention.q = Linear::new(
497            self.self_attention.q.weight().to_device(device)?,
498            self.self_attention
499                .q
500                .bias()
501                .map(|x| x.to_device(device).unwrap()),
502        );
503        self.self_attention.k = Linear::new(
504            self.self_attention.k.weight().to_device(device)?,
505            self.self_attention
506                .k
507                .bias()
508                .map(|x| x.to_device(device).unwrap()),
509        );
510        self.self_attention.v = Linear::new(
511            self.self_attention.v.weight().to_device(device)?,
512            self.self_attention
513                .v
514                .bias()
515                .map(|x| x.to_device(device).unwrap()),
516        );
517        self.self_attention.o = Linear::new(
518            self.self_attention.o.weight().to_device(device)?,
519            self.self_attention
520                .o
521                .bias()
522                .map(|x| x.to_device(device).unwrap()),
523        );
524        if let Some(embed) = &mut self.self_attention.relative_attention_bias {
525            *embed = Embedding::new(embed.embeddings().to_device(device)?, embed.hidden_size());
526        }
527        self.layer_norm = T5LayerNorm {
528            weight: self.layer_norm.weight.to_device(device)?,
529            variance_epsilon: self.layer_norm.variance_epsilon,
530        };
531        Ok(())
532    }
533}
534
535#[derive(Debug, Clone)]
536struct T5LayerCrossAttention {
537    cross_attention: T5Attention,
538    layer_norm: T5LayerNorm,
539}
540
541impl T5LayerCrossAttention {
542    fn load(decoder: bool, vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
543        let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?;
544        let layer_norm =
545            T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
546        Ok(Self {
547            cross_attention,
548            layer_norm,
549        })
550    }
551
552    fn forward(
553        &self,
554        hidden_states: &Tensor,
555        position_bias: Option<&Tensor>,
556        key_value_states: &Tensor,
557    ) -> Result<(Tensor, Option<Tensor>)> {
558        let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
559        let (ys, position_bias) = self.cross_attention.forward(
560            &normed_hidden_states,
561            position_bias,
562            Some(key_value_states),
563            None,
564        )?;
565        let ys = (hidden_states + ys)?;
566        Ok((ys, position_bias))
567    }
568
569    fn cast_to(&mut self, device: &Device) -> Result<()> {
570        self.cross_attention.q = Linear::new(
571            self.cross_attention.q.weight().to_device(device)?,
572            self.cross_attention
573                .q
574                .bias()
575                .map(|x| x.to_device(device).unwrap()),
576        );
577        self.cross_attention.k = Linear::new(
578            self.cross_attention.k.weight().to_device(device)?,
579            self.cross_attention
580                .k
581                .bias()
582                .map(|x| x.to_device(device).unwrap()),
583        );
584        self.cross_attention.v = Linear::new(
585            self.cross_attention.v.weight().to_device(device)?,
586            self.cross_attention
587                .v
588                .bias()
589                .map(|x| x.to_device(device).unwrap()),
590        );
591        self.cross_attention.o = Linear::new(
592            self.cross_attention.o.weight().to_device(device)?,
593            self.cross_attention
594                .o
595                .bias()
596                .map(|x| x.to_device(device).unwrap()),
597        );
598        if let Some(embed) = &mut self.cross_attention.relative_attention_bias {
599            *embed = Embedding::new(embed.embeddings().to_device(device)?, embed.hidden_size());
600        }
601        self.layer_norm = T5LayerNorm {
602            weight: self.layer_norm.weight.to_device(device)?,
603            variance_epsilon: self.layer_norm.variance_epsilon,
604        };
605        Ok(())
606    }
607}
608
609#[derive(Debug, Clone)]
610struct T5Block {
611    self_attn: T5LayerSelfAttention,
612    cross_attn: Option<T5LayerCrossAttention>,
613    ff: T5LayerFF,
614}
615
616impl T5Block {
617    fn load(
618        has_relative_attention_bias: bool,
619        decoder: bool,
620        vb: ShardedVarBuilder,
621        cfg: &Config,
622    ) -> Result<Self> {
623        let vb = vb.pp("layer");
624        let self_attn =
625            T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?;
626        let cross_attn = if cfg.is_decoder {
627            Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?)
628        } else {
629            None
630        };
631        let ff_i = if cross_attn.is_some() { 2 } else { 1 };
632        let ff = T5LayerFF::load(vb.pp(ff_i.to_string()), cfg)?;
633        Ok(Self {
634            self_attn,
635            cross_attn,
636            ff,
637        })
638    }
639
640    fn forward(
641        &self,
642        xs: &Tensor,
643        position_bias: Option<&Tensor>,
644        encoder_hidden_states: Option<&Tensor>,
645    ) -> Result<(Tensor, Option<Tensor>)> {
646        // TODO: Cache masks
647        let mask = match self.cross_attn.is_some() {
648            true => {
649                let mask_len = xs.dim(1)?;
650                // If the input seq length is 1, no need for a mask, this is also helpful to avoid shape
651                // issues when using the KV cache in the decoder.
652                if mask_len <= 1 {
653                    None
654                } else {
655                    Some(get_mask(mask_len, xs.device())?)
656                }
657            }
658            false => None,
659        };
660        let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
661        // Clamp for f16
662        if xs.dtype() == DType::F16 {
663            xs = clamp_for_f16(&xs)?;
664        }
665        if let Some(cross_attn) = &self.cross_attn {
666            (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
667            // Clamp for f16
668            if xs.dtype() == DType::F16 {
669                xs = clamp_for_f16(&xs)?;
670            }
671        }
672        let mut xs = self.ff.forward(&xs)?;
673        // Clamp for f16
674        if xs.dtype() == DType::F16 {
675            xs = clamp_for_f16(&xs)?;
676        }
677        Ok((xs, position_bias))
678    }
679
680    fn cast_to(&mut self, device: &Device) -> Result<()> {
681        self.self_attn.cast_to(device)?;
682        if let Some(cross_attn) = &mut self.cross_attn {
683            cross_attn.cast_to(device)?;
684        }
685        self.ff.cast_to(device)?;
686        Ok(())
687    }
688}
689
690#[derive(Debug, Clone)]
691struct T5Stack {
692    block: Vec<T5Block>,
693    shared: Arc<Embedding>,
694    final_layer_norm: T5LayerNorm,
695    device: Device,
696    offloaded: bool,
697}
698
699impl T5Stack {
700    fn load(
701        decoder: bool,
702        vb: ShardedVarBuilder,
703        shared: &Arc<Embedding>,
704        cfg: &Config,
705        device: &Device,
706        offloaded: bool,
707    ) -> Result<Self> {
708        let block = (0..cfg.num_layers)
709            .map(|i| T5Block::load(i == 0, decoder, vb.pp(format!("block.{i}")), cfg))
710            .collect::<Result<Vec<_>>>()?;
711        let final_layer_norm = T5LayerNorm::load(
712            cfg.d_model,
713            cfg.layer_norm_epsilon,
714            vb.pp("final_layer_norm").set_device(device.clone()),
715        )?;
716        Ok(Self {
717            block,
718            shared: shared.clone(),
719            final_layer_norm,
720            device: device.clone(),
721            offloaded,
722        })
723    }
724
725    fn forward(
726        &mut self,
727        input_ids: &Tensor,
728        encoder_hidden_states: Option<&Tensor>,
729    ) -> Result<Tensor> {
730        let input_embeds = self.shared.as_ref().forward(input_ids)?;
731        let mut hidden_states = input_embeds;
732        let mut position_bias = None;
733        for block in self.block.iter_mut() {
734            if self.offloaded {
735                block.cast_to(&self.device)?;
736            }
737            (hidden_states, position_bias) = block.forward(
738                &hidden_states,
739                position_bias.as_ref(),
740                encoder_hidden_states,
741            )?;
742            if self.offloaded {
743                block.cast_to(&Device::Cpu)?;
744            }
745        }
746        self.final_layer_norm.forward(&hidden_states)
747    }
748}
749
750#[derive(Debug, Clone)]
751pub struct T5EncoderModel {
752    encoder: T5Stack,
753}
754
755impl T5EncoderModel {
756    pub fn load(
757        vb: ShardedVarBuilder,
758        cfg: &Config,
759        device: &Device,
760        offloaded: bool,
761    ) -> Result<Self> {
762        let shared_vb = if vb.contains_tensor("shared.weight") {
763            vb.pp("shared")
764        } else if vb.contains_tensor("decoder.embed_tokens") {
765            vb.pp("decoder").pp("embed_tokens")
766        } else {
767            vb.pp("encoder").pp("embed_tokens")
768        };
769        let shared = embedding(
770            cfg.vocab_size,
771            cfg.d_model,
772            shared_vb.set_device(device.clone()),
773            &None,
774        )?;
775        let shared = Arc::new(shared);
776        let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg, device, offloaded)?;
777        Ok(Self { encoder })
778    }
779
780    pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
781        self.encoder.forward(input_ids, None)
782    }
783}