1use byteorder::{LittleEndian, ReadBytesExt};
2
3use candle_core::{DType, Device, Result, Tensor, WithDType};
4use float8::F8E4M3;
5use half::{bf16, f16};
6
7const UQFF_VERSION_MAJOR: u32 = 0;
14const UQFF_VERSION_MINOR: u32 = 2;
15const UQFF_VERSION_PATCH: u32 = 0;
16
17pub(crate) const UQFF_VERSION: u32 =
19 (UQFF_VERSION_MAJOR << (8 * 2)) | (UQFF_VERSION_MINOR << 8) | UQFF_VERSION_PATCH;
20pub const UQFF_QUANT_TYPE_OFFSET: usize = std::mem::size_of::<u32>();
22
23pub(crate) fn version_is_compatible(version: u32) -> Result<()> {
25 let major = version >> (8 * 2);
26 let minor = (version >> 8) & 0xFF;
27 let patch = version & 0xFF;
28
29 if major != UQFF_VERSION_MAJOR {
30 candle_core::bail!("Major version of ISQ artifact file ({major}) does not match the implementation in this build ({UQFF_VERSION_MAJOR})");
31 }
32
33 if minor > UQFF_VERSION_MINOR {
35 candle_core::bail!("Minor version of ISQ artifact file ({major}.{minor}.{patch}) is newer than this build supports ({UQFF_VERSION_MAJOR}.{UQFF_VERSION_MINOR}.{UQFF_VERSION_PATCH}). Please update mistral.rs.");
36 }
37
38 Ok(())
39}
40
41pub(crate) fn write_dtype(dtype: DType, buffer: &mut Vec<u8>) {
45 let dtype: u32 = match dtype {
46 DType::U8 => 0,
47 DType::U32 => 1,
48 DType::I32 => 2,
49 DType::I64 => 3,
50 DType::F16 => 4,
51 DType::BF16 => 5,
52 DType::F32 => 6,
53 DType::F64 => 7,
54 DType::I16 => 8,
55 DType::F8E4M3 => 9,
56 DType::F6E2M3 => 10,
57 DType::F6E3M2 => 11,
58 DType::F4 => 12,
59 DType::F8E8M0 => 13,
60 };
61 buffer.extend(&dtype.to_le_bytes());
62}
63
64pub(crate) fn read_dtype<R: std::io::Read>(buffer: &mut R) -> Result<DType> {
65 let dtype = buffer.read_u32::<LittleEndian>()?;
66 let dtype = match dtype {
67 0 => DType::U8,
68 1 => DType::U32,
69 2 => DType::I32,
70 3 => DType::I64,
71 4 => DType::F16,
72 5 => DType::BF16,
73 6 => DType::F32,
74 7 => DType::F64,
75 8 => DType::I16,
76 9 => DType::F8E4M3,
77 10 => DType::F6E2M3,
78 11 => DType::F6E3M2,
79 12 => DType::F4,
80 13 => DType::F8E8M0,
81 _ => candle_core::bail!("unknown dtype for quantized tensor {dtype}"),
82 };
83 Ok(dtype)
84}
85
86pub(crate) fn serialize_tensor(buffer: &mut Vec<u8>, tensor: &Tensor) -> Result<()> {
103 let b_shape = tensor.dims();
104 let tensor = tensor.flatten_all()?;
105
106 let bias = match tensor.dtype() {
107 DType::U8 => data_to_bytes::<u8>(tensor.to_vec1()?),
108 DType::U32 => data_to_bytes::<u32>(tensor.to_vec1()?),
109 DType::I16 => data_to_bytes::<i16>(tensor.to_vec1()?),
110 DType::I32 => data_to_bytes::<i32>(tensor.to_vec1()?),
111 DType::I64 => data_to_bytes::<i64>(tensor.to_vec1()?),
112 DType::F16 => data_to_bytes::<half::f16>(tensor.to_vec1()?),
113 DType::BF16 => data_to_bytes::<half::bf16>(tensor.to_vec1()?),
114 DType::F32 => data_to_bytes::<f32>(tensor.to_vec1()?),
115 DType::F64 => data_to_bytes::<f64>(tensor.to_vec1()?),
116 DType::F8E4M3 => data_to_bytes::<F8E4M3>(tensor.to_vec1()?),
117 DType::F4 | DType::F6E3M2 | DType::F6E2M3 | DType::F8E8M0 => {
118 candle_core::bail!("f4/f6e3m2/f6e2m3/f8e8m0 tensors cannot be serialized.")
119 }
120 };
121
122 let data_len = bias.len();
124 if data_len > u32::MAX as usize {
125 candle_core::bail!(
126 "Tensor data too large for UQFF format: {} bytes exceeds u32::MAX",
127 data_len
128 );
129 }
130 buffer.extend(&(data_len as u32).to_le_bytes());
131
132 write_dtype(tensor.dtype(), buffer);
134
135 let shape_len = b_shape.len();
137 if shape_len > u32::MAX as usize {
138 candle_core::bail!(
139 "Tensor has too many dimensions for UQFF format: {} exceeds u32::MAX",
140 shape_len
141 );
142 }
143 buffer.extend((shape_len as u32).to_le_bytes());
144 for dim in b_shape {
145 if *dim > u32::MAX as usize {
146 candle_core::bail!(
147 "Tensor dimension too large for UQFF format: {} exceeds u32::MAX",
148 dim
149 );
150 }
151 buffer.extend((*dim as u32).to_le_bytes());
152 }
153
154 buffer.extend(bias);
155
156 Ok(())
157}
158
159pub(crate) fn deserialize_tensor<R: std::io::Read>(
160 buffer: &mut R,
161 device: &Device,
162) -> Result<Tensor> {
163 let data_len = buffer.read_u32::<LittleEndian>()? as usize;
164
165 let dtype = read_dtype(buffer)?;
167
168 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
169
170 let mut dims = Vec::with_capacity(n_dims);
171 for _ in 0..n_dims {
172 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
173 }
174
175 let mut tensor_data = vec![0; data_len];
176 buffer.read_exact(&mut tensor_data)?;
177
178 match dtype {
179 DType::F16 => bytes_to_data::<f16>(&tensor_data, &dims, device),
180 DType::BF16 => bytes_to_data::<bf16>(&tensor_data, &dims, device),
181 DType::F32 => bytes_to_data::<f32>(&tensor_data, &dims, device),
182 DType::F64 => bytes_to_data::<f64>(&tensor_data, &dims, device),
183 DType::I32 => bytes_to_data::<i32>(&tensor_data, &dims, device),
184 DType::I64 => bytes_to_data::<i64>(&tensor_data, &dims, device),
185 DType::I16 => bytes_to_data::<i16>(&tensor_data, &dims, device),
186 DType::U32 => bytes_to_data::<u32>(&tensor_data, &dims, device),
187 DType::U8 => bytes_to_data::<u8>(&tensor_data, &dims, device),
188 DType::F8E4M3 => bytes_to_data::<F8E4M3>(&tensor_data, &dims, device),
189 DType::F4 | DType::F6E3M2 | DType::F6E2M3 | DType::F8E8M0 => {
190 candle_core::bail!("f4/f6e3m2/f6e2m3/f8e8m0 tensors cannot be deserialized.")
191 }
192 }
193}
194
195pub(crate) fn fake_deserialize_tensor<R: std::io::Read + std::io::Seek>(
197 buffer: &mut R,
198) -> Result<()> {
199 let data_len = buffer.read_u32::<LittleEndian>()? as usize;
200
201 let _dtype = read_dtype(buffer)?;
203
204 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
205
206 let mut dims = Vec::with_capacity(n_dims);
207 for _ in 0..n_dims {
208 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
209 }
210
211 buffer.seek_relative(data_len as i64)?;
213
214 Ok(())
215}
216
217fn data_to_bytes<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
218 let size_in_bytes = T::DTYPE.size_in_bytes();
219 let length = vs.len() * size_in_bytes;
220 let capacity = vs.capacity() * size_in_bytes;
221 let ptr = vs.as_mut_ptr() as *mut u8;
222 std::mem::forget(vs);
224 unsafe { Vec::from_raw_parts(ptr, length, capacity) }
229}
230
231fn bytes_to_data<T: WithDType>(
232 data: &[u8],
233 shape: &[usize],
234 device: &candle_core::Device,
235) -> Result<Tensor> {
236 let size_in_bytes = T::DTYPE.size_in_bytes();
237 let elem_count = data.len() / size_in_bytes;
238 if (data.as_ptr() as usize) % size_in_bytes == 0 {
239 let data: &[T] =
242 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
243 Tensor::from_slice(data, shape, device)
244 } else {
245 let mut c: Vec<T> = Vec::with_capacity(elem_count);
248 unsafe {
253 std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
254 c.set_len(elem_count)
255 }
256 Tensor::from_slice(&c, shape, device)
257 }
258}