1use candle_core::{
2 from_storage_no_op, DType, Device, Error, IndexOp, Result, Shape, Storage, Tensor, WithDType,
3};
4use candle_nn::var_builder::{Backend, SimpleBackend, VarBuilderArgs};
5use float8::F8E4M3;
6use regex::Regex;
7use safetensors::tensor as st;
8use safetensors::tensor::SafeTensors;
9use std::collections::HashMap;
10use std::path::Path;
11use std::sync::Arc;
12
13fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
14 let size_in_bytes = T::DTYPE.size_in_bytes();
15 let elem_count = data.len() / size_in_bytes;
16 if (data.as_ptr() as usize) % size_in_bytes == 0 {
17 let data: &[T] =
20 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
21 Tensor::from_slice(data, shape, device)
22 } else {
23 let mut c: Vec<T> = Vec::with_capacity(elem_count);
26 unsafe {
31 std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
32 c.set_len(elem_count)
33 }
34 Tensor::from_slice(&c, shape, device)
35 }
36}
37
38fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
39 data: &[u8],
40 shape: &[usize],
41 device: &Device,
42 conv: F,
43) -> Result<Tensor> {
44 let size_in_bytes = std::mem::size_of::<T>();
45 let elem_count = data.len() / size_in_bytes;
46 if (data.as_ptr() as usize) % size_in_bytes == 0 {
47 let data: &[T] =
50 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
51 let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
52 Tensor::from_vec(data, shape, device)
53 } else {
54 let mut c: Vec<T> = Vec::with_capacity(elem_count);
57 unsafe {
62 std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
63 c.set_len(elem_count)
64 }
65 let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
66 Tensor::from_vec(c, shape, device)
67 }
68}
69
70fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
71 view: &st::TensorView<'_>,
72 device: &Device,
73 conv: F,
74) -> Result<Tensor> {
75 convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
76}
77
78fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
79 convert_slice::<T>(view.data(), view.shape(), device)
80}
81
82pub trait Load {
83 fn load(&self, device: &Device, dtype: Option<DType>) -> Result<Tensor>;
84}
85
86impl Load for st::TensorView<'_> {
87 fn load(&self, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
88 convert(self, device, dtype)
89 }
90}
91
92fn convert(
93 view: &st::TensorView<'_>,
94 device: &Device,
95 cast_dtype: Option<DType>,
96) -> Result<Tensor> {
97 match (view.dtype(), cast_dtype) {
98 (st::Dtype::BF16, Some(DType::F16)) => {
99 let conv = |x: half::bf16| Ok(half::f16::from_f32(x.to_f32()));
100 convert_with_cast_::<half::bf16, half::f16, _>(view, device, conv)
101 }
102 (st::Dtype::BF16, Some(DType::F32)) => {
103 let conv = |x: half::bf16| Ok(x.to_f32());
104 convert_with_cast_::<half::bf16, f32, _>(view, device, conv)
105 }
106 (st::Dtype::F16, Some(DType::BF16)) => {
107 let conv = |x: half::f16| Ok(half::bf16::from_f32(x.to_f32()));
108 convert_with_cast_::<half::f16, half::bf16, _>(view, device, conv)
109 }
110 (st::Dtype::F16, Some(DType::F32)) => {
111 let conv = |x: half::f16| Ok(x.to_f32());
112 convert_with_cast_::<half::f16, f32, _>(view, device, conv)
113 }
114 (st::Dtype::F32, Some(DType::BF16)) => {
115 let conv = |x: f32| Ok(half::bf16::from_f32(x));
116 convert_with_cast_::<f32, half::bf16, _>(view, device, conv)
117 }
118 (st::Dtype::F32, Some(DType::F16)) => {
119 let conv = |x: f32| Ok(half::f16::from_f32(x));
120 convert_with_cast_::<f32, half::f16, _>(view, device, conv)
121 }
122
123 (st::Dtype::U8, _) => convert_::<u8>(view, device),
124 (st::Dtype::U16, _) => {
125 let conv = |x| Ok(u32::from(x));
126 convert_with_cast_::<u16, u32, _>(view, device, conv)
127 }
128 (st::Dtype::U32, _) => convert_::<u32>(view, device),
129 (st::Dtype::I16, _) => convert_::<i16>(view, device),
130 (st::Dtype::I32, _) => convert_::<i32>(view, device),
131 (st::Dtype::I64, _) => convert_::<i64>(view, device),
132 (st::Dtype::BF16, None | Some(DType::BF16)) => convert_::<half::bf16>(view, device),
133 (st::Dtype::F16, None | Some(DType::F16)) => convert_::<half::f16>(view, device),
134 (st::Dtype::F32, _) => convert_::<f32>(view, device),
135 (st::Dtype::F64, _) => convert_::<f64>(view, device),
136 (st::Dtype::F8_E4M3, _) => convert_::<F8E4M3>(view, device),
137 (st::Dtype::F6_E2M3, _)
138 | (st::Dtype::F6_E3M2, _)
139 | (st::Dtype::F4, _)
140 | (st::Dtype::F8_E8M0, _) => {
141 convert_dummy(view, device)
145 }
146 (dtype, _) => Err(Error::UnsupportedSafeTensorDtype(dtype)),
147 }
148}
149
150fn convert_dummy(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
151 let (dtype, _dtype_name) = match view.dtype() {
154 st::Dtype::F6_E2M3 => (DType::F6E2M3, "F6_E2M3 (MX6)"),
155 st::Dtype::F6_E3M2 => (DType::F6E3M2, "F6_E3M2 (MX6)"),
156 st::Dtype::F4 => (DType::F4, "F4 (MX4)"),
157 st::Dtype::F8_E8M0 => (DType::F8E8M0, "F8_E8M0"),
158 _ => unreachable!("convert_dummy called with non-dummy dtype"),
159 };
160
161 let data = view.data();
163 let shape = view.shape();
164
165 let storage = match device {
167 Device::Cpu => {
168 let cpu_storage = match dtype {
169 DType::F6E2M3 => candle_core::cpu_backend::CpuStorage::F6E2M3(data.to_vec()),
170 DType::F6E3M2 => candle_core::cpu_backend::CpuStorage::F6E3M2(data.to_vec()),
171 DType::F4 => candle_core::cpu_backend::CpuStorage::F4(data.to_vec()),
172 DType::F8E8M0 => candle_core::cpu_backend::CpuStorage::F8E8M0(data.to_vec()),
173 _ => unreachable!(),
174 };
175 Storage::Cpu(cpu_storage)
176 }
177 #[cfg(feature = "cuda")]
178 Device::Cuda(device) => {
179 let mut slice = unsafe { device.alloc::<u8>(data.len())? };
180 device.memcpy_htod(data, &mut slice)?;
181
182 let slice = match dtype {
183 DType::F6E2M3 => candle_core::cuda_backend::CudaStorageSlice::F6E2M3(slice),
184 DType::F6E3M2 => candle_core::cuda_backend::CudaStorageSlice::F6E3M2(slice),
185 DType::F4 => candle_core::cuda_backend::CudaStorageSlice::F4(slice),
186 DType::F8E8M0 => candle_core::cuda_backend::CudaStorageSlice::F8E8M0(slice),
187 _ => unreachable!(),
188 };
189 let storage = candle_core::cuda_backend::CudaStorage {
190 slice,
191 device: device.clone(),
192 };
193 Storage::Cuda(storage)
194 }
195 #[cfg(not(feature = "cuda"))]
196 Device::Cuda(_) => {
197 return Err(Error::Msg("CUDA support not compiled".to_string()));
198 }
199 #[cfg(feature = "metal")]
200 Device::Metal(device) => {
201 let buffer = device.new_buffer_with_data(data)?;
202
203 let storage = candle_core::metal_backend::MetalStorage::new(
204 buffer,
205 device.clone(),
206 data.len(),
207 dtype,
208 );
209 Storage::Metal(storage)
210 }
211 #[cfg(not(feature = "metal"))]
212 Device::Metal(_) => {
213 return Err(Error::Msg("Metal support not compiled".to_string()));
214 }
215 };
216
217 Ok(from_storage_no_op(storage, shape, false))
218}
219
220#[derive(yoke::Yokeable)]
221struct SafeTensors_<'a>(SafeTensors<'a>);
222
223pub struct MmapedSafetensors {
224 safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
225 routing: Option<HashMap<String, usize>>,
226}
227
228impl MmapedSafetensors {
229 pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
235 let p = p.as_ref();
236 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
237 let file = memmap2::MmapOptions::new()
238 .map(&file)
239 .map_err(|e| Error::from(e).with_path(p))?;
240 let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
241 file,
242 |data: &[u8]| {
243 let st = safetensors::SafeTensors::deserialize(data)
244 .map_err(|e| Error::from(e).with_path(p))?;
245 Ok::<_, Error>(SafeTensors_(st))
246 },
247 )?;
248 Ok(Self {
249 safetensors: vec![safetensors],
250 routing: None,
251 })
252 }
253
254 pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
262 let mut routing = HashMap::new();
263 let mut safetensors = vec![];
264 for (index, p) in paths.iter().enumerate() {
265 let p = p.as_ref();
266 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
267 let file = memmap2::MmapOptions::new()
268 .map(&file)
269 .map_err(|e| Error::from(e).with_path(p))?;
270 let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
271 file,
272 |data: &[u8]| {
273 let st = safetensors::SafeTensors::deserialize(data)
274 .map_err(|e| Error::from(e).with_path(p))?;
275 Ok::<_, Error>(SafeTensors_(st))
276 },
277 )?;
278 for k in data.get().0.names() {
279 routing.insert(k.to_string(), index);
280 }
281 safetensors.push(data)
282 }
283 Ok(Self {
284 safetensors,
285 routing: Some(routing),
286 })
287 }
288
289 pub fn load(&self, name: &str, dev: &Device, dtype: Option<DType>) -> Result<Tensor> {
290 self.get(name)?.load(dev, dtype)
291 }
292
293 pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
294 let mut tensors = vec![];
295 for safetensors in self.safetensors.iter() {
296 tensors.push(safetensors.get().0.tensors())
297 }
298 tensors.into_iter().flatten().collect()
299 }
300
301 pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
302 let index = match &self.routing {
303 None => 0,
304 Some(routing) => {
305 let index = routing.get(name).ok_or_else(|| {
306 Error::CannotFindTensor {
307 path: name.to_string(),
308 }
309 .bt()
310 })?;
311 *index
312 }
313 };
314 Ok(self.safetensors[index].get().0.tensor(name)?)
315 }
316}
317
318impl SimpleBackend for MmapedSafetensors {
319 fn get(
320 &self,
321 s: Shape,
322 name: &str,
323 _: candle_nn::Init,
324 dtype: DType,
325 dev: &Device,
326 ) -> Result<Tensor> {
327 let tensor = self.get_unchecked(name, dtype, dev)?;
328 if tensor.shape() != &s {
329 Err(candle_core::Error::UnexpectedShape {
330 msg: format!("shape mismatch for {name}"),
331 expected: s,
332 got: tensor.shape().clone(),
333 }
334 .bt())?
335 }
336 Ok(tensor)
337 }
338
339 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
340 self.load(name, dev, Some(dtype))
341 }
342
343 fn contains_tensor(&self, name: &str) -> bool {
344 self.get(name).is_ok()
345 }
346}
347
348pub enum ShardedSafeTensors {
349 Sharded {
350 b: MmapedSafetensors,
351 make_dummy_regexes: Option<Arc<Vec<Regex>>>,
352 predicate: Arc<dyn Fn(String) -> bool + Send + Sync + 'static>,
353 },
354 SimpleBackend(Box<dyn SimpleBackend + 'static>),
355}
356
357pub type ShardedVarBuilder = VarBuilderArgs<'static, ShardedSafeTensors>;
358
359impl ShardedSafeTensors {
360 pub unsafe fn sharded<P: AsRef<std::path::Path>>(
370 paths: &[P],
371 dtype: DType,
372 dev: &Device,
373 make_dummy_regexes: Option<Arc<Vec<Regex>>>,
374 predicate: Arc<dyn Fn(String) -> bool + Send + Sync + 'static>,
375 ) -> Result<ShardedVarBuilder> {
376 let tensors = MmapedSafetensors::multi(paths)?;
377 let backend = ShardedSafeTensors::Sharded {
378 b: tensors,
379 make_dummy_regexes,
380 predicate,
381 };
382 Ok(VarBuilderArgs::new_with_args(backend, dtype, dev))
383 }
384}
385
386impl ShardedSafeTensors {
387 pub fn wrap(
388 backend: Box<dyn SimpleBackend + 'static>,
389 dtype: DType,
390 dev: Device,
391 ) -> ShardedVarBuilder {
392 VarBuilderArgs::new_with_args(Self::SimpleBackend(backend), dtype, &dev)
393 }
394}
395
396#[derive(Debug, Clone, Copy, Eq, PartialEq)]
397pub enum Shard {
398 Simple {
399 dim: usize,
400 rank: usize,
401 world_size: usize,
402 },
403 Offset {
404 dim: usize,
405 offset: usize,
406 len: usize,
407 },
408}
409
410impl Shard {
411 pub fn apply_to(&self, tensor: &Tensor) -> Result<Tensor> {
412 match *self {
413 Shard::Simple {
414 dim,
415 rank,
416 world_size,
417 } => {
418 let size = tensor.dim(dim)?;
419 let shape = tensor.dims().to_vec();
420
421 if size % world_size != 0 {
422 return Err(Error::ShapeMismatchSplit {
423 shape: shape.into(),
424 dim,
425 n_parts: world_size,
426 });
427 }
428 let block_size = size / world_size;
429 let start = rank * block_size;
430 let stop = (rank + 1) * block_size;
431
432 if dim == 0 {
433 tensor.i(start..stop)
434 } else if dim == 1 {
435 tensor.i((.., start..stop))
436 } else if dim == 2 {
437 tensor.i((.., .., start..stop))
438 } else {
439 candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
440 }
441 }
442 Shard::Offset { dim, offset, len } => {
443 let start = offset;
444 let stop = start + len;
445
446 if dim == 0 {
447 tensor.i(start..stop)
448 } else if dim == 1 {
449 tensor.i((.., start..stop))
450 } else if dim == 2 {
451 tensor.i((.., .., start..stop))
452 } else {
453 candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
454 }
455 }
456 }
457 }
458}
459
460impl Default for Shard {
461 fn default() -> Self {
462 Self::Simple {
463 dim: 0,
464 rank: 0,
465 world_size: 1,
466 }
467 }
468}
469
470impl Backend for ShardedSafeTensors {
482 type Hints = Shard;
483
484 fn get(
485 &self,
486 target_shape: Shape,
487 path: &str,
488 h: Self::Hints,
489 dtype: DType,
490 dev: &Device,
491 ) -> Result<Tensor> {
492 if let Shard::Simple { world_size: 1, .. } = &h {
493 match self {
496 Self::Sharded {
497 b,
498 make_dummy_regexes,
499 predicate,
500 } => {
501 if let Some(make_dummy_regexes) = make_dummy_regexes {
502 if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
503 return Err(Error::CannotFindTensor {
504 path: path.to_string(),
505 });
506 }
507 }
508 let should_include = predicate(path.to_string());
509 if !should_include {
510 return Err(Error::CannotFindTensor {
511 path: path.to_string(),
512 });
513 }
514
515 return SimpleBackend::get(
516 b,
517 target_shape,
518 path,
519 Default::default(),
520 dtype,
521 dev,
522 );
523 }
524 Self::SimpleBackend(b) => {
525 return SimpleBackend::get(
526 b.as_ref(),
527 target_shape,
528 path,
529 Default::default(),
530 dtype,
531 dev,
532 )
533 }
534 }
535 }
536
537 let result = match h {
538 Shard::Simple {
539 dim,
540 rank,
541 world_size,
542 } => {
543 match self {
544 Self::Sharded {
545 b,
546 make_dummy_regexes,
547 predicate,
548 } => {
549 use safetensors::slice::IndexOp;
550
551 if let Some(make_dummy_regexes) = make_dummy_regexes {
552 if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
553 return Err(Error::CannotFindTensor {
554 path: path.to_string(),
555 });
556 }
557 }
558 let should_include = predicate(path.to_string());
559 if !should_include {
560 return Err(Error::CannotFindTensor {
561 path: path.to_string(),
562 });
563 }
564
565 let view = b.get(path)?;
566 let view_dtype = view.dtype();
567 let mut shape = view.shape().to_vec();
568 let size = shape[dim];
569
570 if size % world_size != 0 {
571 return Err(Error::ShapeMismatchSplit {
572 shape: shape.into(),
573 dim,
574 n_parts: world_size,
575 });
576 }
577 let block_size = size / world_size;
578 let start = rank * block_size;
579 let stop = (rank + 1) * block_size;
580
581 let iterator = if dim == 0 {
585 view.slice(start..stop).map_err(|_| {
586 Error::Msg(format!(
587 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
588 ))
589 })?
590 } else if dim == 1 {
591 view.slice((.., start..stop)).map_err(|_| {
592 Error::Msg(format!(
593 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
594 ))
595 })?
596 } else if dim == 2 {
597 view.slice((.., .., start..stop)).map_err(|_| {
598 Error::Msg(format!(
599 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
600 ))
601 })?
602 } else {
603 candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
604 };
605
606 shape[dim] = block_size;
607
608 let view_dtype: DType = view_dtype.try_into()?;
609 let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
610 Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)?
611 }
612 Self::SimpleBackend(b) => {
613 let tensor = b.get(target_shape, path, Default::default(), dtype, dev)?;
614 h.apply_to(&tensor)?
615 }
616 }
617 }
618 Shard::Offset { dim, offset, len } => {
619 match self {
620 Self::Sharded {
621 b,
622 make_dummy_regexes,
623 predicate,
624 } => {
625 use safetensors::slice::IndexOp;
626
627 if let Some(make_dummy_regexes) = make_dummy_regexes {
628 if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
629 return Err(Error::CannotFindTensor {
630 path: path.to_string(),
631 });
632 }
633 }
634 let should_include = predicate(path.to_string());
635 if !should_include {
636 return Err(Error::CannotFindTensor {
637 path: path.to_string(),
638 });
639 }
640
641 let view = b.get(path)?;
642 let view_dtype = view.dtype();
643 let mut shape = view.shape().to_vec();
644
645 let start = offset;
646 let stop = start + len;
647
648 let iterator = if dim == 0 {
652 view.slice(start..stop).map_err(|_| {
653 Error::Msg(format!(
654 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
655 ))
656 })?
657 } else if dim == 1 {
658 view.slice((.., start..stop)).map_err(|_| {
659 Error::Msg(format!(
660 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
661 ))
662 })?
663 } else if dim == 2 {
664 view.slice((.., .., start..stop)).map_err(|_| {
665 Error::Msg(format!(
666 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
667 ))
668 })?
669 } else {
670 candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
671 };
672
673 shape[dim] = len;
674
675 let view_dtype: DType = view_dtype.try_into()?;
676 let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
677 Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)?
678 }
679 Self::SimpleBackend(b) => {
680 let tensor = b.get(target_shape, path, Default::default(), dtype, dev)?;
681 h.apply_to(&tensor)?
682 }
683 }
684 }
685 };
686
687 result.contiguous()
688 }
689
690 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
691 match self {
692 Self::Sharded {
693 b,
694 make_dummy_regexes,
695 predicate,
696 } => {
697 if let Some(make_dummy_regexes) = make_dummy_regexes {
698 if make_dummy_regexes.iter().any(|x| x.is_match(name)) {
699 return Err(Error::CannotFindTensor {
700 path: name.to_string(),
701 });
702 }
703 }
704 let should_include = predicate(name.to_string());
705 if !should_include {
706 return Err(Error::CannotFindTensor {
707 path: name.to_string(),
708 });
709 }
710 <MmapedSafetensors as SimpleBackend>::get_unchecked(b, name, dtype, dev)
711 }
712 Self::SimpleBackend(b) => b.as_ref().get_unchecked(name, dtype, dev),
713 }
714 }
715
716 fn contains_tensor(&self, name: &str) -> bool {
717 match self {
718 Self::Sharded {
719 b,
720 make_dummy_regexes,
721 predicate,
722 } => {
723 if let Some(make_dummy_regexes) = make_dummy_regexes {
724 if make_dummy_regexes.iter().any(|x| x.is_match(name)) {
725 return false;
726 }
727 }
728 let should_include = predicate(name.to_string());
729 if !should_include {
730 return false;
731 }
732 b.get(name).is_ok()
733 }
734 Self::SimpleBackend(b) => b.as_ref().contains_tensor(name),
735 }
736 }
737}