1use candle_core::Device;
2use clap::Parser;
3use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
4use mistralrs_core::{
5 get_auto_device_map_params, get_model_dtype, initialize_logging, paged_attn_supported,
6 parse_isq_value, Constraint, DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata,
7 DeviceMapSetting, DrySamplingParams, IsqType, Loader, LoaderBuilder, MemoryGpuConfig,
8 MistralRs, MistralRsBuilder, ModelSelected, NormalRequest, PagedAttentionConfig, Request,
9 RequestMessage, Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
10};
11use std::sync::Arc;
12use std::{fmt::Display, num::NonZeroUsize};
13use tokio::sync::mpsc::channel;
14use tracing::{info, warn};
15
16enum TestName {
17 Prompt(usize),
18 Gen(usize),
19}
20
21impl Display for TestName {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 let name = match self {
24 TestName::Prompt(n) => format!("pp {}", n),
25 TestName::Gen(n) => format!("tg {}", n),
26 };
27 write!(f, "{}", name)
28 }
29}
30
31struct BenchResult {
32 usages: Vec<Usage>,
33 concurrency: usize,
34 test_name: TestName,
35}
36
37struct UncertainTokSec {
38 mean: f32,
39 std_dev: f32,
40}
41
42impl Display for UncertainTokSec {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "{:.3}±{:.3}", self.mean, self.std_dev)
45 }
46}
47
48fn run_bench(
49 mistralrs: Arc<MistralRs>,
50 prompt: RequestMessage,
51 n_gen: usize,
52 concurrency: usize,
53 repetitions: usize,
54 test_name: TestName,
55) -> anyhow::Result<BenchResult> {
56 let sampling_params = SamplingParams {
57 temperature: Some(0.1),
58 top_k: Some(32),
59 top_p: Some(0.1),
60 min_p: Some(0.05),
61 top_n_logprobs: 0,
62 frequency_penalty: Some(0.1),
63 presence_penalty: Some(0.1),
64 max_len: Some(n_gen),
65 stop_toks: None,
66 logits_bias: None,
67 n_choices: 1,
68 dry_params: Some(DrySamplingParams::default()),
69 };
70 let sender = mistralrs.get_sender().unwrap();
71 let (tx, mut rx) = channel(10_000);
72
73 let req = Request::Normal(NormalRequest {
74 id: mistralrs.next_request_id(),
75 messages: prompt,
76 sampling_params: sampling_params.clone(),
77 response: tx,
78 return_logprobs: false,
79 is_streaming: false,
80 constraint: Constraint::None,
81 suffix: None,
82 adapters: None,
83 tools: None,
84 tool_choice: None,
85 logits_processors: None,
86 return_raw_logits: false,
87 web_search_options: None,
88 });
89
90 let mut usages = Vec::new();
91
92 for _ in 0..repetitions {
93 for _ in 0..concurrency {
94 sender
95 .blocking_send(req.clone())
96 .expect("Expected receiver.");
97 }
98 for _ in 0..concurrency {
99 match rx.blocking_recv() {
100 Some(r) => match r {
101 Response::InternalError(e) => {
102 unreachable!("Got an internal error: {e:?}");
103 }
104 Response::ModelError(e, resp) => {
105 unreachable!("Got a model error: {e:?}, response: {resp:?}");
106 }
107 Response::ValidationError(e) => {
108 unreachable!("Got a validation error: {e:?}");
109 }
110 Response::Done(res) => {
111 usages.push(res.usage);
112 }
113 Response::Chunk(_) => unreachable!(),
114 Response::CompletionModelError(_, _) => unreachable!(),
115 Response::CompletionDone(res) => {
116 usages.push(res.usage);
117 }
118 Response::CompletionChunk(_) => unreachable!(),
119 Response::ImageGeneration(_) => unreachable!(),
120 Response::Raw { .. } => unreachable!(),
121 },
122 None => unreachable!("Expected a Done response, got None",),
123 }
124 }
125 }
126
127 Ok(BenchResult {
128 usages,
129 concurrency,
130 test_name,
131 })
132}
133
134fn get_tok_s(result: &BenchResult) -> UncertainTokSec {
135 let ts_measurements = match result.test_name {
136 TestName::Prompt(_) => result
137 .usages
138 .iter()
139 .map(|u| u.avg_prompt_tok_per_sec)
140 .collect::<Vec<_>>(),
141 TestName::Gen(_) => result
142 .usages
143 .iter()
144 .map(|u| u.avg_compl_tok_per_sec)
145 .collect::<Vec<_>>(),
146 };
147 let mean = ts_measurements.iter().sum::<f32>() / ts_measurements.len() as f32;
149 let variance = ts_measurements
150 .iter()
151 .map(|e| (mean - e).powf(2.))
152 .sum::<f32>()
153 / ts_measurements.len() as f32;
154 let std_dev = variance.sqrt();
155 UncertainTokSec { mean, std_dev }
156}
157
158fn get_ms_tok(result: &BenchResult) -> UncertainTokSec {
159 let ms_tok_measurements = match result.test_name {
160 TestName::Prompt(_) => result
161 .usages
162 .iter()
163 .map(|u| 1000. / u.avg_prompt_tok_per_sec)
164 .collect::<Vec<_>>(),
165 TestName::Gen(_) => result
166 .usages
167 .iter()
168 .map(|u| 1000. / u.avg_compl_tok_per_sec)
169 .collect::<Vec<_>>(),
170 };
171 let mean = ms_tok_measurements.iter().sum::<f32>() / ms_tok_measurements.len() as f32;
173 let variance = ms_tok_measurements
174 .iter()
175 .map(|e| (mean - e).powf(2.))
176 .sum::<f32>()
177 / ms_tok_measurements.len() as f32;
178 let std_dev = variance.sqrt();
179 UncertainTokSec { mean, std_dev }
180}
181
182fn print_usage(model: &str, device: &Device, results: Vec<BenchResult>) {
183 let backend = match device {
184 Device::Cpu => "CPU",
185 Device::Cuda(_) => "CUDA",
186 Device::Metal(_) => "Metal",
187 };
188 let results: Vec<Vec<CellStruct>> = results
189 .into_iter()
190 .map(|r| {
191 vec![
192 model.cell(),
193 backend.cell(),
194 r.test_name.to_string().cell(),
195 get_tok_s(&r).cell().justify(Justify::Right),
196 get_ms_tok(&r).cell().justify(Justify::Right),
197 r.concurrency.cell().justify(Justify::Right),
198 (get_tok_s(&r).mean * r.concurrency as f32)
199 .cell()
200 .justify(Justify::Right),
201 ]
202 })
203 .collect();
204
205 let table = results
206 .table()
207 .title(vec![
208 "model".cell().bold(true),
209 "backend".cell().bold(true),
212 "test".cell().bold(true),
214 "t/s".cell().bold(true),
215 "ms/t".cell().bold(true),
216 "concurrency".cell().bold(true),
217 "throughput/s".cell().bold(true),
218 ])
219 .bold(true);
220 print_stdout(table).expect("print table");
221}
222
223fn warmup_run(mistralrs: Arc<MistralRs>) {
224 let sampling_params = SamplingParams {
225 temperature: Some(0.1),
226 top_k: Some(32),
227 top_p: Some(0.1),
228 min_p: Some(0.05),
229 top_n_logprobs: 0,
230 frequency_penalty: Some(0.1),
231 presence_penalty: Some(0.1),
232 max_len: Some(5),
233 stop_toks: None,
234 logits_bias: None,
235 n_choices: 1,
236 dry_params: Some(DrySamplingParams::default()),
237 };
238 let sender = mistralrs.get_sender().unwrap();
239 let (tx, mut rx) = channel(10_000);
240
241 let req = Request::Normal(NormalRequest {
242 id: mistralrs.next_request_id(),
243 messages: RequestMessage::Completion {
244 text: "Hello!".to_string(),
245 echo_prompt: false,
246 best_of: None,
247 },
248 sampling_params: sampling_params.clone(),
249 response: tx,
250 return_logprobs: false,
251 is_streaming: false,
252 constraint: Constraint::None,
253 suffix: None,
254 adapters: None,
255 tools: None,
256 tool_choice: None,
257 logits_processors: None,
258 return_raw_logits: false,
259 web_search_options: None,
260 });
261
262 sender
263 .blocking_send(req.clone())
264 .expect("Expected receiver.");
265
266 let _ = rx.blocking_recv();
267}
268
269#[derive(Parser)]
270#[command(version, about, long_about = None)]
271struct Args {
272 #[clap(subcommand)]
274 model: ModelSelected,
275
276 #[arg(short, long)]
278 seed: Option<u64>,
279
280 #[arg(long, short = 'p', default_value_t = 512)]
282 n_prompt: usize,
283
284 #[arg(long, short = 'g', default_value_t = 128)]
286 n_gen: usize,
287
288 #[clap(short, long, value_parser, value_delimiter = ',')]
290 concurrency: Option<Vec<usize>>,
291
292 #[arg(long, short, default_value_t = 5)]
294 repetitions: usize,
295
296 #[arg(short, long, value_parser, value_delimiter = ';')]
301 num_device_layers: Option<Vec<String>>,
302
303 #[arg(long = "isq", value_parser = parse_isq_value)]
305 in_situ_quant: Option<IsqType>,
306
307 #[arg(long = "pa-gpu-mem")]
311 paged_attn_gpu_mem: Option<usize>,
312
313 #[arg(long = "pa-gpu-mem-usage")]
318 paged_attn_gpu_mem_usage: Option<f32>,
319
320 #[arg(long = "pa-ctxt-len")]
325 paged_ctxt_len: Option<usize>,
326
327 #[arg(long = "pa-blk-size")]
330 paged_attn_block_size: Option<usize>,
331
332 #[arg(long = "no-paged-attn", default_value_t = false)]
334 no_paged_attn: bool,
335
336 #[arg(long = "paged-attn", default_value_t = false)]
338 paged_attn: bool,
339
340 #[arg(long = "prompt-batchsize")]
342 prompt_chunksize: Option<usize>,
343}
344
345fn main() -> anyhow::Result<()> {
346 let mut args = Args::parse();
347 initialize_logging();
348
349 args.concurrency = Some(args.concurrency.unwrap_or(vec![1]));
350
351 let use_flash_attn = mistralrs_core::using_flash_attn();
352
353 let prompt_chunksize = match args.prompt_chunksize {
354 Some(0) => {
355 anyhow::bail!("`prompt_chunksize` must be a strictly positive integer, got 0.",)
356 }
357 Some(x) => Some(NonZeroUsize::new(x).unwrap()),
358 None => None,
359 };
360
361 let dtype = get_model_dtype(&args.model)?;
362 let auto_device_map_params = get_auto_device_map_params(&args.model)?;
363
364 let max_seq_len = auto_device_map_params.max_seq_len();
365
366 let loader: Box<dyn Loader> = LoaderBuilder::new(args.model)
367 .with_use_flash_attn(use_flash_attn)
368 .with_prompt_chunksize(prompt_chunksize)
369 .build()?;
370 let model_name = loader.get_id();
371
372 #[cfg(feature = "metal")]
373 let device = Device::new_metal(0)?;
374 #[cfg(not(feature = "metal"))]
375 let device = if mistralrs_core::distributed::use_nccl() {
376 Device::Cpu
377 } else {
378 Device::cuda_if_available(0)?
379 };
380
381 if let Some(seed) = args.seed {
382 device.set_seed(seed)?;
383 }
384
385 let token_source = TokenSource::CacheToken;
386 info!(
387 "avx: {}, neon: {}, simd128: {}, f16c: {}",
388 candle_core::utils::with_avx(),
389 candle_core::utils::with_neon(),
390 candle_core::utils::with_simd128(),
391 candle_core::utils::with_f16c()
392 );
393 info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
394 if use_flash_attn {
395 info!("Using flash attention.");
396 }
397 if use_flash_attn && loader.get_kind().is_quantized() {
398 warn!("Using flash attention with a quantized model has no effect!")
399 }
400 info!("Model kind is: {}", loader.get_kind().to_string());
401
402 let mapper = if let Some(device_layers) = args.num_device_layers {
404 if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
405 let layers = device_layers[0].parse::<usize>().unwrap();
406 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
407 DeviceLayerMapMetadata { ordinal: 0, layers },
408 ]))
409 } else {
410 let mut mapping = Vec::new();
411 for layer in device_layers {
412 let split = layer.splitn(2, ':').collect::<Vec<_>>();
413 if split.len() < 2 {
414 panic!("Expected layer to be of format ORD:NUM, got {layer}");
415 }
416 let ord = split[0]
417 .parse::<usize>()
418 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
419 let num = split[1]
420 .parse::<usize>()
421 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
422 for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
423 if *ordinal == ord {
424 panic!("Duplicate ordinal {ord}");
425 }
426 }
427 mapping.push(DeviceLayerMapMetadata {
428 ordinal: ord,
429 layers: num,
430 });
431 }
432 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
433 }
434 } else {
435 DeviceMapSetting::Auto(auto_device_map_params)
436 };
437
438 let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
439 args.no_paged_attn
440 } else if device.is_metal() {
441 !args.paged_attn
442 } else {
443 true
444 };
445
446 let cache_config = match (
449 args.paged_attn_block_size,
450 args.paged_attn_gpu_mem,
451 args.paged_attn_gpu_mem_usage,
452 args.paged_ctxt_len,
453 paged_attn_supported(),
454 no_paged_attn,
455 ) {
456 (block_size, None, None, None, true, false) => Some(PagedAttentionConfig::new(
457 block_size,
458 512,
459 MemoryGpuConfig::ContextSize(max_seq_len),
460 )?),
461 (block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
462 block_size,
463 512,
464 MemoryGpuConfig::ContextSize(ctxt),
465 )?),
466 (block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
467 block_size,
468 512,
469 MemoryGpuConfig::Utilization(f),
470 )?),
471 (block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
472 block_size,
473 512,
474 MemoryGpuConfig::MbAmount(m),
475 )?),
476 (block_size, Some(_m), Some(f), None, true, false) => {
477 info!("Both memory size, and usage were specified, defaulting to the usage value.");
478 Some(PagedAttentionConfig::new(
479 block_size,
480 512,
481 MemoryGpuConfig::Utilization(f),
482 )?)
483 }
484 (block_size, Some(_m), None, Some(ctxt), true, false) => {
485 info!("All memory size and ctxt len, defaulting to the context len value.");
486 Some(PagedAttentionConfig::new(
487 block_size,
488 512,
489 MemoryGpuConfig::ContextSize(ctxt),
490 )?)
491 }
492 (block_size, None, Some(f), Some(_ctxt), true, false) => {
493 info!("Both ctxt len and usage were specified, defaulting to the usage value.");
494 Some(PagedAttentionConfig::new(
495 block_size,
496 512,
497 MemoryGpuConfig::Utilization(f),
498 )?)
499 }
500 (_, _, _, _, _, _) => None,
501 };
502
503 let pipeline = loader.load_model_from_hf(
504 None,
505 token_source,
506 &dtype,
507 &device,
508 false,
509 mapper,
510 args.in_situ_quant,
511 cache_config,
512 )?;
513 info!("Model loaded.");
514
515 let scheduler_config = if cache_config.is_some() {
516 if let Some(ref cache_config) = pipeline.blocking_lock().get_metadata().cache_config {
518 SchedulerConfig::PagedAttentionMeta {
519 max_num_seqs: *args.concurrency.as_ref().unwrap().iter().max().unwrap(),
520 config: cache_config.clone(),
521 }
522 } else {
523 SchedulerConfig::DefaultScheduler {
524 method: DefaultSchedulerMethod::Fixed(
525 (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
526 .try_into()
527 .unwrap(),
528 ),
529 }
530 }
531 } else {
532 SchedulerConfig::DefaultScheduler {
533 method: DefaultSchedulerMethod::Fixed(
534 (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
535 .try_into()
536 .unwrap(),
537 ),
538 }
539 };
540 let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config, false, None)
541 .with_no_prefix_cache(true)
542 .with_disable_eos_stop(true)
543 .build();
544
545 info!("Starting warmup run.");
546 warmup_run(mistralrs.clone());
547 info!("Finished warmup run.");
548 info!("Starting benchmarks.");
549
550 for concurrency in args.concurrency.as_ref().unwrap() {
551 let mut results = vec![];
552 if args.n_gen > 0 {
553 let r = run_bench(
554 mistralrs.clone(),
555 RequestMessage::Completion {
556 text: "Rust".to_string(),
557 echo_prompt: false,
558 best_of: None,
559 },
560 args.n_gen - 1,
561 *concurrency,
562 args.repetitions,
563 TestName::Gen(args.n_gen),
564 )?;
565 results.push(r);
566 }
567
568 if args.n_prompt > 0 {
569 let tks = (1000..1000 + args.n_prompt as u32).collect();
570 let r = run_bench(
571 mistralrs.clone(),
572 RequestMessage::CompletionTokens(tks),
573 1,
574 *concurrency,
575 args.repetitions,
576 TestName::Prompt(args.n_prompt),
577 )?;
578
579 results.push(r);
580 }
581
582 print_usage(&model_name, &device, results);
583 }
584
585 Ok(())
586}