1use std::sync::Arc;
2
3use candle_core::{DType, Device, Result, Tensor};
4use rand_isaac::Isaac64Rng;
5
6use crate::{
7 prefix_cacher::PrefixCacheManagerV2,
8 sampler::Logprobs,
9 sequence::{Sequence, SequenceRecognizer, SequenceState, StopReason},
10 tools::parse_text_tools,
11};
12
13use super::Pipeline;
14
15macro_rules! fixup_sentencepiece {
16 ($txt:expr) => {
17 $txt.to_string().replace("▁", " ")
18 };
19 (Option $txt:expr) => {
20 match &$txt {
21 Some(txt) => Some(fixup_sentencepiece!(txt)),
22 None => None,
23 }
24 };
25}
26
27pub(crate) async fn finish_or_add_toks_to_seq(
28 this: &dyn Pipeline,
29 prefix_cacher: &mut PrefixCacheManagerV2,
30 seq: &mut Sequence,
31 logprobs: Logprobs,
32 eos_tok: Option<&[u32]>,
33 use_prefix_cacher: bool,
34) -> Result<()> {
35 let mut is_done = seq.is_done(logprobs.token, eos_tok, this.get_metadata().max_seq_len);
36 seq.add_token(
37 logprobs.clone(),
38 this.get_metadata()
39 .tok_env
40 .as_ref()
41 .ok_or(candle_core::Error::Msg(
42 "`finish_or_add_toks_to_seq` requires the pipeline to have a token trie"
43 .to_string(),
44 ))?
45 .tok_trie()
46 .decode(&[logprobs.token]),
47 &is_done,
48 );
49
50 if let Some(ref t) = seq.tools {
53 if let Ok(Some(ref d)) = seq.peek_delta() {
54 let (_tool_use_still_possible, tool_use_is_done) =
55 t.prefix_could_be_tool(this, d.as_str())?;
56
57 if tool_use_is_done
58 && matches!(
59 parse_text_tools(this, d, seq.tools.clone()),
60 Ok((None, _tools))
61 )
62 {
63 seq.set_state(SequenceState::Done(StopReason::Eos));
64 is_done = Some(StopReason::Eos);
65 }
66 }
67 };
68
69 if seq.get_mut_group().is_streaming {
71 let mut tool_use_still_possible = false;
72 let mut tool_use_is_done = false;
73 if let Some(ref t) = seq.tools {
74 if let Ok(Some(ref d)) = seq.peek_delta() {
75 (tool_use_still_possible, tool_use_is_done) =
76 t.prefix_could_be_tool(this, d.as_str())?;
77 }
78 };
79
80 if !tool_use_still_possible || tool_use_is_done {
81 if let Some(delta) = crate::handle_seq_error_ok!(seq.get_delta(), seq.responder()) {
82 if seq.get_mut_group().is_chat {
83 let (text_new, tool_calls) =
84 parse_text_tools(this, delta.as_str(), seq.tools.clone())
85 .map_err(candle_core::Error::msg)?;
86
87 if !tool_calls.is_empty() && is_done.is_none() {
88 is_done = Some(StopReason::Eos);
89 };
90 seq.add_streaming_chunk_choice_to_group(crate::ChunkChoice {
91 delta: crate::Delta {
92 content: fixup_sentencepiece!(
93 Option text_new.map(ToString::to_string)
94 ),
95 role: "assistant".to_string(),
96 tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
97 },
98 index: seq.get_response_index(),
99 finish_reason: is_done.map(|x| x.to_string()),
100 logprobs: if seq.return_logprobs() {
101 Some(crate::ResponseLogprob {
102 token: delta,
103 bytes: logprobs.bytes.clone().map(|b| b.into_bytes()),
104 logprob: logprobs.logprob,
105 top_logprobs: logprobs.top_logprobs.unwrap().clone(),
106 })
107 } else {
108 None
109 },
110 });
111 } else {
112 seq.add_streaming_completion_chunk_choice_to_group(
113 crate::CompletionChunkChoice {
114 text: fixup_sentencepiece!(delta),
115 index: seq.get_response_index(),
116 finish_reason: is_done.map(|x| x.to_string()),
117 logprobs: if seq.return_logprobs() {
118 Some(crate::ResponseLogprob {
119 token: delta,
120 bytes: logprobs.bytes.clone().map(|b| b.into_bytes()),
121 logprob: logprobs.logprob,
122 top_logprobs: logprobs.top_logprobs.unwrap().clone(),
123 })
124 } else {
125 None
126 },
127 },
128 );
129 }
130
131 if let Some(reason) = is_done {
132 if use_prefix_cacher {
133 prefix_cacher.add_sequence(seq);
134 prefix_cacher.evict_to_cpu()?;
135 }
136 seq.set_state(crate::sequence::SequenceState::Done(reason));
137 this.reset_non_granular_state();
138 }
139
140 let usage_opt = if is_done.is_some() {
142 let usage = seq.get_mut_group().get_usage();
143 seq.get_mut_group().total_prompt_toks = 0;
144 seq.get_mut_group().total_toks = 0;
145 Some(usage)
146 } else {
147 None
148 };
149
150 if seq
151 .get_mut_group()
152 .maybe_send_streaming_response(seq, this.name().clone(), usage_opt)
153 .await
154 .is_err()
155 {
156 seq.set_state(crate::sequence::SequenceState::Done(
158 crate::sequence::StopReason::Canceled,
159 ));
160 this.reset_non_granular_state();
161 }
162 }
163 }
164 } else if let Some(reason) = is_done {
165 {
171 seq.set_state(crate::sequence::SequenceState::Done(reason));
172 let (tokenizer, pipeline_name) = {
173 let pipeline_name = this.name();
174 let tokenizer = this.tokenizer();
175 (tokenizer, pipeline_name)
176 };
177
178 let logprobs = if seq.return_logprobs() {
179 let mut logprobs = Vec::new();
180 for logprob in seq.logprobs() {
181 let resp_logprob = crate::ResponseLogprob {
182 token: crate::handle_seq_error_ok!(
183 tokenizer
184 .as_ref()
185 .ok_or(candle_core::Error::Msg(
186 "`finish_or_add_toks_to_seq` requires the pipeline to have a tokenizer"
187 .to_string(),
188 ))?.decode(&[logprob.token], false),
189 seq.responder()
190 ),
191 bytes: logprob.bytes.clone().map(|b| b.into_bytes()),
192 logprob: logprob.logprob,
193 top_logprobs: logprob.top_logprobs.clone().unwrap(),
194 };
195 logprobs.push(resp_logprob);
196 }
197 Some(logprobs)
198 } else {
199 None
200 };
201
202 let text = match reason {
203 crate::sequence::StopReason::Length(_)
204 | crate::sequence::StopReason::ModelLength(_)
205 | crate::sequence::StopReason::Eos
206 | crate::sequence::StopReason::StopTok(_)
207 | crate::sequence::StopReason::Canceled => {
208 String::from_utf8_lossy(seq.completion_bytes())
209 .trim_start()
210 .to_string()
211 }
212 crate::sequence::StopReason::StopString {
213 completion_bytes_pos,
214 ..
215 } => {
216 let txt = String::from_utf8_lossy(seq.completion_bytes());
217 txt[..completion_bytes_pos].trim_start().to_string()
218 }
219 crate::sequence::StopReason::GeneratedImage => {
220 candle_core::bail!("Stop reason was `GeneratedImage`.")
221 }
222 };
223
224 if seq.get_mut_group().is_chat {
225 let (text_new, tool_calls) =
226 parse_text_tools(this, text.as_str(), seq.tools.clone())
227 .map_err(candle_core::Error::msg)?;
228 let choice = crate::Choice {
229 finish_reason: fixup_sentencepiece!(reason),
230 index: seq.get_response_index(),
231 message: crate::ResponseMessage {
232 content: text_new.map(ToString::to_string),
233 role: "assistant".to_string(),
234 tool_calls: Some(tool_calls).filter(|v| !v.is_empty()),
235 },
236 logprobs: logprobs.map(|l| crate::Logprobs { content: Some(l) }),
237 };
238 seq.add_choice_to_group(choice);
239 } else {
240 let choice = crate::CompletionChoice {
241 finish_reason: fixup_sentencepiece!(reason),
242 index: seq.get_response_index(),
243 text,
244 logprobs: None,
245 };
246 seq.add_completion_choice_to_group(choice);
247 }
248
249 if use_prefix_cacher {
250 prefix_cacher.add_sequence(seq);
251 prefix_cacher.evict_to_cpu()?;
252 }
253
254 let group = seq.get_mut_group();
255 if group.is_chat {
256 group
257 .maybe_send_chat_done_response(
258 crate::ChatCompletionResponse {
259 id: seq.id().to_string(),
260 choices: group.get_choices().to_vec(),
261 created: seq.creation_time(),
262 model: pipeline_name,
263 system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
264 object: "chat.completion".to_string(),
265 usage: group.get_usage(),
266 },
267 seq.responder(),
268 )
269 .await
270 .map_err(candle_core::Error::msg)?;
271 } else {
272 group
273 .maybe_send_completion_done_response(
274 crate::CompletionResponse {
275 id: seq.id().to_string(),
276 choices: group.get_completion_choices().to_vec(),
277 created: seq.creation_time(),
278 model: pipeline_name,
279 system_fingerprint: crate::SYSTEM_FINGERPRINT.to_string(),
280 object: "text_completion".to_string(),
281 usage: group.get_usage(),
282 },
283 seq.responder(),
284 )
285 .await
286 .map_err(candle_core::Error::msg)?;
287 }
288 }
289 this.reset_non_granular_state();
290 }
291
292 Ok(())
293}
294
295pub async fn sample_and_add_toks(
296 this: &dyn Pipeline,
297 seqs: &mut [&mut Sequence],
298 logits_seq: Vec<Tensor>,
299 prefix_cacher: &mut PrefixCacheManagerV2,
300 disable_eos_stop: bool,
301 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
302) -> Result<()> {
303 let seqs_len = seqs.len();
304 debug_assert_eq!(logits_seq.len(), seqs_len);
305
306 let use_async_pool = seqs_len > 1;
307
308 let sampling_futures: Vec<_> = std::iter::zip(logits_seq, seqs.iter_mut())
309 .map(|(logits_per_seq, seq)| {
310 let return_logprobs = seq.return_logprobs();
311 sample_sequence(
312 logits_per_seq,
313 seq,
314 return_logprobs,
315 rng.clone(),
316 use_async_pool,
317 true, false,
319 )
320 })
321 .collect();
322 let sampled_vec = futures::future::join_all(sampling_futures).await;
323
324 for (sampled, seq) in std::iter::zip(sampled_vec, seqs.iter_mut()) {
325 let next_token = crate::handle_seq_error_stateaware_ok!(sampled, seq);
326
327 let metadata = this.get_metadata();
328 let eos_tok = if disable_eos_stop {
329 None
330 } else {
331 Some(&metadata.eos_tok[..])
332 };
333
334 finish_or_add_toks_to_seq(this, prefix_cacher, seq, next_token, eos_tok, true).await?;
335 }
336
337 Ok(())
338}
339
340#[allow(clippy::too_many_arguments)]
342pub async fn sample_sequence(
343 logits: Tensor,
344 seq: &mut Sequence,
345 return_logprobs: bool,
346 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
347 use_async_pool: bool,
348 add_to_trie: bool,
349 sample_speculative: bool,
350) -> Result<Logprobs> {
351 let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
352
353 let sampler = seq.sampler();
354 let ctx_clone = seq.get_toks().to_vec();
355 let rng_clone = rng.clone();
356 let logits_clone = logits.clone();
357 let first_lobprobs_response = if use_async_pool {
358 tokio_rayon::spawn(move || {
359 sampler.sample(
360 logits_clone,
361 &ctx_clone,
362 return_logprobs,
363 rng_clone,
364 sample_speculative,
365 )
366 })
367 .await?
368 } else {
369 sampler.sample(
370 logits_clone,
371 &ctx_clone,
372 return_logprobs,
373 rng_clone,
374 sample_speculative,
375 )?
376 };
377
378 let bias_if_not_allowed = match &mut seq.recognizer {
379 SequenceRecognizer::Llguidance(ref mut llg) => {
380 let step_res = llg.compute_mask().map_err(candle_core::Error::msg)?;
381 if let Some(mask) = &step_res.sample_mask {
382 if mask.is_allowed(first_lobprobs_response.token) {
383 None
384 } else {
385 let mut acc = vec![-f32::INFINITY; logits.shape().dims1().unwrap()];
386 mask.iter_set_entries(|idx| {
387 if idx < acc.len() {
388 acc[idx] = 0.0;
389 }
390 });
391
392 Some(acc)
393 }
394 } else if step_res.is_stop() {
395 let mut acc = vec![-f32::INFINITY; logits.shape().dims1().unwrap()];
396 for eos_tok in seq.eos_tokens() {
397 acc[*eos_tok as usize] = 0.0;
398 }
399 Some(acc)
400 } else {
401 None
402 }
403 }
404 SequenceRecognizer::None => None,
405 };
406 let second_logprobs_response = match bias_if_not_allowed {
407 Some(acc) => {
408 let new_logits = (logits + Tensor::from_slice(&acc, acc.len(), &Device::Cpu)?)?;
409
410 let ctx_clone = seq.get_toks().to_vec();
411 let rng_clone = rng.clone();
412 let sampler = seq.sampler();
413 if use_async_pool {
414 tokio_rayon::spawn(move || {
415 sampler.sample(
416 new_logits,
417 &ctx_clone,
418 return_logprobs,
419 rng_clone,
420 sample_speculative,
421 )
422 })
423 .await?
424 } else {
425 sampler.sample(
426 new_logits,
427 &ctx_clone,
428 return_logprobs,
429 rng_clone,
430 sample_speculative,
431 )?
432 }
433 }
434 None => first_lobprobs_response,
435 };
436
437 if add_to_trie {
438 match seq.recognizer {
439 SequenceRecognizer::Llguidance(ref mut llg) => {
440 llg.commit_token(Some(second_logprobs_response.token))
441 .map_err(candle_core::Error::msg)?;
442 }
443 SequenceRecognizer::None => {}
444 }
445 }
446 Ok(second_logprobs_response)
447}
448
449#[derive(Clone)]
450pub struct SpeculativeSample {
451 pub sample: Logprobs,
452}
453
454pub async fn sample_target_sequence_speculative(
456 logits: Tensor,
457 seq: &mut Sequence,
458 return_logprobs: bool,
459 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
460 n_toks: usize,
461) -> Result<Vec<SpeculativeSample>> {
462 let mut sampled = Vec::new();
463 for chunk in logits.chunk(n_toks, 1)? {
464 sampled.push(SpeculativeSample {
465 sample: sample_sequence(
466 chunk,
467 seq,
468 return_logprobs,
469 rng.clone(),
470 true, false, true,
473 )
474 .await?,
475 });
476 }
477 Ok(sampled)
478}