1use candle_core::Device;
2use clap::Parser;
3use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
4use mistralrs_core::{
5 get_auto_device_map_params, get_model_dtype, initialize_logging, paged_attn_supported,
6 parse_isq_value, Constraint, DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata,
7 DeviceMapSetting, DrySamplingParams, Loader, LoaderBuilder, MemoryGpuConfig, MistralRs,
8 MistralRsBuilder, ModelSelected, NormalRequest, PagedAttentionConfig, PagedCacheType, Request,
9 RequestMessage, Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
10};
11use std::fmt::Display;
12use std::sync::Arc;
13use tokio::sync::mpsc::channel;
14use tracing::info;
15
16enum TestName {
17 Prompt(usize),
18 Gen(usize),
19}
20
21impl Display for TestName {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 let name = match self {
24 TestName::Prompt(n) => format!("pp {n}"),
25 TestName::Gen(n) => format!("tg {n}"),
26 };
27 write!(f, "{name}")
28 }
29}
30
31struct BenchResult {
32 usages: Vec<Usage>,
33 concurrency: usize,
34 test_name: TestName,
35}
36
37struct UncertainTokSec {
38 mean: f32,
39 std_dev: f32,
40}
41
42impl Display for UncertainTokSec {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "{:.3}±{:.3}", self.mean, self.std_dev)
45 }
46}
47
48async fn run_bench(
49 mistralrs: Arc<MistralRs>,
50 prompt: RequestMessage,
51 n_gen: usize,
52 concurrency: usize,
53 repetitions: usize,
54 test_name: TestName,
55) -> anyhow::Result<BenchResult> {
56 let sampling_params = SamplingParams {
57 temperature: Some(0.1),
58 top_k: Some(32),
59 top_p: Some(0.1),
60 min_p: Some(0.05),
61 top_n_logprobs: 0,
62 frequency_penalty: Some(0.1),
63 presence_penalty: Some(0.1),
64 repetition_penalty: None,
65 max_len: Some(n_gen),
66 stop_toks: None,
67 logits_bias: None,
68 n_choices: 1,
69 dry_params: Some(DrySamplingParams::default()),
70 };
71 let sender = mistralrs.get_sender(None).unwrap();
72 let (tx, mut rx) = channel(10_000);
73
74 let req = Request::Normal(Box::new(NormalRequest {
75 id: mistralrs.next_request_id(),
76 messages: prompt,
77 sampling_params: sampling_params.clone(),
78 response: tx,
79 return_logprobs: false,
80 is_streaming: false,
81 constraint: Constraint::None,
82 suffix: None,
83 tools: None,
84 tool_choice: None,
85 logits_processors: None,
86 return_raw_logits: false,
87 web_search_options: None,
88 model_id: None,
89 truncate_sequence: false,
90 }));
91
92 let mut usages = Vec::new();
93
94 for _ in 0..repetitions {
95 for _ in 0..concurrency {
96 if sender.send(req.clone()).await.is_err() {
97 eprintln!("Receiver disconnected");
98 }
99 }
100 for _ in 0..concurrency {
101 match rx.recv().await {
102 Some(r) => match r {
103 Response::InternalError(e) => {
104 unreachable!("Got an internal error: {e:?}");
105 }
106 Response::ModelError(e, resp) => {
107 unreachable!("Got a model error: {e:?}, response: {resp:?}");
108 }
109 Response::ValidationError(e) => {
110 unreachable!("Got a validation error: {e:?}");
111 }
112 Response::Done(res) => {
113 usages.push(res.usage);
114 }
115 Response::Chunk(_) => unreachable!(),
116 Response::CompletionModelError(_, _) => unreachable!(),
117 Response::CompletionDone(res) => {
118 usages.push(res.usage);
119 }
120 Response::CompletionChunk(_) => unreachable!(),
121 Response::ImageGeneration(_) => unreachable!(),
122 Response::Speech { .. } => unreachable!(),
123 Response::Raw { .. } => unreachable!(),
124 Response::Embeddings { .. } => unreachable!(),
125 },
126 None => unreachable!("Expected a Done response, got None",),
127 }
128 }
129 }
130
131 Ok(BenchResult {
132 usages,
133 concurrency,
134 test_name,
135 })
136}
137
138fn get_tok_s(result: &BenchResult) -> UncertainTokSec {
139 let ts_measurements = match result.test_name {
140 TestName::Prompt(_) => result
141 .usages
142 .iter()
143 .map(|u| u.avg_prompt_tok_per_sec)
144 .collect::<Vec<_>>(),
145 TestName::Gen(_) => result
146 .usages
147 .iter()
148 .map(|u| u.avg_compl_tok_per_sec)
149 .collect::<Vec<_>>(),
150 };
151 let mean = ts_measurements.iter().sum::<f32>() / ts_measurements.len() as f32;
153 let variance = ts_measurements
154 .iter()
155 .map(|e| (mean - e).powf(2.))
156 .sum::<f32>()
157 / ts_measurements.len() as f32;
158 let std_dev = variance.sqrt();
159 UncertainTokSec { mean, std_dev }
160}
161
162fn get_ms_tok(result: &BenchResult) -> UncertainTokSec {
163 let ms_tok_measurements = match result.test_name {
164 TestName::Prompt(_) => result
165 .usages
166 .iter()
167 .map(|u| 1000. / u.avg_prompt_tok_per_sec)
168 .collect::<Vec<_>>(),
169 TestName::Gen(_) => result
170 .usages
171 .iter()
172 .map(|u| 1000. / u.avg_compl_tok_per_sec)
173 .collect::<Vec<_>>(),
174 };
175 let mean = ms_tok_measurements.iter().sum::<f32>() / ms_tok_measurements.len() as f32;
177 let variance = ms_tok_measurements
178 .iter()
179 .map(|e| (mean - e).powf(2.))
180 .sum::<f32>()
181 / ms_tok_measurements.len() as f32;
182 let std_dev = variance.sqrt();
183 UncertainTokSec { mean, std_dev }
184}
185
186fn print_usage(model: &str, device: &Device, results: Vec<BenchResult>) {
187 let backend = match device {
188 Device::Cpu => "CPU",
189 Device::Cuda(_) => "CUDA",
190 Device::Metal(_) => "Metal",
191 };
192 let results: Vec<Vec<CellStruct>> = results
193 .into_iter()
194 .map(|r| {
195 vec![
196 model.cell(),
197 backend.cell(),
198 r.test_name.to_string().cell(),
199 get_tok_s(&r).cell().justify(Justify::Right),
200 get_ms_tok(&r).cell().justify(Justify::Right),
201 r.concurrency.cell().justify(Justify::Right),
202 (get_tok_s(&r).mean * r.concurrency as f32)
203 .cell()
204 .justify(Justify::Right),
205 ]
206 })
207 .collect();
208
209 let table = results
210 .table()
211 .title(vec![
212 "model".cell().bold(true),
213 "backend".cell().bold(true),
216 "test".cell().bold(true),
218 "t/s".cell().bold(true),
219 "ms/t".cell().bold(true),
220 "concurrency".cell().bold(true),
221 "throughput/s".cell().bold(true),
222 ])
223 .bold(true);
224 print_stdout(table).expect("print table");
225}
226
227async fn warmup_run(mistralrs: Arc<MistralRs>) {
228 let sampling_params = SamplingParams {
229 max_len: Some(1),
230 ..SamplingParams::deterministic()
231 };
232 let sender = mistralrs.get_sender(None).unwrap();
233 let (tx, mut rx) = channel(10_000);
234
235 let req = Request::Normal(Box::new(NormalRequest {
236 id: mistralrs.next_request_id(),
237 messages: RequestMessage::Completion {
238 text: "Hello!".to_string(),
239 echo_prompt: false,
240 best_of: None,
241 },
242 sampling_params: sampling_params.clone(),
243 response: tx,
244 return_logprobs: false,
245 is_streaming: false,
246 constraint: Constraint::None,
247 suffix: None,
248 tools: None,
249 tool_choice: None,
250 logits_processors: None,
251 return_raw_logits: false,
252 web_search_options: None,
253 model_id: None,
254 truncate_sequence: false,
255 }));
256
257 if sender.send(req.clone()).await.is_err() {
258 eprintln!("Receiver disconnected");
259 }
260
261 let _ = rx.recv().await;
262}
263
264fn parse_cache_type(s: &str) -> Result<PagedCacheType, String> {
265 s.parse()
266}
267
268#[derive(Parser)]
269#[command(version, about, long_about = None)]
270struct Args {
271 #[clap(subcommand)]
273 model: ModelSelected,
274
275 #[arg(short, long)]
277 seed: Option<u64>,
278
279 #[arg(long, short = 'p', default_value_t = 512)]
281 n_prompt: usize,
282
283 #[arg(long, short = 'g', default_value_t = 128)]
285 n_gen: usize,
286
287 #[clap(short, long, value_parser, value_delimiter = ',')]
289 concurrency: Option<Vec<usize>>,
290
291 #[arg(long, short, default_value_t = 5)]
293 repetitions: usize,
294
295 #[arg(short, long, value_parser, value_delimiter = ';')]
300 num_device_layers: Option<Vec<String>>,
301
302 #[arg(long = "isq")]
304 in_situ_quant: Option<String>,
305
306 #[arg(long = "pa-gpu-mem")]
310 paged_attn_gpu_mem: Option<usize>,
311
312 #[arg(long = "pa-gpu-mem-usage")]
317 paged_attn_gpu_mem_usage: Option<f32>,
318
319 #[arg(long = "pa-ctxt-len")]
324 paged_ctxt_len: Option<usize>,
325
326 #[arg(long = "pa-cache-type", value_parser = parse_cache_type)]
329 cache_type: Option<PagedCacheType>,
330
331 #[arg(long = "pa-blk-size")]
334 paged_attn_block_size: Option<usize>,
335
336 #[arg(long = "no-paged-attn", default_value_t = false)]
338 no_paged_attn: bool,
339
340 #[arg(long = "paged-attn", default_value_t = false)]
342 paged_attn: bool,
343}
344
345#[tokio::main]
346async fn main() -> anyhow::Result<()> {
347 let mut args = Args::parse();
348 initialize_logging();
349
350 args.concurrency = Some(args.concurrency.unwrap_or(vec![1]));
351
352 let dtype = get_model_dtype(&args.model)?;
353 let auto_device_map_params = get_auto_device_map_params(&args.model)?;
354
355 let max_seq_len = auto_device_map_params.max_seq_len();
356
357 let loader: Box<dyn Loader> = LoaderBuilder::new(args.model).build()?;
358 let model_name = loader.get_id();
359
360 #[cfg(feature = "metal")]
361 let device = Device::new_metal(0)?;
362 #[cfg(not(feature = "metal"))]
363 let device = if mistralrs_core::distributed::use_nccl() {
364 Device::Cpu
365 } else {
366 Device::cuda_if_available(0)?
367 };
368
369 if let Some(seed) = args.seed {
370 device.set_seed(seed)?;
371 }
372
373 let token_source = TokenSource::CacheToken;
374 info!(
375 "avx: {}, neon: {}, simd128: {}, f16c: {}",
376 candle_core::utils::with_avx(),
377 candle_core::utils::with_neon(),
378 candle_core::utils::with_simd128(),
379 candle_core::utils::with_f16c()
380 );
381 info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
382 info!("Model kind is: {}", loader.get_kind().to_string());
383
384 let mapper = if let Some(device_layers) = args.num_device_layers {
386 if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
387 let layers = device_layers[0].parse::<usize>().unwrap();
388 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
389 DeviceLayerMapMetadata { ordinal: 0, layers },
390 ]))
391 } else {
392 let mut mapping = Vec::new();
393 for layer in device_layers {
394 let split = layer.splitn(2, ':').collect::<Vec<_>>();
395 if split.len() < 2 {
396 panic!("Expected layer to be of format ORD:NUM, got {layer}");
397 }
398 let ord = split[0]
399 .parse::<usize>()
400 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
401 let num = split[1]
402 .parse::<usize>()
403 .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
404 for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
405 if *ordinal == ord {
406 panic!("Duplicate ordinal {ord}");
407 }
408 }
409 mapping.push(DeviceLayerMapMetadata {
410 ordinal: ord,
411 layers: num,
412 });
413 }
414 DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
415 }
416 } else {
417 DeviceMapSetting::Auto(auto_device_map_params)
418 };
419
420 let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
421 args.no_paged_attn
422 } else if device.is_metal() {
423 !args.paged_attn
424 } else {
425 true
426 };
427
428 let cache_config = match (
429 args.paged_attn_block_size,
430 args.paged_attn_gpu_mem,
431 args.paged_attn_gpu_mem_usage,
432 args.paged_ctxt_len,
433 paged_attn_supported(),
434 no_paged_attn,
435 ) {
436 (block_size, None, None, None, true, false) => Some(PagedAttentionConfig::new(
437 block_size,
438 MemoryGpuConfig::ContextSize(max_seq_len),
439 args.cache_type.unwrap_or_default(),
440 )?),
441 (block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
442 block_size,
443 MemoryGpuConfig::ContextSize(ctxt),
444 args.cache_type.unwrap_or_default(),
445 )?),
446 (block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
447 block_size,
448 MemoryGpuConfig::Utilization(f),
449 args.cache_type.unwrap_or_default(),
450 )?),
451 (block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
452 block_size,
453 MemoryGpuConfig::MbAmount(m),
454 args.cache_type.unwrap_or_default(),
455 )?),
456 (block_size, Some(_m), Some(f), None, true, false) => {
457 info!("Both memory size, and usage were specified, defaulting to the usage value.");
458 Some(PagedAttentionConfig::new(
459 block_size,
460 MemoryGpuConfig::Utilization(f),
461 args.cache_type.unwrap_or_default(),
462 )?)
463 }
464 (block_size, Some(_m), None, Some(ctxt), true, false) => {
465 info!("All memory size and ctxt len, defaulting to the context len value.");
466 Some(PagedAttentionConfig::new(
467 block_size,
468 MemoryGpuConfig::ContextSize(ctxt),
469 args.cache_type.unwrap_or_default(),
470 )?)
471 }
472 (block_size, None, Some(f), Some(_ctxt), true, false) => {
473 info!("Both ctxt len and usage were specified, defaulting to the usage value.");
474 Some(PagedAttentionConfig::new(
475 block_size,
476 MemoryGpuConfig::Utilization(f),
477 args.cache_type.unwrap_or_default(),
478 )?)
479 }
480 (_, _, _, _, _, _) => None,
481 };
482
483 let isq = args
484 .in_situ_quant
485 .as_ref()
486 .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
487
488 let pipeline = loader.load_model_from_hf(
489 None,
490 token_source,
491 &dtype,
492 &device,
493 false,
494 mapper,
495 isq,
496 cache_config,
497 )?;
498 info!("Model loaded.");
499
500 let scheduler_config = if cache_config.is_some() {
501 if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
503 SchedulerConfig::PagedAttentionMeta {
504 max_num_seqs: *args.concurrency.as_ref().unwrap().iter().max().unwrap(),
505 config: cache_config.clone(),
506 }
507 } else {
508 SchedulerConfig::DefaultScheduler {
509 method: DefaultSchedulerMethod::Fixed(
510 (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
511 .try_into()
512 .unwrap(),
513 ),
514 }
515 }
516 } else {
517 SchedulerConfig::DefaultScheduler {
518 method: DefaultSchedulerMethod::Fixed(
519 (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
520 .try_into()
521 .unwrap(),
522 ),
523 }
524 };
525 let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config, false, None)
526 .with_no_prefix_cache(true)
527 .with_disable_eos_stop(true)
528 .build()
529 .await;
530
531 info!("Starting warmup run.");
532 warmup_run(mistralrs.clone()).await;
533 info!("Finished warmup run.");
534 info!("Starting benchmarks.");
535
536 for concurrency in args.concurrency.as_ref().unwrap() {
537 let mut results = vec![];
538 if args.n_gen > 0 {
539 let r = run_bench(
540 mistralrs.clone(),
541 RequestMessage::Completion {
542 text: "Rust".to_string(),
543 echo_prompt: false,
544 best_of: None,
545 },
546 args.n_gen - 1,
547 *concurrency,
548 args.repetitions,
549 TestName::Gen(args.n_gen),
550 )
551 .await?;
552 results.push(r);
553 }
554
555 if args.n_prompt > 0 {
556 let tks = (1000..1000 + args.n_prompt as u32).collect();
557 let r = run_bench(
558 mistralrs.clone(),
559 RequestMessage::CompletionTokens(tks),
560 1,
561 *concurrency,
562 args.repetitions,
563 TestName::Prompt(args.n_prompt),
564 )
565 .await?;
566
567 results.push(r);
568 }
569
570 print_usage(&model_name, &device, results);
571 }
572
573 Ok(())
574}