1use candle_core::Device;
2use clap::Parser;
3use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
4use mistralrs_core::{
5 get_auto_device_map_params, get_model_dtype, initialize_logging, paged_attn_supported,
6 parse_isq_value, Constraint, DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata,
7 DeviceMapSetting, DrySamplingParams, IsqType, Loader, LoaderBuilder, MemoryGpuConfig,
8 MistralRs, MistralRsBuilder, ModelSelected, NormalRequest, PagedAttentionConfig, Request,
9 RequestMessage, Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
10};
11use std::sync::Arc;
12use std::{fmt::Display, num::NonZeroUsize};
13use tokio::sync::mpsc::channel;
14use tracing::{info, warn};
15
16enum TestName {
17 Prompt(usize),
18 Gen(usize),
19}
20
21impl Display for TestName {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 let name = match self {
24 TestName::Prompt(n) => format!("pp {}", n),
25 TestName::Gen(n) => format!("tg {}", n),
26 };
27 write!(f, "{}", name)
28 }
29}
30
31struct BenchResult {
32 usages: Vec<Usage>,
33 concurrency: usize,
34 test_name: TestName,
35}
36
37struct UncertainTokSec {
38 mean: f32,
39 std_dev: f32,
40}
41
42impl Display for UncertainTokSec {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "{:.3}±{:.3}", self.mean, self.std_dev)
45 }
46}
47
48fn run_bench(
49 mistralrs: Arc<MistralRs>,
50 prompt: RequestMessage,
51 n_gen: usize,
52 concurrency: usize,
53 repetitions: usize,
54 test_name: TestName,
55) -> anyhow::Result<BenchResult> {
56 let sampling_params = SamplingParams {
57 temperature: Some(0.1),
58 top_k: Some(32),
59 top_p: Some(0.1),
60 min_p: Some(0.05),
61 top_n_logprobs: 0,
62 frequency_penalty: Some(0.1),
63 presence_penalty: Some(0.1),
64 max_len: Some(n_gen),
65 stop_toks: None,
66 logits_bias: None,
67 n_choices: 1,
68 dry_params: Some(DrySamplingParams::default()),
69 };
70 let sender = mistralrs.get_sender().unwrap();
71 let (tx, mut rx) = channel(10_000);
72
73 let req = Request::Normal(NormalRequest {
74 id: mistralrs.next_request_id(),
75 messages: prompt,
76 sampling_params: sampling_params.clone(),
77 response: tx,
78 return_logprobs: false,
79 is_streaming: false,
80 constraint: Constraint::None,
81 suffix: None,
82 tools: None,
83 tool_choice: None,
84 logits_processors: None,
85 return_raw_logits: false,
86 web_search_options: None,
87 });
88
89 let mut usages = Vec::new();
90
91 for _ in 0..repetitions {
92 for _ in 0..concurrency {
93 sender
94 .blocking_send(req.clone())
95 .expect("Expected receiver.");
96 }
97 for _ in 0..concurrency {
98 match rx.blocking_recv() {
99 Some(r) => match r {
100 Response::InternalError(e) => {
101 unreachable!("Got an internal error: {e:?}");
102 }
103 Response::ModelError(e, resp) => {
104 unreachable!("Got a model error: {e:?}, response: {resp:?}");
105 }
106 Response::ValidationError(e) => {
107 unreachable!("Got a validation error: {e:?}");
108 }
109 Response::Done(res) => {
110 usages.push(res.usage);
111 }
112 Response::Chunk(_) => unreachable!(),
113 Response::CompletionModelError(_, _) => unreachable!(),
114 Response::CompletionDone(res) => {
115 usages.push(res.usage);
116 }
117 Response::CompletionChunk(_) => unreachable!(),
118 Response::ImageGeneration(_) => unreachable!(),
119 Response::Raw { .. } => unreachable!(),
120 },
121 None => unreachable!("Expected a Done response, got None",),
122 }
123 }
124 }
125
126 Ok(BenchResult {
127 usages,
128 concurrency,
129 test_name,
130 })
131}
132
133fn get_tok_s(result: &BenchResult) -> UncertainTokSec {
134 let ts_measurements = match result.test_name {
135 TestName::Prompt(_) => result
136 .usages
137 .iter()
138 .map(|u| u.avg_prompt_tok_per_sec)
139 .collect::<Vec<_>>(),
140 TestName::Gen(_) => result
141 .usages
142 .iter()
143 .map(|u| u.avg_compl_tok_per_sec)
144 .collect::<Vec<_>>(),
145 };
146 let mean = ts_measurements.iter().sum::<f32>() / ts_measurements.len() as f32;
148 let variance = ts_measurements
149 .iter()
150 .map(|e| (mean - e).powf(2.))
151 .sum::<f32>()
152 / ts_measurements.len() as f32;
153 let std_dev = variance.sqrt();
154 UncertainTokSec { mean, std_dev }
155}
156
157fn get_ms_tok(result: &BenchResult) -> UncertainTokSec {
158 let ms_tok_measurements = match result.test_name {
159 TestName::Prompt(_) => result
160 .usages
161 .iter()
162 .map(|u| 1000. / u.avg_prompt_tok_per_sec)
163 .collect::<Vec<_>>(),
164 TestName::Gen(_) => result
165 .usages
166 .iter()
167 .map(|u| 1000. / u.avg_compl_tok_per_sec)
168 .collect::<Vec<_>>(),
169 };
170 let mean = ms_tok_measurements.iter().sum::<f32>() / ms_tok_measurements.len() as f32;
172 let variance = ms_tok_measurements
173 .iter()
174 .map(|e| (mean - e).powf(2.))
175 .sum::<f32>()
176 / ms_tok_measurements.len() as f32;
177 let std_dev = variance.sqrt();
178 UncertainTokSec { mean, std_dev }
179}
180
181fn print_usage(model: &str, device: &Device, results: Vec<BenchResult>) {
182 let backend = match device {
183 Device::Cpu => "CPU",
184 Device::Cuda(_) => "CUDA",
185 Device::Metal(_) => "Metal",
186 };
187 let results: Vec<Vec<CellStruct>> = results
188 .into_iter()
189 .map(|r| {
190 vec![
191 model.cell(),
192 backend.cell(),
193 r.test_name.to_string().cell(),
194 get_tok_s(&r).cell().justify(Justify::Right),
195 get_ms_tok(&r).cell().justify(Justify::Right),
196 r.concurrency.cell().justify(Justify::Right),
197 (get_tok_s(&r).mean * r.concurrency as f32)
198 .cell()
199 .justify(Justify::Right),
200 ]
201 })
202 .collect();
203
204 let table = results
205 .table()
206 .title(vec![
207 "model".cell().bold(true),
208 "backend".cell().bold(true),
211 "test".cell().bold(true),
213 "t/s".cell().bold(true),
214 "ms/t".cell().bold(true),
215 "concurrency".cell().bold(true),
216 "throughput/s".cell().bold(true),
217 ])
218 .bold(true);
219 print_stdout(table).expect("print table");
220}
221
222fn warmup_run(mistralrs: Arc<MistralRs>) {
223 let sampling_params = SamplingParams {
224 temperature: Some(0.1),
225 top_k: Some(32),
226 top_p: Some(0.1),
227 min_p: Some(0.05),
228 top_n_logprobs: 0,
229 frequency_penalty: Some(0.1),
230 presence_penalty: Some(0.1),
231 max_len: Some(5),
232 stop_toks: None,
233 logits_bias: None,
234 n_choices: 1,
235 dry_params: Some(DrySamplingParams::default()),
236 };
237 let sender = mistralrs.get_sender().unwrap();
238 let (tx, mut rx) = channel(10_000);
239
240 let req = Request::Normal(NormalRequest {
241 id: mistralrs.next_request_id(),
242 messages: RequestMessage::Completion {
243 text: "Hello!".to_string(),
244 echo_prompt: false,
245 best_of: None,
246 },
247 sampling_params: sampling_params.clone(),
248 response: tx,
249 return_logprobs: false,
250 is_streaming: false,
251 constraint: Constraint::None,
252 suffix: None,
253 tools: None,
254 tool_choice: None,
255 logits_processors: None,
256 return_raw_logits: false,
257 web_search_options: None,
258 });
259
260 sender
261 .blocking_send(req.clone())
262 .expect("Expected receiver.");
263
264 let _ = rx.blocking_recv();
265}
266
267#[derive(Parser)]
268#[command(version, about, long_about = None)]
269struct Args {
270 #[clap(subcommand)]
272 model: ModelSelected,
273
274 #[arg(short, long)]
276 seed: Option<u64>,
277
278 #[arg(long, short = 'p', default_value_t = 512)]
280 n_prompt: usize,
281
282 #[arg(long, short = 'g', default_value_t = 128)]
284 n_gen: usize,
285
286 #[clap(short, long, value_parser, value_delimiter = ',')]
288 concurrency: Option<Vec<usize>>,
289
290 #[arg(long, short, default_value_t = 5)]
292 repetitions: usize,
293
294 #[arg(short, long, value_parser, value_delimiter = ';')]
299 num_device_layers: Option<Vec<String>>,
300
301 #[arg(long = "isq", value_parser = parse_isq_value)]
303 in_situ_quant: Option<IsqType>,
304
305 #[arg(long = "pa-gpu-mem")]
309 paged_attn_gpu_mem: Option<usize>,
310
311 #[arg(long = "pa-gpu-mem-usage")]
316 paged_attn_gpu_mem_usage: Option<f32>,
317
318 #[arg(long = "pa-ctxt-len")]
323 paged_ctxt_len: Option<usize>,
324
325 #[arg(long = "pa-blk-size")]
328 paged_attn_block_size: Option<usize>,
329
330 #[arg(long = "no-paged-attn", default_value_t = false)]
332 no_paged_attn: bool,
333
334 #[arg(long = "paged-attn", default_value_t = false)]
336 paged_attn: bool,
337
338 #[arg(long = "prompt-batchsize")]
340 prompt_chunksize: Option<usize>,
341}
342
343fn main() -> anyhow::Result<()> {
344 let mut args = Args::parse();
345 initialize_logging();
346
347 args.concurrency = Some(args.concurrency.unwrap_or(vec![1]));
348
349 let use_flash_attn = mistralrs_core::using_flash_attn();
350
351 let prompt_chunksize = match args.prompt_chunksize {
352 Some(0) => {
353 anyhow::bail!("`prompt_chunksize` must be a strictly positive integer, got 0.",)
354 }
355 Some(x) => Some(NonZeroUsize::new(x).unwrap()),
356 None => None,
357 };
358
359 let dtype = get_model_dtype(&args.model)?;
360 let auto_device_map_params = get_auto_device_map_params(&args.model)?;
361
362 let max_seq_len = auto_device_map_params.max_seq_len();
363
364 let loader: Box<dyn Loader> = LoaderBuilder::new(args.model)
365 .with_use_flash_attn(use_flash_attn)
366 .with_prompt_chunksize(prompt_chunksize)
367 .build()?;
368 let model_name = loader.get_id();
369
370 #[cfg(feature = "metal")]
371 let device = Device::new_metal(0)?;
372 #[cfg(not(feature = "metal"))]
373 let device = if mistralrs_core::distributed::use_nccl() {
374 Device::Cpu
375 } else {
376 Device::cuda_if_available(0)?
377 };
378
379 if let Some(seed) = args.seed {
380 device.set_seed(seed)?;
381 }
382
383 let token_source = TokenSource::CacheToken;
384 info!(
385 "avx: {}, neon: {}, simd128: {}, f16c: {}",
386 candle_core::utils::with_avx(),
387 candle_core::utils::with_neon(),
388 candle_core::utils::with_simd128(),
389 candle_core::utils::with_f16c()
390 );
391 info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
392 if use_flash_attn {
393 info!("Using flash attention.");
394 }
395 if use_flash_attn && loader.get_kind().is_quantized() {
396 warn!("Using flash attention with a quantized model has no effect!")
397 }
398 info!("Model kind is: {}", loader.get_kind().to_string());
399
400 let mapper = if let Some(device_layers) = args.num_device_layers {
402 if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
403 let layers = device_layers[0].parse::<usize>().unwrap();
404 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
405 DeviceLayerMapMetadata { ordinal: 0, layers },
406 ]))
407 } else {
408 let mut mapping = Vec::new();
409 for layer in device_layers {
410 let split = layer.splitn(2, ':').collect::<Vec<_>>();
411 if split.len() < 2 {
412 panic!("Expected layer to be of format ORD:NUM, got {layer}");
413 }
414 let ord = split[0]
415 .parse::<usize>()
416 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
417 let num = split[1]
418 .parse::<usize>()
419 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
420 for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
421 if *ordinal == ord {
422 panic!("Duplicate ordinal {ord}");
423 }
424 }
425 mapping.push(DeviceLayerMapMetadata {
426 ordinal: ord,
427 layers: num,
428 });
429 }
430 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
431 }
432 } else {
433 DeviceMapSetting::Auto(auto_device_map_params)
434 };
435
436 let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
437 args.no_paged_attn
438 } else if device.is_metal() {
439 !args.paged_attn
440 } else {
441 true
442 };
443
444 let cache_config = match (
447 args.paged_attn_block_size,
448 args.paged_attn_gpu_mem,
449 args.paged_attn_gpu_mem_usage,
450 args.paged_ctxt_len,
451 paged_attn_supported(),
452 no_paged_attn,
453 ) {
454 (block_size, None, None, None, true, false) => Some(PagedAttentionConfig::new(
455 block_size,
456 512,
457 MemoryGpuConfig::ContextSize(max_seq_len),
458 )?),
459 (block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
460 block_size,
461 512,
462 MemoryGpuConfig::ContextSize(ctxt),
463 )?),
464 (block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
465 block_size,
466 512,
467 MemoryGpuConfig::Utilization(f),
468 )?),
469 (block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
470 block_size,
471 512,
472 MemoryGpuConfig::MbAmount(m),
473 )?),
474 (block_size, Some(_m), Some(f), None, true, false) => {
475 info!("Both memory size, and usage were specified, defaulting to the usage value.");
476 Some(PagedAttentionConfig::new(
477 block_size,
478 512,
479 MemoryGpuConfig::Utilization(f),
480 )?)
481 }
482 (block_size, Some(_m), None, Some(ctxt), true, false) => {
483 info!("All memory size and ctxt len, defaulting to the context len value.");
484 Some(PagedAttentionConfig::new(
485 block_size,
486 512,
487 MemoryGpuConfig::ContextSize(ctxt),
488 )?)
489 }
490 (block_size, None, Some(f), Some(_ctxt), true, false) => {
491 info!("Both ctxt len and usage were specified, defaulting to the usage value.");
492 Some(PagedAttentionConfig::new(
493 block_size,
494 512,
495 MemoryGpuConfig::Utilization(f),
496 )?)
497 }
498 (_, _, _, _, _, _) => None,
499 };
500
501 let pipeline = loader.load_model_from_hf(
502 None,
503 token_source,
504 &dtype,
505 &device,
506 false,
507 mapper,
508 args.in_situ_quant,
509 cache_config,
510 )?;
511 info!("Model loaded.");
512
513 let scheduler_config = if cache_config.is_some() {
514 if let Some(ref cache_config) = pipeline.blocking_lock().get_metadata().cache_config {
516 SchedulerConfig::PagedAttentionMeta {
517 max_num_seqs: *args.concurrency.as_ref().unwrap().iter().max().unwrap(),
518 config: cache_config.clone(),
519 }
520 } else {
521 SchedulerConfig::DefaultScheduler {
522 method: DefaultSchedulerMethod::Fixed(
523 (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
524 .try_into()
525 .unwrap(),
526 ),
527 }
528 }
529 } else {
530 SchedulerConfig::DefaultScheduler {
531 method: DefaultSchedulerMethod::Fixed(
532 (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
533 .try_into()
534 .unwrap(),
535 ),
536 }
537 };
538 let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config, false, None)
539 .with_no_prefix_cache(true)
540 .with_disable_eos_stop(true)
541 .build();
542
543 info!("Starting warmup run.");
544 warmup_run(mistralrs.clone());
545 info!("Finished warmup run.");
546 info!("Starting benchmarks.");
547
548 for concurrency in args.concurrency.as_ref().unwrap() {
549 let mut results = vec![];
550 if args.n_gen > 0 {
551 let r = run_bench(
552 mistralrs.clone(),
553 RequestMessage::Completion {
554 text: "Rust".to_string(),
555 echo_prompt: false,
556 best_of: None,
557 },
558 args.n_gen - 1,
559 *concurrency,
560 args.repetitions,
561 TestName::Gen(args.n_gen),
562 )?;
563 results.push(r);
564 }
565
566 if args.n_prompt > 0 {
567 let tks = (1000..1000 + args.n_prompt as u32).collect();
568 let r = run_bench(
569 mistralrs.clone(),
570 RequestMessage::CompletionTokens(tks),
571 1,
572 *concurrency,
573 args.repetitions,
574 TestName::Prompt(args.n_prompt),
575 )?;
576
577 results.push(r);
578 }
579
580 print_usage(&model_name, &device, results);
581 }
582
583 Ok(())
584}