1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor, D};
7use candle_nn::{Activation, Embedding, Linear};
8use mistralrs_quant::ShardedVarBuilder;
9use serde::Deserialize;
10use std::sync::Arc;
11
12use crate::layers::{clamp_for_f16, embedding, linear_no_bias, MatMul};
13
14fn default_relative_attention_max_distance() -> usize {
15 128
16}
17
18fn default_is_decoder() -> bool {
19 false
20}
21
22fn default_use_cache() -> bool {
23 true
24}
25
26fn default_tie_word_embeddings() -> bool {
27 true
28}
29
30fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
31 let mask: Vec<_> = (0..size)
32 .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
33 .collect();
34 Tensor::from_slice(&mask, (size, size), device)
35}
36
37fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
38 let shape = mask.shape();
39 let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
40 let m = mask.where_cond(&on_true, on_false)?;
41 Ok(m)
42}
43
44#[derive(Debug, Deserialize, Default, Clone, PartialEq)]
45pub struct ActivationWithOptionalGating {
46 pub gated: bool,
47 pub activation: candle_nn::Activation,
48}
49
50pub fn deserialize_feed_forward_proj_activation<'de, D>(
51 deserializer: D,
52) -> std::result::Result<ActivationWithOptionalGating, D::Error>
53where
54 D: serde::de::Deserializer<'de>,
55{
56 match String::deserialize(deserializer)?.as_str() {
57 "gated-gelu" => Ok(ActivationWithOptionalGating {
58 gated: true,
59 activation: candle_nn::Activation::NewGelu,
60 }),
61 "gated-silu" => Ok(ActivationWithOptionalGating {
62 gated: true,
63 activation: candle_nn::Activation::Silu,
64 }),
65 buf => {
66 let activation = serde_plain::from_str(buf).map_err(serde::de::Error::custom)?;
67 Ok(ActivationWithOptionalGating {
68 gated: false,
69 activation,
70 })
71 }
72 }
73}
74
75#[derive(Debug, Clone, PartialEq, Deserialize)]
76pub struct Config {
77 pub vocab_size: usize,
78 pub d_model: usize,
79 pub d_kv: usize,
80 pub d_ff: usize,
81 pub num_layers: usize,
82 pub num_decoder_layers: Option<usize>,
83 pub num_heads: usize,
84 pub relative_attention_num_buckets: usize,
85 #[serde(default = "default_relative_attention_max_distance")]
86 pub relative_attention_max_distance: usize,
87 pub dropout_rate: f64,
88 pub layer_norm_epsilon: f64,
89 pub initializer_factor: f64,
90 #[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
91 pub feed_forward_proj: ActivationWithOptionalGating,
92 #[serde(default = "default_tie_word_embeddings")]
93 pub tie_word_embeddings: bool,
94 #[serde(default = "default_is_decoder")]
95 pub is_decoder: bool,
96 pub is_encoder_decoder: bool,
97 #[serde(default = "default_use_cache")]
98 pub use_cache: bool,
99 pub pad_token_id: usize,
100 pub eos_token_id: usize,
101 pub decoder_start_token_id: Option<usize>,
102}
103
104impl Default for Config {
105 fn default() -> Self {
106 Self {
107 vocab_size: 32128,
108 d_model: 512,
109 d_kv: 64,
110 d_ff: 2048,
111 num_layers: 6,
112 num_decoder_layers: None,
113 num_heads: 8,
114 relative_attention_num_buckets: 32,
115 relative_attention_max_distance: 128,
116 dropout_rate: 0.1,
117 layer_norm_epsilon: 1e-6,
118 initializer_factor: 1.0,
119 feed_forward_proj: ActivationWithOptionalGating {
120 gated: false,
121 activation: Activation::Relu,
122 },
123 tie_word_embeddings: true,
124 is_decoder: false,
125 is_encoder_decoder: true,
126 use_cache: true,
127 pad_token_id: 0,
128 eos_token_id: 1,
129 decoder_start_token_id: Some(0),
130 }
131 }
132}
133
134#[derive(Debug, Clone)]
135struct T5LayerNorm {
136 weight: Tensor,
137 variance_epsilon: f64,
138}
139
140impl T5LayerNorm {
141 fn load(h: usize, eps: f64, vb: ShardedVarBuilder) -> Result<Self> {
142 let weight = vb.get(h, "weight")?;
143 Ok(Self {
144 weight,
145 variance_epsilon: eps,
146 })
147 }
148}
149
150impl Module for T5LayerNorm {
151 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
152 let dtype = xs.dtype();
153 let xs_f32 = xs.to_dtype(DType::F32)?;
154 let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
156 let xs = xs_f32.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
157 let xs = xs.to_dtype(dtype)?;
158 let xs = xs.broadcast_mul(&self.weight)?;
159 Ok(xs)
160 }
161}
162
163#[derive(Debug, Clone)]
164struct T5DenseActDense {
165 wi: Linear,
166 wo: Linear,
167 act: Activation,
168}
169
170impl T5DenseActDense {
171 fn load(vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
172 let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
173 let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
174 Ok(Self {
175 wi,
176 wo,
177 act: Activation::Relu,
178 })
179 }
180}
181
182impl Module for T5DenseActDense {
183 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
184 let xs = self.wi.forward(xs)?;
185 let xs = self.act.forward(&xs)?;
186 let xs = self.wo.forward(&xs)?;
187 Ok(xs)
188 }
189}
190
191#[derive(Debug, Clone)]
192struct T5DenseGatedActDense {
193 wi_0: Linear,
194 wi_1: Linear,
195 wo: Linear,
196 act: Activation,
197}
198
199impl T5DenseGatedActDense {
200 fn load(vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
201 let wi_0 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
202 let wi_1 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
203 let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
204 Ok(Self {
205 wi_0,
206 wi_1,
207 wo,
208 act: cfg.feed_forward_proj.activation,
209 })
210 }
211}
212
213impl Module for T5DenseGatedActDense {
214 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
215 let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?;
216 let hidden_linear = self.wi_1.forward(xs)?;
217 let xs = hidden_gelu.broadcast_mul(&hidden_linear)?;
218 let xs = self.wo.forward(&xs)?;
219 Ok(xs)
220 }
221}
222
223#[derive(Debug, Clone)]
224struct T5LayerFF {
225 dense_act: Option<T5DenseActDense>,
226 gated_dense_act: Option<T5DenseGatedActDense>,
227 layer_norm: T5LayerNorm,
228}
229
230impl T5LayerFF {
231 fn load(vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
232 let layer_norm =
233 T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
234 let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
235 (
236 None,
237 Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
238 )
239 } else {
240 (
241 Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?),
242 None,
243 )
244 };
245 Ok(Self {
246 dense_act,
247 gated_dense_act,
248 layer_norm,
249 })
250 }
251
252 fn cast_to(&mut self, device: &Device) -> Result<()> {
253 self.layer_norm = T5LayerNorm {
254 weight: self.layer_norm.weight.to_device(device)?,
255 variance_epsilon: self.layer_norm.variance_epsilon,
256 };
257 if let Some(dense) = &mut self.dense_act {
258 dense.wi = Linear::new(
259 dense.wi.weight().to_device(device)?,
260 dense.wi.bias().map(|x| x.to_device(device).unwrap()),
261 );
262 dense.wo = Linear::new(
263 dense.wo.weight().to_device(device)?,
264 dense.wo.bias().map(|x| x.to_device(device).unwrap()),
265 );
266 }
267 if let Some(dense) = &mut self.gated_dense_act {
268 dense.wi_0 = Linear::new(
269 dense.wi_0.weight().to_device(device)?,
270 dense.wi_0.bias().map(|x| x.to_device(device).unwrap()),
271 );
272 dense.wi_1 = Linear::new(
273 dense.wi_1.weight().to_device(device)?,
274 dense.wi_1.bias().map(|x| x.to_device(device).unwrap()),
275 );
276 dense.wo = Linear::new(
277 dense.wo.weight().to_device(device)?,
278 dense.wo.bias().map(|x| x.to_device(device).unwrap()),
279 );
280 }
281 Ok(())
282 }
283}
284
285impl Module for T5LayerFF {
286 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
287 let ys = self.layer_norm.forward(xs)?;
288 let ys = match &self.dense_act {
289 Some(dense_act) => dense_act.forward(&ys)?,
290 None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?,
291 };
292 let xs = (xs + ys)?;
293 Ok(xs)
294 }
295}
296
297#[derive(Debug, Clone)]
298struct T5Attention {
299 q: Linear,
300 k: Linear,
301 v: Linear,
302 o: Linear,
303 n_heads: usize,
304 d_kv: usize,
305 relative_attention_bias: Option<Embedding>,
306 relative_attention_num_buckets: usize,
307 relative_attention_max_distance: usize,
308 inner_dim: usize,
309 use_cache: bool,
310}
311
312impl T5Attention {
313 fn load(
314 has_relative_attention_bias: bool,
315 decoder: bool,
316 vb: ShardedVarBuilder,
317 cfg: &Config,
318 ) -> Result<Self> {
319 let inner_dim = cfg.num_heads * cfg.d_kv;
320 let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
321 let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
322 let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
323 let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
324 let relative_attention_bias = if has_relative_attention_bias {
325 let emb = embedding(
326 cfg.relative_attention_num_buckets,
327 cfg.num_heads,
328 vb.pp("relative_attention_bias"),
329 &None,
330 )?;
331 Some(emb)
332 } else {
333 None
334 };
335 Ok(Self {
336 q,
337 k,
338 v,
339 o,
340 n_heads: cfg.num_heads,
341 d_kv: cfg.d_kv,
342 relative_attention_bias,
343 relative_attention_num_buckets: cfg.relative_attention_num_buckets,
344 relative_attention_max_distance: cfg.relative_attention_max_distance,
345 inner_dim,
346 use_cache: cfg.use_cache && decoder,
347 })
348 }
349
350 fn forward(
351 &self,
352 xs: &Tensor,
353 position_bias: Option<&Tensor>,
354 key_value_states: Option<&Tensor>,
355 mask: Option<&Tensor>,
356 ) -> Result<(Tensor, Option<Tensor>)> {
357 let kv_input = match key_value_states {
360 None => xs,
361 Some(key_value_states) => key_value_states,
362 };
363 let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?);
364 let kv_len = kv_input.dim(1)?;
365 let q = self.q.forward(xs)?;
366 let k = self.k.forward(kv_input)?;
367 let v = self.v.forward(kv_input)?;
368 let q = q
369 .reshape((b_sz, q_len, self.n_heads, self.d_kv))?
370 .transpose(1, 2)?
371 .contiguous()?;
372 let k = k
373 .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
374 .transpose(1, 2)?;
375 let v = v
376 .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
377 .transpose(1, 2)?;
378
379 let k = k.contiguous()?;
380 let v = v.contiguous()?;
381 let scores = { MatMul.matmul(&q, &k.t()?)? };
383 let scores = match mask {
384 None => scores,
385 Some(mask) => masked_fill(
386 &scores,
387 &mask
388 .unsqueeze(0)?
389 .unsqueeze(0)?
390 .repeat((b_sz, self.n_heads))?,
391 f32::NEG_INFINITY,
392 )?,
393 };
394
395 let (scores, position_bias) = match position_bias {
396 Some(position_bias) => (
397 scores.broadcast_add(position_bias)?,
398 Some(position_bias.clone()),
399 ),
400 None => match &self.relative_attention_bias {
401 None => (scores, None),
402 Some(relative_attention_bias) => {
403 let kv_len = k.dim(2)?;
405 let (q_start, q_end) = match self.use_cache {
406 true => ((kv_len - q_len) as u32, kv_len as u32),
407 false => (0_u32, kv_len as u32),
408 };
409 let num_buckets = self.relative_attention_num_buckets as u32 / 2;
410 let max_exact = num_buckets / 2;
411 let relative_position = (q_start..q_end)
412 .map(|i| {
413 (0..kv_len as u32)
414 .map(|j| {
415 if i < j {
416 if j - i < max_exact {
417 j - i + num_buckets
418 } else {
419 let b = f32::log(
420 (j - i) as f32 / max_exact as f32,
421 self.relative_attention_max_distance as f32
422 / max_exact as f32,
423 ) * (num_buckets - max_exact) as f32;
424 u32::min(
425 max_exact + num_buckets + b as u32,
426 self.relative_attention_num_buckets as u32 - 1,
427 )
428 }
429 } else if i - j < max_exact {
430 i - j
431 } else {
432 let b = f32::log(
433 (i - j) as f32 / max_exact as f32,
434 self.relative_attention_max_distance as f32
435 / max_exact as f32,
436 ) * (num_buckets - max_exact) as f32;
437 u32::min(max_exact + b as u32, num_buckets - 1)
438 }
439 })
440 .collect::<Vec<u32>>()
441 })
442 .collect::<Vec<Vec<_>>>();
443 let relative_buckets = Tensor::new(relative_position, q.device())?;
444 let position_bias = relative_attention_bias
445 .forward(&relative_buckets)?
446 .permute((2, 0, 1))?
447 .unsqueeze(0)?;
448 (scores.broadcast_add(&position_bias)?, Some(position_bias))
449 }
451 },
452 };
453
454 let attn_weights = { candle_nn::ops::softmax_last_dim(&scores)? };
455 let attn_output = MatMul.matmul(&attn_weights, &v)?;
456 let attn_output = attn_output
457 .transpose(1, 2)?
458 .reshape((b_sz, q_len, self.inner_dim))?;
459 let attn_output = self.o.forward(&attn_output)?;
460 Ok((attn_output, position_bias))
461 }
462}
463
464#[derive(Debug, Clone)]
465struct T5LayerSelfAttention {
466 self_attention: T5Attention,
467 layer_norm: T5LayerNorm,
468}
469
470impl T5LayerSelfAttention {
471 fn load(h: bool, d: bool, vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
472 let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?;
473 let layer_norm =
474 T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
475 Ok(Self {
476 self_attention,
477 layer_norm,
478 })
479 }
480
481 fn forward(
482 &self,
483 xs: &Tensor,
484 position_bias: Option<&Tensor>,
485 mask: Option<&Tensor>,
486 ) -> Result<(Tensor, Option<Tensor>)> {
487 let normed_xs = self.layer_norm.forward(xs)?;
488 let (ys, position_bias) =
489 self.self_attention
490 .forward(&normed_xs, position_bias, None, mask)?;
491 let ys = (xs + ys)?;
492 Ok((ys, position_bias))
493 }
494
495 fn cast_to(&mut self, device: &Device) -> Result<()> {
496 self.self_attention.q = Linear::new(
497 self.self_attention.q.weight().to_device(device)?,
498 self.self_attention
499 .q
500 .bias()
501 .map(|x| x.to_device(device).unwrap()),
502 );
503 self.self_attention.k = Linear::new(
504 self.self_attention.k.weight().to_device(device)?,
505 self.self_attention
506 .k
507 .bias()
508 .map(|x| x.to_device(device).unwrap()),
509 );
510 self.self_attention.v = Linear::new(
511 self.self_attention.v.weight().to_device(device)?,
512 self.self_attention
513 .v
514 .bias()
515 .map(|x| x.to_device(device).unwrap()),
516 );
517 self.self_attention.o = Linear::new(
518 self.self_attention.o.weight().to_device(device)?,
519 self.self_attention
520 .o
521 .bias()
522 .map(|x| x.to_device(device).unwrap()),
523 );
524 if let Some(embed) = &mut self.self_attention.relative_attention_bias {
525 *embed = Embedding::new(embed.embeddings().to_device(device)?, embed.hidden_size());
526 }
527 self.layer_norm = T5LayerNorm {
528 weight: self.layer_norm.weight.to_device(device)?,
529 variance_epsilon: self.layer_norm.variance_epsilon,
530 };
531 Ok(())
532 }
533}
534
535#[derive(Debug, Clone)]
536struct T5LayerCrossAttention {
537 cross_attention: T5Attention,
538 layer_norm: T5LayerNorm,
539}
540
541impl T5LayerCrossAttention {
542 fn load(decoder: bool, vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
543 let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?;
544 let layer_norm =
545 T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
546 Ok(Self {
547 cross_attention,
548 layer_norm,
549 })
550 }
551
552 fn forward(
553 &self,
554 hidden_states: &Tensor,
555 position_bias: Option<&Tensor>,
556 key_value_states: &Tensor,
557 ) -> Result<(Tensor, Option<Tensor>)> {
558 let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
559 let (ys, position_bias) = self.cross_attention.forward(
560 &normed_hidden_states,
561 position_bias,
562 Some(key_value_states),
563 None,
564 )?;
565 let ys = (hidden_states + ys)?;
566 Ok((ys, position_bias))
567 }
568
569 fn cast_to(&mut self, device: &Device) -> Result<()> {
570 self.cross_attention.q = Linear::new(
571 self.cross_attention.q.weight().to_device(device)?,
572 self.cross_attention
573 .q
574 .bias()
575 .map(|x| x.to_device(device).unwrap()),
576 );
577 self.cross_attention.k = Linear::new(
578 self.cross_attention.k.weight().to_device(device)?,
579 self.cross_attention
580 .k
581 .bias()
582 .map(|x| x.to_device(device).unwrap()),
583 );
584 self.cross_attention.v = Linear::new(
585 self.cross_attention.v.weight().to_device(device)?,
586 self.cross_attention
587 .v
588 .bias()
589 .map(|x| x.to_device(device).unwrap()),
590 );
591 self.cross_attention.o = Linear::new(
592 self.cross_attention.o.weight().to_device(device)?,
593 self.cross_attention
594 .o
595 .bias()
596 .map(|x| x.to_device(device).unwrap()),
597 );
598 if let Some(embed) = &mut self.cross_attention.relative_attention_bias {
599 *embed = Embedding::new(embed.embeddings().to_device(device)?, embed.hidden_size());
600 }
601 self.layer_norm = T5LayerNorm {
602 weight: self.layer_norm.weight.to_device(device)?,
603 variance_epsilon: self.layer_norm.variance_epsilon,
604 };
605 Ok(())
606 }
607}
608
609#[derive(Debug, Clone)]
610struct T5Block {
611 self_attn: T5LayerSelfAttention,
612 cross_attn: Option<T5LayerCrossAttention>,
613 ff: T5LayerFF,
614}
615
616impl T5Block {
617 fn load(
618 has_relative_attention_bias: bool,
619 decoder: bool,
620 vb: ShardedVarBuilder,
621 cfg: &Config,
622 ) -> Result<Self> {
623 let vb = vb.pp("layer");
624 let self_attn =
625 T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?;
626 let cross_attn = if cfg.is_decoder {
627 Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?)
628 } else {
629 None
630 };
631 let ff_i = if cross_attn.is_some() { 2 } else { 1 };
632 let ff = T5LayerFF::load(vb.pp(ff_i.to_string()), cfg)?;
633 Ok(Self {
634 self_attn,
635 cross_attn,
636 ff,
637 })
638 }
639
640 fn forward(
641 &self,
642 xs: &Tensor,
643 position_bias: Option<&Tensor>,
644 encoder_hidden_states: Option<&Tensor>,
645 ) -> Result<(Tensor, Option<Tensor>)> {
646 let mask = match self.cross_attn.is_some() {
648 true => {
649 let mask_len = xs.dim(1)?;
650 if mask_len <= 1 {
653 None
654 } else {
655 Some(get_mask(mask_len, xs.device())?)
656 }
657 }
658 false => None,
659 };
660 let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
661 if xs.dtype() == DType::F16 {
663 xs = clamp_for_f16(&xs)?;
664 }
665 if let Some(cross_attn) = &self.cross_attn {
666 (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
667 if xs.dtype() == DType::F16 {
669 xs = clamp_for_f16(&xs)?;
670 }
671 }
672 let mut xs = self.ff.forward(&xs)?;
673 if xs.dtype() == DType::F16 {
675 xs = clamp_for_f16(&xs)?;
676 }
677 Ok((xs, position_bias))
678 }
679
680 fn cast_to(&mut self, device: &Device) -> Result<()> {
681 self.self_attn.cast_to(device)?;
682 if let Some(cross_attn) = &mut self.cross_attn {
683 cross_attn.cast_to(device)?;
684 }
685 self.ff.cast_to(device)?;
686 Ok(())
687 }
688}
689
690#[derive(Debug, Clone)]
691struct T5Stack {
692 block: Vec<T5Block>,
693 shared: Arc<Embedding>,
694 final_layer_norm: T5LayerNorm,
695 device: Device,
696 offloaded: bool,
697}
698
699impl T5Stack {
700 fn load(
701 decoder: bool,
702 vb: ShardedVarBuilder,
703 shared: &Arc<Embedding>,
704 cfg: &Config,
705 device: &Device,
706 offloaded: bool,
707 ) -> Result<Self> {
708 let block = (0..cfg.num_layers)
709 .map(|i| T5Block::load(i == 0, decoder, vb.pp(format!("block.{i}")), cfg))
710 .collect::<Result<Vec<_>>>()?;
711 let final_layer_norm = T5LayerNorm::load(
712 cfg.d_model,
713 cfg.layer_norm_epsilon,
714 vb.pp("final_layer_norm").set_device(device.clone()),
715 )?;
716 Ok(Self {
717 block,
718 shared: shared.clone(),
719 final_layer_norm,
720 device: device.clone(),
721 offloaded,
722 })
723 }
724
725 fn forward(
726 &mut self,
727 input_ids: &Tensor,
728 encoder_hidden_states: Option<&Tensor>,
729 ) -> Result<Tensor> {
730 let input_embeds = self.shared.as_ref().forward(input_ids)?;
731 let mut hidden_states = input_embeds;
732 let mut position_bias = None;
733 for block in self.block.iter_mut() {
734 if self.offloaded {
735 block.cast_to(&self.device)?;
736 }
737 (hidden_states, position_bias) = block.forward(
738 &hidden_states,
739 position_bias.as_ref(),
740 encoder_hidden_states,
741 )?;
742 if self.offloaded {
743 block.cast_to(&Device::Cpu)?;
744 }
745 }
746 self.final_layer_norm.forward(&hidden_states)
747 }
748}
749
750#[derive(Debug, Clone)]
751pub struct T5EncoderModel {
752 encoder: T5Stack,
753}
754
755impl T5EncoderModel {
756 pub fn load(
757 vb: ShardedVarBuilder,
758 cfg: &Config,
759 device: &Device,
760 offloaded: bool,
761 ) -> Result<Self> {
762 let shared_vb = if vb.contains_tensor("shared.weight") {
763 vb.pp("shared")
764 } else if vb.contains_tensor("decoder.embed_tokens") {
765 vb.pp("decoder").pp("embed_tokens")
766 } else {
767 vb.pp("encoder").pp("embed_tokens")
768 };
769 let shared = embedding(
770 cfg.vocab_size,
771 cfg.d_model,
772 shared_vb.set_device(device.clone()),
773 &None,
774 )?;
775 let shared = Arc::new(shared);
776 let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg, device, offloaded)?;
777 Ok(Self { encoder })
778 }
779
780 pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
781 self.encoder.forward(input_ids, None)
782 }
783}