1use candle_core::Device;
2use clap::Parser;
3use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
4use mistralrs_core::{
5 get_auto_device_map_params, get_model_dtype, initialize_logging, paged_attn_supported,
6 parse_isq_value, Constraint, DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata,
7 DeviceMapSetting, DrySamplingParams, Loader, LoaderBuilder, MemoryGpuConfig, MistralRs,
8 MistralRsBuilder, ModelSelected, NormalRequest, PagedAttentionConfig, PagedCacheType, Request,
9 RequestMessage, Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
10};
11use std::sync::Arc;
12use std::{fmt::Display, num::NonZeroUsize};
13use tokio::sync::mpsc::channel;
14use tracing::info;
15
16enum TestName {
17 Prompt(usize),
18 Gen(usize),
19}
20
21impl Display for TestName {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 let name = match self {
24 TestName::Prompt(n) => format!("pp {n}"),
25 TestName::Gen(n) => format!("tg {n}"),
26 };
27 write!(f, "{name}")
28 }
29}
30
31struct BenchResult {
32 usages: Vec<Usage>,
33 concurrency: usize,
34 test_name: TestName,
35}
36
37struct UncertainTokSec {
38 mean: f32,
39 std_dev: f32,
40}
41
42impl Display for UncertainTokSec {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "{:.3}±{:.3}", self.mean, self.std_dev)
45 }
46}
47
48fn run_bench(
49 mistralrs: Arc<MistralRs>,
50 prompt: RequestMessage,
51 n_gen: usize,
52 concurrency: usize,
53 repetitions: usize,
54 test_name: TestName,
55) -> anyhow::Result<BenchResult> {
56 let sampling_params = SamplingParams {
57 temperature: Some(0.1),
58 top_k: Some(32),
59 top_p: Some(0.1),
60 min_p: Some(0.05),
61 top_n_logprobs: 0,
62 frequency_penalty: Some(0.1),
63 presence_penalty: Some(0.1),
64 max_len: Some(n_gen),
65 stop_toks: None,
66 logits_bias: None,
67 n_choices: 1,
68 dry_params: Some(DrySamplingParams::default()),
69 };
70 let sender = mistralrs.get_sender(None).unwrap();
71 let (tx, mut rx) = channel(10_000);
72
73 let req = Request::Normal(Box::new(NormalRequest {
74 id: mistralrs.next_request_id(),
75 messages: prompt,
76 sampling_params: sampling_params.clone(),
77 response: tx,
78 return_logprobs: false,
79 is_streaming: false,
80 constraint: Constraint::None,
81 suffix: None,
82 tools: None,
83 tool_choice: None,
84 logits_processors: None,
85 return_raw_logits: false,
86 web_search_options: None,
87 model_id: None,
88 }));
89
90 let mut usages = Vec::new();
91
92 for _ in 0..repetitions {
93 for _ in 0..concurrency {
94 if sender.blocking_send(req.clone()).is_err() {
95 eprintln!("Receiver disconnected");
96 }
97 }
98 for _ in 0..concurrency {
99 match rx.blocking_recv() {
100 Some(r) => match r {
101 Response::InternalError(e) => {
102 unreachable!("Got an internal error: {e:?}");
103 }
104 Response::ModelError(e, resp) => {
105 unreachable!("Got a model error: {e:?}, response: {resp:?}");
106 }
107 Response::ValidationError(e) => {
108 unreachable!("Got a validation error: {e:?}");
109 }
110 Response::Done(res) => {
111 usages.push(res.usage);
112 }
113 Response::Chunk(_) => unreachable!(),
114 Response::CompletionModelError(_, _) => unreachable!(),
115 Response::CompletionDone(res) => {
116 usages.push(res.usage);
117 }
118 Response::CompletionChunk(_) => unreachable!(),
119 Response::ImageGeneration(_) => unreachable!(),
120 Response::Speech { .. } => unreachable!(),
121 Response::Raw { .. } => unreachable!(),
122 },
123 None => unreachable!("Expected a Done response, got None",),
124 }
125 }
126 }
127
128 Ok(BenchResult {
129 usages,
130 concurrency,
131 test_name,
132 })
133}
134
135fn get_tok_s(result: &BenchResult) -> UncertainTokSec {
136 let ts_measurements = match result.test_name {
137 TestName::Prompt(_) => result
138 .usages
139 .iter()
140 .map(|u| u.avg_prompt_tok_per_sec)
141 .collect::<Vec<_>>(),
142 TestName::Gen(_) => result
143 .usages
144 .iter()
145 .map(|u| u.avg_compl_tok_per_sec)
146 .collect::<Vec<_>>(),
147 };
148 let mean = ts_measurements.iter().sum::<f32>() / ts_measurements.len() as f32;
150 let variance = ts_measurements
151 .iter()
152 .map(|e| (mean - e).powf(2.))
153 .sum::<f32>()
154 / ts_measurements.len() as f32;
155 let std_dev = variance.sqrt();
156 UncertainTokSec { mean, std_dev }
157}
158
159fn get_ms_tok(result: &BenchResult) -> UncertainTokSec {
160 let ms_tok_measurements = match result.test_name {
161 TestName::Prompt(_) => result
162 .usages
163 .iter()
164 .map(|u| 1000. / u.avg_prompt_tok_per_sec)
165 .collect::<Vec<_>>(),
166 TestName::Gen(_) => result
167 .usages
168 .iter()
169 .map(|u| 1000. / u.avg_compl_tok_per_sec)
170 .collect::<Vec<_>>(),
171 };
172 let mean = ms_tok_measurements.iter().sum::<f32>() / ms_tok_measurements.len() as f32;
174 let variance = ms_tok_measurements
175 .iter()
176 .map(|e| (mean - e).powf(2.))
177 .sum::<f32>()
178 / ms_tok_measurements.len() as f32;
179 let std_dev = variance.sqrt();
180 UncertainTokSec { mean, std_dev }
181}
182
183fn print_usage(model: &str, device: &Device, results: Vec<BenchResult>) {
184 let backend = match device {
185 Device::Cpu => "CPU",
186 Device::Cuda(_) => "CUDA",
187 Device::Metal(_) => "Metal",
188 };
189 let results: Vec<Vec<CellStruct>> = results
190 .into_iter()
191 .map(|r| {
192 vec![
193 model.cell(),
194 backend.cell(),
195 r.test_name.to_string().cell(),
196 get_tok_s(&r).cell().justify(Justify::Right),
197 get_ms_tok(&r).cell().justify(Justify::Right),
198 r.concurrency.cell().justify(Justify::Right),
199 (get_tok_s(&r).mean * r.concurrency as f32)
200 .cell()
201 .justify(Justify::Right),
202 ]
203 })
204 .collect();
205
206 let table = results
207 .table()
208 .title(vec![
209 "model".cell().bold(true),
210 "backend".cell().bold(true),
213 "test".cell().bold(true),
215 "t/s".cell().bold(true),
216 "ms/t".cell().bold(true),
217 "concurrency".cell().bold(true),
218 "throughput/s".cell().bold(true),
219 ])
220 .bold(true);
221 print_stdout(table).expect("print table");
222}
223
224fn warmup_run(mistralrs: Arc<MistralRs>) {
225 let sampling_params = SamplingParams {
226 temperature: Some(0.1),
227 top_k: Some(32),
228 top_p: Some(0.1),
229 min_p: Some(0.05),
230 top_n_logprobs: 0,
231 frequency_penalty: Some(0.1),
232 presence_penalty: Some(0.1),
233 max_len: Some(5),
234 stop_toks: None,
235 logits_bias: None,
236 n_choices: 1,
237 dry_params: Some(DrySamplingParams::default()),
238 };
239 let sender = mistralrs.get_sender(None).unwrap();
240 let (tx, mut rx) = channel(10_000);
241
242 let req = Request::Normal(Box::new(NormalRequest {
243 id: mistralrs.next_request_id(),
244 messages: RequestMessage::Completion {
245 text: "Hello!".to_string(),
246 echo_prompt: false,
247 best_of: None,
248 },
249 sampling_params: sampling_params.clone(),
250 response: tx,
251 return_logprobs: false,
252 is_streaming: false,
253 constraint: Constraint::None,
254 suffix: None,
255 tools: None,
256 tool_choice: None,
257 logits_processors: None,
258 return_raw_logits: false,
259 web_search_options: None,
260 model_id: None,
261 }));
262
263 if sender.blocking_send(req.clone()).is_err() {
264 eprintln!("Receiver disconnected");
265 }
266
267 let _ = rx.blocking_recv();
268}
269
270fn parse_cache_type(s: &str) -> Result<PagedCacheType, String> {
271 s.parse()
272}
273
274#[derive(Parser)]
275#[command(version, about, long_about = None)]
276struct Args {
277 #[clap(subcommand)]
279 model: ModelSelected,
280
281 #[arg(short, long)]
283 seed: Option<u64>,
284
285 #[arg(long, short = 'p', default_value_t = 512)]
287 n_prompt: usize,
288
289 #[arg(long, short = 'g', default_value_t = 128)]
291 n_gen: usize,
292
293 #[clap(short, long, value_parser, value_delimiter = ',')]
295 concurrency: Option<Vec<usize>>,
296
297 #[arg(long, short, default_value_t = 5)]
299 repetitions: usize,
300
301 #[arg(short, long, value_parser, value_delimiter = ';')]
306 num_device_layers: Option<Vec<String>>,
307
308 #[arg(long = "isq")]
310 in_situ_quant: Option<String>,
311
312 #[arg(long = "pa-gpu-mem")]
316 paged_attn_gpu_mem: Option<usize>,
317
318 #[arg(long = "pa-gpu-mem-usage")]
323 paged_attn_gpu_mem_usage: Option<f32>,
324
325 #[arg(long = "pa-ctxt-len")]
330 paged_ctxt_len: Option<usize>,
331
332 #[arg(long = "pa-cache-type", value_parser = parse_cache_type)]
335 cache_type: Option<PagedCacheType>,
336
337 #[arg(long = "pa-blk-size")]
340 paged_attn_block_size: Option<usize>,
341
342 #[arg(long = "no-paged-attn", default_value_t = false)]
344 no_paged_attn: bool,
345
346 #[arg(long = "paged-attn", default_value_t = false)]
348 paged_attn: bool,
349
350 #[arg(long = "prompt-batchsize")]
352 prompt_chunksize: Option<usize>,
353}
354
355#[tokio::main]
356async fn main() -> anyhow::Result<()> {
357 let mut args = Args::parse();
358 initialize_logging();
359
360 args.concurrency = Some(args.concurrency.unwrap_or(vec![1]));
361
362 let prompt_chunksize = match args.prompt_chunksize {
363 Some(0) => {
364 anyhow::bail!("`prompt_chunksize` must be a strictly positive integer, got 0.",)
365 }
366 Some(x) => Some(NonZeroUsize::new(x).unwrap()),
367 None => None,
368 };
369
370 let dtype = get_model_dtype(&args.model)?;
371 let auto_device_map_params = get_auto_device_map_params(&args.model)?;
372
373 let max_seq_len = auto_device_map_params.max_seq_len();
374
375 let loader: Box<dyn Loader> = LoaderBuilder::new(args.model)
376 .with_prompt_chunksize(prompt_chunksize)
377 .build()?;
378 let model_name = loader.get_id();
379
380 #[cfg(feature = "metal")]
381 let device = Device::new_metal(0)?;
382 #[cfg(not(feature = "metal"))]
383 let device = if mistralrs_core::distributed::use_nccl() {
384 Device::Cpu
385 } else {
386 Device::cuda_if_available(0)?
387 };
388
389 if let Some(seed) = args.seed {
390 device.set_seed(seed)?;
391 }
392
393 let token_source = TokenSource::CacheToken;
394 info!(
395 "avx: {}, neon: {}, simd128: {}, f16c: {}",
396 candle_core::utils::with_avx(),
397 candle_core::utils::with_neon(),
398 candle_core::utils::with_simd128(),
399 candle_core::utils::with_f16c()
400 );
401 info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
402 info!("Model kind is: {}", loader.get_kind().to_string());
403
404 let mapper = if let Some(device_layers) = args.num_device_layers {
406 if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
407 let layers = device_layers[0].parse::<usize>().unwrap();
408 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
409 DeviceLayerMapMetadata { ordinal: 0, layers },
410 ]))
411 } else {
412 let mut mapping = Vec::new();
413 for layer in device_layers {
414 let split = layer.splitn(2, ':').collect::<Vec<_>>();
415 if split.len() < 2 {
416 panic!("Expected layer to be of format ORD:NUM, got {layer}");
417 }
418 let ord = split[0]
419 .parse::<usize>()
420 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
421 let num = split[1]
422 .parse::<usize>()
423 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
424 for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
425 if *ordinal == ord {
426 panic!("Duplicate ordinal {ord}");
427 }
428 }
429 mapping.push(DeviceLayerMapMetadata {
430 ordinal: ord,
431 layers: num,
432 });
433 }
434 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
435 }
436 } else {
437 DeviceMapSetting::Auto(auto_device_map_params)
438 };
439
440 let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
441 args.no_paged_attn
442 } else if device.is_metal() {
443 !args.paged_attn
444 } else {
445 true
446 };
447
448 let cache_config = match (
451 args.paged_attn_block_size,
452 args.paged_attn_gpu_mem,
453 args.paged_attn_gpu_mem_usage,
454 args.paged_ctxt_len,
455 paged_attn_supported(),
456 no_paged_attn,
457 ) {
458 (block_size, None, None, None, true, false) => Some(PagedAttentionConfig::new(
459 block_size,
460 512,
461 MemoryGpuConfig::ContextSize(max_seq_len),
462 args.cache_type.unwrap_or_default(),
463 )?),
464 (block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
465 block_size,
466 512,
467 MemoryGpuConfig::ContextSize(ctxt),
468 args.cache_type.unwrap_or_default(),
469 )?),
470 (block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
471 block_size,
472 512,
473 MemoryGpuConfig::Utilization(f),
474 args.cache_type.unwrap_or_default(),
475 )?),
476 (block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
477 block_size,
478 512,
479 MemoryGpuConfig::MbAmount(m),
480 args.cache_type.unwrap_or_default(),
481 )?),
482 (block_size, Some(_m), Some(f), None, true, false) => {
483 info!("Both memory size, and usage were specified, defaulting to the usage value.");
484 Some(PagedAttentionConfig::new(
485 block_size,
486 512,
487 MemoryGpuConfig::Utilization(f),
488 args.cache_type.unwrap_or_default(),
489 )?)
490 }
491 (block_size, Some(_m), None, Some(ctxt), true, false) => {
492 info!("All memory size and ctxt len, defaulting to the context len value.");
493 Some(PagedAttentionConfig::new(
494 block_size,
495 512,
496 MemoryGpuConfig::ContextSize(ctxt),
497 args.cache_type.unwrap_or_default(),
498 )?)
499 }
500 (block_size, None, Some(f), Some(_ctxt), true, false) => {
501 info!("Both ctxt len and usage were specified, defaulting to the usage value.");
502 Some(PagedAttentionConfig::new(
503 block_size,
504 512,
505 MemoryGpuConfig::Utilization(f),
506 args.cache_type.unwrap_or_default(),
507 )?)
508 }
509 (_, _, _, _, _, _) => None,
510 };
511
512 let isq = args
513 .in_situ_quant
514 .as_ref()
515 .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
516
517 let pipeline = loader.load_model_from_hf(
518 None,
519 token_source,
520 &dtype,
521 &device,
522 false,
523 mapper,
524 isq,
525 cache_config,
526 )?;
527 info!("Model loaded.");
528
529 let scheduler_config = if cache_config.is_some() {
530 if let Some(ref cache_config) = pipeline.blocking_lock().get_metadata().cache_config {
532 SchedulerConfig::PagedAttentionMeta {
533 max_num_seqs: *args.concurrency.as_ref().unwrap().iter().max().unwrap(),
534 config: cache_config.clone(),
535 }
536 } else {
537 SchedulerConfig::DefaultScheduler {
538 method: DefaultSchedulerMethod::Fixed(
539 (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
540 .try_into()
541 .unwrap(),
542 ),
543 }
544 }
545 } else {
546 SchedulerConfig::DefaultScheduler {
547 method: DefaultSchedulerMethod::Fixed(
548 (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
549 .try_into()
550 .unwrap(),
551 ),
552 }
553 };
554 let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config, false, None)
555 .with_no_prefix_cache(true)
556 .with_disable_eos_stop(true)
557 .build()
558 .await;
559
560 info!("Starting warmup run.");
561 warmup_run(mistralrs.clone());
562 info!("Finished warmup run.");
563 info!("Starting benchmarks.");
564
565 for concurrency in args.concurrency.as_ref().unwrap() {
566 let mut results = vec![];
567 if args.n_gen > 0 {
568 let r = run_bench(
569 mistralrs.clone(),
570 RequestMessage::Completion {
571 text: "Rust".to_string(),
572 echo_prompt: false,
573 best_of: None,
574 },
575 args.n_gen - 1,
576 *concurrency,
577 args.repetitions,
578 TestName::Gen(args.n_gen),
579 )?;
580 results.push(r);
581 }
582
583 if args.n_prompt > 0 {
584 let tks = (1000..1000 + args.n_prompt as u32).collect();
585 let r = run_bench(
586 mistralrs.clone(),
587 RequestMessage::CompletionTokens(tks),
588 1,
589 *concurrency,
590 args.repetitions,
591 TestName::Prompt(args.n_prompt),
592 )?;
593
594 results.push(r);
595 }
596
597 print_usage(&model_name, &device, results);
598 }
599
600 Ok(())
601}