1use std::{
4 collections::HashMap,
5 path::PathBuf,
6 sync::Arc,
7 thread::{self, JoinHandle},
8};
9
10use candle_core::{pickle::PthTensors, DType, Device, Result, Tensor};
11use mistralrs_quant::{safetensors::MmapedSafetensors, ShardedSafeTensors, ShardedVarBuilder};
12use regex::Regex;
13
14use crate::lora::LoraConfig;
15use crate::utils::progress::IterWithProgress;
16use derive_new::new;
17
18trait TensorLoaderBackend {
19 fn get_names(&self) -> Vec<String>;
20 fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor>;
21}
22
23struct SafetensorBackend(MmapedSafetensors);
24
25impl TensorLoaderBackend for SafetensorBackend {
26 fn get_names(&self) -> Vec<String> {
27 self.0
28 .tensors()
29 .into_iter()
30 .map(|(name, _)| name)
31 .collect::<Vec<_>>()
32 }
33 fn load_name(&self, name: &str, device: &Device, dtype: Option<DType>) -> Result<Tensor> {
34 self.0.load(name, device, dtype)
35 }
36}
37
38struct PickleBackend(PthTensors);
39
40impl TensorLoaderBackend for PickleBackend {
41 fn get_names(&self) -> Vec<String> {
42 self.0.tensor_infos().keys().cloned().collect::<Vec<_>>()
43 }
44 fn load_name(&self, name: &str, device: &Device, _dtype: Option<DType>) -> Result<Tensor> {
45 self.0
46 .get(name)?
47 .ok_or(candle_core::Error::Msg(format!(
48 "Could not load tensor {name}"
49 )))?
50 .to_device(device)
51 }
52}
53
54pub enum DeviceForLoadTensor {
55 Base,
56 Idx(usize),
57}
58
59#[allow(clippy::too_many_arguments)]
66pub(crate) fn from_mmaped_safetensors<'a>(
67 paths: Vec<PathBuf>,
68 xlora_paths: Vec<PathBuf>,
69 dtype: Option<DType>,
70 base_device: &Device,
71 layer_devices: Vec<Option<Device>>,
72 silent: bool,
73 make_dummy_regexes: Option<Arc<Vec<Regex>>>,
74 predicate: impl Fn(String) -> bool + Send + Sync + Clone + 'static,
75 get_device_for_tensor: Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>,
76) -> Result<ShardedVarBuilder<'a>> {
77 if base_device.is_cuda() {
78 return Ok(unsafe {
79 ShardedSafeTensors::sharded(
80 &paths,
81 dtype.unwrap_or(DType::F16),
82 base_device,
83 make_dummy_regexes,
84 )?
85 });
86 }
87
88 #[allow(clippy::type_complexity)]
89 let mut handles: Vec<JoinHandle<Result<HashMap<String, Tensor>>>> = Vec::new();
90
91 for path in paths {
92 let base_device = base_device.clone();
93 let layer_devices = layer_devices.clone();
94 let get_device_for_tensor = get_device_for_tensor.clone();
95 if let Some(regexes) = make_dummy_regexes.clone() {
96 let predicate = predicate.clone();
97 handles.push(thread::spawn(Box::new(move || {
98 let loader = Common::new();
99 loader.load_tensors_from_path(
100 &path,
101 &base_device,
102 layer_devices,
103 get_device_for_tensor,
104 dtype,
105 silent,
106 predicate,
107 |key| regexes.iter().any(|r| r.is_match(key)),
108 )
109 })));
110 } else {
111 let predicate = predicate.clone();
112 handles.push(thread::spawn(Box::new(move || {
113 let loader = Common::new();
114 loader.load_tensors_from_path(
115 &path,
116 &base_device,
117 layer_devices,
118 get_device_for_tensor,
119 dtype,
120 silent,
121 predicate,
122 |_| false,
123 )
124 })));
125 }
126 }
127 for (i, path) in xlora_paths.into_iter().enumerate() {
128 let base_device = base_device.clone();
129 let layer_devices = layer_devices.clone();
130 let get_device_for_tensor = get_device_for_tensor.clone();
131 if let Some(regexes) = make_dummy_regexes.clone() {
132 let predicate = predicate.clone();
133 handles.push(thread::spawn(Box::new(move || {
134 let loader = XLora::new(i + 1);
135 loader.load_tensors_from_path(
136 &path,
137 &base_device,
138 layer_devices,
139 get_device_for_tensor,
140 dtype,
141 silent,
142 predicate,
143 |key| regexes.iter().any(|r| r.is_match(key)),
144 )
145 })));
146 } else {
147 let predicate = predicate.clone();
148 handles.push(thread::spawn(Box::new(move || {
149 let loader = XLora::new(i + 1);
150 loader.load_tensors_from_path(
151 &path,
152 &base_device,
153 layer_devices,
154 get_device_for_tensor,
155 dtype,
156 silent,
157 predicate,
158 |_| false,
159 )
160 })));
161 }
162 }
163
164 let mut ws = HashMap::new();
165 while !handles.iter().all(|h| h.is_finished()) {}
167 for h in handles {
168 ws.extend(h.join().unwrap()?);
169 }
170
171 let backend = Box::new(ws);
172
173 Ok(ShardedSafeTensors::wrap(
176 backend,
177 dtype.unwrap_or(DType::F16),
178 base_device.clone(),
179 ))
180}
181
182pub(crate) fn load_preload_adapters<'a>(
183 paths: &Option<HashMap<String, (PathBuf, LoraConfig)>>,
184 dtype: DType,
185 device: &Device,
186 silent: bool,
187) -> Result<Option<HashMap<String, (ShardedVarBuilder<'a>, LoraConfig)>>> {
188 if let Some(paths) = paths {
189 let mut map = HashMap::new();
190 for (name, (path, config)) in paths {
191 let loader = Common::new();
192 let loaded_tensors = loader.load_tensors_from_path(
193 path,
194 device,
195 vec![None],
196 Arc::new(|_| DeviceForLoadTensor::Base),
197 Some(dtype),
198 silent,
199 |_| true,
200 |_| false,
201 )?;
202
203 let backend = Box::new(loaded_tensors);
204
205 let vb = ShardedSafeTensors::wrap(backend, dtype, device.clone());
208
209 map.insert(name.clone(), (vb, config.clone()));
210 }
211 Ok(Some(map))
212 } else {
213 Ok(None)
214 }
215}
216
217trait LoadTensors {
219 #[allow(clippy::too_many_arguments)]
220 fn load_tensors_from_path(
221 &self,
222 path: &PathBuf,
223 base_device: &Device,
224 layer_devices: Vec<Option<Device>>,
225 get_device_for_tensor: Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>,
226 dtype: Option<DType>,
227 is_silent: bool,
228 predicate: impl Fn(String) -> bool,
229 make_dummy_predicate: impl Fn(&str) -> bool,
230 ) -> Result<HashMap<String, Tensor>> {
231 let tensors: Box<dyn TensorLoaderBackend> = match path
232 .extension()
233 .expect("Expected extension")
234 .to_str()
235 .expect("Expected to convert")
236 {
237 "safetensors" => Box::new(SafetensorBackend(unsafe {
238 MmapedSafetensors::new(path)?
239 })),
240 "pth" | "pt" | "bin" => Box::new(PickleBackend(
241 candle_core::pickle::PthTensors::new(path, None)?
242 )),
243 other => candle_core::bail!("Unexpected extension `{other}`, this should have been handled by `get_model_paths`."),
244 };
245
246 let names_only = tensors
248 .get_names()
249 .into_iter()
250 .filter(|x| predicate(x.to_string()));
251 let iter = self.get_name_key_pairs(names_only).collect::<Vec<_>>();
252
253 let mut loaded_tensors = HashMap::new();
255 if !iter.is_empty() {
256 for (load_name, key_name) in iter.into_iter().with_progress(is_silent) {
257 if !make_dummy_predicate(&load_name) {
258 let dev = match get_device_for_tensor(load_name.clone()) {
259 DeviceForLoadTensor::Base => base_device,
260 DeviceForLoadTensor::Idx(i) => layer_devices
261 .get(i)
262 .and_then(|d| d.as_ref())
263 .unwrap_or(base_device),
264 };
265 let tensor = tensors.load_name(&load_name, dev, dtype)?;
267
268 loaded_tensors.insert(key_name, tensor);
269 }
270 }
271 }
272
273 Ok(loaded_tensors)
274 }
275
276 fn get_name_key_pairs(
277 &self,
278 tensors: impl Iterator<Item = String>,
279 ) -> impl Iterator<Item = (String, String)> {
280 tensors.map(|name| {
281 let new_name = name.replace("base_model.model.model", "model");
282
283 (name, new_name)
284 })
285 }
286}
287
288#[derive(new)]
289struct Common {}
290impl LoadTensors for Common {}
291
292#[derive(new)]
293struct XLora {
294 adapter_index: usize,
296}
297
298impl LoadTensors for XLora {
299 fn get_name_key_pairs(
300 &self,
301 tensors: impl Iterator<Item = String>,
302 ) -> impl Iterator<Item = (String, String)> {
303 let expectation = "tensor name `{new_name}` should have substring `.lora`";
304
305 tensors
306 .filter(|name| !name.contains("internal_xlora_classifier"))
307 .map(|name| {
308 let mut new_name = name.replace("base_model.model.model", "model");
309 let pos = new_name.find(".lora").expect(expectation);
311 new_name.insert_str(pos + 7, &format!(".{}", self.adapter_index));
312
313 (name, new_name)
314 })
315 }
316}