mistralrs_quant/
safetensors.rs

1use candle_core::{
2    from_storage_no_op, DType, Device, Error, IndexOp, Result, Shape, Storage, Tensor, WithDType,
3};
4use candle_nn::var_builder::{Backend, SimpleBackend, VarBuilderArgs};
5use float8::F8E4M3;
6use regex::Regex;
7use safetensors::tensor as st;
8use safetensors::tensor::SafeTensors;
9use std::collections::HashMap;
10use std::path::Path;
11use std::sync::Arc;
12
13fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
14    let size_in_bytes = T::DTYPE.size_in_bytes();
15    let elem_count = data.len() / size_in_bytes;
16    if (data.as_ptr() as usize) % size_in_bytes == 0 {
17        // SAFETY This is safe because we just checked that this
18        // was correctly aligned.
19        let data: &[T] =
20            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
21        Tensor::from_slice(data, shape, device)
22    } else {
23        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
24        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
25        let mut c: Vec<T> = Vec::with_capacity(elem_count);
26        // SAFETY: We just created c, so the allocated memory is necessarily
27        // contiguous and non overlapping with the view's data.
28        // We're downgrading the `c` pointer from T to u8, which removes alignment
29        // constraints.
30        unsafe {
31            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
32            c.set_len(elem_count)
33        }
34        Tensor::from_slice(&c, shape, device)
35    }
36}
37
38fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
39    data: &[u8],
40    shape: &[usize],
41    device: &Device,
42    conv: F,
43) -> Result<Tensor> {
44    let size_in_bytes = std::mem::size_of::<T>();
45    let elem_count = data.len() / size_in_bytes;
46    if (data.as_ptr() as usize) % size_in_bytes == 0 {
47        // SAFETY This is safe because we just checked that this
48        // was correctly aligned.
49        let data: &[T] =
50            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
51        let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
52        Tensor::from_vec(data, shape, device)
53    } else {
54        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
55        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
56        let mut c: Vec<T> = Vec::with_capacity(elem_count);
57        // SAFETY: We just created c, so the allocated memory is necessarily
58        // contiguous and non overlapping with the view's data.
59        // We're downgrading the `c` pointer from T to u8, which removes alignment
60        // constraints.
61        unsafe {
62            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
63            c.set_len(elem_count)
64        }
65        let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
66        Tensor::from_vec(c, shape, device)
67    }
68}
69
70fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
71    view: &st::TensorView<'_>,
72    device: &Device,
73    conv: F,
74) -> Result<Tensor> {
75    convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
76}
77
78fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
79    convert_slice::<T>(view.data(), view.shape(), device)
80}
81
82pub trait Load {
83    fn load(&self, device: &Device, dtype: Option<DType>) -> Result<Tensor>;
84}
85
86impl Load for st::TensorView<'_> {
87    fn load(&self, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
88        convert(self, device, dtype)
89    }
90}
91
92fn convert(
93    view: &st::TensorView<'_>,
94    device: &Device,
95    cast_dtype: Option<DType>,
96) -> Result<Tensor> {
97    match (view.dtype(), cast_dtype) {
98        (st::Dtype::BF16, Some(DType::F16)) => {
99            let conv = |x: half::bf16| Ok(half::f16::from_f32(x.to_f32()));
100            convert_with_cast_::<half::bf16, half::f16, _>(view, device, conv)
101        }
102        (st::Dtype::BF16, Some(DType::F32)) => {
103            let conv = |x: half::bf16| Ok(x.to_f32());
104            convert_with_cast_::<half::bf16, f32, _>(view, device, conv)
105        }
106        (st::Dtype::F16, Some(DType::BF16)) => {
107            let conv = |x: half::f16| Ok(half::bf16::from_f32(x.to_f32()));
108            convert_with_cast_::<half::f16, half::bf16, _>(view, device, conv)
109        }
110        (st::Dtype::F16, Some(DType::F32)) => {
111            let conv = |x: half::f16| Ok(x.to_f32());
112            convert_with_cast_::<half::f16, f32, _>(view, device, conv)
113        }
114        (st::Dtype::F32, Some(DType::BF16)) => {
115            let conv = |x: f32| Ok(half::bf16::from_f32(x));
116            convert_with_cast_::<f32, half::bf16, _>(view, device, conv)
117        }
118        (st::Dtype::F32, Some(DType::F16)) => {
119            let conv = |x: f32| Ok(half::f16::from_f32(x));
120            convert_with_cast_::<f32, half::f16, _>(view, device, conv)
121        }
122
123        (st::Dtype::U8, _) => convert_::<u8>(view, device),
124        (st::Dtype::U16, _) => {
125            let conv = |x| Ok(u32::from(x));
126            convert_with_cast_::<u16, u32, _>(view, device, conv)
127        }
128        (st::Dtype::U32, _) => convert_::<u32>(view, device),
129        (st::Dtype::I16, _) => convert_::<i16>(view, device),
130        (st::Dtype::I32, _) => convert_::<i32>(view, device),
131        (st::Dtype::I64, _) => convert_::<i64>(view, device),
132        (st::Dtype::BF16, None | Some(DType::BF16)) => convert_::<half::bf16>(view, device),
133        (st::Dtype::F16, None | Some(DType::F16)) => convert_::<half::f16>(view, device),
134        (st::Dtype::F32, _) => convert_::<f32>(view, device),
135        (st::Dtype::F64, _) => convert_::<f64>(view, device),
136        (st::Dtype::F8_E4M3, _) => convert_::<F8E4M3>(view, device),
137        (st::Dtype::F6_E2M3, _)
138        | (st::Dtype::F6_E3M2, _)
139        | (st::Dtype::F4, _)
140        | (st::Dtype::F8_E8M0, _) => {
141            // For dummy types, we need to handle loading by creating a dummy tensor
142            // Since these types don't have actual data representation, we'll create
143            // a tensor that indicates it's a dummy type
144            convert_dummy(view, device)
145        }
146        (dtype, _) => Err(Error::UnsupportedSafeTensorDtype(dtype)),
147    }
148}
149
150fn convert_dummy(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
151    // For dummy types, we'll create the appropriate storage variant that preserves
152    // both the raw data and the correct dtype
153    let (dtype, _dtype_name) = match view.dtype() {
154        st::Dtype::F6_E2M3 => (DType::F6E2M3, "F6_E2M3 (MX6)"),
155        st::Dtype::F6_E3M2 => (DType::F6E3M2, "F6_E3M2 (MX6)"),
156        st::Dtype::F4 => (DType::F4, "F4 (MX4)"),
157        st::Dtype::F8_E8M0 => (DType::F8E8M0, "F8_E8M0"),
158        _ => unreachable!("convert_dummy called with non-dummy dtype"),
159    };
160
161    // Load the raw bytes
162    let data = view.data();
163    let shape = view.shape();
164
165    // Create storage with the appropriate dummy type variant
166    let storage = match device {
167        Device::Cpu => {
168            let cpu_storage = match dtype {
169                DType::F6E2M3 => candle_core::cpu_backend::CpuStorage::F6E2M3(data.to_vec()),
170                DType::F6E3M2 => candle_core::cpu_backend::CpuStorage::F6E3M2(data.to_vec()),
171                DType::F4 => candle_core::cpu_backend::CpuStorage::F4(data.to_vec()),
172                DType::F8E8M0 => candle_core::cpu_backend::CpuStorage::F8E8M0(data.to_vec()),
173                _ => unreachable!(),
174            };
175            Storage::Cpu(cpu_storage)
176        }
177        #[cfg(feature = "cuda")]
178        Device::Cuda(device) => {
179            let mut slice = unsafe { device.alloc::<u8>(data.len())? };
180            device.memcpy_htod(data, &mut slice)?;
181
182            let slice = match dtype {
183                DType::F6E2M3 => candle_core::cuda_backend::CudaStorageSlice::F6E2M3(slice),
184                DType::F6E3M2 => candle_core::cuda_backend::CudaStorageSlice::F6E3M2(slice),
185                DType::F4 => candle_core::cuda_backend::CudaStorageSlice::F4(slice),
186                DType::F8E8M0 => candle_core::cuda_backend::CudaStorageSlice::F8E8M0(slice),
187                _ => unreachable!(),
188            };
189            let storage = candle_core::cuda_backend::CudaStorage {
190                slice,
191                device: device.clone(),
192            };
193            Storage::Cuda(storage)
194        }
195        #[cfg(not(feature = "cuda"))]
196        Device::Cuda(_) => {
197            return Err(Error::Msg("CUDA support not compiled".to_string()));
198        }
199        #[cfg(feature = "metal")]
200        Device::Metal(device) => {
201            let buffer = device.new_buffer_with_data(data)?;
202
203            let storage = candle_core::metal_backend::MetalStorage::new(
204                buffer,
205                device.clone(),
206                data.len(),
207                dtype,
208            );
209            Storage::Metal(storage)
210        }
211        #[cfg(not(feature = "metal"))]
212        Device::Metal(_) => {
213            return Err(Error::Msg("Metal support not compiled".to_string()));
214        }
215    };
216
217    Ok(from_storage_no_op(storage, shape, false))
218}
219
220#[derive(yoke::Yokeable)]
221struct SafeTensors_<'a>(SafeTensors<'a>);
222
223pub struct MmapedSafetensors {
224    safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
225    routing: Option<HashMap<String, usize>>,
226}
227
228impl MmapedSafetensors {
229    /// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
230    ///
231    /// # Safety
232    ///
233    /// The unsafe is inherited from [`memmap2::MmapOptions`].
234    pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
235        let p = p.as_ref();
236        let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
237        let file = memmap2::MmapOptions::new()
238            .map(&file)
239            .map_err(|e| Error::from(e).with_path(p))?;
240        let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
241            file,
242            |data: &[u8]| {
243                let st = safetensors::SafeTensors::deserialize(data)
244                    .map_err(|e| Error::from(e).with_path(p))?;
245                Ok::<_, Error>(SafeTensors_(st))
246            },
247        )?;
248        Ok(Self {
249            safetensors: vec![safetensors],
250            routing: None,
251        })
252    }
253
254    /// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
255    ///
256    /// If a tensor name appears in multiple files, the last entry is returned.
257    ///
258    /// # Safety
259    ///
260    /// The unsafe is inherited from [`memmap2::MmapOptions`].
261    pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
262        let mut routing = HashMap::new();
263        let mut safetensors = vec![];
264        for (index, p) in paths.iter().enumerate() {
265            let p = p.as_ref();
266            let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
267            let file = memmap2::MmapOptions::new()
268                .map(&file)
269                .map_err(|e| Error::from(e).with_path(p))?;
270            let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
271                file,
272                |data: &[u8]| {
273                    let st = safetensors::SafeTensors::deserialize(data)
274                        .map_err(|e| Error::from(e).with_path(p))?;
275                    Ok::<_, Error>(SafeTensors_(st))
276                },
277            )?;
278            for k in data.get().0.names() {
279                routing.insert(k.to_string(), index);
280            }
281            safetensors.push(data)
282        }
283        Ok(Self {
284            safetensors,
285            routing: Some(routing),
286        })
287    }
288
289    pub fn load(&self, name: &str, dev: &Device, dtype: Option<DType>) -> Result<Tensor> {
290        self.get(name)?.load(dev, dtype)
291    }
292
293    pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
294        let mut tensors = vec![];
295        for safetensors in self.safetensors.iter() {
296            tensors.push(safetensors.get().0.tensors())
297        }
298        tensors.into_iter().flatten().collect()
299    }
300
301    pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
302        let index = match &self.routing {
303            None => 0,
304            Some(routing) => {
305                let index = routing.get(name).ok_or_else(|| {
306                    Error::CannotFindTensor {
307                        path: name.to_string(),
308                    }
309                    .bt()
310                })?;
311                *index
312            }
313        };
314        Ok(self.safetensors[index].get().0.tensor(name)?)
315    }
316}
317
318impl SimpleBackend for MmapedSafetensors {
319    fn get(
320        &self,
321        s: Shape,
322        name: &str,
323        _: candle_nn::Init,
324        dtype: DType,
325        dev: &Device,
326    ) -> Result<Tensor> {
327        let tensor = self.get_unchecked(name, dtype, dev)?;
328        if tensor.shape() != &s {
329            Err(candle_core::Error::UnexpectedShape {
330                msg: format!("shape mismatch for {name}"),
331                expected: s,
332                got: tensor.shape().clone(),
333            }
334            .bt())?
335        }
336        Ok(tensor)
337    }
338
339    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
340        self.load(name, dev, Some(dtype))
341    }
342
343    fn contains_tensor(&self, name: &str) -> bool {
344        self.get(name).is_ok()
345    }
346}
347
348pub enum ShardedSafeTensors {
349    Sharded {
350        b: MmapedSafetensors,
351        make_dummy_regexes: Option<Arc<Vec<Regex>>>,
352        predicate: Arc<dyn Fn(String) -> bool + Send + Sync + 'static>,
353    },
354    SimpleBackend(Box<dyn SimpleBackend + 'static>),
355}
356
357pub type ShardedVarBuilder = VarBuilderArgs<'static, ShardedSafeTensors>;
358
359impl ShardedSafeTensors {
360    /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
361    /// files and make them usable in a sharded way.
362    ///
363    /// - If `regexes` is specified, this will be used in `make_dummy_predicate` based on `.any`
364    /// - Only include keys for which predicate evaluates to true.
365    ///
366    /// # Safety
367    ///
368    /// The unsafe is inherited from [`memmap2::MmapOptions`].
369    pub unsafe fn sharded<P: AsRef<std::path::Path>>(
370        paths: &[P],
371        dtype: DType,
372        dev: &Device,
373        make_dummy_regexes: Option<Arc<Vec<Regex>>>,
374        predicate: Arc<dyn Fn(String) -> bool + Send + Sync + 'static>,
375    ) -> Result<ShardedVarBuilder> {
376        let tensors = MmapedSafetensors::multi(paths)?;
377        let backend = ShardedSafeTensors::Sharded {
378            b: tensors,
379            make_dummy_regexes,
380            predicate,
381        };
382        Ok(VarBuilderArgs::new_with_args(backend, dtype, dev))
383    }
384}
385
386impl ShardedSafeTensors {
387    pub fn wrap(
388        backend: Box<dyn SimpleBackend + 'static>,
389        dtype: DType,
390        dev: Device,
391    ) -> ShardedVarBuilder {
392        VarBuilderArgs::new_with_args(Self::SimpleBackend(backend), dtype, &dev)
393    }
394}
395
396#[derive(Debug, Clone, Copy, Eq, PartialEq)]
397pub enum Shard {
398    Simple {
399        dim: usize,
400        rank: usize,
401        world_size: usize,
402    },
403    Offset {
404        dim: usize,
405        offset: usize,
406        len: usize,
407    },
408}
409
410impl Shard {
411    pub fn apply_to(&self, tensor: &Tensor) -> Result<Tensor> {
412        match *self {
413            Shard::Simple {
414                dim,
415                rank,
416                world_size,
417            } => {
418                let size = tensor.dim(dim)?;
419                let shape = tensor.dims().to_vec();
420
421                if size % world_size != 0 {
422                    return Err(Error::ShapeMismatchSplit {
423                        shape: shape.into(),
424                        dim,
425                        n_parts: world_size,
426                    });
427                }
428                let block_size = size / world_size;
429                let start = rank * block_size;
430                let stop = (rank + 1) * block_size;
431
432                if dim == 0 {
433                    tensor.i(start..stop)
434                } else if dim == 1 {
435                    tensor.i((.., start..stop))
436                } else if dim == 2 {
437                    tensor.i((.., .., start..stop))
438                } else {
439                    candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
440                }
441            }
442            Shard::Offset { dim, offset, len } => {
443                let start = offset;
444                let stop = start + len;
445
446                if dim == 0 {
447                    tensor.i(start..stop)
448                } else if dim == 1 {
449                    tensor.i((.., start..stop))
450                } else if dim == 2 {
451                    tensor.i((.., .., start..stop))
452                } else {
453                    candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
454                }
455            }
456        }
457    }
458}
459
460impl Default for Shard {
461    fn default() -> Self {
462        Self::Simple {
463            dim: 0,
464            rank: 0,
465            world_size: 1,
466        }
467    }
468}
469
470/// Get part of a tensor, typically used to do Tensor Parallelism sharding.
471///
472/// If the tensor is of size (1024, 1024).
473///
474/// `dim` corresponds to the dimension to slice into
475/// `rank` is the rank of the current process
476/// `world_size` is the total number of ranks in the process group
477///
478/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))`
479/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))`
480/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))`
481impl Backend for ShardedSafeTensors {
482    type Hints = Shard;
483
484    fn get(
485        &self,
486        target_shape: Shape,
487        path: &str,
488        h: Self::Hints,
489        dtype: DType,
490        dev: &Device,
491    ) -> Result<Tensor> {
492        if let Shard::Simple { world_size: 1, .. } = &h {
493            // There is no sharding to be applied here so we use the default backend to speed
494            // things up.
495            match self {
496                Self::Sharded {
497                    b,
498                    make_dummy_regexes,
499                    predicate,
500                } => {
501                    if let Some(make_dummy_regexes) = make_dummy_regexes {
502                        if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
503                            return Err(Error::CannotFindTensor {
504                                path: path.to_string(),
505                            });
506                        }
507                    }
508                    let should_include = predicate(path.to_string());
509                    if !should_include {
510                        return Err(Error::CannotFindTensor {
511                            path: path.to_string(),
512                        });
513                    }
514
515                    return SimpleBackend::get(
516                        b,
517                        target_shape,
518                        path,
519                        Default::default(),
520                        dtype,
521                        dev,
522                    );
523                }
524                Self::SimpleBackend(b) => {
525                    return SimpleBackend::get(
526                        b.as_ref(),
527                        target_shape,
528                        path,
529                        Default::default(),
530                        dtype,
531                        dev,
532                    )
533                }
534            }
535        }
536
537        let result = match h {
538            Shard::Simple {
539                dim,
540                rank,
541                world_size,
542            } => {
543                match self {
544                    Self::Sharded {
545                        b,
546                        make_dummy_regexes,
547                        predicate,
548                    } => {
549                        use safetensors::slice::IndexOp;
550
551                        if let Some(make_dummy_regexes) = make_dummy_regexes {
552                            if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
553                                return Err(Error::CannotFindTensor {
554                                    path: path.to_string(),
555                                });
556                            }
557                        }
558                        let should_include = predicate(path.to_string());
559                        if !should_include {
560                            return Err(Error::CannotFindTensor {
561                                path: path.to_string(),
562                            });
563                        }
564
565                        let view = b.get(path)?;
566                        let view_dtype = view.dtype();
567                        let mut shape = view.shape().to_vec();
568                        let size = shape[dim];
569
570                        if size % world_size != 0 {
571                            return Err(Error::ShapeMismatchSplit {
572                                shape: shape.into(),
573                                dim,
574                                n_parts: world_size,
575                            });
576                        }
577                        let block_size = size / world_size;
578                        let start = rank * block_size;
579                        let stop = (rank + 1) * block_size;
580
581                        // Everything is expressed in tensor dimension
582                        // bytes offsets is handled automatically for safetensors.
583
584                        let iterator = if dim == 0 {
585                            view.slice(start..stop).map_err(|_| {
586                                Error::Msg(format!(
587                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
588                                ))
589                            })?
590                        } else if dim == 1 {
591                            view.slice((.., start..stop)).map_err(|_| {
592                                Error::Msg(format!(
593                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
594                                ))
595                            })?
596                        } else if dim == 2 {
597                            view.slice((.., .., start..stop)).map_err(|_| {
598                                Error::Msg(format!(
599                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
600                                ))
601                            })?
602                        } else {
603                            candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
604                        };
605
606                        shape[dim] = block_size;
607
608                        let view_dtype: DType = view_dtype.try_into()?;
609                        let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
610                        Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)?
611                    }
612                    Self::SimpleBackend(b) => {
613                        let tensor = b.get(target_shape, path, Default::default(), dtype, dev)?;
614                        h.apply_to(&tensor)?
615                    }
616                }
617            }
618            Shard::Offset { dim, offset, len } => {
619                match self {
620                    Self::Sharded {
621                        b,
622                        make_dummy_regexes,
623                        predicate,
624                    } => {
625                        use safetensors::slice::IndexOp;
626
627                        if let Some(make_dummy_regexes) = make_dummy_regexes {
628                            if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
629                                return Err(Error::CannotFindTensor {
630                                    path: path.to_string(),
631                                });
632                            }
633                        }
634                        let should_include = predicate(path.to_string());
635                        if !should_include {
636                            return Err(Error::CannotFindTensor {
637                                path: path.to_string(),
638                            });
639                        }
640
641                        let view = b.get(path)?;
642                        let view_dtype = view.dtype();
643                        let mut shape = view.shape().to_vec();
644
645                        let start = offset;
646                        let stop = start + len;
647
648                        // Everything is expressed in tensor dimension
649                        // bytes offsets is handled automatically for safetensors.
650
651                        let iterator = if dim == 0 {
652                            view.slice(start..stop).map_err(|_| {
653                                Error::Msg(format!(
654                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
655                                ))
656                            })?
657                        } else if dim == 1 {
658                            view.slice((.., start..stop)).map_err(|_| {
659                                Error::Msg(format!(
660                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
661                                ))
662                            })?
663                        } else if dim == 2 {
664                            view.slice((.., .., start..stop)).map_err(|_| {
665                                Error::Msg(format!(
666                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
667                                ))
668                            })?
669                        } else {
670                            candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
671                        };
672
673                        shape[dim] = len;
674
675                        let view_dtype: DType = view_dtype.try_into()?;
676                        let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
677                        Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)?
678                    }
679                    Self::SimpleBackend(b) => {
680                        let tensor = b.get(target_shape, path, Default::default(), dtype, dev)?;
681                        h.apply_to(&tensor)?
682                    }
683                }
684            }
685        };
686
687        result.contiguous()
688    }
689
690    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
691        match self {
692            Self::Sharded {
693                b,
694                make_dummy_regexes,
695                predicate,
696            } => {
697                if let Some(make_dummy_regexes) = make_dummy_regexes {
698                    if make_dummy_regexes.iter().any(|x| x.is_match(name)) {
699                        return Err(Error::CannotFindTensor {
700                            path: name.to_string(),
701                        });
702                    }
703                }
704                let should_include = predicate(name.to_string());
705                if !should_include {
706                    return Err(Error::CannotFindTensor {
707                        path: name.to_string(),
708                    });
709                }
710                <MmapedSafetensors as SimpleBackend>::get_unchecked(b, name, dtype, dev)
711            }
712            Self::SimpleBackend(b) => b.as_ref().get_unchecked(name, dtype, dev),
713        }
714    }
715
716    fn contains_tensor(&self, name: &str) -> bool {
717        match self {
718            Self::Sharded {
719                b,
720                make_dummy_regexes,
721                predicate,
722            } => {
723                if let Some(make_dummy_regexes) = make_dummy_regexes {
724                    if make_dummy_regexes.iter().any(|x| x.is_match(name)) {
725                        return false;
726                    }
727                }
728                let should_include = predicate(name.to_string());
729                if !should_include {
730                    return false;
731                }
732                b.get(name).is_ok()
733            }
734            Self::SimpleBackend(b) => b.as_ref().contains_tensor(name),
735        }
736    }
737}