mistralrs_quant/rotary/
mod.rs

1#[cfg(feature = "cuda")]
2mod ffi;
3
4#[cfg(feature = "cuda")]
5mod cuda {
6    use candle_core::cuda_backend::cudarc::driver::DevicePtr;
7    use candle_core::{DType, Result, Storage, Tensor};
8    use half::{bf16, f16};
9    use std::ffi::{c_int, c_long};
10
11    fn apply_rotary_<
12        T: candle_core::cuda_backend::CudaDType
13            + candle_core::cuda_backend::cudarc::driver::DeviceRepr,
14    >(
15        query: &Tensor,
16        key: &Tensor,
17        cos_cache: &Tensor,
18        sin_cache: &Tensor,
19        is_neox: bool,
20    ) -> Result<()> {
21        let dtype = query.dtype();
22        if key.dtype() != dtype || cos_cache.dtype() != dtype || sin_cache.dtype() != dtype {
23            candle_core::bail!("apply-rotary expects all tensors to have the same dtype");
24        }
25
26        let internal_type = match dtype {
27            DType::F16 => 0,
28            DType::BF16 => 1,
29            DType::F32 => 2,
30            dtype => candle_core::bail!("dtype {dtype:?} is not supported"),
31        };
32
33        let (q, q_l) = query.storage_and_layout();
34        let q = match &*q {
35            Storage::Cuda(q) => q,
36            _ => candle_core::bail!("query must be a cuda tensor"),
37        };
38
39        let (k, k_l) = key.storage_and_layout();
40        let k = match &*k {
41            Storage::Cuda(k) => k,
42            _ => candle_core::bail!("key must be a cuda tensor"),
43        };
44
45        let (cc, cc_l) = cos_cache.storage_and_layout();
46        let cc = match &*cc {
47            Storage::Cuda(cc) => cc,
48            _ => candle_core::bail!("cos_cache must be a cuda tensor"),
49        };
50
51        let (sc, sc_l) = sin_cache.storage_and_layout();
52        let sc = match &*sc {
53            Storage::Cuda(sc) => sc,
54            _ => candle_core::bail!("sin_cache must be a cuda tensor"),
55        };
56
57        let q_rank = q_l.stride().len();
58        let k_rank = k_l.stride().len();
59        let cc_rank = cc_l.stride().len();
60        let sc_rank = sc_l.stride().len();
61
62        if q_rank != 3 || k_rank != 3 {
63            candle_core::bail!(
64                "apply-rotary expects input tensors of rank 3 (k: {q_l:?}, v: {k_l:?})"
65            )
66        }
67
68        if cc_rank != 2 || sc_rank != 2 {
69            candle_core::bail!(
70                "apply-rotary expects cache tensors of rank 2 (k: {cc_l:?}, v: {sc_l:?})"
71            )
72        }
73
74        // Get cuda slices for all tensors
75        let q = q.as_cuda_slice::<T>()?;
76        let k = k.as_cuda_slice::<T>()?;
77        let cc = cc.as_cuda_slice::<T>()?;
78        let sc = sc.as_cuda_slice::<T>()?;
79
80        // Get cuda views for all tensors
81        let q = q.slice(q_l.start_offset()..);
82        let k = k.slice(k_l.start_offset()..);
83        let cc = cc.slice(cc_l.start_offset()..);
84        let sc = sc.slice(sc_l.start_offset()..);
85
86        let (num_tokens, num_heads, head_size) = q_l.shape().dims3()?;
87        let (num_tokens_kv, num_kv_heads, head_size_kv) = k_l.shape().dims3()?;
88
89        if (num_tokens, head_size) != (num_tokens_kv, head_size_kv) {
90            candle_core::bail!("shape mismatch q {:?} and k {:?}", q_l.shape(), k_l.shape())
91        }
92
93        let rot_dim = cc_l.dims()[1];
94        if (num_tokens, rot_dim) != cc_l.shape().dims2()? {
95            candle_core::bail!(
96                "shape mismatch cos_cache {:?}, expected {:?}",
97                cc_l.shape(),
98                (num_tokens, rot_dim)
99            )
100        }
101
102        if (num_tokens, rot_dim) != sc_l.shape().dims2()? {
103            candle_core::bail!(
104                "shape mismatch sin_cache {:?}, expected {:?}",
105                sc_l.shape(),
106                (num_tokens, rot_dim)
107            )
108        }
109
110        let query_stride = q_l.stride()[0];
111        let key_stride = k_l.stride()[0];
112
113        let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
114        let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
115        let cc_ptr = *cc.device_ptr() as *const core::ffi::c_void;
116        let sc_ptr = *sc.device_ptr() as *const core::ffi::c_void;
117
118        let neox = if is_neox { 1 } else { 0 };
119
120        unsafe {
121            super::ffi::rotary_embedding(
122                q_ptr,
123                k_ptr,
124                cc_ptr,
125                sc_ptr,
126                neox,
127                head_size as c_int,
128                num_tokens as c_long,
129                rot_dim as c_int,
130                num_heads as c_int,
131                num_kv_heads as c_int,
132                query_stride as c_long,
133                key_stride as c_long,
134                internal_type,
135            )
136        }
137        Ok(())
138    }
139
140    /// Apply Rotary position encoding inplace
141    ///
142    /// # Arguments
143    ///
144    /// * `query` - Query tensor of shape `(num_tokens, num_heads, head_size)`.
145    /// * `key` - Key tensor of shape `(num_tokens, num_kv_heads, head_size)`.
146    /// * `cos_cache` - Aligned cache of shape `(num_tokens, rot_dim)`
147    /// * `sin_cache` - Aligned cache of shape `(num_tokens, rot_dim)`
148    /// * `is_neox` - Use neox encoding instead of gpt-j style rotary
149    pub fn apply_rotary_inplace(
150        query: &Tensor,
151        key: &Tensor,
152        cos_cache: &Tensor,
153        sin_cache: &Tensor,
154        is_neox: bool,
155    ) -> Result<()> {
156        match key.dtype() {
157            DType::F16 => apply_rotary_::<f16>(query, key, cos_cache, sin_cache, is_neox),
158            DType::BF16 => apply_rotary_::<bf16>(query, key, cos_cache, sin_cache, is_neox),
159            DType::F32 => apply_rotary_::<f32>(query, key, cos_cache, sin_cache, is_neox),
160            dt => {
161                candle_core::bail!("apply_rotary is only supported for f32, f16 and bf16 ({dt:?})")
162            }
163        }
164    }
165}
166
167#[cfg(feature = "cuda")]
168pub use cuda::*;
169
170/// Apply Rotary position encoding inplace
171///
172/// # Arguments
173///
174/// * `query` - Query tensor of shape `(num_tokens, num_heads, head_size)`.
175/// * `key` - Key tensor of shape `(num_tokens, num_kv_heads, head_size)`.
176/// * `cos_cache` - Aligned cache of shape `(num_tokens, rot_dim)`
177/// * `sin_cache` - Aligned cache of shape `(num_tokens, rot_dim)`
178/// * `is_neox` - Use neox encoding instead of gpt-j style rotary
179#[cfg(not(feature = "cuda"))]
180pub fn apply_rotary_inplace(
181    _query: &candle_core::Tensor,
182    _key: &candle_core::Tensor,
183    _cos_cache: &candle_core::Tensor,
184    _sin_cache: &candle_core::Tensor,
185    _is_neox: bool,
186) -> candle_core::Result<()> {
187    candle_core::bail!("apply_rotary is only supported for cuda");
188}