1use candle_core::Device;
2use clap::Parser;
3use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
4use mistralrs_core::{
5 get_auto_device_map_params, get_model_dtype, initialize_logging, paged_attn_supported,
6 parse_isq_value, Constraint, DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata,
7 DeviceMapSetting, DrySamplingParams, Loader, LoaderBuilder, MemoryGpuConfig, MistralRs,
8 MistralRsBuilder, ModelSelected, NormalRequest, PagedAttentionConfig, PagedCacheType, Request,
9 RequestMessage, Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
10};
11use std::fmt::Display;
12use std::sync::Arc;
13use tokio::sync::mpsc::channel;
14use tracing::info;
15
16enum TestName {
17 Prompt(usize),
18 Gen(usize),
19}
20
21impl Display for TestName {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 let name = match self {
24 TestName::Prompt(n) => format!("pp {n}"),
25 TestName::Gen(n) => format!("tg {n}"),
26 };
27 write!(f, "{name}")
28 }
29}
30
31struct BenchResult {
32 usages: Vec<Usage>,
33 concurrency: usize,
34 test_name: TestName,
35}
36
37struct UncertainTokSec {
38 mean: f32,
39 std_dev: f32,
40}
41
42impl Display for UncertainTokSec {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "{:.3}±{:.3}", self.mean, self.std_dev)
45 }
46}
47
48async fn run_bench(
49 mistralrs: Arc<MistralRs>,
50 prompt: RequestMessage,
51 n_gen: usize,
52 concurrency: usize,
53 repetitions: usize,
54 test_name: TestName,
55) -> anyhow::Result<BenchResult> {
56 let sampling_params = SamplingParams {
57 temperature: Some(0.1),
58 top_k: Some(32),
59 top_p: Some(0.1),
60 min_p: Some(0.05),
61 top_n_logprobs: 0,
62 frequency_penalty: Some(0.1),
63 presence_penalty: Some(0.1),
64 repetition_penalty: None,
65 max_len: Some(n_gen),
66 stop_toks: None,
67 logits_bias: None,
68 n_choices: 1,
69 dry_params: Some(DrySamplingParams::default()),
70 };
71 let sender = mistralrs.get_sender(None).unwrap();
72 let (tx, mut rx) = channel(10_000);
73
74 let req = Request::Normal(Box::new(NormalRequest {
75 id: mistralrs.next_request_id(),
76 messages: prompt,
77 sampling_params: sampling_params.clone(),
78 response: tx,
79 return_logprobs: false,
80 is_streaming: false,
81 constraint: Constraint::None,
82 suffix: None,
83 tools: None,
84 tool_choice: None,
85 logits_processors: None,
86 return_raw_logits: false,
87 web_search_options: None,
88 model_id: None,
89 }));
90
91 let mut usages = Vec::new();
92
93 for _ in 0..repetitions {
94 for _ in 0..concurrency {
95 if sender.send(req.clone()).await.is_err() {
96 eprintln!("Receiver disconnected");
97 }
98 }
99 for _ in 0..concurrency {
100 match rx.recv().await {
101 Some(r) => match r {
102 Response::InternalError(e) => {
103 unreachable!("Got an internal error: {e:?}");
104 }
105 Response::ModelError(e, resp) => {
106 unreachable!("Got a model error: {e:?}, response: {resp:?}");
107 }
108 Response::ValidationError(e) => {
109 unreachable!("Got a validation error: {e:?}");
110 }
111 Response::Done(res) => {
112 usages.push(res.usage);
113 }
114 Response::Chunk(_) => unreachable!(),
115 Response::CompletionModelError(_, _) => unreachable!(),
116 Response::CompletionDone(res) => {
117 usages.push(res.usage);
118 }
119 Response::CompletionChunk(_) => unreachable!(),
120 Response::ImageGeneration(_) => unreachable!(),
121 Response::Speech { .. } => unreachable!(),
122 Response::Raw { .. } => unreachable!(),
123 },
124 None => unreachable!("Expected a Done response, got None",),
125 }
126 }
127 }
128
129 Ok(BenchResult {
130 usages,
131 concurrency,
132 test_name,
133 })
134}
135
136fn get_tok_s(result: &BenchResult) -> UncertainTokSec {
137 let ts_measurements = match result.test_name {
138 TestName::Prompt(_) => result
139 .usages
140 .iter()
141 .map(|u| u.avg_prompt_tok_per_sec)
142 .collect::<Vec<_>>(),
143 TestName::Gen(_) => result
144 .usages
145 .iter()
146 .map(|u| u.avg_compl_tok_per_sec)
147 .collect::<Vec<_>>(),
148 };
149 let mean = ts_measurements.iter().sum::<f32>() / ts_measurements.len() as f32;
151 let variance = ts_measurements
152 .iter()
153 .map(|e| (mean - e).powf(2.))
154 .sum::<f32>()
155 / ts_measurements.len() as f32;
156 let std_dev = variance.sqrt();
157 UncertainTokSec { mean, std_dev }
158}
159
160fn get_ms_tok(result: &BenchResult) -> UncertainTokSec {
161 let ms_tok_measurements = match result.test_name {
162 TestName::Prompt(_) => result
163 .usages
164 .iter()
165 .map(|u| 1000. / u.avg_prompt_tok_per_sec)
166 .collect::<Vec<_>>(),
167 TestName::Gen(_) => result
168 .usages
169 .iter()
170 .map(|u| 1000. / u.avg_compl_tok_per_sec)
171 .collect::<Vec<_>>(),
172 };
173 let mean = ms_tok_measurements.iter().sum::<f32>() / ms_tok_measurements.len() as f32;
175 let variance = ms_tok_measurements
176 .iter()
177 .map(|e| (mean - e).powf(2.))
178 .sum::<f32>()
179 / ms_tok_measurements.len() as f32;
180 let std_dev = variance.sqrt();
181 UncertainTokSec { mean, std_dev }
182}
183
184fn print_usage(model: &str, device: &Device, results: Vec<BenchResult>) {
185 let backend = match device {
186 Device::Cpu => "CPU",
187 Device::Cuda(_) => "CUDA",
188 Device::Metal(_) => "Metal",
189 };
190 let results: Vec<Vec<CellStruct>> = results
191 .into_iter()
192 .map(|r| {
193 vec![
194 model.cell(),
195 backend.cell(),
196 r.test_name.to_string().cell(),
197 get_tok_s(&r).cell().justify(Justify::Right),
198 get_ms_tok(&r).cell().justify(Justify::Right),
199 r.concurrency.cell().justify(Justify::Right),
200 (get_tok_s(&r).mean * r.concurrency as f32)
201 .cell()
202 .justify(Justify::Right),
203 ]
204 })
205 .collect();
206
207 let table = results
208 .table()
209 .title(vec![
210 "model".cell().bold(true),
211 "backend".cell().bold(true),
214 "test".cell().bold(true),
216 "t/s".cell().bold(true),
217 "ms/t".cell().bold(true),
218 "concurrency".cell().bold(true),
219 "throughput/s".cell().bold(true),
220 ])
221 .bold(true);
222 print_stdout(table).expect("print table");
223}
224
225async fn warmup_run(mistralrs: Arc<MistralRs>) {
226 let sampling_params = SamplingParams {
227 max_len: Some(1),
228 ..SamplingParams::deterministic()
229 };
230 let sender = mistralrs.get_sender(None).unwrap();
231 let (tx, mut rx) = channel(10_000);
232
233 let req = Request::Normal(Box::new(NormalRequest {
234 id: mistralrs.next_request_id(),
235 messages: RequestMessage::Completion {
236 text: "Hello!".to_string(),
237 echo_prompt: false,
238 best_of: None,
239 },
240 sampling_params: sampling_params.clone(),
241 response: tx,
242 return_logprobs: false,
243 is_streaming: false,
244 constraint: Constraint::None,
245 suffix: None,
246 tools: None,
247 tool_choice: None,
248 logits_processors: None,
249 return_raw_logits: false,
250 web_search_options: None,
251 model_id: None,
252 }));
253
254 if sender.send(req.clone()).await.is_err() {
255 eprintln!("Receiver disconnected");
256 }
257
258 let _ = rx.recv().await;
259}
260
261fn parse_cache_type(s: &str) -> Result<PagedCacheType, String> {
262 s.parse()
263}
264
265#[derive(Parser)]
266#[command(version, about, long_about = None)]
267struct Args {
268 #[clap(subcommand)]
270 model: ModelSelected,
271
272 #[arg(short, long)]
274 seed: Option<u64>,
275
276 #[arg(long, short = 'p', default_value_t = 512)]
278 n_prompt: usize,
279
280 #[arg(long, short = 'g', default_value_t = 128)]
282 n_gen: usize,
283
284 #[clap(short, long, value_parser, value_delimiter = ',')]
286 concurrency: Option<Vec<usize>>,
287
288 #[arg(long, short, default_value_t = 5)]
290 repetitions: usize,
291
292 #[arg(short, long, value_parser, value_delimiter = ';')]
297 num_device_layers: Option<Vec<String>>,
298
299 #[arg(long = "isq")]
301 in_situ_quant: Option<String>,
302
303 #[arg(long = "pa-gpu-mem")]
307 paged_attn_gpu_mem: Option<usize>,
308
309 #[arg(long = "pa-gpu-mem-usage")]
314 paged_attn_gpu_mem_usage: Option<f32>,
315
316 #[arg(long = "pa-ctxt-len")]
321 paged_ctxt_len: Option<usize>,
322
323 #[arg(long = "pa-cache-type", value_parser = parse_cache_type)]
326 cache_type: Option<PagedCacheType>,
327
328 #[arg(long = "pa-blk-size")]
331 paged_attn_block_size: Option<usize>,
332
333 #[arg(long = "no-paged-attn", default_value_t = false)]
335 no_paged_attn: bool,
336
337 #[arg(long = "paged-attn", default_value_t = false)]
339 paged_attn: bool,
340}
341
342#[tokio::main]
343async fn main() -> anyhow::Result<()> {
344 let mut args = Args::parse();
345 initialize_logging();
346
347 args.concurrency = Some(args.concurrency.unwrap_or(vec![1]));
348
349 let dtype = get_model_dtype(&args.model)?;
350 let auto_device_map_params = get_auto_device_map_params(&args.model)?;
351
352 let max_seq_len = auto_device_map_params.max_seq_len();
353
354 let loader: Box<dyn Loader> = LoaderBuilder::new(args.model).build()?;
355 let model_name = loader.get_id();
356
357 #[cfg(feature = "metal")]
358 let device = Device::new_metal(0)?;
359 #[cfg(not(feature = "metal"))]
360 let device = if mistralrs_core::distributed::use_nccl() {
361 Device::Cpu
362 } else {
363 Device::cuda_if_available(0)?
364 };
365
366 if let Some(seed) = args.seed {
367 device.set_seed(seed)?;
368 }
369
370 let token_source = TokenSource::CacheToken;
371 info!(
372 "avx: {}, neon: {}, simd128: {}, f16c: {}",
373 candle_core::utils::with_avx(),
374 candle_core::utils::with_neon(),
375 candle_core::utils::with_simd128(),
376 candle_core::utils::with_f16c()
377 );
378 info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
379 info!("Model kind is: {}", loader.get_kind().to_string());
380
381 let mapper = if let Some(device_layers) = args.num_device_layers {
383 if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
384 let layers = device_layers[0].parse::<usize>().unwrap();
385 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
386 DeviceLayerMapMetadata { ordinal: 0, layers },
387 ]))
388 } else {
389 let mut mapping = Vec::new();
390 for layer in device_layers {
391 let split = layer.splitn(2, ':').collect::<Vec<_>>();
392 if split.len() < 2 {
393 panic!("Expected layer to be of format ORD:NUM, got {layer}");
394 }
395 let ord = split[0]
396 .parse::<usize>()
397 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
398 let num = split[1]
399 .parse::<usize>()
400 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
401 for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
402 if *ordinal == ord {
403 panic!("Duplicate ordinal {ord}");
404 }
405 }
406 mapping.push(DeviceLayerMapMetadata {
407 ordinal: ord,
408 layers: num,
409 });
410 }
411 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
412 }
413 } else {
414 DeviceMapSetting::Auto(auto_device_map_params)
415 };
416
417 let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
418 args.no_paged_attn
419 } else if device.is_metal() {
420 !args.paged_attn
421 } else {
422 true
423 };
424
425 let cache_config = match (
428 args.paged_attn_block_size,
429 args.paged_attn_gpu_mem,
430 args.paged_attn_gpu_mem_usage,
431 args.paged_ctxt_len,
432 paged_attn_supported(),
433 no_paged_attn,
434 ) {
435 (block_size, None, None, None, true, false) => Some(PagedAttentionConfig::new(
436 block_size,
437 512,
438 MemoryGpuConfig::ContextSize(max_seq_len),
439 args.cache_type.unwrap_or_default(),
440 )?),
441 (block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
442 block_size,
443 512,
444 MemoryGpuConfig::ContextSize(ctxt),
445 args.cache_type.unwrap_or_default(),
446 )?),
447 (block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
448 block_size,
449 512,
450 MemoryGpuConfig::Utilization(f),
451 args.cache_type.unwrap_or_default(),
452 )?),
453 (block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
454 block_size,
455 512,
456 MemoryGpuConfig::MbAmount(m),
457 args.cache_type.unwrap_or_default(),
458 )?),
459 (block_size, Some(_m), Some(f), None, true, false) => {
460 info!("Both memory size, and usage were specified, defaulting to the usage value.");
461 Some(PagedAttentionConfig::new(
462 block_size,
463 512,
464 MemoryGpuConfig::Utilization(f),
465 args.cache_type.unwrap_or_default(),
466 )?)
467 }
468 (block_size, Some(_m), None, Some(ctxt), true, false) => {
469 info!("All memory size and ctxt len, defaulting to the context len value.");
470 Some(PagedAttentionConfig::new(
471 block_size,
472 512,
473 MemoryGpuConfig::ContextSize(ctxt),
474 args.cache_type.unwrap_or_default(),
475 )?)
476 }
477 (block_size, None, Some(f), Some(_ctxt), true, false) => {
478 info!("Both ctxt len and usage were specified, defaulting to the usage value.");
479 Some(PagedAttentionConfig::new(
480 block_size,
481 512,
482 MemoryGpuConfig::Utilization(f),
483 args.cache_type.unwrap_or_default(),
484 )?)
485 }
486 (_, _, _, _, _, _) => None,
487 };
488
489 let isq = args
490 .in_situ_quant
491 .as_ref()
492 .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
493
494 let pipeline = loader.load_model_from_hf(
495 None,
496 token_source,
497 &dtype,
498 &device,
499 false,
500 mapper,
501 isq,
502 cache_config,
503 )?;
504 info!("Model loaded.");
505
506 let scheduler_config = if cache_config.is_some() {
507 if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
509 SchedulerConfig::PagedAttentionMeta {
510 max_num_seqs: *args.concurrency.as_ref().unwrap().iter().max().unwrap(),
511 config: cache_config.clone(),
512 }
513 } else {
514 SchedulerConfig::DefaultScheduler {
515 method: DefaultSchedulerMethod::Fixed(
516 (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
517 .try_into()
518 .unwrap(),
519 ),
520 }
521 }
522 } else {
523 SchedulerConfig::DefaultScheduler {
524 method: DefaultSchedulerMethod::Fixed(
525 (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
526 .try_into()
527 .unwrap(),
528 ),
529 }
530 };
531 let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config, false, None)
532 .with_no_prefix_cache(true)
533 .with_disable_eos_stop(true)
534 .build()
535 .await;
536
537 info!("Starting warmup run.");
538 warmup_run(mistralrs.clone()).await;
539 info!("Finished warmup run.");
540 info!("Starting benchmarks.");
541
542 for concurrency in args.concurrency.as_ref().unwrap() {
543 let mut results = vec![];
544 if args.n_gen > 0 {
545 let r = run_bench(
546 mistralrs.clone(),
547 RequestMessage::Completion {
548 text: "Rust".to_string(),
549 echo_prompt: false,
550 best_of: None,
551 },
552 args.n_gen - 1,
553 *concurrency,
554 args.repetitions,
555 TestName::Gen(args.n_gen),
556 )
557 .await?;
558 results.push(r);
559 }
560
561 if args.n_prompt > 0 {
562 let tks = (1000..1000 + args.n_prompt as u32).collect();
563 let r = run_bench(
564 mistralrs.clone(),
565 RequestMessage::CompletionTokens(tks),
566 1,
567 *concurrency,
568 args.repetitions,
569 TestName::Prompt(args.n_prompt),
570 )
571 .await?;
572
573 results.push(r);
574 }
575
576 print_usage(&model_name, &device, results);
577 }
578
579 Ok(())
580}