mistralrs_quant/
safetensors.rs

1use candle_core::{DType, Device, Error, Result, Shape, Tensor, WithDType};
2use candle_nn::var_builder::{Backend, SimpleBackend, VarBuilderArgs};
3use float8::F8E4M3;
4use regex::Regex;
5use safetensors::tensor as st;
6use safetensors::tensor::SafeTensors;
7use std::collections::HashMap;
8use std::path::Path;
9use std::sync::Arc;
10
11fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
12    let size_in_bytes = T::DTYPE.size_in_bytes();
13    let elem_count = data.len() / size_in_bytes;
14    if (data.as_ptr() as usize) % size_in_bytes == 0 {
15        // SAFETY This is safe because we just checked that this
16        // was correctly aligned.
17        let data: &[T] =
18            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
19        Tensor::from_slice(data, shape, device)
20    } else {
21        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
22        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
23        let mut c: Vec<T> = Vec::with_capacity(elem_count);
24        // SAFETY: We just created c, so the allocated memory is necessarily
25        // contiguous and non overlapping with the view's data.
26        // We're downgrading the `c` pointer from T to u8, which removes alignment
27        // constraints.
28        unsafe {
29            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
30            c.set_len(elem_count)
31        }
32        Tensor::from_slice(&c, shape, device)
33    }
34}
35
36fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
37    data: &[u8],
38    shape: &[usize],
39    device: &Device,
40    conv: F,
41) -> Result<Tensor> {
42    let size_in_bytes = std::mem::size_of::<T>();
43    let elem_count = data.len() / size_in_bytes;
44    if (data.as_ptr() as usize) % size_in_bytes == 0 {
45        // SAFETY This is safe because we just checked that this
46        // was correctly aligned.
47        let data: &[T] =
48            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
49        let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
50        Tensor::from_vec(data, shape, device)
51    } else {
52        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
53        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
54        let mut c: Vec<T> = Vec::with_capacity(elem_count);
55        // SAFETY: We just created c, so the allocated memory is necessarily
56        // contiguous and non overlapping with the view's data.
57        // We're downgrading the `c` pointer from T to u8, which removes alignment
58        // constraints.
59        unsafe {
60            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
61            c.set_len(elem_count)
62        }
63        let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
64        Tensor::from_vec(c, shape, device)
65    }
66}
67
68fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
69    view: &st::TensorView<'_>,
70    device: &Device,
71    conv: F,
72) -> Result<Tensor> {
73    convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
74}
75
76fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
77    convert_slice::<T>(view.data(), view.shape(), device)
78}
79
80pub trait Load {
81    fn load(&self, device: &Device, dtype: Option<DType>) -> Result<Tensor>;
82}
83
84impl Load for st::TensorView<'_> {
85    fn load(&self, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
86        convert(self, device, dtype)
87    }
88}
89
90fn convert(
91    view: &st::TensorView<'_>,
92    device: &Device,
93    cast_dtype: Option<DType>,
94) -> Result<Tensor> {
95    match (view.dtype(), cast_dtype) {
96        (st::Dtype::U8, _) => convert_::<u8>(view, device),
97        (st::Dtype::U16, _) => {
98            let conv = |x| Ok(u32::from(x));
99            convert_with_cast_::<u16, u32, _>(view, device, conv)
100        }
101        (st::Dtype::U32, _) => convert_::<u32>(view, device),
102        (st::Dtype::I16, _) => convert_::<i16>(view, device),
103        (st::Dtype::I32, _) => convert_::<i32>(view, device),
104        (st::Dtype::I64, _) => convert_::<i64>(view, device),
105        (st::Dtype::BF16, None | Some(DType::BF16)) => convert_::<half::bf16>(view, device),
106        (st::Dtype::F16, None | Some(DType::F16)) => convert_::<half::f16>(view, device),
107        (st::Dtype::F32, _) => convert_::<f32>(view, device),
108        (st::Dtype::F64, _) => convert_::<f64>(view, device),
109        (st::Dtype::F8_E4M3, _) => convert_::<F8E4M3>(view, device),
110
111        (st::Dtype::BF16, Some(DType::F16)) => {
112            let conv = |x: half::bf16| Ok(half::f16::from_f32(x.to_f32()));
113            convert_with_cast_::<half::bf16, half::f16, _>(view, device, conv)
114        }
115        (st::Dtype::BF16, Some(DType::F32)) => {
116            let conv = |x: half::bf16| Ok(x.to_f32());
117            convert_with_cast_::<half::bf16, f32, _>(view, device, conv)
118        }
119        (st::Dtype::F16, Some(DType::BF16)) => {
120            let conv = |x: half::f16| Ok(half::bf16::from_f32(x.to_f32()));
121            convert_with_cast_::<half::f16, half::bf16, _>(view, device, conv)
122        }
123        (st::Dtype::F16, Some(DType::F32)) => {
124            let conv = |x: half::f16| Ok(x.to_f32());
125            convert_with_cast_::<half::f16, f32, _>(view, device, conv)
126        }
127        (dtype, _) => Err(Error::UnsupportedSafeTensorDtype(dtype)),
128    }
129}
130
131#[derive(yoke::Yokeable)]
132struct SafeTensors_<'a>(SafeTensors<'a>);
133
134pub struct MmapedSafetensors {
135    safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
136    routing: Option<HashMap<String, usize>>,
137}
138
139impl MmapedSafetensors {
140    /// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
141    ///
142    /// # Safety
143    ///
144    /// The unsafe is inherited from [`memmap2::MmapOptions`].
145    pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
146        let p = p.as_ref();
147        let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
148        let file = memmap2::MmapOptions::new()
149            .map(&file)
150            .map_err(|e| Error::from(e).with_path(p))?;
151        let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
152            file,
153            |data: &[u8]| {
154                let st = safetensors::SafeTensors::deserialize(data)
155                    .map_err(|e| Error::from(e).with_path(p))?;
156                Ok::<_, Error>(SafeTensors_(st))
157            },
158        )?;
159        Ok(Self {
160            safetensors: vec![safetensors],
161            routing: None,
162        })
163    }
164
165    /// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
166    ///
167    /// If a tensor name appears in multiple files, the last entry is returned.
168    ///
169    /// # Safety
170    ///
171    /// The unsafe is inherited from [`memmap2::MmapOptions`].
172    pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
173        let mut routing = HashMap::new();
174        let mut safetensors = vec![];
175        for (index, p) in paths.iter().enumerate() {
176            let p = p.as_ref();
177            let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
178            let file = memmap2::MmapOptions::new()
179                .map(&file)
180                .map_err(|e| Error::from(e).with_path(p))?;
181            let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
182                file,
183                |data: &[u8]| {
184                    let st = safetensors::SafeTensors::deserialize(data)
185                        .map_err(|e| Error::from(e).with_path(p))?;
186                    Ok::<_, Error>(SafeTensors_(st))
187                },
188            )?;
189            for k in data.get().0.names() {
190                routing.insert(k.to_string(), index);
191            }
192            safetensors.push(data)
193        }
194        Ok(Self {
195            safetensors,
196            routing: Some(routing),
197        })
198    }
199
200    pub fn load(&self, name: &str, dev: &Device, dtype: Option<DType>) -> Result<Tensor> {
201        self.get(name)?.load(dev, dtype)
202    }
203
204    pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
205        let mut tensors = vec![];
206        for safetensors in self.safetensors.iter() {
207            tensors.push(safetensors.get().0.tensors())
208        }
209        tensors.into_iter().flatten().collect()
210    }
211
212    pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
213        let index = match &self.routing {
214            None => 0,
215            Some(routing) => {
216                let index = routing.get(name).ok_or_else(|| {
217                    Error::CannotFindTensor {
218                        path: name.to_string(),
219                    }
220                    .bt()
221                })?;
222                *index
223            }
224        };
225        Ok(self.safetensors[index].get().0.tensor(name)?)
226    }
227}
228
229impl SimpleBackend for MmapedSafetensors {
230    fn get(
231        &self,
232        s: Shape,
233        name: &str,
234        _: candle_nn::Init,
235        dtype: DType,
236        dev: &Device,
237    ) -> Result<Tensor> {
238        let tensor = self.get_unchecked(name, dtype, dev)?;
239        if tensor.shape() != &s {
240            Err(candle_core::Error::UnexpectedShape {
241                msg: format!("shape mismatch for {name}"),
242                expected: s,
243                got: tensor.shape().clone(),
244            }
245            .bt())?
246        }
247        Ok(tensor)
248    }
249
250    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
251        self.load(name, dev, Some(dtype))
252    }
253
254    fn contains_tensor(&self, name: &str) -> bool {
255        self.get(name).is_ok()
256    }
257}
258
259pub enum ShardedSafeTensors {
260    Sharded {
261        b: MmapedSafetensors,
262        make_dummy_regexes: Option<Arc<Vec<Regex>>>,
263    },
264    SimpleBackend(Box<dyn SimpleBackend + 'static>),
265}
266
267pub type ShardedVarBuilder = VarBuilderArgs<'static, ShardedSafeTensors>;
268
269impl ShardedSafeTensors {
270    /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
271    /// files and make them usable in a sharded way.
272    ///
273    /// # Safety
274    ///
275    /// The unsafe is inherited from [`memmap2::MmapOptions`].
276    pub unsafe fn sharded<P: AsRef<std::path::Path>>(
277        paths: &[P],
278        dtype: DType,
279        dev: &Device,
280        make_dummy_regexes: Option<Arc<Vec<Regex>>>,
281    ) -> Result<ShardedVarBuilder> {
282        let tensors = MmapedSafetensors::multi(paths)?;
283        let backend = ShardedSafeTensors::Sharded {
284            b: tensors,
285            make_dummy_regexes,
286        };
287        Ok(VarBuilderArgs::new_with_args(backend, dtype, dev))
288    }
289}
290
291impl ShardedSafeTensors {
292    pub fn wrap(
293        backend: Box<dyn SimpleBackend + 'static>,
294        dtype: DType,
295        dev: Device,
296    ) -> ShardedVarBuilder {
297        VarBuilderArgs::new_with_args(Self::SimpleBackend(backend), dtype, &dev)
298    }
299}
300
301#[derive(Debug, Clone, Copy, Eq, PartialEq)]
302pub enum Shard {
303    Simple {
304        dim: usize,
305        rank: usize,
306        world_size: usize,
307    },
308    Offset {
309        dim: usize,
310        offset: usize,
311        len: usize,
312    },
313}
314
315impl Default for Shard {
316    fn default() -> Self {
317        Self::Simple {
318            dim: 0,
319            rank: 0,
320            world_size: 1,
321        }
322    }
323}
324
325/// Get part of a tensor, typically used to do Tensor Parallelism sharding.
326///
327/// If the tensor is of size (1024, 1024).
328///
329/// `dim` corresponds to the dimension to slice into
330/// `rank` is the rank of the current process
331/// `world_size` is the total number of ranks in the process group
332///
333/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))`
334/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))`
335/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))`
336impl Backend for ShardedSafeTensors {
337    type Hints = Shard;
338
339    fn get(
340        &self,
341        target_shape: Shape,
342        path: &str,
343        h: Self::Hints,
344        dtype: DType,
345        dev: &Device,
346    ) -> Result<Tensor> {
347        if let Shard::Simple { world_size: 1, .. } = &h {
348            // There is no sharding to be applied here so we use the default backend to speed
349            // things up.
350            match self {
351                Self::Sharded {
352                    b,
353                    make_dummy_regexes,
354                } => {
355                    if let Some(make_dummy_regexes) = make_dummy_regexes {
356                        if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
357                            return Err(Error::CannotFindTensor {
358                                path: path.to_string(),
359                            });
360                        }
361                    }
362                    return SimpleBackend::get(
363                        b,
364                        target_shape,
365                        path,
366                        Default::default(),
367                        dtype,
368                        dev,
369                    );
370                }
371                Self::SimpleBackend(b) => {
372                    return SimpleBackend::get(
373                        b.as_ref(),
374                        target_shape,
375                        path,
376                        Default::default(),
377                        dtype,
378                        dev,
379                    )
380                }
381            }
382        }
383
384        match h {
385            Shard::Simple {
386                dim,
387                rank,
388                world_size,
389            } => {
390                match self {
391                    Self::Sharded {
392                        b,
393                        make_dummy_regexes,
394                    } => {
395                        use safetensors::slice::IndexOp;
396
397                        if let Some(make_dummy_regexes) = make_dummy_regexes {
398                            if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
399                                return Err(Error::CannotFindTensor {
400                                    path: path.to_string(),
401                                });
402                            }
403                        }
404
405                        let view = b.get(path)?;
406                        let view_dtype = view.dtype();
407                        let mut shape = view.shape().to_vec();
408                        let size = shape[dim];
409
410                        if size % world_size != 0 {
411                            return Err(Error::ShapeMismatchSplit {
412                                shape: shape.into(),
413                                dim,
414                                n_parts: world_size,
415                            });
416                        }
417                        let block_size = size / world_size;
418                        let start = rank * block_size;
419                        let stop = (rank + 1) * block_size;
420
421                        // Everything is expressed in tensor dimension
422                        // bytes offsets is handled automatically for safetensors.
423
424                        let iterator = if dim == 0 {
425                            view.slice(start..stop).map_err(|_| {
426                                Error::Msg(format!(
427                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
428                                ))
429                            })?
430                        } else if dim == 1 {
431                            view.slice((.., start..stop)).map_err(|_| {
432                                Error::Msg(format!(
433                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
434                                ))
435                            })?
436                        } else {
437                            candle_core::bail!("Got sharded on dimensions != 0 or 1")
438                        };
439
440                        shape[dim] = block_size;
441
442                        let view_dtype: DType = view_dtype.try_into()?;
443                        let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
444                        Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)
445                    }
446                    Self::SimpleBackend(b) => {
447                        use candle_core::IndexOp;
448                        let tensor = b.get(target_shape, path, Default::default(), dtype, dev)?;
449
450                        let size = tensor.dim(dim)?;
451                        let shape = tensor.dims().to_vec();
452
453                        if size % world_size != 0 {
454                            return Err(Error::ShapeMismatchSplit {
455                                shape: shape.into(),
456                                dim,
457                                n_parts: world_size,
458                            });
459                        }
460                        let block_size = size / world_size;
461                        let start = rank * block_size;
462                        let stop = (rank + 1) * block_size;
463
464                        if dim == 0 {
465                            tensor.i((start..stop, ..))
466                        } else if dim == 1 {
467                            tensor.i((.., start..stop))
468                        } else {
469                            candle_core::bail!("Got sharded on dimensions != 0 or 1")
470                        }
471                    }
472                }
473            }
474            Shard::Offset { dim, offset, len } => {
475                match self {
476                    Self::Sharded {
477                        b,
478                        make_dummy_regexes,
479                    } => {
480                        use safetensors::slice::IndexOp;
481
482                        if let Some(make_dummy_regexes) = make_dummy_regexes {
483                            if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
484                                return Err(Error::CannotFindTensor {
485                                    path: path.to_string(),
486                                });
487                            }
488                        }
489
490                        let view = b.get(path)?;
491                        let view_dtype = view.dtype();
492                        let mut shape = view.shape().to_vec();
493
494                        let start = offset;
495                        let stop = start + len;
496
497                        // Everything is expressed in tensor dimension
498                        // bytes offsets is handled automatically for safetensors.
499
500                        let iterator = if dim == 0 {
501                            view.slice(start..stop).map_err(|_| {
502                                Error::Msg(format!(
503                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
504                                ))
505                            })?
506                        } else if dim == 1 {
507                            view.slice((.., start..stop)).map_err(|_| {
508                                Error::Msg(format!(
509                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
510                                ))
511                            })?
512                        } else {
513                            candle_core::bail!("Got sharded on dimensions != 0 or 1")
514                        };
515
516                        shape[dim] = len;
517
518                        let view_dtype: DType = view_dtype.try_into()?;
519                        let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
520                        Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)
521                    }
522                    Self::SimpleBackend(b) => {
523                        use candle_core::IndexOp;
524                        let tensor = b.get(target_shape, path, Default::default(), dtype, dev)?;
525
526                        let start = offset;
527                        let stop = start + len;
528
529                        if dim == 0 {
530                            tensor.i((start..stop, ..))
531                        } else if dim == 1 {
532                            tensor.i((.., start..stop))
533                        } else {
534                            candle_core::bail!("Got sharded on dimensions != 0 or 1")
535                        }
536                    }
537                }
538            }
539        }
540    }
541
542    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
543        match self {
544            Self::Sharded {
545                b,
546                make_dummy_regexes,
547            } => {
548                if let Some(make_dummy_regexes) = make_dummy_regexes {
549                    if make_dummy_regexes.iter().any(|x| x.is_match(name)) {
550                        return Err(Error::CannotFindTensor {
551                            path: name.to_string(),
552                        });
553                    }
554                }
555                <MmapedSafetensors as SimpleBackend>::get_unchecked(b, name, dtype, dev)
556            }
557            Self::SimpleBackend(b) => b.as_ref().get_unchecked(name, dtype, dev),
558        }
559    }
560
561    fn contains_tensor(&self, name: &str) -> bool {
562        match self {
563            Self::Sharded {
564                b,
565                make_dummy_regexes,
566            } => {
567                if let Some(make_dummy_regexes) = make_dummy_regexes {
568                    if make_dummy_regexes.iter().any(|x| x.is_match(name)) {
569                        return false;
570                    }
571                }
572                b.get(name).is_ok()
573            }
574            Self::SimpleBackend(b) => b.as_ref().contains_tensor(name),
575        }
576    }
577}