1use candle_core::{DType, Device, Error, Result, Shape, Tensor, WithDType};
2use candle_nn::var_builder::{Backend, SimpleBackend, VarBuilderArgs};
3use float8::F8E4M3;
4use regex::Regex;
5use safetensors::tensor as st;
6use safetensors::tensor::SafeTensors;
7use std::collections::HashMap;
8use std::path::Path;
9use std::sync::Arc;
10
11fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
12 let size_in_bytes = T::DTYPE.size_in_bytes();
13 let elem_count = data.len() / size_in_bytes;
14 if (data.as_ptr() as usize) % size_in_bytes == 0 {
15 let data: &[T] =
18 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
19 Tensor::from_slice(data, shape, device)
20 } else {
21 let mut c: Vec<T> = Vec::with_capacity(elem_count);
24 unsafe {
29 std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
30 c.set_len(elem_count)
31 }
32 Tensor::from_slice(&c, shape, device)
33 }
34}
35
36fn convert_slice_with_cast<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
37 data: &[u8],
38 shape: &[usize],
39 device: &Device,
40 conv: F,
41) -> Result<Tensor> {
42 let size_in_bytes = std::mem::size_of::<T>();
43 let elem_count = data.len() / size_in_bytes;
44 if (data.as_ptr() as usize) % size_in_bytes == 0 {
45 let data: &[T] =
48 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
49 let data = data.iter().map(|t| conv(*t)).collect::<Result<Vec<_>>>()?;
50 Tensor::from_vec(data, shape, device)
51 } else {
52 let mut c: Vec<T> = Vec::with_capacity(elem_count);
55 unsafe {
60 std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
61 c.set_len(elem_count)
62 }
63 let c = c.into_iter().map(conv).collect::<Result<Vec<_>>>()?;
64 Tensor::from_vec(c, shape, device)
65 }
66}
67
68fn convert_with_cast_<T: Sized + Copy, U: WithDType, F: Fn(T) -> Result<U>>(
69 view: &st::TensorView<'_>,
70 device: &Device,
71 conv: F,
72) -> Result<Tensor> {
73 convert_slice_with_cast::<T, U, F>(view.data(), view.shape(), device, conv)
74}
75
76fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
77 convert_slice::<T>(view.data(), view.shape(), device)
78}
79
80pub trait Load {
81 fn load(&self, device: &Device, dtype: Option<DType>) -> Result<Tensor>;
82}
83
84impl Load for st::TensorView<'_> {
85 fn load(&self, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
86 convert(self, device, dtype)
87 }
88}
89
90fn convert(
91 view: &st::TensorView<'_>,
92 device: &Device,
93 cast_dtype: Option<DType>,
94) -> Result<Tensor> {
95 match (view.dtype(), cast_dtype) {
96 (st::Dtype::U8, _) => convert_::<u8>(view, device),
97 (st::Dtype::U16, _) => {
98 let conv = |x| Ok(u32::from(x));
99 convert_with_cast_::<u16, u32, _>(view, device, conv)
100 }
101 (st::Dtype::U32, _) => convert_::<u32>(view, device),
102 (st::Dtype::I16, _) => convert_::<i16>(view, device),
103 (st::Dtype::I32, _) => convert_::<i32>(view, device),
104 (st::Dtype::I64, _) => convert_::<i64>(view, device),
105 (st::Dtype::BF16, None | Some(DType::BF16)) => convert_::<half::bf16>(view, device),
106 (st::Dtype::F16, None | Some(DType::F16)) => convert_::<half::f16>(view, device),
107 (st::Dtype::F32, _) => convert_::<f32>(view, device),
108 (st::Dtype::F64, _) => convert_::<f64>(view, device),
109 (st::Dtype::F8_E4M3, _) => convert_::<F8E4M3>(view, device),
110
111 (st::Dtype::BF16, Some(DType::F16)) => {
112 let conv = |x: half::bf16| Ok(half::f16::from_f32(x.to_f32()));
113 convert_with_cast_::<half::bf16, half::f16, _>(view, device, conv)
114 }
115 (st::Dtype::BF16, Some(DType::F32)) => {
116 let conv = |x: half::bf16| Ok(x.to_f32());
117 convert_with_cast_::<half::bf16, f32, _>(view, device, conv)
118 }
119 (st::Dtype::F16, Some(DType::BF16)) => {
120 let conv = |x: half::f16| Ok(half::bf16::from_f32(x.to_f32()));
121 convert_with_cast_::<half::f16, half::bf16, _>(view, device, conv)
122 }
123 (st::Dtype::F16, Some(DType::F32)) => {
124 let conv = |x: half::f16| Ok(x.to_f32());
125 convert_with_cast_::<half::f16, f32, _>(view, device, conv)
126 }
127 (dtype, _) => Err(Error::UnsupportedSafeTensorDtype(dtype)),
128 }
129}
130
131#[derive(yoke::Yokeable)]
132struct SafeTensors_<'a>(SafeTensors<'a>);
133
134pub struct MmapedSafetensors {
135 safetensors: Vec<yoke::Yoke<SafeTensors_<'static>, memmap2::Mmap>>,
136 routing: Option<HashMap<String, usize>>,
137}
138
139impl MmapedSafetensors {
140 pub unsafe fn new<P: AsRef<Path>>(p: P) -> Result<Self> {
146 let p = p.as_ref();
147 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
148 let file = memmap2::MmapOptions::new()
149 .map(&file)
150 .map_err(|e| Error::from(e).with_path(p))?;
151 let safetensors = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
152 file,
153 |data: &[u8]| {
154 let st = safetensors::SafeTensors::deserialize(data)
155 .map_err(|e| Error::from(e).with_path(p))?;
156 Ok::<_, Error>(SafeTensors_(st))
157 },
158 )?;
159 Ok(Self {
160 safetensors: vec![safetensors],
161 routing: None,
162 })
163 }
164
165 pub unsafe fn multi<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
173 let mut routing = HashMap::new();
174 let mut safetensors = vec![];
175 for (index, p) in paths.iter().enumerate() {
176 let p = p.as_ref();
177 let file = std::fs::File::open(p).map_err(|e| Error::from(e).with_path(p))?;
178 let file = memmap2::MmapOptions::new()
179 .map(&file)
180 .map_err(|e| Error::from(e).with_path(p))?;
181 let data = yoke::Yoke::<SafeTensors_<'static>, memmap2::Mmap>::try_attach_to_cart(
182 file,
183 |data: &[u8]| {
184 let st = safetensors::SafeTensors::deserialize(data)
185 .map_err(|e| Error::from(e).with_path(p))?;
186 Ok::<_, Error>(SafeTensors_(st))
187 },
188 )?;
189 for k in data.get().0.names() {
190 routing.insert(k.to_string(), index);
191 }
192 safetensors.push(data)
193 }
194 Ok(Self {
195 safetensors,
196 routing: Some(routing),
197 })
198 }
199
200 pub fn load(&self, name: &str, dev: &Device, dtype: Option<DType>) -> Result<Tensor> {
201 self.get(name)?.load(dev, dtype)
202 }
203
204 pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
205 let mut tensors = vec![];
206 for safetensors in self.safetensors.iter() {
207 tensors.push(safetensors.get().0.tensors())
208 }
209 tensors.into_iter().flatten().collect()
210 }
211
212 pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
213 let index = match &self.routing {
214 None => 0,
215 Some(routing) => {
216 let index = routing.get(name).ok_or_else(|| {
217 Error::CannotFindTensor {
218 path: name.to_string(),
219 }
220 .bt()
221 })?;
222 *index
223 }
224 };
225 Ok(self.safetensors[index].get().0.tensor(name)?)
226 }
227}
228
229impl SimpleBackend for MmapedSafetensors {
230 fn get(
231 &self,
232 s: Shape,
233 name: &str,
234 _: candle_nn::Init,
235 dtype: DType,
236 dev: &Device,
237 ) -> Result<Tensor> {
238 let tensor = self.get_unchecked(name, dtype, dev)?;
239 if tensor.shape() != &s {
240 Err(candle_core::Error::UnexpectedShape {
241 msg: format!("shape mismatch for {name}"),
242 expected: s,
243 got: tensor.shape().clone(),
244 }
245 .bt())?
246 }
247 Ok(tensor)
248 }
249
250 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
251 self.load(name, dev, Some(dtype))
252 }
253
254 fn contains_tensor(&self, name: &str) -> bool {
255 self.get(name).is_ok()
256 }
257}
258
259pub enum ShardedSafeTensors {
260 Sharded {
261 b: MmapedSafetensors,
262 make_dummy_regexes: Option<Arc<Vec<Regex>>>,
263 },
264 SimpleBackend(Box<dyn SimpleBackend + 'static>),
265}
266
267pub type ShardedVarBuilder = VarBuilderArgs<'static, ShardedSafeTensors>;
268
269impl ShardedSafeTensors {
270 pub unsafe fn sharded<P: AsRef<std::path::Path>>(
277 paths: &[P],
278 dtype: DType,
279 dev: &Device,
280 make_dummy_regexes: Option<Arc<Vec<Regex>>>,
281 ) -> Result<ShardedVarBuilder> {
282 let tensors = MmapedSafetensors::multi(paths)?;
283 let backend = ShardedSafeTensors::Sharded {
284 b: tensors,
285 make_dummy_regexes,
286 };
287 Ok(VarBuilderArgs::new_with_args(backend, dtype, dev))
288 }
289}
290
291impl ShardedSafeTensors {
292 pub fn wrap(
293 backend: Box<dyn SimpleBackend + 'static>,
294 dtype: DType,
295 dev: Device,
296 ) -> ShardedVarBuilder {
297 VarBuilderArgs::new_with_args(Self::SimpleBackend(backend), dtype, &dev)
298 }
299}
300
301#[derive(Debug, Clone, Copy, Eq, PartialEq)]
302pub enum Shard {
303 Simple {
304 dim: usize,
305 rank: usize,
306 world_size: usize,
307 },
308 Offset {
309 dim: usize,
310 offset: usize,
311 len: usize,
312 },
313}
314
315impl Default for Shard {
316 fn default() -> Self {
317 Self::Simple {
318 dim: 0,
319 rank: 0,
320 world_size: 1,
321 }
322 }
323}
324
325impl Backend for ShardedSafeTensors {
337 type Hints = Shard;
338
339 fn get(
340 &self,
341 target_shape: Shape,
342 path: &str,
343 h: Self::Hints,
344 dtype: DType,
345 dev: &Device,
346 ) -> Result<Tensor> {
347 if let Shard::Simple { world_size: 1, .. } = &h {
348 match self {
351 Self::Sharded {
352 b,
353 make_dummy_regexes,
354 } => {
355 if let Some(make_dummy_regexes) = make_dummy_regexes {
356 if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
357 return Err(Error::CannotFindTensor {
358 path: path.to_string(),
359 });
360 }
361 }
362 return SimpleBackend::get(
363 b,
364 target_shape,
365 path,
366 Default::default(),
367 dtype,
368 dev,
369 );
370 }
371 Self::SimpleBackend(b) => {
372 return SimpleBackend::get(
373 b.as_ref(),
374 target_shape,
375 path,
376 Default::default(),
377 dtype,
378 dev,
379 )
380 }
381 }
382 }
383
384 match h {
385 Shard::Simple {
386 dim,
387 rank,
388 world_size,
389 } => {
390 match self {
391 Self::Sharded {
392 b,
393 make_dummy_regexes,
394 } => {
395 use safetensors::slice::IndexOp;
396
397 if let Some(make_dummy_regexes) = make_dummy_regexes {
398 if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
399 return Err(Error::CannotFindTensor {
400 path: path.to_string(),
401 });
402 }
403 }
404
405 let view = b.get(path)?;
406 let view_dtype = view.dtype();
407 let mut shape = view.shape().to_vec();
408 let size = shape[dim];
409
410 if size % world_size != 0 {
411 return Err(Error::ShapeMismatchSplit {
412 shape: shape.into(),
413 dim,
414 n_parts: world_size,
415 });
416 }
417 let block_size = size / world_size;
418 let start = rank * block_size;
419 let stop = (rank + 1) * block_size;
420
421 let iterator = if dim == 0 {
425 view.slice(start..stop).map_err(|_| {
426 Error::Msg(format!(
427 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
428 ))
429 })?
430 } else if dim == 1 {
431 view.slice((.., start..stop)).map_err(|_| {
432 Error::Msg(format!(
433 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
434 ))
435 })?
436 } else {
437 candle_core::bail!("Got sharded on dimensions != 0 or 1")
438 };
439
440 shape[dim] = block_size;
441
442 let view_dtype: DType = view_dtype.try_into()?;
443 let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
444 Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)
445 }
446 Self::SimpleBackend(b) => {
447 use candle_core::IndexOp;
448 let tensor = b.get(target_shape, path, Default::default(), dtype, dev)?;
449
450 let size = tensor.dim(dim)?;
451 let shape = tensor.dims().to_vec();
452
453 if size % world_size != 0 {
454 return Err(Error::ShapeMismatchSplit {
455 shape: shape.into(),
456 dim,
457 n_parts: world_size,
458 });
459 }
460 let block_size = size / world_size;
461 let start = rank * block_size;
462 let stop = (rank + 1) * block_size;
463
464 if dim == 0 {
465 tensor.i((start..stop, ..))
466 } else if dim == 1 {
467 tensor.i((.., start..stop))
468 } else {
469 candle_core::bail!("Got sharded on dimensions != 0 or 1")
470 }
471 }
472 }
473 }
474 Shard::Offset { dim, offset, len } => {
475 match self {
476 Self::Sharded {
477 b,
478 make_dummy_regexes,
479 } => {
480 use safetensors::slice::IndexOp;
481
482 if let Some(make_dummy_regexes) = make_dummy_regexes {
483 if make_dummy_regexes.iter().any(|x| x.is_match(path)) {
484 return Err(Error::CannotFindTensor {
485 path: path.to_string(),
486 });
487 }
488 }
489
490 let view = b.get(path)?;
491 let view_dtype = view.dtype();
492 let mut shape = view.shape().to_vec();
493
494 let start = offset;
495 let stop = start + len;
496
497 let iterator = if dim == 0 {
501 view.slice(start..stop).map_err(|_| {
502 Error::Msg(format!(
503 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
504 ))
505 })?
506 } else if dim == 1 {
507 view.slice((.., start..stop)).map_err(|_| {
508 Error::Msg(format!(
509 "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
510 ))
511 })?
512 } else {
513 candle_core::bail!("Got sharded on dimensions != 0 or 1")
514 };
515
516 shape[dim] = len;
517
518 let view_dtype: DType = view_dtype.try_into()?;
519 let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
520 Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)
521 }
522 Self::SimpleBackend(b) => {
523 use candle_core::IndexOp;
524 let tensor = b.get(target_shape, path, Default::default(), dtype, dev)?;
525
526 let start = offset;
527 let stop = start + len;
528
529 if dim == 0 {
530 tensor.i((start..stop, ..))
531 } else if dim == 1 {
532 tensor.i((.., start..stop))
533 } else {
534 candle_core::bail!("Got sharded on dimensions != 0 or 1")
535 }
536 }
537 }
538 }
539 }
540 }
541
542 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
543 match self {
544 Self::Sharded {
545 b,
546 make_dummy_regexes,
547 } => {
548 if let Some(make_dummy_regexes) = make_dummy_regexes {
549 if make_dummy_regexes.iter().any(|x| x.is_match(name)) {
550 return Err(Error::CannotFindTensor {
551 path: name.to_string(),
552 });
553 }
554 }
555 <MmapedSafetensors as SimpleBackend>::get_unchecked(b, name, dtype, dev)
556 }
557 Self::SimpleBackend(b) => b.as_ref().get_unchecked(name, dtype, dev),
558 }
559 }
560
561 fn contains_tensor(&self, name: &str) -> bool {
562 match self {
563 Self::Sharded {
564 b,
565 make_dummy_regexes,
566 } => {
567 if let Some(make_dummy_regexes) = make_dummy_regexes {
568 if make_dummy_regexes.iter().any(|x| x.is_match(name)) {
569 return false;
570 }
571 }
572 b.get(name).is_ok()
573 }
574 Self::SimpleBackend(b) => b.as_ref().contains_tensor(name),
575 }
576 }
577}