mistralrs_quant/rotary/
mod.rs

1#[cfg(feature = "cuda")]
2mod ffi;
3
4#[cfg(feature = "cuda")]
5mod cuda {
6    use candle_core::{DType, Result, Storage, Tensor};
7    use half::{bf16, f16};
8    use std::ffi::{c_int, c_long};
9
10    use crate::utils::slice_ptr;
11
12    fn apply_rotary_<
13        T: candle_core::cuda_backend::CudaDType
14            + candle_core::cuda_backend::cudarc::driver::DeviceRepr,
15    >(
16        query: &Tensor,
17        key: &Tensor,
18        cos_cache: &Tensor,
19        sin_cache: &Tensor,
20        is_neox: bool,
21    ) -> Result<()> {
22        let dtype = query.dtype();
23        if key.dtype() != dtype || cos_cache.dtype() != dtype || sin_cache.dtype() != dtype {
24            candle_core::bail!("apply-rotary expects all tensors to have the same dtype");
25        }
26
27        let internal_type = match dtype {
28            DType::F16 => 0,
29            DType::BF16 => 1,
30            DType::F32 => 2,
31            dtype => candle_core::bail!("dtype {dtype:?} is not supported"),
32        };
33
34        let (q, q_l) = query.storage_and_layout();
35        let q = match &*q {
36            Storage::Cuda(q) => q,
37            _ => candle_core::bail!("query must be a cuda tensor"),
38        };
39
40        let (k, k_l) = key.storage_and_layout();
41        let k = match &*k {
42            Storage::Cuda(k) => k,
43            _ => candle_core::bail!("key must be a cuda tensor"),
44        };
45
46        let (cc, cc_l) = cos_cache.storage_and_layout();
47        let cc = match &*cc {
48            Storage::Cuda(cc) => cc,
49            _ => candle_core::bail!("cos_cache must be a cuda tensor"),
50        };
51
52        let (sc, sc_l) = sin_cache.storage_and_layout();
53        let sc = match &*sc {
54            Storage::Cuda(sc) => sc,
55            _ => candle_core::bail!("sin_cache must be a cuda tensor"),
56        };
57
58        let q_rank = q_l.stride().len();
59        let k_rank = k_l.stride().len();
60        let cc_rank = cc_l.stride().len();
61        let sc_rank = sc_l.stride().len();
62
63        if q_rank != 3 || k_rank != 3 {
64            candle_core::bail!(
65                "apply-rotary expects input tensors of rank 3 (k: {q_l:?}, v: {k_l:?})"
66            )
67        }
68
69        if cc_rank != 2 || sc_rank != 2 {
70            candle_core::bail!(
71                "apply-rotary expects cache tensors of rank 2 (k: {cc_l:?}, v: {sc_l:?})"
72            )
73        }
74
75        // Get cuda slices for all tensors
76        let q = q.as_cuda_slice::<T>()?;
77        let k = k.as_cuda_slice::<T>()?;
78        let cc = cc.as_cuda_slice::<T>()?;
79        let sc = sc.as_cuda_slice::<T>()?;
80
81        // Get cuda views for all tensors
82        let (q, _q_guard) = slice_ptr(q, q_l.start_offset());
83        let (k, _k_guard) = slice_ptr(k, k_l.start_offset());
84        let (cc, _cc_guard) = slice_ptr(cc, cc_l.start_offset());
85        let (sc, _sc_guard) = slice_ptr(sc, sc_l.start_offset());
86
87        let (num_tokens, num_heads, head_size) = q_l.shape().dims3()?;
88        let (num_tokens_kv, num_kv_heads, head_size_kv) = k_l.shape().dims3()?;
89
90        if (num_tokens, head_size) != (num_tokens_kv, head_size_kv) {
91            candle_core::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape())
92        }
93
94        let rot_dim = cc_l.dims()[1];
95        if (num_tokens, rot_dim) != cc_l.shape().dims2()? {
96            candle_core::bail!(
97                "shape mismatch cos_cache {:?}, expected {:?}",
98                cc_l.shape(),
99                (num_tokens, rot_dim)
100            )
101        }
102
103        if (num_tokens, rot_dim) != sc_l.shape().dims2()? {
104            candle_core::bail!(
105                "shape mismatch sin_cache {:?}, expected {:?}",
106                sc_l.shape(),
107                (num_tokens, rot_dim)
108            )
109        }
110
111        let query_stride = q_l.stride()[0];
112        let key_stride = k_l.stride()[0];
113
114        let neox = if is_neox { 1 } else { 0 };
115
116        unsafe {
117            super::ffi::rotary_embedding(
118                q as *const core::ffi::c_void,
119                k as *const core::ffi::c_void,
120                cc as *const core::ffi::c_void,
121                sc as *const core::ffi::c_void,
122                neox,
123                head_size as c_int,
124                num_tokens as c_long,
125                rot_dim as c_int,
126                num_heads as c_int,
127                num_kv_heads as c_int,
128                query_stride as c_long,
129                key_stride as c_long,
130                internal_type,
131            )
132        }
133        Ok(())
134    }
135
136    /// Apply Rotary position encoding inplace
137    ///
138    /// # Arguments
139    ///
140    /// * `query` - Query tensor of shape `(num_tokens, num_heads, head_size)`.
141    /// * `key` - Key tensor of shape `(num_tokens, num_kv_heads, head_size)`.
142    /// * `cos_cache` - Aligned cache of shape `(num_tokens, rot_dim)`
143    /// * `sin_cache` - Aligned cache of shape `(num_tokens, rot_dim)`
144    /// * `is_neox` - Use neox encoding instead of gpt-j style rotary
145    pub fn apply_rotary_inplace(
146        query: &Tensor,
147        key: &Tensor,
148        cos_cache: &Tensor,
149        sin_cache: &Tensor,
150        is_neox: bool,
151    ) -> Result<()> {
152        match key.dtype() {
153            DType::F16 => apply_rotary_::<f16>(query, key, cos_cache, sin_cache, is_neox),
154            DType::BF16 => apply_rotary_::<bf16>(query, key, cos_cache, sin_cache, is_neox),
155            DType::F32 => apply_rotary_::<f32>(query, key, cos_cache, sin_cache, is_neox),
156            dt => {
157                candle_core::bail!("apply_rotary is only supported for f32, f16 and bf16 ({dt:?})")
158            }
159        }
160    }
161}
162
163#[cfg(feature = "cuda")]
164pub use cuda::*;
165
166/// Apply Rotary position encoding inplace
167///
168/// # Arguments
169///
170/// * `query` - Query tensor of shape `(num_tokens, num_heads, head_size)`.
171/// * `key` - Key tensor of shape `(num_tokens, num_kv_heads, head_size)`.
172/// * `cos_cache` - Aligned cache of shape `(num_tokens, rot_dim)`
173/// * `sin_cache` - Aligned cache of shape `(num_tokens, rot_dim)`
174/// * `is_neox` - Use neox encoding instead of gpt-j style rotary
175#[cfg(not(feature = "cuda"))]
176pub fn apply_rotary_inplace(
177    _query: &candle_core::Tensor,
178    _key: &candle_core::Tensor,
179    _cos_cache: &candle_core::Tensor,
180    _sin_cache: &candle_core::Tensor,
181    _is_neox: bool,
182) -> candle_core::Result<()> {
183    candle_core::bail!("apply_rotary is only supported for cuda");
184}