1use std::sync::Arc;
2
3use candle_core::{DType, Result, Tensor};
4use rand_isaac::Isaac64Rng;
5
6use crate::{
7 prefix_cacher::PrefixCacheManagerV2,
8 sampler::Logprobs,
9 sequence::{Sequence, SequenceRecognizer, SequenceState, StopReason},
10 tools::parse_text_tools,
11};
12
13use super::Pipeline;
14
15macro_rules! fixup_sentencepiece {
16 ($txt:expr) => {
17 $txt.to_string().replace("▁", " ")
18 };
19 (Option $txt:expr) => {
20 match &$txt {
21 Some(txt) => Some(fixup_sentencepiece!(txt)),
22 None => None,
23 }
24 };
25}
26
27pub(crate) async fn finish_or_add_toks_to_seq(
28 this: &dyn Pipeline,
29 prefix_cacher: &mut PrefixCacheManagerV2,
30 seq: &mut Sequence,
31 logprobs: Logprobs,
32 eos_tok: Option<&[u32]>,
33 use_prefix_cacher: bool,
34) -> Result<()> {
35 let mut is_done = seq.is_done(logprobs.token, eos_tok, this.get_metadata().max_seq_len);
36 seq.add_token(
37 logprobs.clone(),
38 this.get_metadata()
39 .tok_env()
40 .ok_or(candle_core::Error::Msg(
41 "`finish_or_add_toks_to_seq` requires the pipeline to have a token trie"
42 .to_string(),
43 ))?
44 .tok_trie()
45 .decode(&[logprobs.token]),
46 &is_done,
47 );
48
49 if let Some(ref t) = seq.tools {
52 if let Ok(Some(ref d)) = seq.peek_delta() {
53 let (_tool_use_still_possible, tool_use_is_done) =
54 t.prefix_could_be_tool(this, d.as_str())?;
55
56 if tool_use_is_done
57 && matches!(
58 parse_text_tools(this, d, seq.tools.clone()),
59 Ok((None, _tools))
60 )
61 {
62 seq.set_state(SequenceState::Done(StopReason::Eos));
63 is_done = Some(StopReason::Eos);
64 }
65 }
66 };
67
68 if seq.get_mut_group().is_streaming {
70 let mut tool_use_still_possible = false;
71 let mut tool_use_is_done = false;
72 if let Some(ref t) = seq.tools {
73 if let Ok(Some(ref d)) = seq.peek_delta() {
74 (tool_use_still_possible, tool_use_is_done) =
75 t.prefix_could_be_tool(this, d.as_str())?;
76 }
77 };
78
79 let send = true;
81 if !tool_use_still_possible || tool_use_is_done {
82 if send {
83 if let Some(delta) = crate::handle_seq_error_ok!(seq.get_delta(), seq.responder()) {
84 if seq.get_mut_group().is_chat {
85 let (text_new, tool_calls) =
86 parse_text_tools(this, delta.as_str(), seq.tools.clone())
87 .map_err(candle_core::Error::msg)?;
88 if !tool_calls.is_empty() {
89 is_done = Some(StopReason::ToolCalls);
90 }
91
92 seq.add_streaming_chunk_choice_to_group(crate::ChunkChoice {
93 delta: crate::Delta {
94 content: fixup_sentencepiece!(
95 Option text_new.map(ToString::to_string)
96 ),
97 role: "assistant".to_string(),
98 tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
99 },
100 index: seq.get_response_index(),
101 finish_reason: is_done.map(|x| x.to_string()),
102 logprobs: if seq.return_logprobs() {
103 Some(crate::ResponseLogprob {
104 token: delta,
105 bytes: logprobs.bytes.clone().map(|b| b.into_bytes()),
106 logprob: logprobs.logprob,
107 top_logprobs: logprobs.top_logprobs.unwrap().clone(),
108 })
109 } else {
110 None
111 },
112 });
113 } else {
114 seq.add_streaming_completion_chunk_choice_to_group(
115 crate::CompletionChunkChoice {
116 text: fixup_sentencepiece!(delta),
117 index: seq.get_response_index(),
118 finish_reason: is_done.map(|x| x.to_string()),
119 logprobs: if seq.return_logprobs() {
120 Some(crate::ResponseLogprob {
121 token: delta,
122 bytes: logprobs.bytes.clone().map(|b| b.into_bytes()),
123 logprob: logprobs.logprob,
124 top_logprobs: logprobs.top_logprobs.unwrap().clone(),
125 })
126 } else {
127 None
128 },
129 },
130 );
131 }
132 }
133 }
134
135 if let Some(reason) = is_done {
136 if use_prefix_cacher {
137 prefix_cacher.add_sequence(seq);
138 prefix_cacher.evict_caches()?;
139 }
140 seq.set_state(crate::sequence::SequenceState::Done(reason));
141 this.reset_non_granular_state();
142 }
143
144 let usage_opt = if is_done.is_some() {
146 let usage = seq.get_mut_group().get_usage();
147 seq.get_mut_group().total_prompt_toks = 0;
148 seq.get_mut_group().total_toks = 0;
149 Some(usage)
150 } else {
151 None
152 };
153
154 if seq
155 .get_mut_group()
156 .maybe_send_streaming_response(seq, this.name().clone(), usage_opt)
157 .await
158 .is_err()
159 {
160 seq.set_state(crate::sequence::SequenceState::Done(
162 crate::sequence::StopReason::Canceled,
163 ));
164 this.reset_non_granular_state();
165 }
166 }
167 } else if let Some(mut reason) = is_done {
168 {
174 seq.set_state(crate::sequence::SequenceState::Done(reason));
175 let (tokenizer, pipeline_name) = {
176 let pipeline_name = this.name();
177 let tokenizer = this.tokenizer();
178 (tokenizer, pipeline_name)
179 };
180
181 let logprobs = if seq.return_logprobs() {
182 let mut logprobs = Vec::new();
183 for logprob in seq.logprobs() {
184 let resp_logprob = crate::ResponseLogprob {
185 token: crate::handle_seq_error_ok!(
186 tokenizer
187 .as_ref()
188 .ok_or(candle_core::Error::Msg(
189 "`finish_or_add_toks_to_seq` requires the pipeline to have a tokenizer"
190 .to_string(),
191 ))?.decode(&[logprob.token], false),
192 seq.responder()
193 ),
194 bytes: logprob.bytes.clone().map(|b| b.into_bytes()),
195 logprob: logprob.logprob,
196 top_logprobs: logprob.top_logprobs.clone().unwrap(),
197 };
198 logprobs.push(resp_logprob);
199 }
200 Some(logprobs)
201 } else {
202 None
203 };
204
205 let text = match reason {
206 crate::sequence::StopReason::Length(_)
207 | crate::sequence::StopReason::ModelLength(_)
208 | crate::sequence::StopReason::Eos
209 | crate::sequence::StopReason::StopTok(_)
210 | crate::sequence::StopReason::Canceled
211 | crate::sequence::StopReason::ToolCalls => {
212 String::from_utf8_lossy(seq.completion_bytes())
213 .trim_start()
214 .to_string()
215 }
216 crate::sequence::StopReason::StopString {
217 completion_bytes_pos,
218 ..
219 } => {
220 let txt = String::from_utf8_lossy(seq.completion_bytes());
221 txt[..completion_bytes_pos].trim_start().to_string()
222 }
223 crate::sequence::StopReason::GeneratedImage
224 | crate::sequence::StopReason::GeneratedSpeech => {
225 candle_core::bail!("Stop reason was `GeneratedImage`.")
226 }
227 };
228
229 if seq.get_mut_group().is_chat {
230 let (text_new, tool_calls) =
231 parse_text_tools(this, text.as_str(), seq.tools.clone())
232 .map_err(candle_core::Error::msg)?;
233 if !tool_calls.is_empty() {
234 reason = StopReason::ToolCalls;
235 }
236
237 let choice = crate::Choice {
238 finish_reason: fixup_sentencepiece!(reason),
239 index: seq.get_response_index(),
240 message: crate::ResponseMessage {
241 content: text_new.map(ToString::to_string),
242 role: "assistant".to_string(),
243 tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
244 },
245 logprobs: logprobs.map(|l| crate::Logprobs { content: Some(l) }),
246 };
247 seq.add_choice_to_group(choice);
248 } else {
249 let choice = crate::CompletionChoice {
250 finish_reason: fixup_sentencepiece!(reason),
251 index: seq.get_response_index(),
252 text,
253 logprobs: None,
254 };
255 seq.add_completion_choice_to_group(choice);
256 }
257
258 if use_prefix_cacher {
259 prefix_cacher.add_sequence(seq);
260 prefix_cacher.evict_caches()?;
261 }
262
263 let group = seq.get_mut_group();
264 if group.is_chat {
265 group
266 .maybe_send_chat_done_response(
267 crate::ChatCompletionResponse {
268 id: seq.id().to_string(),
269 choices: group.get_choices().to_vec(),
270 created: seq.creation_time(),
271 model: pipeline_name,
272 system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
273 object: "chat.completion".to_string(),
274 usage: group.get_usage(),
275 },
276 seq.responder(),
277 )
278 .await
279 .map_err(candle_core::Error::msg)?;
280 } else {
281 group
282 .maybe_send_completion_done_response(
283 crate::CompletionResponse {
284 id: seq.id().to_string(),
285 choices: group.get_completion_choices().to_vec(),
286 created: seq.creation_time(),
287 model: pipeline_name,
288 system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
289 object: "text_completion".to_string(),
290 usage: group.get_usage(),
291 },
292 seq.responder(),
293 )
294 .await
295 .map_err(candle_core::Error::msg)?;
296 }
297 }
298 this.reset_non_granular_state();
299 }
300
301 Ok(())
302}
303
304pub async fn sample_and_add_toks(
305 this: &dyn Pipeline,
306 seqs: &mut [&mut Sequence],
307 logits_seq: Vec<Tensor>,
308 prefix_cacher: &mut PrefixCacheManagerV2,
309 disable_eos_stop: bool,
310 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
311) -> Result<()> {
312 let seqs_len = seqs.len();
313 debug_assert_eq!(logits_seq.len(), seqs_len);
314
315 let use_async_pool = seqs_len > 1;
316
317 let sampling_futures: Vec<_> = std::iter::zip(logits_seq, seqs.iter_mut())
318 .map(|(logits_per_seq, seq)| {
319 let return_logprobs = seq.return_logprobs();
320 sample_sequence(
321 logits_per_seq,
322 seq,
323 return_logprobs,
324 rng.clone(),
325 use_async_pool,
326 false,
327 use_async_pool,
328 )
329 })
330 .collect();
331 let sampled_vec = futures::future::join_all(sampling_futures).await;
332
333 for (sampled, seq) in std::iter::zip(sampled_vec, seqs.iter_mut()) {
334 let next_token = crate::handle_seq_error_stateaware_ok!(sampled, seq);
335
336 let metadata = this.get_metadata();
337 let eos_tok = if disable_eos_stop {
338 None
339 } else {
340 Some(&metadata.eos_tok[..])
341 };
342
343 finish_or_add_toks_to_seq(this, prefix_cacher, seq, next_token, eos_tok, true).await?;
344 }
345
346 Ok(())
347}
348
349#[allow(clippy::too_many_arguments)]
351pub async fn sample_sequence(
352 logits: Tensor,
353 seq: &mut Sequence,
354 return_logprobs: bool,
355 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
356 use_async_pool: bool,
357 sample_speculative: bool,
358 multiple_sequences: bool,
359) -> Result<Logprobs> {
360 let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
361
362 let sampler = seq.sampler();
363 let ctx_clone = seq.get_toks().to_vec();
364 let rng_clone = rng.clone();
365 let logits_clone = logits.clone();
366 let first_lobprobs_response = if use_async_pool {
367 tokio_rayon::spawn(move || {
368 sampler.sample(
369 logits_clone,
370 &ctx_clone,
371 return_logprobs,
372 rng_clone,
373 sample_speculative,
374 multiple_sequences,
375 )
376 })
377 .await?
378 } else {
379 sampler.sample(
380 logits_clone,
381 &ctx_clone,
382 return_logprobs,
383 rng_clone,
384 sample_speculative,
385 multiple_sequences,
386 )?
387 };
388
389 let bias_if_not_allowed = match &mut seq.recognizer {
390 SequenceRecognizer::Llguidance(ref mut llg) => {
391 if !llg.is_stopped()
392 && llg
393 .validate_tokens(&[first_lobprobs_response.token])
394 .unwrap_or(0)
395 == 1
396 {
397 None
398 } else {
399 let mask = llg.compute_mask_or_eos().map_err(candle_core::Error::msg)?;
400 if mask.is_allowed(first_lobprobs_response.token) {
401 None
403 } else {
404 let mut acc = vec![-f32::INFINITY; logits.shape().dims1().unwrap()];
405 mask.iter_set_entries(|idx| {
406 if idx < acc.len() {
407 acc[idx] = 0.0;
408 }
409 });
410
411 Some(acc)
412 }
413 }
414 }
415 SequenceRecognizer::None => None,
416 };
417 let second_logprobs_response = match bias_if_not_allowed {
418 Some(acc) => {
419 let new_logits = (&logits + Tensor::from_slice(&acc, acc.len(), logits.device())?)?;
420
421 let ctx_clone = seq.get_toks().to_vec();
422 let rng_clone = rng.clone();
423 let sampler = seq.sampler();
424 if use_async_pool {
425 tokio_rayon::spawn(move || {
426 sampler.sample(
427 new_logits,
428 &ctx_clone,
429 return_logprobs,
430 rng_clone,
431 sample_speculative,
432 multiple_sequences,
433 )
434 })
435 .await?
436 } else {
437 sampler.sample(
438 new_logits,
439 &ctx_clone,
440 return_logprobs,
441 rng_clone,
442 sample_speculative,
443 multiple_sequences,
444 )?
445 }
446 }
447 None => first_lobprobs_response,
448 };
449
450 match seq.recognizer {
451 SequenceRecognizer::Llguidance(ref mut llg) => {
452 if !llg.is_stopped() {
453 llg.consume_token(second_logprobs_response.token)
454 .map_err(candle_core::Error::msg)?;
455 }
456 }
457 SequenceRecognizer::None => {}
458 }
459
460 Ok(second_logprobs_response)
461}
462
463#[derive(Clone)]
464pub struct SpeculativeSample {
465 pub sample: Logprobs,
466}
467
468pub async fn sample_target_sequence_speculative(
470 logits: Tensor,
471 seq: &mut Sequence,
472 return_logprobs: bool,
473 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
474 draft_samples: &[SpeculativeSample],
475) -> Result<Vec<SpeculativeSample>> {
476 let n_toks = draft_samples.len();
477
478 match &mut seq.recognizer {
480 SequenceRecognizer::Llguidance(ref mut llg) => {
481 llg.rollback(n_toks).map_err(candle_core::Error::msg)?;
482 }
483 SequenceRecognizer::None => {}
484 }
485
486 let mut sampled = Vec::new();
487 for (chunk, draft) in logits
488 .chunk(n_toks, 1)?
489 .into_iter()
490 .zip(draft_samples.iter())
491 {
492 let sample = sample_sequence(
493 chunk,
494 seq,
495 return_logprobs,
496 rng.clone(),
497 true, true,
499 false,
500 )
501 .await?;
502 let sampled_token = sample.token;
503 sampled.push(SpeculativeSample { sample });
504 if sampled_token != draft.sample.token {
505 break;
506 }
507 }
508 Ok(sampled)
509}