1use std::collections::HashMap;
2use std::sync::Arc;
3
4pub mod rag;
5
6use anyhow::Result;
7use html2text::{config, render::PlainDecorator};
8use rayon::prelude::*;
9use scraper::{Html, Selector};
10use serde::{Deserialize, Serialize};
11use serde_json::{json, Value};
12use std::env::consts::{ARCH, FAMILY, OS};
13use tokenizers::Tokenizer;
14
15use crate::{Function, Tool, ToolType, WebSearchOptions, WebSearchUserLocation};
16
17pub type SearchCallback =
20 dyn Fn(&SearchFunctionParameters) -> Result<Vec<SearchResult>> + Send + Sync;
21
22pub(crate) fn search_tool_called(name: &str) -> bool {
23 name == SEARCH_TOOL_NAME || name == EXTRACT_TOOL_NAME
24}
25
26pub(crate) const SEARCH_TOOL_NAME: &str = "search_the_web";
27pub(crate) const EXTRACT_TOOL_NAME: &str = "website_content_extractor";
28
29const APP_VERSION: &str = env!("CARGO_PKG_VERSION");
30pub(crate) const SEARCH_DESCRIPTION: &str = r#"This tool is used to search the web given a query.
31If the user wants up-to-date information or you want to retrieve new information, call this tool.
32If you call this tool, then you MUST complete your answer using the output.
33The input can be a query. It should not be a URL. Either is fine.
34Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly.
35
36You should expect output like this:
37{
38 "output": [
39 {
40 "title": "...",
41 "description": "...",
42 "url": "...",
43 "content": "...",
44 },
45 ...
46 ]
47}
48"#;
49pub(crate) const EXTRACT_DESCRIPTION: &str = r#"This tool is used to extract the content of a website.
50If the user wants information about a specific site or you want to extract the content of a specific site, call this tool.
51The input must be a URL.
52Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly.
53
54You should expect output like this:
55{
56 "output": [
57 {
58 "url": "...",
59 "content": "...",
60 },
61 ...
62 ]
63}
64"#;
65
66#[derive(Debug, Serialize, Deserialize, Default)]
67pub struct SearchResult {
68 pub title: String,
69 pub description: String,
70 pub url: String,
71 pub content: String,
72}
73
74#[derive(Debug, Serialize, Deserialize, Default)]
75pub struct ExtractResult {
76 pub url: String,
77 pub content: String,
78}
79
80impl SearchResult {
81 pub fn cap_content_len(self, tokenizer: &Tokenizer, size: usize) -> Result<Self> {
82 let tokenized_content = tokenizer
83 .encode_fast(self.content, false)
84 .map_err(anyhow::Error::msg)?;
85 let ids = tokenized_content.get_ids();
86 let content = tokenizer
87 .decode(&ids[..size.min(ids.len())], false)
88 .map_err(anyhow::Error::msg)?;
89
90 Ok(Self {
91 title: self.title,
92 description: self.description,
93 url: self.url,
94 content,
95 })
96 }
97}
98
99impl ExtractResult {
100 pub fn cap_content_len(self, tokenizer: &Tokenizer, size: usize) -> Result<Self> {
101 let tokenized_content = tokenizer
102 .encode_fast(self.content, false)
103 .map_err(anyhow::Error::msg)?;
104 let ids = tokenized_content.get_ids();
105 let content = tokenizer
106 .decode(&ids[..size.min(ids.len())], false)
107 .map_err(anyhow::Error::msg)?;
108
109 Ok(Self {
110 url: self.url,
111 content,
112 })
113 }
114}
115
116#[derive(Debug, Serialize, Deserialize)]
117pub struct SearchFunctionParameters {
118 pub query: String,
119}
120
121#[derive(Debug, Serialize, Deserialize)]
122pub struct ExtractFunctionParameters {
123 pub url: String,
124}
125
126pub fn get_search_tools(web_search_options: &WebSearchOptions) -> Result<Vec<Tool>> {
127 let search_tool = {
128 let parameters: HashMap<String, Value> = serde_json::from_value(json!({
129 "type": "object",
130 "properties": {
131 "query": {
132 "type": "string",
133 "description": "A query for web searching.",
134 },
135 },
136 "required": ["query"],
137 }))?;
138
139 let location_details = match &web_search_options.user_location {
140 Some(WebSearchUserLocation::Approximate { approximate }) => {
141 format!(
142 "\nThe user's location is: {}, {}, {}, {}.",
143 approximate.city, approximate.region, approximate.country, approximate.timezone
144 )
145 }
146 None => "".to_string(),
147 };
148 let description = web_search_options
149 .search_description
150 .as_deref()
151 .unwrap_or(SEARCH_DESCRIPTION);
152 Tool {
153 tp: ToolType::Function,
154 function: Function {
155 description: Some(format!("{}{}", description, location_details)),
156 name: SEARCH_TOOL_NAME.to_string(),
157 parameters: Some(parameters),
158 },
159 }
160 };
161
162 let extract_tool = {
163 let parameters: HashMap<String, Value> = serde_json::from_value(json!({
164 "type": "object",
165 "properties": {
166 "url": {
167 "type": "string",
168 "description": "A URL to extract the content of the website from.",
169 },
170 },
171 "required": ["url"],
172 }))?;
173
174 let description = web_search_options
175 .extract_description
176 .as_deref()
177 .unwrap_or(EXTRACT_DESCRIPTION);
178 Tool {
179 tp: ToolType::Function,
180 function: Function {
181 description: Some(description.to_string()),
182 name: EXTRACT_TOOL_NAME.to_string(),
183 parameters: Some(parameters),
184 },
185 }
186 };
187
188 Ok(vec![search_tool, extract_tool])
189}
190
191pub fn run_search_tool(params: &SearchFunctionParameters) -> Result<Vec<SearchResult>> {
192 let client = reqwest::blocking::Client::new();
193
194 let encoded_query = urlencoding::encode(¶ms.query);
195 let url = format!("https://html.duckduckgo.com/html/?q={}", encoded_query);
196
197 let user_agent = format!("mistralrs/{APP_VERSION} ({OS}; {ARCH}; {FAMILY})");
198 let response = client.get(&url).header("User-Agent", &user_agent).send()?;
199
200 if !response.status().is_success() {
202 anyhow::bail!("Failed to fetch search results: {}", response.status())
203 }
204
205 let html = response.text()?;
206
207 let document = Html::parse_document(&html);
208
209 let result_selector = Selector::parse(".result").unwrap();
210 let title_selector = Selector::parse(".result__title").unwrap();
211 let snippet_selector = Selector::parse(".result__snippet").unwrap();
212 let url_selector = Selector::parse(".result__url").unwrap();
213
214 let partials: Vec<(String, String, String)> = document
216 .select(&result_selector)
217 .filter_map(|element| {
218 let title = element
219 .select(&title_selector)
220 .next()
221 .map(|e| e.text().collect::<String>().trim().to_string())
222 .unwrap_or_default();
223 let description = element
224 .select(&snippet_selector)
225 .next()
226 .map(|e| e.text().collect::<String>().trim().to_string())
227 .unwrap_or_default();
228 let mut url = element
229 .select(&url_selector)
230 .next()
231 .map(|e| e.text().collect::<String>().trim().to_string())
232 .unwrap_or_default();
233 if title.is_empty() || description.is_empty() || url.is_empty() {
234 return None;
235 }
236 if !url.starts_with("http") {
237 url = format!("https://{}", url);
238 }
239 Some((title, description, url))
240 })
241 .collect();
242
243 let client = Arc::new(client);
245 let results: Vec<SearchResult> = partials
246 .into_par_iter()
247 .filter_map(|(title, description, url)| {
248 let content = match client.get(&url).header("User-Agent", &user_agent).send() {
249 Ok(response) => {
250 let html = response.text().ok()?;
251 config::with_decorator(PlainDecorator::new())
252 .do_decorate()
253 .string_from_read(html.as_bytes(), 80)
254 .ok()?
255 }
256 Err(_) => return None,
257 };
258 Some(SearchResult {
259 title,
260 description,
261 url,
262 content,
263 })
264 })
265 .collect();
266
267 Ok(results)
268}
269
270pub fn run_extract_tool(params: &ExtractFunctionParameters) -> Result<ExtractResult> {
271 let client = reqwest::blocking::Client::new();
272
273 let user_agent = format!("mistralrs/{APP_VERSION} ({OS}; {ARCH}; {FAMILY})");
274
275 let content = match client
276 .get(¶ms.url)
277 .header("User-Agent", &user_agent)
278 .send()
279 {
280 Ok(response) => response.text().ok().and_then(|html| {
281 config::with_decorator(PlainDecorator::new())
282 .do_decorate()
283 .string_from_read(html.as_bytes(), 80)
284 .ok()
285 }),
286 Err(_) => None,
287 };
288 Ok(ExtractResult {
289 url: params.url.clone(),
290 content: content.unwrap_or("ERROR: failed to extract content".to_string()),
291 })
292}