mistralrs_quant/utils/
uqff.rs

1use byteorder::{LittleEndian, ReadBytesExt};
2
3use candle_core::{DType, Device, Result, Tensor, WithDType};
4use float8::F8E4M3;
5use half::{bf16, f16};
6
7// v0.1.0: initial release
8// v0.1.1: add i16 dtype
9// v0.1.2: add F8E4M3
10
11const UQFF_VERSION_MAJOR: u32 = 0;
12const UQFF_VERSION_MINOR: u32 = 1;
13const UQFF_VERSION_PATCH: u32 = 2;
14
15/// Format 4 bytes, little endian: [ UNSPECIFIED ] [ MAJOR ] [ MINOR ] [ PATCH ]
16pub(crate) const UQFF_VERSION: u32 =
17    (UQFF_VERSION_MAJOR << (8 * 2)) | (UQFF_VERSION_MINOR << 8) | UQFF_VERSION_PATCH;
18/// Offset for the quant type. UQFF always serializes the version first.
19pub const UQFF_QUANT_TYPE_OFFSET: usize = std::mem::size_of::<u32>();
20
21/// Check if major version matches: is backwards compatible
22pub(crate) fn version_is_compatible(version: u32) -> Result<()> {
23    let major = version >> (8 * 2);
24    let _minor = version >> 8;
25    let _patch = version;
26
27    if major != UQFF_VERSION_MAJOR {
28        candle_core::bail!("Major version of ISQ artifact file ({major}) does not match the implementation in this build ({UQFF_VERSION_MAJOR})");
29    }
30
31    Ok(())
32}
33
34// -----------------------
35// Tensor dtype, u32, little endian
36// -----------------------
37pub(crate) fn write_dtype(dtype: DType, buffer: &mut Vec<u8>) {
38    let dtype: u32 = match dtype {
39        DType::U8 => 0,
40        DType::U32 => 1,
41        DType::I32 => 2,
42        DType::I64 => 3,
43        DType::F16 => 4,
44        DType::BF16 => 5,
45        DType::F32 => 6,
46        DType::F64 => 7,
47        DType::I16 => 8,
48        DType::F8E4M3 => 9,
49    };
50    buffer.extend(&dtype.to_le_bytes());
51}
52
53pub(crate) fn read_dtype<R: std::io::Read>(buffer: &mut R) -> Result<DType> {
54    let dtype = buffer.read_u32::<LittleEndian>()?;
55    let dtype = match dtype {
56        0 => DType::U8,
57        1 => DType::U32,
58        2 => DType::I32,
59        3 => DType::I64,
60        4 => DType::F16,
61        5 => DType::BF16,
62        6 => DType::F32,
63        7 => DType::F64,
64        8 => DType::I16,
65        9 => DType::F8E4M3,
66        _ => candle_core::bail!("unknown dtype for quantized tensor {dtype}"),
67    };
68    Ok(dtype)
69}
70
71// -----------------------
72// Tensor data length, u32, little endian
73// -----------------------
74// Tensor dtype, u32, little endian
75// -----------------------
76// Num shape dims, u32, little endian
77// -----------------------
78// ...
79// Array (in original order): shape dims, u32, little endian
80// ...
81// -----------------------
82// ...
83// Array: tensor data, u8s
84// ...
85// -----------------------
86
87pub(crate) fn serialize_tensor(buffer: &mut Vec<u8>, tensor: &Tensor) -> Result<()> {
88    let b_shape = tensor.dims();
89    let tensor = tensor.flatten_all()?;
90
91    let bias = match tensor.dtype() {
92        DType::U8 => data_to_bytes::<u8>(tensor.to_vec1()?),
93        DType::U32 => data_to_bytes::<u32>(tensor.to_vec1()?),
94        DType::I16 => data_to_bytes::<i16>(tensor.to_vec1()?),
95        DType::I32 => data_to_bytes::<i32>(tensor.to_vec1()?),
96        DType::I64 => data_to_bytes::<i64>(tensor.to_vec1()?),
97        DType::F16 => data_to_bytes::<half::f16>(tensor.to_vec1()?),
98        DType::BF16 => data_to_bytes::<half::bf16>(tensor.to_vec1()?),
99        DType::F32 => data_to_bytes::<f32>(tensor.to_vec1()?),
100        DType::F64 => data_to_bytes::<f64>(tensor.to_vec1()?),
101        DType::F8E4M3 => data_to_bytes::<F8E4M3>(tensor.to_vec1()?),
102    };
103    buffer.extend(&(bias.len() as u32).to_le_bytes());
104
105    // DType
106    write_dtype(tensor.dtype(), buffer);
107
108    // Shape
109    buffer.extend((b_shape.len() as u32).to_le_bytes());
110    for dim in b_shape {
111        buffer.extend((*dim as u32).to_le_bytes());
112    }
113
114    buffer.extend(bias);
115
116    Ok(())
117}
118
119pub(crate) fn deserialize_tensor<R: std::io::Read>(
120    buffer: &mut R,
121    device: &Device,
122) -> Result<Tensor> {
123    let data_len = buffer.read_u32::<LittleEndian>()? as usize;
124
125    // DType
126    let dtype = read_dtype(buffer)?;
127
128    let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
129
130    let mut dims = Vec::with_capacity(n_dims);
131    for _ in 0..n_dims {
132        dims.push(buffer.read_u32::<LittleEndian>()? as usize)
133    }
134
135    let mut tensor_data = vec![0; data_len];
136    buffer.read_exact(&mut tensor_data)?;
137
138    match dtype {
139        DType::F16 => bytes_to_data::<f16>(&tensor_data, &dims, device),
140        DType::BF16 => bytes_to_data::<bf16>(&tensor_data, &dims, device),
141        DType::F32 => bytes_to_data::<f32>(&tensor_data, &dims, device),
142        DType::F64 => bytes_to_data::<f64>(&tensor_data, &dims, device),
143        DType::I32 => bytes_to_data::<i32>(&tensor_data, &dims, device),
144        DType::I64 => bytes_to_data::<i64>(&tensor_data, &dims, device),
145        DType::I16 => bytes_to_data::<i16>(&tensor_data, &dims, device),
146        DType::U32 => bytes_to_data::<u32>(&tensor_data, &dims, device),
147        DType::U8 => bytes_to_data::<u8>(&tensor_data, &dims, device),
148        DType::F8E4M3 => bytes_to_data::<F8E4M3>(&tensor_data, &dims, device),
149    }
150}
151
152/// Just seek the reader ahead.
153pub(crate) fn fake_deserialize_tensor<R: std::io::Read + std::io::Seek>(
154    buffer: &mut R,
155) -> Result<()> {
156    let data_len = buffer.read_u32::<LittleEndian>()? as usize;
157
158    // DType
159    let _dtype = read_dtype(buffer)?;
160
161    let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
162
163    let mut dims = Vec::with_capacity(n_dims);
164    for _ in 0..n_dims {
165        dims.push(buffer.read_u32::<LittleEndian>()? as usize)
166    }
167
168    // Fake read the data in bytes
169    buffer.seek_relative(data_len as i64)?;
170
171    Ok(())
172}
173
174fn data_to_bytes<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
175    let size_in_bytes = T::DTYPE.size_in_bytes();
176    let length = vs.len() * size_in_bytes;
177    let capacity = vs.capacity() * size_in_bytes;
178    let ptr = vs.as_mut_ptr() as *mut u8;
179    // Don't run the destructor for Vec<T>
180    std::mem::forget(vs);
181    // SAFETY:
182    //
183    // Every T is larger than u8, so there is no issue regarding alignment.
184    // This re-interpret the Vec<T> as a Vec<u8>.
185    unsafe { Vec::from_raw_parts(ptr, length, capacity) }
186}
187
188fn bytes_to_data<T: WithDType>(
189    data: &[u8],
190    shape: &[usize],
191    device: &candle_core::Device,
192) -> Result<Tensor> {
193    let size_in_bytes = T::DTYPE.size_in_bytes();
194    let elem_count = data.len() / size_in_bytes;
195    if (data.as_ptr() as usize) % size_in_bytes == 0 {
196        // SAFETY This is safe because we just checked that this
197        // was correctly aligned.
198        let data: &[T] =
199            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
200        Tensor::from_slice(data, shape, device)
201    } else {
202        // XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
203        // Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
204        let mut c: Vec<T> = Vec::with_capacity(elem_count);
205        // SAFETY: We just created c, so the allocated memory is necessarily
206        // contiguous and non overlapping with the view's data.
207        // We're downgrading the `c` pointer from T to u8, which removes alignment
208        // constraints.
209        unsafe {
210            std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
211            c.set_len(elem_count)
212        }
213        Tensor::from_slice(&c, shape, device)
214    }
215}