1use std::sync::Arc;
2
3use candle_core::{DType, Device, Result, Tensor};
4use rand_isaac::Isaac64Rng;
5
6use crate::{
7 prefix_cacher::PrefixCacheManagerV2,
8 sampler::Logprobs,
9 sequence::{Sequence, SequenceRecognizer, SequenceState, StopReason},
10 tools::parse_text_tools,
11};
12
13use super::Pipeline;
14
15macro_rules! fixup_sentencepiece {
16 ($txt:expr) => {
17 $txt.to_string().replace("▁", " ")
18 };
19 (Option $txt:expr) => {
20 match &$txt {
21 Some(txt) => Some(fixup_sentencepiece!(txt)),
22 None => None,
23 }
24 };
25}
26
27pub(crate) async fn finish_or_add_toks_to_seq(
28 this: &dyn Pipeline,
29 prefix_cacher: &mut PrefixCacheManagerV2,
30 seq: &mut Sequence,
31 logprobs: Logprobs,
32 eos_tok: Option<&[u32]>,
33 use_prefix_cacher: bool,
34) -> Result<()> {
35 let meta = this.get_metadata();
37 let max_len = meta.max_seq_len;
38 let mut is_done = seq.is_done(logprobs.token, eos_tok, max_len);
39 seq.add_token(
40 logprobs.clone(),
41 this.get_metadata()
42 .tok_env()
43 .ok_or(candle_core::Error::Msg(
44 "`finish_or_add_toks_to_seq` requires the pipeline to have a token trie"
45 .to_string(),
46 ))?
47 .tok_trie()
48 .decode(&[logprobs.token]),
49 &is_done,
50 );
51
52 if let Some(ref t) = seq.tools {
55 if let Ok(Some(ref d)) = seq.peek_delta() {
56 let (_tool_use_still_possible, tool_use_is_done) =
57 t.prefix_could_be_tool(this, d.as_str())?;
58
59 if tool_use_is_done
60 && matches!(
61 parse_text_tools(this, d, seq.tools.clone()),
62 Ok((None, _tools))
63 )
64 {
65 seq.set_state(SequenceState::Done(StopReason::Eos));
66 is_done = Some(StopReason::Eos);
67 }
68 }
69 };
70
71 if seq.get_mut_group().is_streaming {
73 let mut tool_use_still_possible = false;
74 let mut tool_use_is_done = false;
75 if let Some(ref t) = seq.tools {
76 if let Ok(Some(ref d)) = seq.peek_delta() {
77 (tool_use_still_possible, tool_use_is_done) =
78 t.prefix_could_be_tool(this, d.as_str())?;
79 }
80 };
81
82 let send = seq.get_toks().len() % 2 == 0 || is_done.is_some();
83 if !tool_use_still_possible || tool_use_is_done {
84 if send {
85 if let Some(delta) = crate::handle_seq_error_ok!(seq.get_delta(), seq.responder()) {
86 if seq.get_mut_group().is_chat {
87 let (text_new, tool_calls) =
88 parse_text_tools(this, delta.as_str(), seq.tools.clone())
89 .map_err(candle_core::Error::msg)?;
90
91 if !tool_calls.is_empty() && is_done.is_none() {
92 is_done = Some(StopReason::Eos);
93 };
94 seq.add_streaming_chunk_choice_to_group(crate::ChunkChoice {
95 delta: crate::Delta {
96 content: fixup_sentencepiece!(
97 Option text_new.map(ToString::to_string)
98 ),
99 role: "assistant".to_string(),
100 tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
101 },
102 index: seq.get_response_index(),
103 finish_reason: is_done.map(|x| x.to_string()),
104 logprobs: if seq.return_logprobs() {
105 Some(crate::ResponseLogprob {
106 token: delta,
107 bytes: logprobs.bytes.clone().map(|b| b.into_bytes()),
108 logprob: logprobs.logprob,
109 top_logprobs: logprobs.top_logprobs.unwrap().clone(),
110 })
111 } else {
112 None
113 },
114 });
115 } else {
116 seq.add_streaming_completion_chunk_choice_to_group(
117 crate::CompletionChunkChoice {
118 text: fixup_sentencepiece!(delta),
119 index: seq.get_response_index(),
120 finish_reason: is_done.map(|x| x.to_string()),
121 logprobs: if seq.return_logprobs() {
122 Some(crate::ResponseLogprob {
123 token: delta,
124 bytes: logprobs.bytes.clone().map(|b| b.into_bytes()),
125 logprob: logprobs.logprob,
126 top_logprobs: logprobs.top_logprobs.unwrap().clone(),
127 })
128 } else {
129 None
130 },
131 },
132 );
133 }
134 }
135 }
136
137 if let Some(reason) = is_done {
138 if use_prefix_cacher {
139 prefix_cacher.add_sequence(seq);
140 prefix_cacher.evict_to_cpu()?;
141 }
142 seq.set_state(crate::sequence::SequenceState::Done(reason));
143 this.reset_non_granular_state();
144 }
145
146 let usage_opt = if is_done.is_some() {
148 let usage = seq.get_mut_group().get_usage();
149 seq.get_mut_group().total_prompt_toks = 0;
150 seq.get_mut_group().total_toks = 0;
151 Some(usage)
152 } else {
153 None
154 };
155
156 if seq
157 .get_mut_group()
158 .maybe_send_streaming_response(seq, this.name().clone(), usage_opt)
159 .await
160 .is_err()
161 {
162 seq.set_state(crate::sequence::SequenceState::Done(
164 crate::sequence::StopReason::Canceled,
165 ));
166 this.reset_non_granular_state();
167 }
168 }
169 } else if let Some(reason) = is_done {
170 {
176 seq.set_state(crate::sequence::SequenceState::Done(reason));
177 let (tokenizer, pipeline_name) = {
178 let pipeline_name = this.name();
179 let tokenizer = this.tokenizer();
180 (tokenizer, pipeline_name)
181 };
182
183 let logprobs = if seq.return_logprobs() {
184 let mut logprobs = Vec::new();
185 for logprob in seq.logprobs() {
186 let resp_logprob = crate::ResponseLogprob {
187 token: crate::handle_seq_error_ok!(
188 tokenizer
189 .as_ref()
190 .ok_or(candle_core::Error::Msg(
191 "`finish_or_add_toks_to_seq` requires the pipeline to have a tokenizer"
192 .to_string(),
193 ))?.decode(&[logprob.token], false),
194 seq.responder()
195 ),
196 bytes: logprob.bytes.clone().map(|b| b.into_bytes()),
197 logprob: logprob.logprob,
198 top_logprobs: logprob.top_logprobs.clone().unwrap(),
199 };
200 logprobs.push(resp_logprob);
201 }
202 Some(logprobs)
203 } else {
204 None
205 };
206
207 let text = match reason {
208 crate::sequence::StopReason::Length(_)
209 | crate::sequence::StopReason::ModelLength(_)
210 | crate::sequence::StopReason::Eos
211 | crate::sequence::StopReason::StopTok(_)
212 | crate::sequence::StopReason::Canceled => {
213 String::from_utf8_lossy(seq.completion_bytes())
214 .trim_start()
215 .to_string()
216 }
217 crate::sequence::StopReason::StopString {
218 completion_bytes_pos,
219 ..
220 } => {
221 let txt = String::from_utf8_lossy(seq.completion_bytes());
222 txt[..completion_bytes_pos].trim_start().to_string()
223 }
224 crate::sequence::StopReason::GeneratedImage => {
225 candle_core::bail!("Stop reason was `GeneratedImage`.")
226 }
227 };
228
229 if seq.get_mut_group().is_chat {
230 let (text_new, tool_calls) =
231 parse_text_tools(this, text.as_str(), seq.tools.clone())
232 .map_err(candle_core::Error::msg)?;
233 let choice = crate::Choice {
234 finish_reason: fixup_sentencepiece!(reason),
235 index: seq.get_response_index(),
236 message: crate::ResponseMessage {
237 content: text_new.map(ToString::to_string),
238 role: "assistant".to_string(),
239 tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
240 },
241 logprobs: logprobs.map(|l| crate::Logprobs { content: Some(l) }),
242 };
243 seq.add_choice_to_group(choice);
244 } else {
245 let choice = crate::CompletionChoice {
246 finish_reason: fixup_sentencepiece!(reason),
247 index: seq.get_response_index(),
248 text,
249 logprobs: None,
250 };
251 seq.add_completion_choice_to_group(choice);
252 }
253
254 if use_prefix_cacher {
255 prefix_cacher.add_sequence(seq);
256 prefix_cacher.evict_to_cpu()?;
257 }
258
259 let group = seq.get_mut_group();
260 if group.is_chat {
261 group
262 .maybe_send_chat_done_response(
263 crate::ChatCompletionResponse {
264 id: seq.id().to_string(),
265 choices: group.get_choices().to_vec(),
266 created: seq.creation_time(),
267 model: pipeline_name,
268 system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
269 object: "chat.completion".to_string(),
270 usage: group.get_usage(),
271 },
272 seq.responder(),
273 )
274 .await
275 .map_err(candle_core::Error::msg)?;
276 } else {
277 group
278 .maybe_send_completion_done_response(
279 crate::CompletionResponse {
280 id: seq.id().to_string(),
281 choices: group.get_completion_choices().to_vec(),
282 created: seq.creation_time(),
283 model: pipeline_name,
284 system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
285 object: "text_completion".to_string(),
286 usage: group.get_usage(),
287 },
288 seq.responder(),
289 )
290 .await
291 .map_err(candle_core::Error::msg)?;
292 }
293 }
294 this.reset_non_granular_state();
295 }
296
297 Ok(())
298}
299
300pub async fn sample_and_add_toks(
301 this: &dyn Pipeline,
302 seqs: &mut [&mut Sequence],
303 logits_seq: Vec<Tensor>,
304 prefix_cacher: &mut PrefixCacheManagerV2,
305 disable_eos_stop: bool,
306 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
307) -> Result<()> {
308 let seqs_len = seqs.len();
309 debug_assert_eq!(logits_seq.len(), seqs_len);
310
311 let use_async_pool = seqs_len > 1;
312
313 let sampling_futures: Vec<_> = std::iter::zip(logits_seq, seqs.iter_mut())
314 .map(|(logits_per_seq, seq)| {
315 let return_logprobs = seq.return_logprobs();
316 sample_sequence(
317 logits_per_seq,
318 seq,
319 return_logprobs,
320 rng.clone(),
321 use_async_pool,
322 false,
323 )
324 })
325 .collect();
326 let sampled_vec = futures::future::join_all(sampling_futures).await;
327
328 for (sampled, seq) in std::iter::zip(sampled_vec, seqs.iter_mut()) {
329 let next_token = crate::handle_seq_error_stateaware_ok!(sampled, seq);
330
331 let metadata = this.get_metadata();
332 let eos_tok = if disable_eos_stop {
333 None
334 } else {
335 Some(&metadata.eos_tok[..])
336 };
337
338 finish_or_add_toks_to_seq(this, prefix_cacher, seq, next_token, eos_tok, true).await?;
339 }
340
341 Ok(())
342}
343
344#[allow(clippy::too_many_arguments)]
346pub async fn sample_sequence(
347 logits: Tensor,
348 seq: &mut Sequence,
349 return_logprobs: bool,
350 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
351 use_async_pool: bool,
352 sample_speculative: bool,
353) -> Result<Logprobs> {
354 let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
355
356 let sampler = seq.sampler();
357 let ctx_clone = seq.get_toks().to_vec();
358 let rng_clone = rng.clone();
359 let logits_clone = logits.clone();
360 let first_lobprobs_response = if use_async_pool {
361 tokio_rayon::spawn(move || {
362 sampler.sample(
363 logits_clone,
364 &ctx_clone,
365 return_logprobs,
366 rng_clone,
367 sample_speculative,
368 )
369 })
370 .await?
371 } else {
372 sampler.sample(
373 logits_clone,
374 &ctx_clone,
375 return_logprobs,
376 rng_clone,
377 sample_speculative,
378 )?
379 };
380
381 let bias_if_not_allowed = match &mut seq.recognizer {
382 SequenceRecognizer::Llguidance(ref mut llg) => {
383 if !llg.is_stopped()
384 && llg
385 .validate_tokens(&[first_lobprobs_response.token])
386 .unwrap_or(0)
387 == 1
388 {
389 None
390 } else {
391 let mask = llg.compute_mask_or_eos().map_err(candle_core::Error::msg)?;
392 if mask.is_allowed(first_lobprobs_response.token) {
393 None
395 } else {
396 let mut acc = vec![-f32::INFINITY; logits.shape().dims1().unwrap()];
397 mask.iter_set_entries(|idx| {
398 if idx < acc.len() {
399 acc[idx] = 0.0;
400 }
401 });
402
403 Some(acc)
404 }
405 }
406 }
407 SequenceRecognizer::None => None,
408 };
409 let second_logprobs_response = match bias_if_not_allowed {
410 Some(acc) => {
411 let new_logits = (logits + Tensor::from_slice(&acc, acc.len(), &Device::Cpu)?)?;
412
413 let ctx_clone = seq.get_toks().to_vec();
414 let rng_clone = rng.clone();
415 let sampler = seq.sampler();
416 if use_async_pool {
417 tokio_rayon::spawn(move || {
418 sampler.sample(
419 new_logits,
420 &ctx_clone,
421 return_logprobs,
422 rng_clone,
423 sample_speculative,
424 )
425 })
426 .await?
427 } else {
428 sampler.sample(
429 new_logits,
430 &ctx_clone,
431 return_logprobs,
432 rng_clone,
433 sample_speculative,
434 )?
435 }
436 }
437 None => first_lobprobs_response,
438 };
439
440 match seq.recognizer {
441 SequenceRecognizer::Llguidance(ref mut llg) => {
442 if !llg.is_stopped() {
443 llg.consume_token(second_logprobs_response.token)
444 .map_err(candle_core::Error::msg)?;
445 }
446 }
447 SequenceRecognizer::None => {}
448 }
449
450 Ok(second_logprobs_response)
451}
452
453#[derive(Clone)]
454pub struct SpeculativeSample {
455 pub sample: Logprobs,
456}
457
458pub async fn sample_target_sequence_speculative(
460 logits: Tensor,
461 seq: &mut Sequence,
462 return_logprobs: bool,
463 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
464 draft_samples: &[SpeculativeSample],
465) -> Result<Vec<SpeculativeSample>> {
466 let n_toks = draft_samples.len();
467
468 match &mut seq.recognizer {
470 SequenceRecognizer::Llguidance(ref mut llg) => {
471 llg.rollback(n_toks).map_err(candle_core::Error::msg)?;
472 }
473 SequenceRecognizer::None => {}
474 }
475
476 let mut sampled = Vec::new();
477 for (chunk, draft) in logits
478 .chunk(n_toks, 1)?
479 .into_iter()
480 .zip(draft_samples.iter())
481 {
482 let sample = sample_sequence(
483 chunk,
484 seq,
485 return_logprobs,
486 rng.clone(),
487 true, true,
489 )
490 .await?;
491 let sampled_token = sample.token;
492 sampled.push(SpeculativeSample { sample });
493 if sampled_token != draft.sample.token {
494 break;
495 }
496 }
497 Ok(sampled)
498}