1use std::sync::Arc;
2
3use candle_core::{DType, Result, Tensor};
4use rand_isaac::Isaac64Rng;
5
6use crate::{
7 prefix_cacher::PrefixCacheManagerV2,
8 sampler::Logprobs,
9 sequence::{Sequence, SequenceRecognizer, SequenceState, StopReason},
10 tools::{parse_text_tools, ToolCallResponse, ToolCallType},
11};
12use mistralrs_mcp::CalledFunction;
13
14use super::Pipeline;
15
16macro_rules! fixup_sentencepiece {
17 ($txt:expr) => {
18 $txt.to_string().replace("▁", " ")
19 };
20 (Option $txt:expr) => {
21 match &$txt {
22 Some(txt) => Some(fixup_sentencepiece!(txt)),
23 None => None,
24 }
25 };
26}
27
28pub(crate) async fn finish_or_add_toks_to_seq(
29 this: &dyn Pipeline,
30 prefix_cacher: &mut PrefixCacheManagerV2,
31 seq: &mut Sequence,
32 logprobs: Logprobs,
33 eos_tok: Option<&[u32]>,
34 use_prefix_cacher: bool,
35) -> Result<()> {
36 let mut is_done = seq.is_done(logprobs.token, eos_tok, this.get_metadata().max_seq_len);
37 seq.add_token(
38 logprobs.clone(),
39 this.get_metadata()
40 .tok_env()
41 .ok_or(candle_core::Error::Msg(
42 "`finish_or_add_toks_to_seq` requires the pipeline to have a token trie"
43 .to_string(),
44 ))?
45 .tok_trie()
46 .decode(&[logprobs.token]),
47 &is_done,
48 );
49
50 if let Some(ref t) = seq.tools {
53 if let Ok(Some(ref d)) = seq.peek_delta() {
54 let (_tool_use_still_possible, tool_use_is_done) =
55 t.prefix_could_be_tool(this, d.as_str())?;
56
57 if tool_use_is_done
58 && matches!(
59 parse_text_tools(this, d, seq.tools.clone()),
60 Ok((None, _tools))
61 )
62 {
63 seq.set_state(SequenceState::Done(StopReason::Eos));
64 is_done = Some(StopReason::Eos);
65 }
66 }
67 };
68
69 if seq.get_mut_group().is_streaming {
71 let mut tool_use_still_possible = false;
72 let mut tool_use_is_done = false;
73 if let Some(ref t) = seq.tools {
74 if let Ok(Some(ref d)) = seq.peek_delta() {
75 (tool_use_still_possible, tool_use_is_done) =
76 t.prefix_could_be_tool(this, d.as_str())?;
77 }
78 };
79
80 let send = true;
82 if !tool_use_still_possible || tool_use_is_done || is_done.is_some() {
87 if send {
88 let delta_result = seq.get_delta();
89 if let Some(delta) = crate::handle_seq_error_ok!(delta_result, seq.responder()) {
90 if seq.get_mut_group().is_chat {
91 let (content_delta, reasoning_delta) = if seq.is_harmony_mode() {
93 let final_delta = seq.get_harmony_final_delta();
95 let reasoning = seq.get_harmony_reasoning_delta();
96 (final_delta, reasoning)
97 } else {
98 let (text_new, _) =
100 parse_text_tools(this, delta.as_str(), seq.tools.clone())
101 .map_err(candle_core::Error::msg)?;
102 (text_new.map(ToString::to_string), None)
103 };
104
105 let tool_calls = if seq.is_harmony_mode() {
107 if is_done.is_some() && seq.has_harmony_tool_calls() {
111 is_done = Some(StopReason::ToolCalls);
113 let harmony_tool_calls = seq.get_harmony_tool_calls();
114 harmony_tool_calls
115 .into_iter()
116 .enumerate()
117 .map(|(i, tc)| ToolCallResponse {
118 index: i,
119 id: tc.id,
120 tp: ToolCallType::Function,
121 function: CalledFunction {
122 name: tc.name,
123 arguments: tc.arguments,
124 },
125 })
126 .collect()
127 } else {
128 vec![]
129 }
130 } else {
131 let (_, tool_calls) =
133 parse_text_tools(this, delta.as_str(), seq.tools.clone())
134 .map_err(candle_core::Error::msg)?;
135 if !tool_calls.is_empty() {
136 is_done = Some(StopReason::ToolCalls);
137 }
138 tool_calls
139 };
140
141 seq.add_streaming_chunk_choice_to_group(crate::ChunkChoice {
142 delta: crate::Delta {
143 content: fixup_sentencepiece!(Option content_delta),
144 role: "assistant".to_string(),
145 tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
146 reasoning_content: reasoning_delta,
147 },
148 index: seq.get_response_index(),
149 finish_reason: is_done.map(|x| x.to_string()),
150 logprobs: if seq.return_logprobs() {
151 Some(crate::ResponseLogprob {
152 token: delta,
153 bytes: logprobs.bytes.clone().map(|b| b.into_bytes()),
154 logprob: logprobs.logprob,
155 top_logprobs: logprobs.top_logprobs.unwrap().clone(),
156 })
157 } else {
158 None
159 },
160 });
161 } else {
162 seq.add_streaming_completion_chunk_choice_to_group(
163 crate::CompletionChunkChoice {
164 text: fixup_sentencepiece!(delta),
165 index: seq.get_response_index(),
166 finish_reason: is_done.map(|x| x.to_string()),
167 logprobs: if seq.return_logprobs() {
168 Some(crate::ResponseLogprob {
169 token: delta,
170 bytes: logprobs.bytes.clone().map(|b| b.into_bytes()),
171 logprob: logprobs.logprob,
172 top_logprobs: logprobs.top_logprobs.unwrap().clone(),
173 })
174 } else {
175 None
176 },
177 },
178 );
179 }
180 }
181 }
182
183 let usage_opt = if is_done.is_some() {
185 let usage = seq.get_mut_group().get_usage();
186 seq.get_mut_group().total_prompt_toks = 0;
187 seq.get_mut_group().total_toks = 0;
188 Some(usage)
189 } else {
190 None
191 };
192
193 if seq
194 .get_mut_group()
195 .maybe_send_streaming_response(seq, this.name().clone(), usage_opt)
196 .await
197 .is_err()
198 {
199 seq.set_state(crate::sequence::SequenceState::Done(
201 crate::sequence::StopReason::Canceled,
202 ));
203 this.reset_non_granular_state();
204 }
205 }
206
207 if let Some(reason) = is_done {
210 if use_prefix_cacher {
211 prefix_cacher.add_sequence(seq);
212 prefix_cacher.evict_caches()?;
213 }
214 seq.set_state(crate::sequence::SequenceState::Done(reason));
215 this.reset_non_granular_state();
216 }
217 } else if let Some(mut reason) = is_done {
218 {
224 seq.set_state(crate::sequence::SequenceState::Done(reason));
225 let (tokenizer, pipeline_name) = {
226 let pipeline_name = this.name();
227 let tokenizer = this.tokenizer();
228 (tokenizer, pipeline_name)
229 };
230
231 let logprobs = if seq.return_logprobs() {
232 let mut logprobs = Vec::new();
233 for logprob in seq.logprobs() {
234 let resp_logprob = crate::ResponseLogprob {
235 token: crate::handle_seq_error_ok!(
236 tokenizer
237 .as_ref()
238 .ok_or(candle_core::Error::Msg(
239 "`finish_or_add_toks_to_seq` requires the pipeline to have a tokenizer"
240 .to_string(),
241 ))?.decode(&[logprob.token], false),
242 seq.responder()
243 ),
244 bytes: logprob.bytes.clone().map(|b| b.into_bytes()),
245 logprob: logprob.logprob,
246 top_logprobs: logprob.top_logprobs.clone().unwrap(),
247 };
248 logprobs.push(resp_logprob);
249 }
250 Some(logprobs)
251 } else {
252 None
253 };
254
255 seq.harmony_process_eos();
257
258 let text = match reason {
259 crate::sequence::StopReason::Length(_)
260 | crate::sequence::StopReason::ModelLength(_)
261 | crate::sequence::StopReason::Eos
262 | crate::sequence::StopReason::StopTok(_)
263 | crate::sequence::StopReason::Canceled
264 | crate::sequence::StopReason::ToolCalls => {
265 String::from_utf8_lossy(seq.completion_bytes())
266 .trim_start()
267 .to_string()
268 }
269 crate::sequence::StopReason::StopString {
270 completion_bytes_pos,
271 ..
272 } => {
273 let txt = String::from_utf8_lossy(seq.completion_bytes());
274 txt[..completion_bytes_pos].trim_start().to_string()
275 }
276 crate::sequence::StopReason::GeneratedImage
277 | crate::sequence::StopReason::GeneratedSpeech => {
278 candle_core::bail!("Stop reason was `GeneratedImage`.")
279 }
280 };
281
282 if seq.get_mut_group().is_chat {
283 let (text_new, tool_calls, reasoning_content) = if seq.is_harmony_mode() {
285 let final_content = seq.get_harmony_final_content();
286 let reasoning = seq.get_harmony_reasoning_content();
287
288 let harmony_tool_calls = seq.get_harmony_tool_calls();
290 let tool_calls: Vec<ToolCallResponse> = harmony_tool_calls
291 .into_iter()
292 .enumerate()
293 .map(|(i, tc)| ToolCallResponse {
294 index: i,
295 id: tc.id,
296 tp: ToolCallType::Function,
297 function: CalledFunction {
298 name: tc.name,
299 arguments: tc.arguments,
300 },
301 })
302 .collect();
303
304 (final_content, tool_calls, reasoning)
305 } else {
306 let (text_new, tool_calls) =
308 parse_text_tools(this, text.as_str(), seq.tools.clone())
309 .map_err(candle_core::Error::msg)?;
310 (text_new.map(ToString::to_string), tool_calls, None)
311 };
312
313 if !tool_calls.is_empty() {
314 reason = StopReason::ToolCalls;
315 }
316
317 let choice = crate::Choice {
318 finish_reason: fixup_sentencepiece!(reason),
319 index: seq.get_response_index(),
320 message: crate::ResponseMessage {
321 content: text_new,
322 role: "assistant".to_string(),
323 tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
324 reasoning_content,
325 },
326 logprobs: logprobs.map(|l| crate::Logprobs { content: Some(l) }),
327 };
328 seq.add_choice_to_group(choice);
329 } else {
330 let choice = crate::CompletionChoice {
331 finish_reason: fixup_sentencepiece!(reason),
332 index: seq.get_response_index(),
333 text,
334 logprobs: None,
335 };
336 seq.add_completion_choice_to_group(choice);
337 }
338
339 if use_prefix_cacher {
340 prefix_cacher.add_sequence(seq);
341 prefix_cacher.evict_caches()?;
342 }
343
344 let group = seq.get_mut_group();
345 if group.is_chat {
346 group
347 .maybe_send_chat_done_response(
348 crate::ChatCompletionResponse {
349 id: seq.id().to_string(),
350 choices: group.get_choices().to_vec(),
351 created: seq.creation_time(),
352 model: pipeline_name,
353 system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
354 object: "chat.completion".to_string(),
355 usage: group.get_usage(),
356 },
357 seq.responder(),
358 )
359 .await
360 .map_err(candle_core::Error::msg)?;
361 } else {
362 group
363 .maybe_send_completion_done_response(
364 crate::CompletionResponse {
365 id: seq.id().to_string(),
366 choices: group.get_completion_choices().to_vec(),
367 created: seq.creation_time(),
368 model: pipeline_name,
369 system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
370 object: "text_completion".to_string(),
371 usage: group.get_usage(),
372 },
373 seq.responder(),
374 )
375 .await
376 .map_err(candle_core::Error::msg)?;
377 }
378 }
379 this.reset_non_granular_state();
380 }
381
382 Ok(())
383}
384
385pub async fn sample_and_add_toks(
386 this: &dyn Pipeline,
387 seqs: &mut [&mut Sequence],
388 logits_seq: Vec<Tensor>,
389 prefix_cacher: &mut PrefixCacheManagerV2,
390 disable_eos_stop: bool,
391 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
392) -> Result<()> {
393 let seqs_len = seqs.len();
394 debug_assert_eq!(logits_seq.len(), seqs_len);
395
396 let use_async_pool = seqs_len > 1;
397
398 let sampling_futures: Vec<_> = std::iter::zip(logits_seq, seqs.iter_mut())
399 .map(|(logits_per_seq, seq)| {
400 let return_logprobs = seq.return_logprobs();
401 sample_sequence(
402 logits_per_seq,
403 seq,
404 return_logprobs,
405 rng.clone(),
406 use_async_pool,
407 false,
408 use_async_pool,
409 )
410 })
411 .collect();
412 let sampled_vec = futures::future::join_all(sampling_futures).await;
413
414 for (sampled, seq) in std::iter::zip(sampled_vec, seqs.iter_mut()) {
415 let next_token = crate::handle_seq_error_stateaware_ok!(sampled, seq);
416
417 let metadata = this.get_metadata();
418 let eos_tok = if disable_eos_stop {
419 None
420 } else {
421 Some(&metadata.eos_tok[..])
422 };
423
424 finish_or_add_toks_to_seq(this, prefix_cacher, seq, next_token, eos_tok, true).await?;
425 }
426
427 Ok(())
428}
429
430#[allow(clippy::too_many_arguments)]
432pub async fn sample_sequence(
433 logits: Tensor,
434 seq: &mut Sequence,
435 return_logprobs: bool,
436 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
437 use_async_pool: bool,
438 sample_speculative: bool,
439 multiple_sequences: bool,
440) -> Result<Logprobs> {
441 let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
442
443 let sampler = seq.sampler();
444 let ctx_clone = seq.get_toks().to_vec();
445 let rng_clone = rng.clone();
446 let logits_clone = logits.clone();
447 let first_lobprobs_response = if use_async_pool {
448 tokio_rayon::spawn(move || {
449 sampler.sample(
450 logits_clone,
451 &ctx_clone,
452 return_logprobs,
453 rng_clone,
454 sample_speculative,
455 multiple_sequences,
456 )
457 })
458 .await?
459 } else {
460 sampler.sample(
461 logits_clone,
462 &ctx_clone,
463 return_logprobs,
464 rng_clone,
465 sample_speculative,
466 multiple_sequences,
467 )?
468 };
469
470 let bias_if_not_allowed = match &mut seq.recognizer {
471 SequenceRecognizer::Llguidance(ref mut llg) => {
472 if !llg.is_stopped()
473 && llg
474 .validate_tokens(&[first_lobprobs_response.token])
475 .unwrap_or(0)
476 == 1
477 {
478 None
479 } else {
480 let mask = llg.compute_mask_or_eos().map_err(candle_core::Error::msg)?;
481 if mask.is_allowed(first_lobprobs_response.token) {
482 None
484 } else {
485 let mut acc = vec![-f32::INFINITY; logits.shape().dims1().unwrap()];
486 mask.iter_set_entries(|idx| {
487 if idx < acc.len() {
488 acc[idx] = 0.0;
489 }
490 });
491
492 Some(acc)
493 }
494 }
495 }
496 SequenceRecognizer::None => None,
497 };
498 let second_logprobs_response = match bias_if_not_allowed {
499 Some(acc) => {
500 let new_logits = (&logits + Tensor::from_slice(&acc, acc.len(), logits.device())?)?;
501
502 let ctx_clone = seq.get_toks().to_vec();
503 let rng_clone = rng.clone();
504 let sampler = seq.sampler();
505 if use_async_pool {
506 tokio_rayon::spawn(move || {
507 sampler.sample(
508 new_logits,
509 &ctx_clone,
510 return_logprobs,
511 rng_clone,
512 sample_speculative,
513 multiple_sequences,
514 )
515 })
516 .await?
517 } else {
518 sampler.sample(
519 new_logits,
520 &ctx_clone,
521 return_logprobs,
522 rng_clone,
523 sample_speculative,
524 multiple_sequences,
525 )?
526 }
527 }
528 None => first_lobprobs_response,
529 };
530
531 match seq.recognizer {
532 SequenceRecognizer::Llguidance(ref mut llg) => {
533 if !llg.is_stopped() {
534 llg.consume_token(second_logprobs_response.token)
535 .map_err(candle_core::Error::msg)?;
536 }
537 }
538 SequenceRecognizer::None => {}
539 }
540
541 Ok(second_logprobs_response)
542}
543
544#[derive(Clone)]
545pub struct SpeculativeSample {
546 pub sample: Logprobs,
547}
548
549pub async fn sample_target_sequence_speculative(
551 logits: Tensor,
552 seq: &mut Sequence,
553 return_logprobs: bool,
554 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
555 draft_samples: &[SpeculativeSample],
556) -> Result<Vec<SpeculativeSample>> {
557 let n_toks = draft_samples.len();
558
559 match &mut seq.recognizer {
561 SequenceRecognizer::Llguidance(ref mut llg) => {
562 llg.rollback(n_toks).map_err(candle_core::Error::msg)?;
563 }
564 SequenceRecognizer::None => {}
565 }
566
567 let mut sampled = Vec::new();
568 for (chunk, draft) in logits
569 .chunk(n_toks, 1)?
570 .into_iter()
571 .zip(draft_samples.iter())
572 {
573 let sample = sample_sequence(
574 chunk,
575 seq,
576 return_logprobs,
577 rng.clone(),
578 true, true,
580 false,
581 )
582 .await?;
583 let sampled_token = sample.token;
584 sampled.push(SpeculativeSample { sample });
585 if sampled_token != draft.sample.token {
586 break;
587 }
588 }
589 Ok(sampled)
590}