mistralrs_core/search/
mod.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4pub mod rag;
5
6use anyhow::Result;
7use html2text::{config, render::PlainDecorator};
8use rayon::prelude::*;
9use scraper::{Html, Selector};
10use serde::{Deserialize, Serialize};
11use serde_json::{json, Value};
12use std::env::consts::{ARCH, FAMILY, OS};
13use tokenizers::Tokenizer;
14
15use crate::{Function, Tool, ToolType, WebSearchOptions, WebSearchUserLocation};
16
17/// Callback used to override how search results are gathered. The returned
18/// vector must be sorted in decreasing order of relevance.
19pub type SearchCallback =
20    dyn Fn(&SearchFunctionParameters) -> Result<Vec<SearchResult>> + Send + Sync;
21
22pub(crate) fn search_tool_called(name: &str) -> bool {
23    name == SEARCH_TOOL_NAME || name == EXTRACT_TOOL_NAME
24}
25
26pub(crate) const SEARCH_TOOL_NAME: &str = "search_the_web";
27pub(crate) const EXTRACT_TOOL_NAME: &str = "website_content_extractor";
28
29const APP_VERSION: &str = env!("CARGO_PKG_VERSION");
30pub(crate) const SEARCH_DESCRIPTION: &str = r#"This tool is used to search the web given a query.
31If the user wants up-to-date information or you want to retrieve new information, call this tool.
32If you call this tool, then you MUST complete your answer using the output.
33The input can be a query. It should not be a URL. Either is fine.
34Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly.
35
36You should expect output like this:
37{
38    "output": [
39        {
40            "title": "...",
41            "description": "...",
42            "url": "...",
43            "content": "...",
44        },
45        ...
46    ]
47}
48"#;
49pub(crate) const EXTRACT_DESCRIPTION: &str = r#"This tool is used to extract the content of a website.
50If the user wants information about a specific site or you want to extract the content of a specific site, call this tool.
51The input must be a URL.
52Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly.
53
54You should expect output like this:
55{
56    "output": [
57        {
58            "url": "...",
59            "content": "...",
60        },
61        ...
62    ]
63}
64"#;
65
66#[derive(Debug, Serialize, Deserialize, Default)]
67pub struct SearchResult {
68    pub title: String,
69    pub description: String,
70    pub url: String,
71    pub content: String,
72}
73
74#[derive(Debug, Serialize, Deserialize, Default)]
75pub struct ExtractResult {
76    pub url: String,
77    pub content: String,
78}
79
80impl SearchResult {
81    pub fn cap_content_len(self, tokenizer: &Tokenizer, size: usize) -> Result<Self> {
82        let tokenized_content = tokenizer
83            .encode_fast(self.content, false)
84            .map_err(anyhow::Error::msg)?;
85        let ids = tokenized_content.get_ids();
86        let content = tokenizer
87            .decode(&ids[..size.min(ids.len())], false)
88            .map_err(anyhow::Error::msg)?;
89
90        Ok(Self {
91            title: self.title,
92            description: self.description,
93            url: self.url,
94            content,
95        })
96    }
97}
98
99impl ExtractResult {
100    pub fn cap_content_len(self, tokenizer: &Tokenizer, size: usize) -> Result<Self> {
101        let tokenized_content = tokenizer
102            .encode_fast(self.content, false)
103            .map_err(anyhow::Error::msg)?;
104        let ids = tokenized_content.get_ids();
105        let content = tokenizer
106            .decode(&ids[..size.min(ids.len())], false)
107            .map_err(anyhow::Error::msg)?;
108
109        Ok(Self {
110            url: self.url,
111            content,
112        })
113    }
114}
115
116#[derive(Debug, Serialize, Deserialize)]
117pub struct SearchFunctionParameters {
118    pub query: String,
119}
120
121#[derive(Debug, Serialize, Deserialize)]
122pub struct ExtractFunctionParameters {
123    pub url: String,
124}
125
126pub fn get_search_tools(web_search_options: &WebSearchOptions) -> Result<Vec<Tool>> {
127    let search_tool = {
128        let parameters: HashMap<String, Value> = serde_json::from_value(json!({
129            "type": "object",
130            "properties": {
131                "query": {
132                    "type": "string",
133                    "description": "A query for web searching.",
134                },
135            },
136            "required": ["query"],
137        }))?;
138
139        let location_details = match &web_search_options.user_location {
140            Some(WebSearchUserLocation::Approximate { approximate }) => {
141                format!(
142                    "\nThe user's location is: {}, {}, {}, {}.",
143                    approximate.city, approximate.region, approximate.country, approximate.timezone
144                )
145            }
146            None => "".to_string(),
147        };
148        let description = web_search_options
149            .search_description
150            .as_deref()
151            .unwrap_or(SEARCH_DESCRIPTION);
152        Tool {
153            tp: ToolType::Function,
154            function: Function {
155                description: Some(format!("{}{}", description, location_details)),
156                name: SEARCH_TOOL_NAME.to_string(),
157                parameters: Some(parameters),
158            },
159        }
160    };
161
162    let extract_tool = {
163        let parameters: HashMap<String, Value> = serde_json::from_value(json!({
164            "type": "object",
165            "properties": {
166                "url": {
167                    "type": "string",
168                    "description": "A URL to extract the content of the website from.",
169                },
170            },
171            "required": ["url"],
172        }))?;
173
174        let description = web_search_options
175            .extract_description
176            .as_deref()
177            .unwrap_or(EXTRACT_DESCRIPTION);
178        Tool {
179            tp: ToolType::Function,
180            function: Function {
181                description: Some(description.to_string()),
182                name: EXTRACT_TOOL_NAME.to_string(),
183                parameters: Some(parameters),
184            },
185        }
186    };
187
188    Ok(vec![search_tool, extract_tool])
189}
190
191pub fn run_search_tool(params: &SearchFunctionParameters) -> Result<Vec<SearchResult>> {
192    let client = reqwest::blocking::Client::new();
193
194    let encoded_query = urlencoding::encode(&params.query);
195    let url = format!("https://html.duckduckgo.com/html/?q={}", encoded_query);
196
197    let user_agent = format!("mistralrs/{APP_VERSION} ({OS}; {ARCH}; {FAMILY})");
198    let response = client.get(&url).header("User-Agent", &user_agent).send()?;
199
200    // Check the response status
201    if !response.status().is_success() {
202        anyhow::bail!("Failed to fetch search results: {}", response.status())
203    }
204
205    let html = response.text()?;
206
207    let document = Html::parse_document(&html);
208
209    let result_selector = Selector::parse(".result").unwrap();
210    let title_selector = Selector::parse(".result__title").unwrap();
211    let snippet_selector = Selector::parse(".result__snippet").unwrap();
212    let url_selector = Selector::parse(".result__url").unwrap();
213
214    // Phase 1: collect title, description, and url serially into a Vec of tuples
215    let partials: Vec<(String, String, String)> = document
216        .select(&result_selector)
217        .filter_map(|element| {
218            let title = element
219                .select(&title_selector)
220                .next()
221                .map(|e| e.text().collect::<String>().trim().to_string())
222                .unwrap_or_default();
223            let description = element
224                .select(&snippet_selector)
225                .next()
226                .map(|e| e.text().collect::<String>().trim().to_string())
227                .unwrap_or_default();
228            let mut url = element
229                .select(&url_selector)
230                .next()
231                .map(|e| e.text().collect::<String>().trim().to_string())
232                .unwrap_or_default();
233            if title.is_empty() || description.is_empty() || url.is_empty() {
234                return None;
235            }
236            if !url.starts_with("http") {
237                url = format!("https://{}", url);
238            }
239            Some((title, description, url))
240        })
241        .collect();
242
243    // Phase 2: fetch content in parallel using Rayon
244    let client = Arc::new(client);
245    let results: Vec<SearchResult> = partials
246        .into_par_iter()
247        .filter_map(|(title, description, url)| {
248            let content = match client.get(&url).header("User-Agent", &user_agent).send() {
249                Ok(response) => {
250                    let html = response.text().ok()?;
251                    config::with_decorator(PlainDecorator::new())
252                        .do_decorate()
253                        .string_from_read(html.as_bytes(), 80)
254                        .ok()?
255                }
256                Err(_) => return None,
257            };
258            Some(SearchResult {
259                title,
260                description,
261                url,
262                content,
263            })
264        })
265        .collect();
266
267    Ok(results)
268}
269
270pub fn run_extract_tool(params: &ExtractFunctionParameters) -> Result<ExtractResult> {
271    let client = reqwest::blocking::Client::new();
272
273    let user_agent = format!("mistralrs/{APP_VERSION} ({OS}; {ARCH}; {FAMILY})");
274
275    let content = match client
276        .get(&params.url)
277        .header("User-Agent", &user_agent)
278        .send()
279    {
280        Ok(response) => response.text().ok().and_then(|html| {
281            config::with_decorator(PlainDecorator::new())
282                .do_decorate()
283                .string_from_read(html.as_bytes(), 80)
284                .ok()
285        }),
286        Err(_) => None,
287    };
288    Ok(ExtractResult {
289        url: params.url.clone(),
290        content: content.unwrap_or("ERROR: failed to extract content".to_string()),
291    })
292}