1use candle_core::{DType, Device, Error, IndexOp, Result, Shape, Storage, Tensor, WithDType};
2use candle_nn::var_builder::{Backend, SimpleBackend, VarBuilderArgs};
3use float8::F8E4M3;
4use regex::Regex;
5use safetensors::tensor as st;
6use safetensors::tensor::SafeTensors;
7use std::collections::HashMap;
8use std::path::Path;
9use std::sync::Arc;
10
11fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
12 let size_in_bytes = T::DTYPE.size_in_bytes();
13 let elem_count = data.len() / size_in_bytes;
14 if (data.as_ptr() as usize) % size_in_bytes == 0 {
15 let data: &[T] =
18 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
19 Tensor::from_slice(data, shape, device)
20 } else {
21 let mut c: Vec<T> = Vec::with_capacity(elem_count);
24 unsafe {
29 std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
30 c.set_len(elem_count)
31 }
32 Tensor::from_slice(&c, shape, device)
33 }
34}
35
36fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
37 data: &[u8],
38 shape: &[usize],
39 device: &Device,
40 conv: F,
41) -> Result<Tensor> {
42 let size_in_bytes = std::mem::size_of::<T>();
43 let elem_count = data.len() / size_in_bytes;
44 if (data.as_ptr() as usize) % size_in_bytes == 0 {
45 let data: &[T] =
48 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
49 let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
50 Tensor::from_vec(data, shape, device)
51 } else {
52 let mut c: Vec<T> = Vec::with_capacity(elem_count);
55 unsafe {
60 std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
61 c.set_len(elem_count)
62 }
63 let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
64 Tensor::from_vec(c, shape, device)
65 }
66}
67
68fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
69 view: &st::TensorView<'_>,
70 device: &Device,
71 conv: F,
72) -> Result<Tensor> {
73 convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
74}
75
76fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
77 convert_slice::<T>(view.data(), view.shape(), device)
78}
79
80pub trait Load {
81 fn load(&self, device: &Device, dtype: Option<DType>) -> Result<Tensor>;
82}
83
84impl Load for st::TensorView<'_> {
85 fn load(&self, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
86 convert(self, device, dtype)
87 }
88}
89
90fn convert(
91 view: &st::TensorView<'_>,
92 device: &Device,
93 cast_dtype: Option<DType>,
94) -> Result<Tensor> {
95 match (view.dtype(), cast_dtype) {
96 (st::Dtype::BF16, Some(DType::F16)) => {
97 let conv = |x: half::bf16| Ok(half::f16::from_f32(x.to_f32()));
98 convert_with_cast_::<half::bf16, half::f16, _>(view, device, conv)
99 }
100 (st::Dtype::BF16, Some(DType::F32)) => {
101 let conv = |x: half::bf16| Ok(x.to_f32());
102 convert_with_cast_::<half::bf16, f32, _>(view, device, conv)
103 }
104 (st::Dtype::F16, Some(DType::BF16)) => {
105 let conv = |x: half::f16| Ok(half::bf16::from_f32(x.to_f32()));
106 convert_with_cast_::<half::f16, half::bf16, _>(view, device, conv)
107 }
108 (st::Dtype::F16, Some(DType::F32)) => {
109 let conv = |x: half::f16| Ok(x.to_f32());
110 convert_with_cast_::<half::f16, f32, _>(view, device, conv)
111 }
112 (st::Dtype::F32, Some(DType::BF16)) => {
113 let conv = |x: f32| Ok(half::bf16::from_f32(x));
114 convert_with_cast_::<f32, half::bf16, _>(view, device, conv)
115 }
116 (st::Dtype::F32, Some(DType::F16)) => {
117 let conv = |x: f32| Ok(half::f16::from_f32(x));
118 convert_with_cast_::<f32, half::f16, _>(view, device, conv)
119 }
120
121 (st::Dtype::U8, _) => convert_::<u8>(view, device),
122 (st::Dtype::U16, _) => {
123 let conv = |x| Ok(u32::from(x));
124 convert_with_cast_::<u16, u32, _>(view, device, conv)
125 }
126 (st::Dtype::U32, _) => convert_::<u32>(view, device),
127 (st::Dtype::I16, _) => convert_::<i16>(view, device),
128 (st::Dtype::I32, _) => convert_::<i32>(view, device),
129 (st::Dtype::I64, _) => convert_::<i64>(view, device),
130 (st::Dtype::BF16, None | Some(DType::BF16)) => convert_::<half::bf16>(view, device),
131 (st::Dtype::F16, None | Some(DType::F16)) => convert_::<half::f16>(view, device),
132 (st::Dtype::F32, _) => convert_::<f32>(view, device),
133 (st::Dtype::F64, _) => convert_::<f64>(view, device),
134 (st::Dtype::F8_E4M3, _) => convert_::<F8E4M3>(view, device),
135 (st::Dtype::F6_E2M3, _)
136 | (st::Dtype::F6_E3M2, _)
137 | (st::Dtype::F4, _)
138 | (st::Dtype::F8_E8M0, _) => {
139 convert_dummy(view, device)
143 }
144 (dtype, _) => Err(Error::UnsupportedSafeTensorDtype(dtype)),
145 }
146}
147
148fn convert_dummy(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
149 let (dtype, _dtype_name) = match view.dtype() {
152 st::Dtype::F6_E2M3 => (DType::F6E2M3, "F6_E2M3 (MX6)"),
153 st::Dtype::F6_E3M2 => (DType::F6E3M2, "F6_E3M2 (MX6)"),
154 st::Dtype::F4 => (DType::F4, "F4 (MX4)"),
155 st::Dtype::F8_E8M0 => (DType::F8E8M0, "F8_E8M0"),
156 _ => unreachable!("convert_dummy called with non-dummy dtype"),
157 };
158
159 let data = view.data();
161 let shape = view.shape();
162
163 let storage = match device {
165 Device::Cpu => {
166 let cpu_storage = match dtype {
167 DType::F6E2M3 => candle_core::cpu_backend::CpuStorage::F6E2M3(data.to_vec()),
168 DType::F6E3M2 => candle_core::cpu_backend::CpuStorage::F6E3M2(data.to_vec()),
169 DType::F4 => candle_core::cpu_backend::CpuStorage::F4(data.to_vec()),
170 DType::F8E8M0 => candle_core::cpu_backend::CpuStorage::F8E8M0(data.to_vec()),
171 _ => unreachable!(),
172 };
173 Storage::Cpu(cpu_storage)
174 }
175 #[cfg(feature = "cuda")]
176 Device::Cuda(device) => {
177 let mut slice = unsafe { device.alloc::<u8>(data.len())? };
178 device.memcpy_htod(data, &mut slice)?;
179
180 let slice = match dtype {
181 DType::F6E2M3 => candle_core::cuda_backend::CudaStorageSlice::F6E2M3(slice),
182 DType::F6E3M2 => candle_core::cuda_backend::CudaStorageSlice::F6E3M2(slice),
183 DType::F4 => candle_core::cuda_backend::CudaStorageSlice::F4(slice),
184 DType::F8E8M0 => candle_core::cuda_backend::CudaStorageSlice::F8E8M0(slice),
185 _ => unreachable!(),
186 };
187 let storage = candle_core::cuda_backend::CudaStorage {
188 slice,
189 device: device.clone(),
190 };
191 Storage::Cuda(storage)
192 }
193 #[cfg(not(feature = "cuda"))]
194 Device::Cuda(_) => {
195 return Err(Error::Msg("CUDA support not compiled".to_string()));
196 }
197 #[cfg(feature = "metal")]
198 Device::Metal(device) => {
199 let buffer = device.new_buffer_with_data(data)?;
200
201 let storage = candle_core::metal_backend::MetalStorage::new(
202 buffer,
203 device.clone(),
204 data.len(),
205 dtype,
206 );
207 Storage::Metal(storage)
208 }
209 #[cfg(not(feature = "metal"))]
210 Device::Metal(_) => {
211 return Err(Error::Msg("Metal support not compiled".to_string()));
212 }
213 };
214
215 Ok(Tensor::from((storage, shape)))
216}
217
218#[derive(yoke::Yokeable)]
219struct SafeTensors_<'a>(SafeTensors<'a>);
220
221pub struct MmapedSafetensors {
222 safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
223 routing: Option<HashMap<String, usize>>,
224}
225
226impl MmapedSafetensors {
227 pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
233 let p = p.as_ref();
234 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
235 let file = memmap2::MmapOptions::new()
236 .map(&file)
237 .map_err(|e| Error::from(e).with_path(p))?;
238 let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
239 file,
240 |data: &[u8]| {
241 let st = safetensors::SafeTensors::deserialize(data)
242 .map_err(|e| Error::from(e).with_path(p))?;
243 Ok::<_, Error>(SafeTensors_(st))
244 },
245 )?;
246 Ok(Self {
247 safetensors: vec![safetensors],
248 routing: None,
249 })
250 }
251
252 pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
260 let mut routing = HashMap::new();
261 let mut safetensors = vec![];
262 for (index, p) in paths.iter().enumerate() {
263 let p = p.as_ref();
264 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
265 let file = memmap2::MmapOptions::new()
266 .map(&file)
267 .map_err(|e| Error::from(e).with_path(p))?;
268 let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
269 file,
270 |data: &[u8]| {
271 let st = safetensors::SafeTensors::deserialize(data)
272 .map_err(|e| Error::from(e).with_path(p))?;
273 Ok::<_, Error>(SafeTensors_(st))
274 },
275 )?;
276 for k in data.get().0.names() {
277 routing.insert(k.to_string(), index);
278 }
279 safetensors.push(data)
280 }
281 Ok(Self {
282 safetensors,
283 routing: Some(routing),
284 })
285 }
286
287 pub fn load(&self, name: &str, dev: &Device, dtype: Option<DType>) -> Result<Tensor> {
288 self.get(name)?.load(dev, dtype)
289 }
290
291 pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
292 let mut tensors = vec![];
293 for safetensors in self.safetensors.iter() {
294 tensors.push(safetensors.get().0.tensors())
295 }
296 tensors.into_iter().flatten().collect()
297 }
298
299 pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
300 let index = match &self.routing {
301 None => 0,
302 Some(routing) => {
303 let index = routing.get(name).ok_or_else(|| {
304 Error::CannotFindTensor {
305 path: name.to_string(),
306 }
307 .bt()
308 })?;
309 *index
310 }
311 };
312 Ok(self.safetensors[index].get().0.tensor(name)?)
313 }
314}
315
316impl SimpleBackend for MmapedSafetensors {
317 fn get(
318 &self,
319 s: Shape,
320 name: &str,
321 _: candle_nn::Init,
322 dtype: DType,
323 dev: &Device,
324 ) -> Result<Tensor> {
325 let tensor = self.get_unchecked(name, dtype, dev)?;
326 if tensor.shape() != &s {
327 Err(candle_core::Error::UnexpectedShape {
328 msg: format!("shape mismatch for {name}"),
329 expected: s,
330 got: tensor.shape().clone(),
331 }
332 .bt())?
333 }
334 Ok(tensor)
335 }
336
337 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
338 self.load(name, dev, Some(dtype))
339 }
340
341 fn contains_tensor(&self, name: &str) -> bool {
342 self.get(name).is_ok()
343 }
344}
345
346pub enum ShardedSafeTensors {
347 Sharded {
348 b: MmapedSafetensors,
349 make_dummy_regexes: Option<Arc<Vec<Regex>>>,
350 predicate: Arc<dyn Fn(String) -> bool + Send + Sync + 'static>,
351 },
352 SimpleBackend(Box<dyn SimpleBackend + 'static>),
353}
354
355pub type ShardedVarBuilder = VarBuilderArgs<'static, ShardedSafeTensors>;
356
357impl ShardedSafeTensors {
358 pub unsafe fn sharded<P: AsRef<std::path::Path>>(
368 paths: &[P],
369 dtype: DType,
370 dev: &Device,
371 make_dummy_regexes: Option<Arc<Vec<Regex>>>,
372 predicate: Arc<dyn Fn(String) -> bool + Send + Sync + 'static>,
373 ) -> Result<ShardedVarBuilder> {
374 let tensors = MmapedSafetensors::multi(paths)?;
375 let backend = ShardedSafeTensors::Sharded {
376 b: tensors,
377 make_dummy_regexes,
378 predicate,
379 };
380 Ok(VarBuilderArgs::new_with_args(backend, dtype, dev))
381 }
382}
383
384impl ShardedSafeTensors {
385 pub fn wrap(
386 backend: Box<dyn SimpleBackend + 'static>,
387 dtype: DType,
388 dev: Device,
389 ) -> ShardedVarBuilder {
390 VarBuilderArgs::new_with_args(Self::SimpleBackend(backend), dtype, &dev)
391 }
392}
393
394#[derive(Debug, Clone, Copy, Eq, PartialEq)]
395pub enum Shard {
396 Simple {
397 dim: usize,
398 rank: usize,
399 world_size: usize,
400 },
401 Offset {
402 dim: usize,
403 offset: usize,
404 len: usize,
405 },
406}
407
408impl Shard {
409 pub fn apply_to(&self, tensor: &Tensor) -> Result<Tensor> {
410 match *self {
411 Shard::Simple {
412 dim,
413 rank,
414 world_size,
415 } => {
416 let size = tensor.dim(dim)?;
417 let shape = tensor.dims().to_vec();
418
419 if size % world_size != 0 {
420 return Err(Error::ShapeMismatchSplit {
421 shape: shape.into(),
422 dim,
423 n_parts: world_size,
424 });
425 }
426 let block_size = size / world_size;
427 let start = rank * block_size;
428 let stop = (rank + 1) * block_size;
429
430 if dim == 0 {
431 tensor.i(start..stop)
432 } else if dim == 1 {
433 tensor.i((.., start..stop))
434 } else if dim == 2 {
435 tensor.i((.., .., start..stop))
436 } else {
437 candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
438 }
439 }
440 Shard::Offset { dim, offset, len } => {
441 let start = offset;
442 let stop = start + len;
443
444 if dim == 0 {
445 tensor.i(start..stop)
446 } else if dim == 1 {
447 tensor.i((.., start..stop))
448 } else if dim == 2 {
449 tensor.i((.., .., start..stop))
450 } else {
451 candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
452 }
453 }
454 }
455 }
456}
457
458impl Default for Shard {
459 fn default() -> Self {
460 Self::Simple {
461 dim: 0,
462 rank: 0,
463 world_size: 1,
464 }
465 }
466}
467
468impl Backend for ShardedSafeTensors {
480 type Hints = Shard;
481
482 fn get(
483 &self,
484 target_shape: Shape,
485 path: &str,
486 h: Self::Hints,
487 dtype: DType,
488 dev: &Device,
489 ) -> Result<Tensor> {
490 if let Shard::Simple { world_size: 1, .. } = &h {
491 match self {
494 Self::Sharded {
495 b,
496 make_dummy_regexes,
497 predicate,
498 } => {
499 if let Some(make_dummy_regexes) = make_dummy_regexes {
500 if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
501 return Err(Error::CannotFindTensor {
502 path: path.to_string(),
503 });
504 }
505 }
506 let should_include = predicate(path.to_string());
507 if !should_include {
508 return Err(Error::CannotFindTensor {
509 path: path.to_string(),
510 });
511 }
512
513 return SimpleBackend::get(
514 b,
515 target_shape,
516 path,
517 Default::default(),
518 dtype,
519 dev,
520 );
521 }
522 Self::SimpleBackend(b) => {
523 return SimpleBackend::get(
524 b.as_ref(),
525 target_shape,
526 path,
527 Default::default(),
528 dtype,
529 dev,
530 )
531 }
532 }
533 }
534
535 let result = match h {
536 Shard::Simple {
537 dim,
538 rank,
539 world_size,
540 } => {
541 match self {
542 Self::Sharded {
543 b,
544 make_dummy_regexes,
545 predicate,
546 } => {
547 use safetensors::slice::IndexOp;
548
549 if let Some(make_dummy_regexes) = make_dummy_regexes {
550 if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
551 return Err(Error::CannotFindTensor {
552 path: path.to_string(),
553 });
554 }
555 }
556 let should_include = predicate(path.to_string());
557 if !should_include {
558 return Err(Error::CannotFindTensor {
559 path: path.to_string(),
560 });
561 }
562
563 let view = b.get(path)?;
564 let view_dtype = view.dtype();
565 let mut shape = view.shape().to_vec();
566 let size = shape[dim];
567
568 if size % world_size != 0 {
569 return Err(Error::ShapeMismatchSplit {
570 shape: shape.into(),
571 dim,
572 n_parts: world_size,
573 });
574 }
575 let block_size = size / world_size;
576 let start = rank * block_size;
577 let stop = (rank + 1) * block_size;
578
579 let iterator = if dim == 0 {
583 view.slice(start..stop).map_err(|_| {
584 Error::Msg(format!(
585 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
586 ))
587 })?
588 } else if dim == 1 {
589 view.slice((.., start..stop)).map_err(|_| {
590 Error::Msg(format!(
591 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
592 ))
593 })?
594 } else if dim == 2 {
595 view.slice((.., .., start..stop)).map_err(|_| {
596 Error::Msg(format!(
597 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
598 ))
599 })?
600 } else {
601 candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
602 };
603
604 shape[dim] = block_size;
605
606 let view_dtype: DType = view_dtype.try_into()?;
607 let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
608 Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)?
609 }
610 Self::SimpleBackend(b) => {
611 let tensor = b.get(target_shape, path, Default::default(), dtype, dev)?;
612 h.apply_to(&tensor)?
613 }
614 }
615 }
616 Shard::Offset { dim, offset, len } => {
617 match self {
618 Self::Sharded {
619 b,
620 make_dummy_regexes,
621 predicate,
622 } => {
623 use safetensors::slice::IndexOp;
624
625 if let Some(make_dummy_regexes) = make_dummy_regexes {
626 if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
627 return Err(Error::CannotFindTensor {
628 path: path.to_string(),
629 });
630 }
631 }
632 let should_include = predicate(path.to_string());
633 if !should_include {
634 return Err(Error::CannotFindTensor {
635 path: path.to_string(),
636 });
637 }
638
639 let view = b.get(path)?;
640 let view_dtype = view.dtype();
641 let mut shape = view.shape().to_vec();
642
643 let start = offset;
644 let stop = start + len;
645
646 let iterator = if dim == 0 {
650 view.slice(start..stop).map_err(|_| {
651 Error::Msg(format!(
652 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
653 ))
654 })?
655 } else if dim == 1 {
656 view.slice((.., start..stop)).map_err(|_| {
657 Error::Msg(format!(
658 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
659 ))
660 })?
661 } else if dim == 2 {
662 view.slice((.., .., start..stop)).map_err(|_| {
663 Error::Msg(format!(
664 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
665 ))
666 })?
667 } else {
668 candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
669 };
670
671 shape[dim] = len;
672
673 let view_dtype: DType = view_dtype.try_into()?;
674 let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
675 Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)?
676 }
677 Self::SimpleBackend(b) => {
678 let tensor = b.get(target_shape, path, Default::default(), dtype, dev)?;
679 h.apply_to(&tensor)?
680 }
681 }
682 }
683 };
684
685 result.contiguous()
686 }
687
688 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
689 match self {
690 Self::Sharded {
691 b,
692 make_dummy_regexes,
693 predicate,
694 } => {
695 if let Some(make_dummy_regexes) = make_dummy_regexes {
696 if make_dummy_regexes.iter().any(|x| x.is_match(name)) {
697 return Err(Error::CannotFindTensor {
698 path: name.to_string(),
699 });
700 }
701 }
702 let should_include = predicate(name.to_string());
703 if !should_include {
704 return Err(Error::CannotFindTensor {
705 path: name.to_string(),
706 });
707 }
708 <MmapedSafetensors as SimpleBackend>::get_unchecked(b, name, dtype, dev)
709 }
710 Self::SimpleBackend(b) => b.as_ref().get_unchecked(name, dtype, dev),
711 }
712 }
713
714 fn contains_tensor(&self, name: &str) -> bool {
715 match self {
716 Self::Sharded {
717 b,
718 make_dummy_regexes,
719 predicate,
720 } => {
721 if let Some(make_dummy_regexes) = make_dummy_regexes {
722 if make_dummy_regexes.iter().any(|x| x.is_match(name)) {
723 return false;
724 }
725 }
726 let should_include = predicate(name.to_string());
727 if !should_include {
728 return false;
729 }
730 b.get(name).is_ok()
731 }
732 Self::SimpleBackend(b) => b.as_ref().contains_tensor(name),
733 }
734 }
735}