mistralrs_quant/
safetensors.rs

1use candle_core::{DType, Device, Error, IndexOp, Result, Shape, Storage, Tensor, WithDType};
2use candle_nn::var_builder::{Backend, SimpleBackend, VarBuilderArgs};
3use float8::F8E4M3;
4use regex::Regex;
5use safetensors::tensor as st;
6use safetensors::tensor::SafeTensors;
7use std::collections::HashMap;
8use std::path::Path;
9use std::sync::Arc;
10
11fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
12    let size_in_bytes = T::DTYPE.size_in_bytes();
13    let elem_count = data.len() / size_in_bytes;
14    if (data.as_ptr() as usize) % size_in_bytes == 0 {
15        // SAFETY This is safe because we just checked that this
16        // was correctly aligned.
17        let data: &[T] =
18            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
19        Tensor::from_slice(data, shape, device)
20    } else {
21        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
22        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
23        let mut c: Vec<T> = Vec::with_capacity(elem_count);
24        // SAFETY: We just created c, so the allocated memory is necessarily
25        // contiguous and non overlapping with the view's data.
26        // We're downgrading the `c` pointer from T to u8, which removes alignment
27        // constraints.
28        unsafe {
29            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
30            c.set_len(elem_count)
31        }
32        Tensor::from_slice(&c, shape, device)
33    }
34}
35
36fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
37    data: &[u8],
38    shape: &[usize],
39    device: &Device,
40    conv: F,
41) -> Result<Tensor> {
42    let size_in_bytes = std::mem::size_of::<T>();
43    let elem_count = data.len() / size_in_bytes;
44    if (data.as_ptr() as usize) % size_in_bytes == 0 {
45        // SAFETY This is safe because we just checked that this
46        // was correctly aligned.
47        let data: &[T] =
48            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
49        let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
50        Tensor::from_vec(data, shape, device)
51    } else {
52        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
53        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
54        let mut c: Vec<T> = Vec::with_capacity(elem_count);
55        // SAFETY: We just created c, so the allocated memory is necessarily
56        // contiguous and non overlapping with the view's data.
57        // We're downgrading the `c` pointer from T to u8, which removes alignment
58        // constraints.
59        unsafe {
60            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
61            c.set_len(elem_count)
62        }
63        let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
64        Tensor::from_vec(c, shape, device)
65    }
66}
67
68fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
69    view: &st::TensorView<'_>,
70    device: &Device,
71    conv: F,
72) -> Result<Tensor> {
73    convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
74}
75
76fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
77    convert_slice::<T>(view.data(), view.shape(), device)
78}
79
80pub trait Load {
81    fn load(&self, device: &Device, dtype: Option<DType>) -> Result<Tensor>;
82}
83
84impl Load for st::TensorView<'_> {
85    fn load(&self, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
86        convert(self, device, dtype)
87    }
88}
89
90fn convert(
91    view: &st::TensorView<'_>,
92    device: &Device,
93    cast_dtype: Option<DType>,
94) -> Result<Tensor> {
95    match (view.dtype(), cast_dtype) {
96        (st::Dtype::BF16, Some(DType::F16)) => {
97            let conv = |x: half::bf16| Ok(half::f16::from_f32(x.to_f32()));
98            convert_with_cast_::<half::bf16, half::f16, _>(view, device, conv)
99        }
100        (st::Dtype::BF16, Some(DType::F32)) => {
101            let conv = |x: half::bf16| Ok(x.to_f32());
102            convert_with_cast_::<half::bf16, f32, _>(view, device, conv)
103        }
104        (st::Dtype::F16, Some(DType::BF16)) => {
105            let conv = |x: half::f16| Ok(half::bf16::from_f32(x.to_f32()));
106            convert_with_cast_::<half::f16, half::bf16, _>(view, device, conv)
107        }
108        (st::Dtype::F16, Some(DType::F32)) => {
109            let conv = |x: half::f16| Ok(x.to_f32());
110            convert_with_cast_::<half::f16, f32, _>(view, device, conv)
111        }
112        (st::Dtype::F32, Some(DType::BF16)) => {
113            let conv = |x: f32| Ok(half::bf16::from_f32(x));
114            convert_with_cast_::<f32, half::bf16, _>(view, device, conv)
115        }
116        (st::Dtype::F32, Some(DType::F16)) => {
117            let conv = |x: f32| Ok(half::f16::from_f32(x));
118            convert_with_cast_::<f32, half::f16, _>(view, device, conv)
119        }
120
121        (st::Dtype::U8, _) => convert_::<u8>(view, device),
122        (st::Dtype::U16, _) => {
123            let conv = |x| Ok(u32::from(x));
124            convert_with_cast_::<u16, u32, _>(view, device, conv)
125        }
126        (st::Dtype::U32, _) => convert_::<u32>(view, device),
127        (st::Dtype::I16, _) => convert_::<i16>(view, device),
128        (st::Dtype::I32, _) => convert_::<i32>(view, device),
129        (st::Dtype::I64, _) => convert_::<i64>(view, device),
130        (st::Dtype::BF16, None | Some(DType::BF16)) => convert_::<half::bf16>(view, device),
131        (st::Dtype::F16, None | Some(DType::F16)) => convert_::<half::f16>(view, device),
132        (st::Dtype::F32, _) => convert_::<f32>(view, device),
133        (st::Dtype::F64, _) => convert_::<f64>(view, device),
134        (st::Dtype::F8_E4M3, _) => convert_::<F8E4M3>(view, device),
135        (st::Dtype::F6_E2M3, _)
136        | (st::Dtype::F6_E3M2, _)
137        | (st::Dtype::F4, _)
138        | (st::Dtype::F8_E8M0, _) => {
139            // For dummy types, we need to handle loading by creating a dummy tensor
140            // Since these types don't have actual data representation, we'll create
141            // a tensor that indicates it's a dummy type
142            convert_dummy(view, device)
143        }
144        (dtype, _) => Err(Error::UnsupportedSafeTensorDtype(dtype)),
145    }
146}
147
148fn convert_dummy(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
149    // For dummy types, we'll create the appropriate storage variant that preserves
150    // both the raw data and the correct dtype
151    let (dtype, _dtype_name) = match view.dtype() {
152        st::Dtype::F6_E2M3 => (DType::F6E2M3, "F6_E2M3 (MX6)"),
153        st::Dtype::F6_E3M2 => (DType::F6E3M2, "F6_E3M2 (MX6)"),
154        st::Dtype::F4 => (DType::F4, "F4 (MX4)"),
155        st::Dtype::F8_E8M0 => (DType::F8E8M0, "F8_E8M0"),
156        _ => unreachable!("convert_dummy called with non-dummy dtype"),
157    };
158
159    // Load the raw bytes
160    let data = view.data();
161    let shape = view.shape();
162
163    // Create storage with the appropriate dummy type variant
164    let storage = match device {
165        Device::Cpu => {
166            let cpu_storage = match dtype {
167                DType::F6E2M3 => candle_core::cpu_backend::CpuStorage::F6E2M3(data.to_vec()),
168                DType::F6E3M2 => candle_core::cpu_backend::CpuStorage::F6E3M2(data.to_vec()),
169                DType::F4 => candle_core::cpu_backend::CpuStorage::F4(data.to_vec()),
170                DType::F8E8M0 => candle_core::cpu_backend::CpuStorage::F8E8M0(data.to_vec()),
171                _ => unreachable!(),
172            };
173            Storage::Cpu(cpu_storage)
174        }
175        #[cfg(feature = "cuda")]
176        Device::Cuda(device) => {
177            let mut slice = unsafe { device.alloc::<u8>(data.len())? };
178            device.memcpy_htod(data, &mut slice)?;
179
180            let slice = match dtype {
181                DType::F6E2M3 => candle_core::cuda_backend::CudaStorageSlice::F6E2M3(slice),
182                DType::F6E3M2 => candle_core::cuda_backend::CudaStorageSlice::F6E3M2(slice),
183                DType::F4 => candle_core::cuda_backend::CudaStorageSlice::F4(slice),
184                DType::F8E8M0 => candle_core::cuda_backend::CudaStorageSlice::F8E8M0(slice),
185                _ => unreachable!(),
186            };
187            let storage = candle_core::cuda_backend::CudaStorage {
188                slice,
189                device: device.clone(),
190            };
191            Storage::Cuda(storage)
192        }
193        #[cfg(not(feature = "cuda"))]
194        Device::Cuda(_) => {
195            return Err(Error::Msg("CUDA support not compiled".to_string()));
196        }
197        #[cfg(feature = "metal")]
198        Device::Metal(device) => {
199            let buffer = device.new_buffer_with_data(data)?;
200
201            let storage = candle_core::metal_backend::MetalStorage::new(
202                buffer,
203                device.clone(),
204                data.len(),
205                dtype,
206            );
207            Storage::Metal(storage)
208        }
209        #[cfg(not(feature = "metal"))]
210        Device::Metal(_) => {
211            return Err(Error::Msg("Metal support not compiled".to_string()));
212        }
213    };
214
215    Ok(Tensor::from((storage, shape)))
216}
217
218#[derive(yoke::Yokeable)]
219struct SafeTensors_<'a>(SafeTensors<'a>);
220
221pub struct MmapedSafetensors {
222    safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
223    routing: Option<HashMap<String, usize>>,
224}
225
226impl MmapedSafetensors {
227    /// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
228    ///
229    /// # Safety
230    ///
231    /// The unsafe is inherited from [`memmap2::MmapOptions`].
232    pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
233        let p = p.as_ref();
234        let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
235        let file = memmap2::MmapOptions::new()
236            .map(&file)
237            .map_err(|e| Error::from(e).with_path(p))?;
238        let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
239            file,
240            |data: &[u8]| {
241                let st = safetensors::SafeTensors::deserialize(data)
242                    .map_err(|e| Error::from(e).with_path(p))?;
243                Ok::<_, Error>(SafeTensors_(st))
244            },
245        )?;
246        Ok(Self {
247            safetensors: vec![safetensors],
248            routing: None,
249        })
250    }
251
252    /// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
253    ///
254    /// If a tensor name appears in multiple files, the last entry is returned.
255    ///
256    /// # Safety
257    ///
258    /// The unsafe is inherited from [`memmap2::MmapOptions`].
259    pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
260        let mut routing = HashMap::new();
261        let mut safetensors = vec![];
262        for (index, p) in paths.iter().enumerate() {
263            let p = p.as_ref();
264            let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
265            let file = memmap2::MmapOptions::new()
266                .map(&file)
267                .map_err(|e| Error::from(e).with_path(p))?;
268            let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
269                file,
270                |data: &[u8]| {
271                    let st = safetensors::SafeTensors::deserialize(data)
272                        .map_err(|e| Error::from(e).with_path(p))?;
273                    Ok::<_, Error>(SafeTensors_(st))
274                },
275            )?;
276            for k in data.get().0.names() {
277                routing.insert(k.to_string(), index);
278            }
279            safetensors.push(data)
280        }
281        Ok(Self {
282            safetensors,
283            routing: Some(routing),
284        })
285    }
286
287    pub fn load(&self, name: &str, dev: &Device, dtype: Option<DType>) -> Result<Tensor> {
288        self.get(name)?.load(dev, dtype)
289    }
290
291    pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
292        let mut tensors = vec![];
293        for safetensors in self.safetensors.iter() {
294            tensors.push(safetensors.get().0.tensors())
295        }
296        tensors.into_iter().flatten().collect()
297    }
298
299    pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
300        let index = match &self.routing {
301            None => 0,
302            Some(routing) => {
303                let index = routing.get(name).ok_or_else(|| {
304                    Error::CannotFindTensor {
305                        path: name.to_string(),
306                    }
307                    .bt()
308                })?;
309                *index
310            }
311        };
312        Ok(self.safetensors[index].get().0.tensor(name)?)
313    }
314}
315
316impl SimpleBackend for MmapedSafetensors {
317    fn get(
318        &self,
319        s: Shape,
320        name: &str,
321        _: candle_nn::Init,
322        dtype: DType,
323        dev: &Device,
324    ) -> Result<Tensor> {
325        let tensor = self.get_unchecked(name, dtype, dev)?;
326        if tensor.shape() != &s {
327            Err(candle_core::Error::UnexpectedShape {
328                msg: format!("shape mismatch for {name}"),
329                expected: s,
330                got: tensor.shape().clone(),
331            }
332            .bt())?
333        }
334        Ok(tensor)
335    }
336
337    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
338        self.load(name, dev, Some(dtype))
339    }
340
341    fn contains_tensor(&self, name: &str) -> bool {
342        self.get(name).is_ok()
343    }
344}
345
346pub enum ShardedSafeTensors {
347    Sharded {
348        b: MmapedSafetensors,
349        make_dummy_regexes: Option<Arc<Vec<Regex>>>,
350        predicate: Arc<dyn Fn(String) -> bool + Send + Sync + 'static>,
351    },
352    SimpleBackend(Box<dyn SimpleBackend + 'static>),
353}
354
355pub type ShardedVarBuilder = VarBuilderArgs<'static, ShardedSafeTensors>;
356
357impl ShardedSafeTensors {
358    /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
359    /// files and make them usable in a sharded way.
360    ///
361    /// - If `regexes` is specified, this will be used in `make_dummy_predicate` based on `.any`
362    /// - Only include keys for which predicate evaluates to true.
363    ///
364    /// # Safety
365    ///
366    /// The unsafe is inherited from [`memmap2::MmapOptions`].
367    pub unsafe fn sharded<P: AsRef<std::path::Path>>(
368        paths: &[P],
369        dtype: DType,
370        dev: &Device,
371        make_dummy_regexes: Option<Arc<Vec<Regex>>>,
372        predicate: Arc<dyn Fn(String) -> bool + Send + Sync + 'static>,
373    ) -> Result<ShardedVarBuilder> {
374        let tensors = MmapedSafetensors::multi(paths)?;
375        let backend = ShardedSafeTensors::Sharded {
376            b: tensors,
377            make_dummy_regexes,
378            predicate,
379        };
380        Ok(VarBuilderArgs::new_with_args(backend, dtype, dev))
381    }
382}
383
384impl ShardedSafeTensors {
385    pub fn wrap(
386        backend: Box<dyn SimpleBackend + 'static>,
387        dtype: DType,
388        dev: Device,
389    ) -> ShardedVarBuilder {
390        VarBuilderArgs::new_with_args(Self::SimpleBackend(backend), dtype, &dev)
391    }
392}
393
394#[derive(Debug, Clone, Copy, Eq, PartialEq)]
395pub enum Shard {
396    Simple {
397        dim: usize,
398        rank: usize,
399        world_size: usize,
400    },
401    Offset {
402        dim: usize,
403        offset: usize,
404        len: usize,
405    },
406}
407
408impl Shard {
409    pub fn apply_to(&self, tensor: &Tensor) -> Result<Tensor> {
410        match *self {
411            Shard::Simple {
412                dim,
413                rank,
414                world_size,
415            } => {
416                let size = tensor.dim(dim)?;
417                let shape = tensor.dims().to_vec();
418
419                if size % world_size != 0 {
420                    return Err(Error::ShapeMismatchSplit {
421                        shape: shape.into(),
422                        dim,
423                        n_parts: world_size,
424                    });
425                }
426                let block_size = size / world_size;
427                let start = rank * block_size;
428                let stop = (rank + 1) * block_size;
429
430                if dim == 0 {
431                    tensor.i(start..stop)
432                } else if dim == 1 {
433                    tensor.i((.., start..stop))
434                } else if dim == 2 {
435                    tensor.i((.., .., start..stop))
436                } else {
437                    candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
438                }
439            }
440            Shard::Offset { dim, offset, len } => {
441                let start = offset;
442                let stop = start + len;
443
444                if dim == 0 {
445                    tensor.i(start..stop)
446                } else if dim == 1 {
447                    tensor.i((.., start..stop))
448                } else if dim == 2 {
449                    tensor.i((.., .., start..stop))
450                } else {
451                    candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
452                }
453            }
454        }
455    }
456}
457
458impl Default for Shard {
459    fn default() -> Self {
460        Self::Simple {
461            dim: 0,
462            rank: 0,
463            world_size: 1,
464        }
465    }
466}
467
468/// Get part of a tensor, typically used to do Tensor Parallelism sharding.
469///
470/// If the tensor is of size (1024, 1024).
471///
472/// `dim` corresponds to the dimension to slice into
473/// `rank` is the rank of the current process
474/// `world_size` is the total number of ranks in the process group
475///
476/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))`
477/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))`
478/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))`
479impl Backend for ShardedSafeTensors {
480    type Hints = Shard;
481
482    fn get(
483        &self,
484        target_shape: Shape,
485        path: &str,
486        h: Self::Hints,
487        dtype: DType,
488        dev: &Device,
489    ) -> Result<Tensor> {
490        if let Shard::Simple { world_size: 1, .. } = &h {
491            // There is no sharding to be applied here so we use the default backend to speed
492            // things up.
493            match self {
494                Self::Sharded {
495                    b,
496                    make_dummy_regexes,
497                    predicate,
498                } => {
499                    if let Some(make_dummy_regexes) = make_dummy_regexes {
500                        if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
501                            return Err(Error::CannotFindTensor {
502                                path: path.to_string(),
503                            });
504                        }
505                    }
506                    let should_include = predicate(path.to_string());
507                    if !should_include {
508                        return Err(Error::CannotFindTensor {
509                            path: path.to_string(),
510                        });
511                    }
512
513                    return SimpleBackend::get(
514                        b,
515                        target_shape,
516                        path,
517                        Default::default(),
518                        dtype,
519                        dev,
520                    );
521                }
522                Self::SimpleBackend(b) => {
523                    return SimpleBackend::get(
524                        b.as_ref(),
525                        target_shape,
526                        path,
527                        Default::default(),
528                        dtype,
529                        dev,
530                    )
531                }
532            }
533        }
534
535        let result = match h {
536            Shard::Simple {
537                dim,
538                rank,
539                world_size,
540            } => {
541                match self {
542                    Self::Sharded {
543                        b,
544                        make_dummy_regexes,
545                        predicate,
546                    } => {
547                        use safetensors::slice::IndexOp;
548
549                        if let Some(make_dummy_regexes) = make_dummy_regexes {
550                            if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
551                                return Err(Error::CannotFindTensor {
552                                    path: path.to_string(),
553                                });
554                            }
555                        }
556                        let should_include = predicate(path.to_string());
557                        if !should_include {
558                            return Err(Error::CannotFindTensor {
559                                path: path.to_string(),
560                            });
561                        }
562
563                        let view = b.get(path)?;
564                        let view_dtype = view.dtype();
565                        let mut shape = view.shape().to_vec();
566                        let size = shape[dim];
567
568                        if size % world_size != 0 {
569                            return Err(Error::ShapeMismatchSplit {
570                                shape: shape.into(),
571                                dim,
572                                n_parts: world_size,
573                            });
574                        }
575                        let block_size = size / world_size;
576                        let start = rank * block_size;
577                        let stop = (rank + 1) * block_size;
578
579                        // Everything is expressed in tensor dimension
580                        // bytes offsets is handled automatically for safetensors.
581
582                        let iterator = if dim == 0 {
583                            view.slice(start..stop).map_err(|_| {
584                                Error::Msg(format!(
585                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
586                                ))
587                            })?
588                        } else if dim == 1 {
589                            view.slice((.., start..stop)).map_err(|_| {
590                                Error::Msg(format!(
591                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
592                                ))
593                            })?
594                        } else if dim == 2 {
595                            view.slice((.., .., start..stop)).map_err(|_| {
596                                Error::Msg(format!(
597                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
598                                ))
599                            })?
600                        } else {
601                            candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
602                        };
603
604                        shape[dim] = block_size;
605
606                        let view_dtype: DType = view_dtype.try_into()?;
607                        let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
608                        Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)?
609                    }
610                    Self::SimpleBackend(b) => {
611                        let tensor = b.get(target_shape, path, Default::default(), dtype, dev)?;
612                        h.apply_to(&tensor)?
613                    }
614                }
615            }
616            Shard::Offset { dim, offset, len } => {
617                match self {
618                    Self::Sharded {
619                        b,
620                        make_dummy_regexes,
621                        predicate,
622                    } => {
623                        use safetensors::slice::IndexOp;
624
625                        if let Some(make_dummy_regexes) = make_dummy_regexes {
626                            if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
627                                return Err(Error::CannotFindTensor {
628                                    path: path.to_string(),
629                                });
630                            }
631                        }
632                        let should_include = predicate(path.to_string());
633                        if !should_include {
634                            return Err(Error::CannotFindTensor {
635                                path: path.to_string(),
636                            });
637                        }
638
639                        let view = b.get(path)?;
640                        let view_dtype = view.dtype();
641                        let mut shape = view.shape().to_vec();
642
643                        let start = offset;
644                        let stop = start + len;
645
646                        // Everything is expressed in tensor dimension
647                        // bytes offsets is handled automatically for safetensors.
648
649                        let iterator = if dim == 0 {
650                            view.slice(start..stop).map_err(|_| {
651                                Error::Msg(format!(
652                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
653                                ))
654                            })?
655                        } else if dim == 1 {
656                            view.slice((.., start..stop)).map_err(|_| {
657                                Error::Msg(format!(
658                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
659                                ))
660                            })?
661                        } else if dim == 2 {
662                            view.slice((.., .., start..stop)).map_err(|_| {
663                                Error::Msg(format!(
664                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
665                                ))
666                            })?
667                        } else {
668                            candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
669                        };
670
671                        shape[dim] = len;
672
673                        let view_dtype: DType = view_dtype.try_into()?;
674                        let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
675                        Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)?
676                    }
677                    Self::SimpleBackend(b) => {
678                        let tensor = b.get(target_shape, path, Default::default(), dtype, dev)?;
679                        h.apply_to(&tensor)?
680                    }
681                }
682            }
683        };
684
685        result.contiguous()
686    }
687
688    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
689        match self {
690            Self::Sharded {
691                b,
692                make_dummy_regexes,
693                predicate,
694            } => {
695                if let Some(make_dummy_regexes) = make_dummy_regexes {
696                    if make_dummy_regexes.iter().any(|x| x.is_match(name)) {
697                        return Err(Error::CannotFindTensor {
698                            path: name.to_string(),
699                        });
700                    }
701                }
702                let should_include = predicate(name.to_string());
703                if !should_include {
704                    return Err(Error::CannotFindTensor {
705                        path: name.to_string(),
706                    });
707                }
708                <MmapedSafetensors as SimpleBackend>::get_unchecked(b, name, dtype, dev)
709            }
710            Self::SimpleBackend(b) => b.as_ref().get_unchecked(name, dtype, dev),
711        }
712    }
713
714    fn contains_tensor(&self, name: &str) -> bool {
715        match self {
716            Self::Sharded {
717                b,
718                make_dummy_regexes,
719                predicate,
720            } => {
721                if let Some(make_dummy_regexes) = make_dummy_regexes {
722                    if make_dummy_regexes.iter().any(|x| x.is_match(name)) {
723                        return false;
724                    }
725                }
726                let should_include = predicate(name.to_string());
727                if !should_include {
728                    return false;
729                }
730                b.get(name).is_ok()
731            }
732            Self::SimpleBackend(b) => b.as_ref().contains_tensor(name),
733        }
734    }
735}