1use candle_core::{DType, Device, Error, IndexOp, Result, Shape, Tensor, WithDType};
2use candle_nn::var_builder::{Backend, SimpleBackend, VarBuilderArgs};
3use float8::F8E4M3;
4use regex::Regex;
5use safetensors::tensor as st;
6use safetensors::tensor::SafeTensors;
7use std::collections::HashMap;
8use std::path::Path;
9use std::sync::Arc;
10
11fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
12 let size_in_bytes = T::DTYPE.size_in_bytes();
13 let elem_count = data.len() / size_in_bytes;
14 if (data.as_ptr() as usize) % size_in_bytes == 0 {
15 let data: &[T] =
18 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
19 Tensor::from_slice(data, shape, device)
20 } else {
21 let mut c: Vec<T> = Vec::with_capacity(elem_count);
24 unsafe {
29 std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
30 c.set_len(elem_count)
31 }
32 Tensor::from_slice(&c, shape, device)
33 }
34}
35
36fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
37 data: &[u8],
38 shape: &[usize],
39 device: &Device,
40 conv: F,
41) -> Result<Tensor> {
42 let size_in_bytes = std::mem::size_of::<T>();
43 let elem_count = data.len() / size_in_bytes;
44 if (data.as_ptr() as usize) % size_in_bytes == 0 {
45 let data: &[T] =
48 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
49 let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
50 Tensor::from_vec(data, shape, device)
51 } else {
52 let mut c: Vec<T> = Vec::with_capacity(elem_count);
55 unsafe {
60 std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
61 c.set_len(elem_count)
62 }
63 let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
64 Tensor::from_vec(c, shape, device)
65 }
66}
67
68fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
69 view: &st::TensorView<'_>,
70 device: &Device,
71 conv: F,
72) -> Result<Tensor> {
73 convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
74}
75
76fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
77 convert_slice::<T>(view.data(), view.shape(), device)
78}
79
80pub trait Load {
81 fn load(&self, device: &Device, dtype: Option<DType>) -> Result<Tensor>;
82}
83
84impl Load for st::TensorView<'_> {
85 fn load(&self, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
86 convert(self, device, dtype)
87 }
88}
89
90fn convert(
91 view: &st::TensorView<'_>,
92 device: &Device,
93 cast_dtype: Option<DType>,
94) -> Result<Tensor> {
95 match (view.dtype(), cast_dtype) {
96 (st::Dtype::BF16, Some(DType::F16)) => {
97 let conv = |x: half::bf16| Ok(half::f16::from_f32(x.to_f32()));
98 convert_with_cast_::<half::bf16, half::f16, _>(view, device, conv)
99 }
100 (st::Dtype::BF16, Some(DType::F32)) => {
101 let conv = |x: half::bf16| Ok(x.to_f32());
102 convert_with_cast_::<half::bf16, f32, _>(view, device, conv)
103 }
104 (st::Dtype::F16, Some(DType::BF16)) => {
105 let conv = |x: half::f16| Ok(half::bf16::from_f32(x.to_f32()));
106 convert_with_cast_::<half::f16, half::bf16, _>(view, device, conv)
107 }
108 (st::Dtype::F16, Some(DType::F32)) => {
109 let conv = |x: half::f16| Ok(x.to_f32());
110 convert_with_cast_::<half::f16, f32, _>(view, device, conv)
111 }
112 (st::Dtype::F32, Some(DType::BF16)) => {
113 let conv = |x: f32| Ok(half::bf16::from_f32(x));
114 convert_with_cast_::<f32, half::bf16, _>(view, device, conv)
115 }
116 (st::Dtype::F32, Some(DType::F16)) => {
117 let conv = |x: f32| Ok(half::f16::from_f32(x));
118 convert_with_cast_::<f32, half::f16, _>(view, device, conv)
119 }
120
121 (st::Dtype::U8, _) => convert_::<u8>(view, device),
122 (st::Dtype::U16, _) => {
123 let conv = |x| Ok(u32::from(x));
124 convert_with_cast_::<u16, u32, _>(view, device, conv)
125 }
126 (st::Dtype::U32, _) => convert_::<u32>(view, device),
127 (st::Dtype::I16, _) => convert_::<i16>(view, device),
128 (st::Dtype::I32, _) => convert_::<i32>(view, device),
129 (st::Dtype::I64, _) => convert_::<i64>(view, device),
130 (st::Dtype::BF16, None | Some(DType::BF16)) => convert_::<half::bf16>(view, device),
131 (st::Dtype::F16, None | Some(DType::F16)) => convert_::<half::f16>(view, device),
132 (st::Dtype::F32, _) => convert_::<f32>(view, device),
133 (st::Dtype::F64, _) => convert_::<f64>(view, device),
134 (st::Dtype::F8_E4M3, _) => convert_::<F8E4M3>(view, device),
135 (dtype, _) => Err(Error::UnsupportedSafeTensorDtype(dtype)),
136 }
137}
138
139#[derive(yoke::Yokeable)]
140struct SafeTensors_<'a>(SafeTensors<'a>);
141
142pub struct MmapedSafetensors {
143 safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
144 routing: Option<HashMap<String, usize>>,
145}
146
147impl MmapedSafetensors {
148 pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
154 let p = p.as_ref();
155 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
156 let file = memmap2::MmapOptions::new()
157 .map(&file)
158 .map_err(|e| Error::from(e).with_path(p))?;
159 let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
160 file,
161 |data: &[u8]| {
162 let st = safetensors::SafeTensors::deserialize(data)
163 .map_err(|e| Error::from(e).with_path(p))?;
164 Ok::<_, Error>(SafeTensors_(st))
165 },
166 )?;
167 Ok(Self {
168 safetensors: vec![safetensors],
169 routing: None,
170 })
171 }
172
173 pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
181 let mut routing = HashMap::new();
182 let mut safetensors = vec![];
183 for (index, p) in paths.iter().enumerate() {
184 let p = p.as_ref();
185 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
186 let file = memmap2::MmapOptions::new()
187 .map(&file)
188 .map_err(|e| Error::from(e).with_path(p))?;
189 let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
190 file,
191 |data: &[u8]| {
192 let st = safetensors::SafeTensors::deserialize(data)
193 .map_err(|e| Error::from(e).with_path(p))?;
194 Ok::<_, Error>(SafeTensors_(st))
195 },
196 )?;
197 for k in data.get().0.names() {
198 routing.insert(k.to_string(), index);
199 }
200 safetensors.push(data)
201 }
202 Ok(Self {
203 safetensors,
204 routing: Some(routing),
205 })
206 }
207
208 pub fn load(&self, name: &str, dev: &Device, dtype: Option<DType>) -> Result<Tensor> {
209 self.get(name)?.load(dev, dtype)
210 }
211
212 pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
213 let mut tensors = vec![];
214 for safetensors in self.safetensors.iter() {
215 tensors.push(safetensors.get().0.tensors())
216 }
217 tensors.into_iter().flatten().collect()
218 }
219
220 pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
221 let index = match &self.routing {
222 None => 0,
223 Some(routing) => {
224 let index = routing.get(name).ok_or_else(|| {
225 Error::CannotFindTensor {
226 path: name.to_string(),
227 }
228 .bt()
229 })?;
230 *index
231 }
232 };
233 Ok(self.safetensors[index].get().0.tensor(name)?)
234 }
235}
236
237impl SimpleBackend for MmapedSafetensors {
238 fn get(
239 &self,
240 s: Shape,
241 name: &str,
242 _: candle_nn::Init,
243 dtype: DType,
244 dev: &Device,
245 ) -> Result<Tensor> {
246 let tensor = self.get_unchecked(name, dtype, dev)?;
247 if tensor.shape() != &s {
248 Err(candle_core::Error::UnexpectedShape {
249 msg: format!("shape mismatch for {name}"),
250 expected: s,
251 got: tensor.shape().clone(),
252 }
253 .bt())?
254 }
255 Ok(tensor)
256 }
257
258 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
259 self.load(name, dev, Some(dtype))
260 }
261
262 fn contains_tensor(&self, name: &str) -> bool {
263 self.get(name).is_ok()
264 }
265}
266
267pub enum ShardedSafeTensors {
268 Sharded {
269 b: MmapedSafetensors,
270 make_dummy_regexes: Option<Arc<Vec<Regex>>>,
271 predicate: Arc<dyn Fn(String) -> bool + Send + Sync + 'static>,
272 },
273 SimpleBackend(Box<dyn SimpleBackend + 'static>),
274}
275
276pub type ShardedVarBuilder = VarBuilderArgs<'static, ShardedSafeTensors>;
277
278impl ShardedSafeTensors {
279 pub unsafe fn sharded<P: AsRef<std::path::Path>>(
289 paths: &[P],
290 dtype: DType,
291 dev: &Device,
292 make_dummy_regexes: Option<Arc<Vec<Regex>>>,
293 predicate: Arc<dyn Fn(String) -> bool + Send + Sync + 'static>,
294 ) -> Result<ShardedVarBuilder> {
295 let tensors = MmapedSafetensors::multi(paths)?;
296 let backend = ShardedSafeTensors::Sharded {
297 b: tensors,
298 make_dummy_regexes,
299 predicate,
300 };
301 Ok(VarBuilderArgs::new_with_args(backend, dtype, dev))
302 }
303}
304
305impl ShardedSafeTensors {
306 pub fn wrap(
307 backend: Box<dyn SimpleBackend + 'static>,
308 dtype: DType,
309 dev: Device,
310 ) -> ShardedVarBuilder {
311 VarBuilderArgs::new_with_args(Self::SimpleBackend(backend), dtype, &dev)
312 }
313}
314
315#[derive(Debug, Clone, Copy, Eq, PartialEq)]
316pub enum Shard {
317 Simple {
318 dim: usize,
319 rank: usize,
320 world_size: usize,
321 },
322 Offset {
323 dim: usize,
324 offset: usize,
325 len: usize,
326 },
327}
328
329impl Shard {
330 pub fn apply_to(&self, tensor: &Tensor) -> Result<Tensor> {
331 match *self {
332 Shard::Simple {
333 dim,
334 rank,
335 world_size,
336 } => {
337 let size = tensor.dim(dim)?;
338 let shape = tensor.dims().to_vec();
339
340 if size % world_size != 0 {
341 return Err(Error::ShapeMismatchSplit {
342 shape: shape.into(),
343 dim,
344 n_parts: world_size,
345 });
346 }
347 let block_size = size / world_size;
348 let start = rank * block_size;
349 let stop = (rank + 1) * block_size;
350
351 if dim == 0 {
352 tensor.i(start..stop)
353 } else if dim == 1 {
354 tensor.i((.., start..stop))
355 } else if dim == 2 {
356 tensor.i((.., .., start..stop))
357 } else {
358 candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
359 }
360 }
361 Shard::Offset { dim, offset, len } => {
362 let start = offset;
363 let stop = start + len;
364
365 if dim == 0 {
366 tensor.i(start..stop)
367 } else if dim == 1 {
368 tensor.i((.., start..stop))
369 } else if dim == 2 {
370 tensor.i((.., .., start..stop))
371 } else {
372 candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
373 }
374 }
375 }
376 }
377}
378
379impl Default for Shard {
380 fn default() -> Self {
381 Self::Simple {
382 dim: 0,
383 rank: 0,
384 world_size: 1,
385 }
386 }
387}
388
389impl Backend for ShardedSafeTensors {
401 type Hints = Shard;
402
403 fn get(
404 &self,
405 target_shape: Shape,
406 path: &str,
407 h: Self::Hints,
408 dtype: DType,
409 dev: &Device,
410 ) -> Result<Tensor> {
411 if let Shard::Simple { world_size: 1, .. } = &h {
412 match self {
415 Self::Sharded {
416 b,
417 make_dummy_regexes,
418 predicate,
419 } => {
420 if let Some(make_dummy_regexes) = make_dummy_regexes {
421 if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
422 return Err(Error::CannotFindTensor {
423 path: path.to_string(),
424 });
425 }
426 }
427 let should_include = predicate(path.to_string());
428 if !should_include {
429 return Err(Error::CannotFindTensor {
430 path: path.to_string(),
431 });
432 }
433
434 return SimpleBackend::get(
435 b,
436 target_shape,
437 path,
438 Default::default(),
439 dtype,
440 dev,
441 );
442 }
443 Self::SimpleBackend(b) => {
444 return SimpleBackend::get(
445 b.as_ref(),
446 target_shape,
447 path,
448 Default::default(),
449 dtype,
450 dev,
451 )
452 }
453 }
454 }
455
456 let result = match h {
457 Shard::Simple {
458 dim,
459 rank,
460 world_size,
461 } => {
462 match self {
463 Self::Sharded {
464 b,
465 make_dummy_regexes,
466 predicate,
467 } => {
468 use safetensors::slice::IndexOp;
469
470 if let Some(make_dummy_regexes) = make_dummy_regexes {
471 if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
472 return Err(Error::CannotFindTensor {
473 path: path.to_string(),
474 });
475 }
476 }
477 let should_include = predicate(path.to_string());
478 if !should_include {
479 return Err(Error::CannotFindTensor {
480 path: path.to_string(),
481 });
482 }
483
484 let view = b.get(path)?;
485 let view_dtype = view.dtype();
486 let mut shape = view.shape().to_vec();
487 let size = shape[dim];
488
489 if size % world_size != 0 {
490 return Err(Error::ShapeMismatchSplit {
491 shape: shape.into(),
492 dim,
493 n_parts: world_size,
494 });
495 }
496 let block_size = size / world_size;
497 let start = rank * block_size;
498 let stop = (rank + 1) * block_size;
499
500 let iterator = if dim == 0 {
504 view.slice(start..stop).map_err(|_| {
505 Error::Msg(format!(
506 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
507 ))
508 })?
509 } else if dim == 1 {
510 view.slice((.., start..stop)).map_err(|_| {
511 Error::Msg(format!(
512 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
513 ))
514 })?
515 } else if dim == 2 {
516 view.slice((.., .., start..stop)).map_err(|_| {
517 Error::Msg(format!(
518 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
519 ))
520 })?
521 } else {
522 candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
523 };
524
525 shape[dim] = block_size;
526
527 let view_dtype: DType = view_dtype.try_into()?;
528 let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
529 Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)?
530 }
531 Self::SimpleBackend(b) => {
532 let tensor = b.get(target_shape, path, Default::default(), dtype, dev)?;
533 h.apply_to(&tensor)?
534 }
535 }
536 }
537 Shard::Offset { dim, offset, len } => {
538 match self {
539 Self::Sharded {
540 b,
541 make_dummy_regexes,
542 predicate,
543 } => {
544 use safetensors::slice::IndexOp;
545
546 if let Some(make_dummy_regexes) = make_dummy_regexes {
547 if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
548 return Err(Error::CannotFindTensor {
549 path: path.to_string(),
550 });
551 }
552 }
553 let should_include = predicate(path.to_string());
554 if !should_include {
555 return Err(Error::CannotFindTensor {
556 path: path.to_string(),
557 });
558 }
559
560 let view = b.get(path)?;
561 let view_dtype = view.dtype();
562 let mut shape = view.shape().to_vec();
563
564 let start = offset;
565 let stop = start + len;
566
567 let iterator = if dim == 0 {
571 view.slice(start..stop).map_err(|_| {
572 Error::Msg(format!(
573 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
574 ))
575 })?
576 } else if dim == 1 {
577 view.slice((.., start..stop)).map_err(|_| {
578 Error::Msg(format!(
579 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
580 ))
581 })?
582 } else if dim == 2 {
583 view.slice((.., .., start..stop)).map_err(|_| {
584 Error::Msg(format!(
585 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
586 ))
587 })?
588 } else {
589 candle_core::bail!("Got sharded on dimensions != 0 or 1 or 2")
590 };
591
592 shape[dim] = len;
593
594 let view_dtype: DType = view_dtype.try_into()?;
595 let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
596 Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)?
597 }
598 Self::SimpleBackend(b) => {
599 let tensor = b.get(target_shape, path, Default::default(), dtype, dev)?;
600 h.apply_to(&tensor)?
601 }
602 }
603 }
604 };
605
606 result.contiguous()
607 }
608
609 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
610 match self {
611 Self::Sharded {
612 b,
613 make_dummy_regexes,
614 predicate,
615 } => {
616 if let Some(make_dummy_regexes) = make_dummy_regexes {
617 if make_dummy_regexes.iter().any(|x| x.is_match(name)) {
618 return Err(Error::CannotFindTensor {
619 path: name.to_string(),
620 });
621 }
622 }
623 let should_include = predicate(name.to_string());
624 if !should_include {
625 return Err(Error::CannotFindTensor {
626 path: name.to_string(),
627 });
628 }
629 <MmapedSafetensors as SimpleBackend>::get_unchecked(b, name, dtype, dev)
630 }
631 Self::SimpleBackend(b) => b.as_ref().get_unchecked(name, dtype, dev),
632 }
633 }
634
635 fn contains_tensor(&self, name: &str) -> bool {
636 match self {
637 Self::Sharded {
638 b,
639 make_dummy_regexes,
640 predicate,
641 } => {
642 if let Some(make_dummy_regexes) = make_dummy_regexes {
643 if make_dummy_regexes.iter().any(|x| x.is_match(name)) {
644 return false;
645 }
646 }
647 let should_include = predicate(name.to_string());
648 if !should_include {
649 return false;
650 }
651 b.get(name).is_ok()
652 }
653 Self::SimpleBackend(b) => b.as_ref().contains_tensor(name),
654 }
655 }
656}