mistralrs_quant/
safetensors.rs

1use candle_core::{DType, Device, Error, Result, Shape, Tensor, WithDType};
2use candle_nn::var_builder::{Backend, SimpleBackend, VarBuilderArgs};
3use float8::F8E4M3;
4use regex::Regex;
5use safetensors::tensor as st;
6use safetensors::tensor::SafeTensors;
7use std::collections::HashMap;
8use std::path::Path;
9use std::sync::Arc;
10
11fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
12    let size_in_bytes = T::DTYPE.size_in_bytes();
13    let elem_count = data.len() / size_in_bytes;
14    if (data.as_ptr() as usize) % size_in_bytes == 0 {
15        // SAFETY This is safe because we just checked that this
16        // was correctly aligned.
17        let data: &[T] =
18            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
19        Tensor::from_slice(data, shape, device)
20    } else {
21        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
22        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
23        let mut c: Vec<T> = Vec::with_capacity(elem_count);
24        // SAFETY: We just created c, so the allocated memory is necessarily
25        // contiguous and non overlapping with the view's data.
26        // We're downgrading the `c` pointer from T to u8, which removes alignment
27        // constraints.
28        unsafe {
29            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
30            c.set_len(elem_count)
31        }
32        Tensor::from_slice(&c, shape, device)
33    }
34}
35
36fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
37    data: &[u8],
38    shape: &[usize],
39    device: &Device,
40    conv: F,
41) -> Result<Tensor> {
42    let size_in_bytes = std::mem::size_of::<T>();
43    let elem_count = data.len() / size_in_bytes;
44    if (data.as_ptr() as usize) % size_in_bytes == 0 {
45        // SAFETY This is safe because we just checked that this
46        // was correctly aligned.
47        let data: &[T] =
48            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
49        let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
50        Tensor::from_vec(data, shape, device)
51    } else {
52        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
53        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
54        let mut c: Vec<T> = Vec::with_capacity(elem_count);
55        // SAFETY: We just created c, so the allocated memory is necessarily
56        // contiguous and non overlapping with the view's data.
57        // We're downgrading the `c` pointer from T to u8, which removes alignment
58        // constraints.
59        unsafe {
60            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
61            c.set_len(elem_count)
62        }
63        let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
64        Tensor::from_vec(c, shape, device)
65    }
66}
67
68fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
69    view: &st::TensorView<'_>,
70    device: &Device,
71    conv: F,
72) -> Result<Tensor> {
73    convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
74}
75
76fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
77    convert_slice::<T>(view.data(), view.shape(), device)
78}
79
80pub trait Load {
81    fn load(&self, device: &Device, dtype: Option<DType>) -> Result<Tensor>;
82}
83
84impl Load for st::TensorView<'_> {
85    fn load(&self, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
86        convert(self, device, dtype)
87    }
88}
89
90fn convert(
91    view: &st::TensorView<'_>,
92    device: &Device,
93    cast_dtype: Option<DType>,
94) -> Result<Tensor> {
95    match (view.dtype(), cast_dtype) {
96        (st::Dtype::BF16, Some(DType::F16)) => {
97            let conv = |x: half::bf16| Ok(half::f16::from_f32(x.to_f32()));
98            convert_with_cast_::<half::bf16, half::f16, _>(view, device, conv)
99        }
100        (st::Dtype::BF16, Some(DType::F32)) => {
101            let conv = |x: half::bf16| Ok(x.to_f32());
102            convert_with_cast_::<half::bf16, f32, _>(view, device, conv)
103        }
104        (st::Dtype::F16, Some(DType::BF16)) => {
105            let conv = |x: half::f16| Ok(half::bf16::from_f32(x.to_f32()));
106            convert_with_cast_::<half::f16, half::bf16, _>(view, device, conv)
107        }
108        (st::Dtype::F16, Some(DType::F32)) => {
109            let conv = |x: half::f16| Ok(x.to_f32());
110            convert_with_cast_::<half::f16, f32, _>(view, device, conv)
111        }
112        (st::Dtype::F32, Some(DType::BF16)) => {
113            let conv = |x: f32| Ok(half::bf16::from_f32(x));
114            convert_with_cast_::<f32, half::bf16, _>(view, device, conv)
115        }
116        (st::Dtype::F32, Some(DType::F16)) => {
117            let conv = |x: f32| Ok(half::f16::from_f32(x));
118            convert_with_cast_::<f32, half::f16, _>(view, device, conv)
119        }
120
121        (st::Dtype::U8, _) => convert_::<u8>(view, device),
122        (st::Dtype::U16, _) => {
123            let conv = |x| Ok(u32::from(x));
124            convert_with_cast_::<u16, u32, _>(view, device, conv)
125        }
126        (st::Dtype::U32, _) => convert_::<u32>(view, device),
127        (st::Dtype::I16, _) => convert_::<i16>(view, device),
128        (st::Dtype::I32, _) => convert_::<i32>(view, device),
129        (st::Dtype::I64, _) => convert_::<i64>(view, device),
130        (st::Dtype::BF16, None | Some(DType::BF16)) => convert_::<half::bf16>(view, device),
131        (st::Dtype::F16, None | Some(DType::F16)) => convert_::<half::f16>(view, device),
132        (st::Dtype::F32, _) => convert_::<f32>(view, device),
133        (st::Dtype::F64, _) => convert_::<f64>(view, device),
134        (st::Dtype::F8_E4M3, _) => convert_::<F8E4M3>(view, device),
135        (dtype, _) => Err(Error::UnsupportedSafeTensorDtype(dtype)),
136    }
137}
138
139#[derive(yoke::Yokeable)]
140struct SafeTensors_<'a>(SafeTensors<'a>);
141
142pub struct MmapedSafetensors {
143    safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
144    routing: Option<HashMap<String, usize>>,
145}
146
147impl MmapedSafetensors {
148    /// Creates a wrapper around a memory mapped file and deserialize the safetensors header.
149    ///
150    /// # Safety
151    ///
152    /// The unsafe is inherited from [`memmap2::MmapOptions`].
153    pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
154        let p = p.as_ref();
155        let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
156        let file = memmap2::MmapOptions::new()
157            .map(&file)
158            .map_err(|e| Error::from(e).with_path(p))?;
159        let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
160            file,
161            |data: &[u8]| {
162                let st = safetensors::SafeTensors::deserialize(data)
163                    .map_err(|e| Error::from(e).with_path(p))?;
164                Ok::<_, Error>(SafeTensors_(st))
165            },
166        )?;
167        Ok(Self {
168            safetensors: vec![safetensors],
169            routing: None,
170        })
171    }
172
173    /// Creates a wrapper around multiple memory mapped file and deserialize the safetensors headers.
174    ///
175    /// If a tensor name appears in multiple files, the last entry is returned.
176    ///
177    /// # Safety
178    ///
179    /// The unsafe is inherited from [`memmap2::MmapOptions`].
180    pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
181        let mut routing = HashMap::new();
182        let mut safetensors = vec![];
183        for (index, p) in paths.iter().enumerate() {
184            let p = p.as_ref();
185            let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
186            let file = memmap2::MmapOptions::new()
187                .map(&file)
188                .map_err(|e| Error::from(e).with_path(p))?;
189            let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
190                file,
191                |data: &[u8]| {
192                    let st = safetensors::SafeTensors::deserialize(data)
193                        .map_err(|e| Error::from(e).with_path(p))?;
194                    Ok::<_, Error>(SafeTensors_(st))
195                },
196            )?;
197            for k in data.get().0.names() {
198                routing.insert(k.to_string(), index);
199            }
200            safetensors.push(data)
201        }
202        Ok(Self {
203            safetensors,
204            routing: Some(routing),
205        })
206    }
207
208    pub fn load(&self, name: &str, dev: &Device, dtype: Option<DType>) -> Result<Tensor> {
209        self.get(name)?.load(dev, dtype)
210    }
211
212    pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
213        let mut tensors = vec![];
214        for safetensors in self.safetensors.iter() {
215            tensors.push(safetensors.get().0.tensors())
216        }
217        tensors.into_iter().flatten().collect()
218    }
219
220    pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
221        let index = match &self.routing {
222            None => 0,
223            Some(routing) => {
224                let index = routing.get(name).ok_or_else(|| {
225                    Error::CannotFindTensor {
226                        path: name.to_string(),
227                    }
228                    .bt()
229                })?;
230                *index
231            }
232        };
233        Ok(self.safetensors[index].get().0.tensor(name)?)
234    }
235}
236
237impl SimpleBackend for MmapedSafetensors {
238    fn get(
239        &self,
240        s: Shape,
241        name: &str,
242        _: candle_nn::Init,
243        dtype: DType,
244        dev: &Device,
245    ) -> Result<Tensor> {
246        let tensor = self.get_unchecked(name, dtype, dev)?;
247        if tensor.shape() != &s {
248            Err(candle_core::Error::UnexpectedShape {
249                msg: format!("shape mismatch for {name}"),
250                expected: s,
251                got: tensor.shape().clone(),
252            }
253            .bt())?
254        }
255        Ok(tensor)
256    }
257
258    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
259        self.load(name, dev, Some(dtype))
260    }
261
262    fn contains_tensor(&self, name: &str) -> bool {
263        self.get(name).is_ok()
264    }
265}
266
267pub enum ShardedSafeTensors {
268    Sharded {
269        b: MmapedSafetensors,
270        make_dummy_regexes: Option<Arc<Vec<Regex>>>,
271        predicate: Arc<dyn Fn(String) -> bool + Send + Sync + 'static>,
272    },
273    SimpleBackend(Box<dyn SimpleBackend + 'static>),
274}
275
276pub type ShardedVarBuilder = VarBuilderArgs<'static, ShardedSafeTensors>;
277
278impl ShardedSafeTensors {
279    /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
280    /// files and make them usable in a sharded way.
281    ///
282    /// - If `regexes` is specified, this will be used in `make_dummy_predicate` based on `.any`
283    /// - Only include keys for which predicate evaluates to true.
284    ///
285    /// # Safety
286    ///
287    /// The unsafe is inherited from [`memmap2::MmapOptions`].
288    pub unsafe fn sharded<P: AsRef<std::path::Path>>(
289        paths: &[P],
290        dtype: DType,
291        dev: &Device,
292        make_dummy_regexes: Option<Arc<Vec<Regex>>>,
293        predicate: Arc<dyn Fn(String) -> bool + Send + Sync + 'static>,
294    ) -> Result<ShardedVarBuilder> {
295        let tensors = MmapedSafetensors::multi(paths)?;
296        let backend = ShardedSafeTensors::Sharded {
297            b: tensors,
298            make_dummy_regexes,
299            predicate,
300        };
301        Ok(VarBuilderArgs::new_with_args(backend, dtype, dev))
302    }
303}
304
305impl ShardedSafeTensors {
306    pub fn wrap(
307        backend: Box<dyn SimpleBackend + 'static>,
308        dtype: DType,
309        dev: Device,
310    ) -> ShardedVarBuilder {
311        VarBuilderArgs::new_with_args(Self::SimpleBackend(backend), dtype, &dev)
312    }
313}
314
315#[derive(Debug, Clone, Copy, Eq, PartialEq)]
316pub enum Shard {
317    Simple {
318        dim: usize,
319        rank: usize,
320        world_size: usize,
321    },
322    Offset {
323        dim: usize,
324        offset: usize,
325        len: usize,
326    },
327}
328
329impl Default for Shard {
330    fn default() -> Self {
331        Self::Simple {
332            dim: 0,
333            rank: 0,
334            world_size: 1,
335        }
336    }
337}
338
339/// Get part of a tensor, typically used to do Tensor Parallelism sharding.
340///
341/// If the tensor is of size (1024, 1024).
342///
343/// `dim` corresponds to the dimension to slice into
344/// `rank` is the rank of the current process
345/// `world_size` is the total number of ranks in the process group
346///
347/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))`
348/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))`
349/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))`
350impl Backend for ShardedSafeTensors {
351    type Hints = Shard;
352
353    fn get(
354        &self,
355        target_shape: Shape,
356        path: &str,
357        h: Self::Hints,
358        dtype: DType,
359        dev: &Device,
360    ) -> Result<Tensor> {
361        if let Shard::Simple { world_size: 1, .. } = &h {
362            // There is no sharding to be applied here so we use the default backend to speed
363            // things up.
364            match self {
365                Self::Sharded {
366                    b,
367                    make_dummy_regexes,
368                    predicate,
369                } => {
370                    if let Some(make_dummy_regexes) = make_dummy_regexes {
371                        if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
372                            return Err(Error::CannotFindTensor {
373                                path: path.to_string(),
374                            });
375                        }
376                    }
377                    let should_include = predicate(path.to_string());
378                    if !should_include {
379                        return Err(Error::CannotFindTensor {
380                            path: path.to_string(),
381                        });
382                    }
383
384                    return SimpleBackend::get(
385                        b,
386                        target_shape,
387                        path,
388                        Default::default(),
389                        dtype,
390                        dev,
391                    );
392                }
393                Self::SimpleBackend(b) => {
394                    return SimpleBackend::get(
395                        b.as_ref(),
396                        target_shape,
397                        path,
398                        Default::default(),
399                        dtype,
400                        dev,
401                    )
402                }
403            }
404        }
405
406        let result = match h {
407            Shard::Simple {
408                dim,
409                rank,
410                world_size,
411            } => {
412                match self {
413                    Self::Sharded {
414                        b,
415                        make_dummy_regexes,
416                        predicate,
417                    } => {
418                        use safetensors::slice::IndexOp;
419
420                        if let Some(make_dummy_regexes) = make_dummy_regexes {
421                            if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
422                                return Err(Error::CannotFindTensor {
423                                    path: path.to_string(),
424                                });
425                            }
426                        }
427                        let should_include = predicate(path.to_string());
428                        if !should_include {
429                            return Err(Error::CannotFindTensor {
430                                path: path.to_string(),
431                            });
432                        }
433
434                        let view = b.get(path)?;
435                        let view_dtype = view.dtype();
436                        let mut shape = view.shape().to_vec();
437                        let size = shape[dim];
438
439                        if size % world_size != 0 {
440                            return Err(Error::ShapeMismatchSplit {
441                                shape: shape.into(),
442                                dim,
443                                n_parts: world_size,
444                            });
445                        }
446                        let block_size = size / world_size;
447                        let start = rank * block_size;
448                        let stop = (rank + 1) * block_size;
449
450                        // Everything is expressed in tensor dimension
451                        // bytes offsets is handled automatically for safetensors.
452
453                        let iterator = if dim == 0 {
454                            view.slice(start..stop).map_err(|_| {
455                                Error::Msg(format!(
456                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
457                                ))
458                            })?
459                        } else if dim == 1 {
460                            view.slice((.., start..stop)).map_err(|_| {
461                                Error::Msg(format!(
462                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
463                                ))
464                            })?
465                        } else if dim == 2 {
466                            view.slice((.., .., start..stop)).map_err(|_| {
467                                Error::Msg(format!(
468                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
469                                ))
470                            })?
471                        } else {
472                            candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
473                        };
474
475                        shape[dim] = block_size;
476
477                        let view_dtype: DType = view_dtype.try_into()?;
478                        let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
479                        Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)?
480                    }
481                    Self::SimpleBackend(b) => {
482                        use candle_core::IndexOp;
483                        let tensor = b.get(target_shape, path, Default::default(), dtype, dev)?;
484
485                        let size = tensor.dim(dim)?;
486                        let shape = tensor.dims().to_vec();
487
488                        if size % world_size != 0 {
489                            return Err(Error::ShapeMismatchSplit {
490                                shape: shape.into(),
491                                dim,
492                                n_parts: world_size,
493                            });
494                        }
495                        let block_size = size / world_size;
496                        let start = rank * block_size;
497                        let stop = (rank + 1) * block_size;
498
499                        if dim == 0 {
500                            tensor.i(start..stop)?
501                        } else if dim == 1 {
502                            tensor.i((.., start..stop))?
503                        } else if dim == 2 {
504                            tensor.i((.., .., start..stop))?
505                        } else {
506                            candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
507                        }
508                    }
509                }
510            }
511            Shard::Offset { dim, offset, len } => {
512                match self {
513                    Self::Sharded {
514                        b,
515                        make_dummy_regexes,
516                        predicate,
517                    } => {
518                        use safetensors::slice::IndexOp;
519
520                        if let Some(make_dummy_regexes) = make_dummy_regexes {
521                            if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
522                                return Err(Error::CannotFindTensor {
523                                    path: path.to_string(),
524                                });
525                            }
526                        }
527                        let should_include = predicate(path.to_string());
528                        if !should_include {
529                            return Err(Error::CannotFindTensor {
530                                path: path.to_string(),
531                            });
532                        }
533
534                        let view = b.get(path)?;
535                        let view_dtype = view.dtype();
536                        let mut shape = view.shape().to_vec();
537
538                        let start = offset;
539                        let stop = start + len;
540
541                        // Everything is expressed in tensor dimension
542                        // bytes offsets is handled automatically for safetensors.
543
544                        let iterator = if dim == 0 {
545                            view.slice(start..stop).map_err(|_| {
546                                Error::Msg(format!(
547                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
548                                ))
549                            })?
550                        } else if dim == 1 {
551                            view.slice((.., start..stop)).map_err(|_| {
552                                Error::Msg(format!(
553                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
554                                ))
555                            })?
556                        } else if dim == 2 {
557                            view.slice((.., .., start..stop)).map_err(|_| {
558                                Error::Msg(format!(
559                                    "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
560                                ))
561                            })?
562                        } else {
563                            candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
564                        };
565
566                        shape[dim] = len;
567
568                        let view_dtype: DType = view_dtype.try_into()?;
569                        let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
570                        Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)?
571                    }
572                    Self::SimpleBackend(b) => {
573                        use candle_core::IndexOp;
574                        let tensor = b.get(target_shape, path, Default::default(), dtype, dev)?;
575
576                        let start = offset;
577                        let stop = start + len;
578
579                        if dim == 0 {
580                            tensor.i(start..stop)?
581                        } else if dim == 1 {
582                            tensor.i((.., start..stop))?
583                        } else if dim == 2 {
584                            tensor.i((.., .., start..stop))?
585                        } else {
586                            candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
587                        }
588                    }
589                }
590            }
591        };
592
593        result.contiguous()
594    }
595
596    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
597        match self {
598            Self::Sharded {
599                b,
600                make_dummy_regexes,
601                predicate,
602            } => {
603                if let Some(make_dummy_regexes) = make_dummy_regexes {
604                    if make_dummy_regexes.iter().any(|x| x.is_match(name)) {
605                        return Err(Error::CannotFindTensor {
606                            path: name.to_string(),
607                        });
608                    }
609                }
610                let should_include = predicate(name.to_string());
611                if !should_include {
612                    return Err(Error::CannotFindTensor {
613                        path: name.to_string(),
614                    });
615                }
616                <MmapedSafetensors as SimpleBackend>::get_unchecked(b, name, dtype, dev)
617            }
618            Self::SimpleBackend(b) => b.as_ref().get_unchecked(name, dtype, dev),
619        }
620    }
621
622    fn contains_tensor(&self, name: &str) -> bool {
623        match self {
624            Self::Sharded {
625                b,
626                make_dummy_regexes,
627                predicate,
628            } => {
629                if let Some(make_dummy_regexes) = make_dummy_regexes {
630                    if make_dummy_regexes.iter().any(|x| x.is_match(name)) {
631                        return false;
632                    }
633                }
634                let should_include = predicate(name.to_string());
635                if !should_include {
636                    return false;
637                }
638                b.get(name).is_ok()
639            }
640            Self::SimpleBackend(b) => b.as_ref().contains_tensor(name),
641        }
642    }
643}