1use std::{
4 collections::HashMap,
5 path::PathBuf,
6 sync::Arc,
7 thread::{self, JoinHandle},
8};
9
10use candle_core::{pickle::PthTensors, DType, Device, Result, Tensor};
11use mistralrs_quant::{safetensors::MmapedSafetensors, ShardedSafeTensors, ShardedVarBuilder};
12use regex::Regex;
13
14use crate::lora::LoraConfig;
15use crate::utils::progress::IterWithProgress;
16use derive_new::new;
17
18trait TensorLoaderBackend {
19 fn get_names(&self) -> Vec<String>;
20 fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor>;
21}
22
23struct SafetensorBackend(MmapedSafetensors);
24
25impl TensorLoaderBackend for SafetensorBackend {
26 fn get_names(&self) -> Vec<String> {
27 self.0
28 .tensors()
29 .into_iter()
30 .map(|(name, _)| name)
31 .collect::<Vec<_>>()
32 }
33 fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
34 self.0.load(name, device, dtype)
35 }
36}
37
38struct PickleBackend(PthTensors);
39
40impl TensorLoaderBackend for PickleBackend {
41 fn get_names(&self) -> Vec<String> {
42 self.0.tensor_infos().keys().cloned().collect::<Vec<_>>()
43 }
44 fn load_name(&self, name: &str, device: &Device, _dtype: Option<DType>) -> Result<Tensor> {
45 self.0
46 .get(name)?
47 .ok_or(candle_core::Error::Msg(format!(
48 "Could not load tensor {name}"
49 )))?
50 .to_device(device)
51 }
52}
53
54pub enum DeviceForLoadTensor {
55 Base,
56 Idx(usize),
57}
58
59#[allow(clippy::too_many_arguments)]
66pub(crate) fn from_mmaped_safetensors(
67 paths: Vec<PathBuf>,
68 xlora_paths: Vec<PathBuf>,
69 dtype: Option<DType>,
70 base_device: &Device,
71 layer_devices: Vec<Option<Device>>,
72 silent: bool,
73 make_dummy_regexes: Option<Arc<Vec<Regex>>>,
74 predicate: impl Fn(String) -> bool + Send + Sync + Clone + 'static,
75 get_device_for_tensor: Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>,
76) -> Result<ShardedVarBuilder> {
77 if xlora_paths.is_empty() && !base_device.is_cuda() || cfg!(feature = "ring") {
79 if !silent {
80 tracing::info!("Loading model using mmap strategy.");
81 }
82 return Ok(unsafe {
83 ShardedSafeTensors::sharded(
84 &paths,
85 dtype.unwrap_or(DType::F16),
86 base_device,
87 make_dummy_regexes,
88 Arc::new(predicate),
89 )?
90 });
91 }
92
93 #[allow(clippy::type_complexity)]
94 let mut handles: Vec<JoinHandle<Result<HashMap<String, Tensor>>>> = Vec::new();
95
96 for path in paths {
97 let base_device = base_device.clone();
98 let layer_devices = layer_devices.clone();
99 let get_device_for_tensor = get_device_for_tensor.clone();
100 if let Some(regexes) = make_dummy_regexes.clone() {
101 let predicate = predicate.clone();
102 handles.push(thread::spawn(Box::new(move || {
103 let loader = Common::new();
104 loader.load_tensors_from_path(
105 &path,
106 &base_device,
107 layer_devices,
108 get_device_for_tensor,
109 dtype,
110 silent,
111 predicate,
112 |key| regexes.iter().any(|r| r.is_match(key)),
113 )
114 })));
115 } else {
116 let predicate = predicate.clone();
117 handles.push(thread::spawn(Box::new(move || {
118 let loader = Common::new();
119 loader.load_tensors_from_path(
120 &path,
121 &base_device,
122 layer_devices,
123 get_device_for_tensor,
124 dtype,
125 silent,
126 predicate,
127 |_| false,
128 )
129 })));
130 }
131 }
132 for (i, path) in xlora_paths.into_iter().enumerate() {
133 let base_device = base_device.clone();
134 let layer_devices = layer_devices.clone();
135 let get_device_for_tensor = get_device_for_tensor.clone();
136 if let Some(regexes) = make_dummy_regexes.clone() {
137 let predicate = predicate.clone();
138 handles.push(thread::spawn(Box::new(move || {
139 let loader = XLora::new(i + 1);
140 loader.load_tensors_from_path(
141 &path,
142 &base_device,
143 layer_devices,
144 get_device_for_tensor,
145 dtype,
146 silent,
147 predicate,
148 |key| regexes.iter().any(|r| r.is_match(key)),
149 )
150 })));
151 } else {
152 let predicate = predicate.clone();
153 handles.push(thread::spawn(Box::new(move || {
154 let loader = XLora::new(i + 1);
155 loader.load_tensors_from_path(
156 &path,
157 &base_device,
158 layer_devices,
159 get_device_for_tensor,
160 dtype,
161 silent,
162 predicate,
163 |_| false,
164 )
165 })));
166 }
167 }
168
169 let mut ws = HashMap::new();
170 while !handles.iter().all(|h| h.is_finished()) {}
172 for h in handles {
173 ws.extend(h.join().unwrap()?);
174 }
175
176 let backend = Box::new(ws);
177
178 Ok(ShardedSafeTensors::wrap(
181 backend,
182 dtype.unwrap_or(DType::F16),
183 base_device.clone(),
184 ))
185}
186
187pub(crate) fn load_preload_adapters(
188 paths: &Option<HashMap<String, (PathBuf, LoraConfig)>>,
189 dtype: DType,
190 device: &Device,
191 silent: bool,
192) -> Result<Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>> {
193 if let Some(paths) = paths {
194 let mut map = HashMap::new();
195 for (name, (path, config)) in paths {
196 let loader = Common::new();
197 let loaded_tensors = loader.load_tensors_from_path(
198 path,
199 device,
200 vec![None],
201 Arc::new(|_| DeviceForLoadTensor::Base),
202 Some(dtype),
203 silent,
204 |_| true,
205 |_| false,
206 )?;
207
208 let backend = Box::new(loaded_tensors);
209
210 let vb = ShardedSafeTensors::wrap(backend, dtype, device.clone());
213
214 map.insert(name.clone(), (vb, config.clone()));
215 }
216 Ok(Some(map))
217 } else {
218 Ok(None)
219 }
220}
221
222trait LoadTensors {
224 #[allow(clippy::too_many_arguments)]
225 fn load_tensors_from_path(
226 &self,
227 path: &PathBuf,
228 base_device: &Device,
229 layer_devices: Vec<Option<Device>>,
230 get_device_for_tensor: Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>,
231 dtype: Option<DType>,
232 is_silent: bool,
233 predicate: impl Fn(String) -> bool,
234 make_dummy_predicate: impl Fn(&str) -> bool,
235 ) -> Result<HashMap<String, Tensor>> {
236 let tensors: Box<dyn TensorLoaderBackend> = match path
237 .extension()
238 .expect("Expected extension")
239 .to_str()
240 .expect("Expected to convert")
241 {
242 "safetensors" => Box::new(SafetensorBackend(unsafe {
243 MmapedSafetensors::new(path)?
244 })),
245 "pth" | "pt" | "bin" => Box::new(PickleBackend(
246 candle_core::pickle::PthTensors::new(path, None)?
247 )),
248 other => candle_core::bail!("Unexpected extension `{other}`, this should have been handled by `get_model_paths`."),
249 };
250
251 let names_only = tensors
253 .get_names()
254 .into_iter()
255 .filter(|x| predicate(x.to_string()));
256 let iter = self.get_name_key_pairs(names_only).collect::<Vec<_>>();
257
258 let mut loaded_tensors = HashMap::new();
260 if !iter.is_empty() {
261 for (load_name, key_name) in iter.into_iter().with_progress(is_silent) {
262 if !make_dummy_predicate(&load_name) {
263 let dev = match get_device_for_tensor(load_name.clone()) {
264 DeviceForLoadTensor::Base => base_device,
265 DeviceForLoadTensor::Idx(i) => layer_devices
266 .get(i)
267 .and_then(|d| d.as_ref())
268 .unwrap_or(base_device),
269 };
270 let tensor = tensors.load_name(&load_name, dev, dtype)?;
272
273 loaded_tensors.insert(key_name, tensor);
274 }
275 }
276 }
277
278 Ok(loaded_tensors)
279 }
280
281 fn get_name_key_pairs(
282 &self,
283 tensors: impl Iterator<Item = String>,
284 ) -> impl Iterator<Item = (String, String)> {
285 tensors.map(|name| {
286 let new_name = name.replace("base_model.model.model", "model");
287
288 (name, new_name)
289 })
290 }
291}
292
293#[derive(new)]
294struct Common {}
295impl LoadTensors for Common {}
296
297#[derive(new)]
298struct XLora {
299 adapter_index: usize,
301}
302
303impl LoadTensors for XLora {
304 fn get_name_key_pairs(
305 &self,
306 tensors: impl Iterator<Item = String>,
307 ) -> impl Iterator<Item = (String, String)> {
308 let expectation = "tensor name `{new_name}` should have substring `.lora`";
309
310 tensors
311 .filter(|name| !name.contains("internal_xlora_classifier"))
312 .map(|name| {
313 let mut new_name = name.replace("base_model.model.model", "model");
314 let pos = new_name.find(".lora").expect(expectation);
316 new_name.insert_str(pos + 7, &format!(".{}", self.adapter_index));
317
318 (name, new_name)
319 })
320 }
321}