1use openai_harmony::{
16 chat::Role, load_harmony_encoding, HarmonyEncoding, HarmonyEncodingName, StreamableParser,
17};
18use std::sync::OnceLock;
19use uuid::Uuid;
20
21fn extract_tool_name(recipient: &str) -> String {
26 if let Some(name) = recipient.strip_prefix("functions.") {
27 name.to_string()
28 } else {
29 recipient.to_string()
31 }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum HarmonyChannel {
37 Analysis,
39 Commentary,
41 Final,
43}
44
45impl HarmonyChannel {
46 pub fn parse(s: &str) -> Option<Self> {
48 match s {
49 "analysis" => Some(Self::Analysis),
50 "commentary" => Some(Self::Commentary),
51 "final" => Some(Self::Final),
52 _ => None,
53 }
54 }
55}
56
57#[derive(Debug, Clone, Default)]
59pub struct HarmonyDelta {
60 pub analysis_delta: Option<String>,
62 pub commentary_delta: Option<String>,
64 pub final_delta: Option<String>,
66 pub current_channel: Option<HarmonyChannel>,
68}
69
70impl HarmonyDelta {
71 pub fn has_content(&self) -> bool {
73 self.analysis_delta.is_some()
74 || self.commentary_delta.is_some()
75 || self.final_delta.is_some()
76 }
77
78 pub fn reasoning_content(&self) -> Option<String> {
80 match (&self.analysis_delta, &self.commentary_delta) {
81 (Some(a), Some(c)) => Some(format!("{}{}", a, c)),
82 (Some(a), None) => Some(a.clone()),
83 (None, Some(c)) => Some(c.clone()),
84 (None, None) => None,
85 }
86 }
87}
88
89#[derive(Debug, Clone, Default)]
91pub struct HarmonyAccumulated {
92 pub analysis: String,
94 pub commentary: String,
96 pub final_content: String,
98}
99
100#[derive(Debug, Clone)]
102pub struct HarmonyToolCall {
103 pub id: String,
105 pub name: String,
107 pub arguments: String,
109}
110
111impl HarmonyAccumulated {
112 pub fn reasoning_content(&self) -> Option<String> {
114 let combined = format!("{}{}", self.analysis, self.commentary);
115 if combined.is_empty() {
116 None
117 } else {
118 Some(combined)
119 }
120 }
121}
122
123pub struct HarmonyContext {
128 parser: StreamableParser,
129 last_analysis_len: usize,
131 last_commentary_len: usize,
132 last_final_len: usize,
133 accumulated: HarmonyAccumulated,
135 channel: Option<HarmonyChannel>,
137 sent_reasoning_len: usize,
139 sent_final_len: usize,
140 tool_calls: Vec<HarmonyToolCall>,
142 current_tool_call: Option<(String, String)>,
144 sent_tool_args_len: usize,
146}
147
148impl HarmonyContext {
149 pub fn new() -> Result<Self, anyhow::Error> {
151 let encoding = get_harmony_encoding().clone();
152 let parser = StreamableParser::new(encoding, Some(Role::Assistant))
153 .map_err(|e| anyhow::anyhow!("Failed to create Harmony parser: {:?}", e))?;
154 Ok(Self {
155 parser,
156 last_analysis_len: 0,
157 last_commentary_len: 0,
158 last_final_len: 0,
159 accumulated: HarmonyAccumulated::default(),
160 channel: None,
161 sent_reasoning_len: 0,
162 sent_final_len: 0,
163 tool_calls: Vec::new(),
164 current_tool_call: None,
165 sent_tool_args_len: 0,
166 })
167 }
168
169 pub fn process_token(&mut self, token_id: u32) -> HarmonyDelta {
171 let _ = self.parser.process(token_id);
173 self.extract_delta()
174 }
175
176 fn extract_delta(&mut self) -> HarmonyDelta {
178 let mut delta = HarmonyDelta::default();
179
180 if let Some(channel_str) = self.parser.current_channel() {
182 if let Some(channel) = HarmonyChannel::parse(&channel_str) {
183 self.channel = Some(channel);
184 delta.current_channel = Some(channel);
185 }
186 }
187
188 let current_recipient = self.parser.current_recipient();
191
192 if let Ok(content) = self.parser.current_content() {
195 if let Some(ref recipient) = current_recipient {
201 let is_tool_call = recipient.starts_with("functions.")
202 || recipient.starts_with("browser.")
203 || recipient == "python";
204
205 if is_tool_call {
206 let is_same_tool_call = self
209 .current_tool_call
210 .as_ref()
211 .is_some_and(|(existing, _)| existing == recipient);
212
213 if is_same_tool_call {
214 if let Some((_, ref mut args)) = self.current_tool_call {
216 *args = content.clone();
217 }
218 } else {
219 if let Some((prev_recipient, prev_args)) = self.current_tool_call.take() {
222 let prev_name = extract_tool_name(&prev_recipient);
223 self.tool_calls.push(HarmonyToolCall {
224 id: format!("call_{}", Uuid::new_v4()),
225 name: prev_name,
226 arguments: prev_args,
227 });
228 }
229 self.current_tool_call = Some((recipient.clone(), content.clone()));
231 self.sent_tool_args_len = 0;
232 }
233 return delta;
235 }
236 }
237
238 match self.channel {
240 Some(HarmonyChannel::Analysis) => {
241 if content.len() > self.last_analysis_len {
242 let new_content = content[self.last_analysis_len..].to_string();
243 self.accumulated.analysis.push_str(&new_content);
244 delta.analysis_delta = Some(new_content);
245 self.last_analysis_len = content.len();
246 }
247 }
248 Some(HarmonyChannel::Commentary) => {
249 if content.len() > self.last_commentary_len {
250 let new_content = content[self.last_commentary_len..].to_string();
251 self.accumulated.commentary.push_str(&new_content);
252 delta.commentary_delta = Some(new_content);
253 self.last_commentary_len = content.len();
254 }
255 }
256 Some(HarmonyChannel::Final) | None => {
257 if content.len() > self.last_final_len {
261 let new_content = content[self.last_final_len..].to_string();
262 self.accumulated.final_content.push_str(&new_content);
263 delta.final_delta = Some(new_content);
264 self.last_final_len = content.len();
265 }
266 }
267 }
268 }
269
270 delta
271 }
272
273 pub fn current_channel(&self) -> Option<HarmonyChannel> {
275 self.channel
276 }
277
278 pub fn accumulated(&self) -> &HarmonyAccumulated {
280 &self.accumulated
281 }
282
283 pub fn reasoning_content(&self) -> Option<String> {
285 self.accumulated.reasoning_content()
286 }
287
288 pub fn final_content(&self) -> Option<String> {
290 if self.accumulated.final_content.is_empty() {
291 None
292 } else {
293 Some(self.accumulated.final_content.clone())
294 }
295 }
296
297 pub fn get_reasoning_delta(&mut self) -> Option<String> {
300 let reasoning = format!(
301 "{}{}",
302 self.accumulated.analysis, self.accumulated.commentary
303 );
304 if reasoning.len() > self.sent_reasoning_len {
305 let delta = reasoning[self.sent_reasoning_len..].to_string();
306 self.sent_reasoning_len = reasoning.len();
307 if delta.is_empty() {
308 None
309 } else {
310 Some(delta)
311 }
312 } else {
313 None
314 }
315 }
316
317 pub fn get_final_delta(&mut self) -> Option<String> {
320 if self.accumulated.final_content.len() > self.sent_final_len {
321 let delta = self.accumulated.final_content[self.sent_final_len..].to_string();
322 self.sent_final_len = self.accumulated.final_content.len();
323 if delta.is_empty() {
324 None
325 } else {
326 Some(delta)
327 }
328 } else {
329 None
330 }
331 }
332
333 pub fn process_eos(&mut self) {
335 let _ = self.parser.process_eos();
336
337 if let Some((recipient, args)) = self.current_tool_call.take() {
339 let name = extract_tool_name(&recipient);
340 self.tool_calls.push(HarmonyToolCall {
341 id: format!("call_{}", Uuid::new_v4()),
342 name,
343 arguments: args,
344 });
345 }
346 }
347
348 pub fn current_recipient(&self) -> Option<String> {
350 self.parser.current_recipient()
351 }
352
353 pub fn has_tool_call(&self) -> bool {
355 self.current_tool_call.is_some() || !self.tool_calls.is_empty()
356 }
357
358 pub fn get_tool_calls(&self) -> &[HarmonyToolCall] {
360 &self.tool_calls
361 }
362
363 pub fn get_current_tool_call(&self) -> Option<(&str, &str)> {
366 self.current_tool_call
367 .as_ref()
368 .map(|(recipient, args)| (recipient.as_str(), args.as_str()))
369 }
370
371 pub fn finalize_tool_calls(&mut self) -> Vec<HarmonyToolCall> {
376 if let Some((recipient, args)) = self.current_tool_call.take() {
378 let name = extract_tool_name(&recipient);
379 self.tool_calls.push(HarmonyToolCall {
380 id: format!("call_{}", Uuid::new_v4()),
381 name,
382 arguments: args,
383 });
384 }
385 std::mem::take(&mut self.tool_calls)
387 }
388}
389
390static HARMONY_ENCODING: OnceLock<HarmonyEncoding> = OnceLock::new();
392
393pub fn prewarm_harmony_encoding() {
398 let _ = HARMONY_ENCODING.get_or_init(|| {
399 load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss)
400 .expect("Failed to load Harmony encoding")
401 });
402}
403
404pub fn is_harmony_encoding_ready() -> bool {
406 HARMONY_ENCODING.get().is_some()
407}
408
409fn get_harmony_encoding() -> &'static HarmonyEncoding {
410 HARMONY_ENCODING
411 .get()
412 .expect("Harmony encoding not initialized. Call prewarm_harmony_encoding() first.")
413}
414
415pub fn is_harmony_template(template: &str) -> bool {
420 if template.contains("<|channel|>") {
422 return true;
423 }
424
425 template.contains("<|start|>")
427 && template.contains("<|message|>")
428 && template.contains("<|end|>")
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_is_harmony_template() {
437 assert!(is_harmony_template(
439 "<|start|>system<|message|>content<|end|>"
440 ));
441 assert!(is_harmony_template(
442 "some prefix <|channel|>analysis<|message|>thinking"
443 ));
444
445 assert!(!is_harmony_template("<|im_start|>system<|im_end|>"));
447 assert!(!is_harmony_template("regular chat template"));
448 }
449
450 #[test]
451 fn test_harmony_channel_from_str() {
452 assert_eq!(
453 HarmonyChannel::parse("analysis"),
454 Some(HarmonyChannel::Analysis)
455 );
456 assert_eq!(
457 HarmonyChannel::parse("commentary"),
458 Some(HarmonyChannel::Commentary)
459 );
460 assert_eq!(HarmonyChannel::parse("final"), Some(HarmonyChannel::Final));
461 assert_eq!(HarmonyChannel::parse("unknown"), None);
462 }
463
464 #[test]
465 fn test_harmony_delta_has_content() {
466 let empty = HarmonyDelta::default();
467 assert!(!empty.has_content());
468
469 let with_analysis = HarmonyDelta {
470 analysis_delta: Some("thinking".to_string()),
471 ..Default::default()
472 };
473 assert!(with_analysis.has_content());
474
475 let with_final = HarmonyDelta {
476 final_delta: Some("response".to_string()),
477 ..Default::default()
478 };
479 assert!(with_final.has_content());
480 }
481
482 #[test]
483 fn test_harmony_delta_reasoning_content() {
484 let both = HarmonyDelta {
485 analysis_delta: Some("thinking ".to_string()),
486 commentary_delta: Some("about tools".to_string()),
487 ..Default::default()
488 };
489 assert_eq!(
490 both.reasoning_content(),
491 Some("thinking about tools".to_string())
492 );
493
494 let only_analysis = HarmonyDelta {
495 analysis_delta: Some("just thinking".to_string()),
496 ..Default::default()
497 };
498 assert_eq!(
499 only_analysis.reasoning_content(),
500 Some("just thinking".to_string())
501 );
502
503 let none = HarmonyDelta::default();
504 assert_eq!(none.reasoning_content(), None);
505 }
506}