mistralrs_bench/
main.rs

1use candle_core::Device;
2use clap::Parser;
3use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
4use mistralrs_core::{
5    get_auto_device_map_params, get_model_dtype, initialize_logging, paged_attn_supported,
6    parse_isq_value, Constraint, DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata,
7    DeviceMapSetting, DrySamplingParams, Loader, LoaderBuilder, MemoryGpuConfig, MistralRs,
8    MistralRsBuilder, ModelSelected, NormalRequest, PagedAttentionConfig, PagedCacheType, Request,
9    RequestMessage, Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
10};
11use std::fmt::Display;
12use std::sync::Arc;
13use tokio::sync::mpsc::channel;
14use tracing::info;
15
16enum TestName {
17    Prompt(usize),
18    Gen(usize),
19}
20
21impl Display for TestName {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        let name = match self {
24            TestName::Prompt(n) => format!("pp {n}"),
25            TestName::Gen(n) => format!("tg {n}"),
26        };
27        write!(f, "{name}")
28    }
29}
30
31struct BenchResult {
32    usages: Vec<Usage>,
33    concurrency: usize,
34    test_name: TestName,
35}
36
37struct UncertainTokSec {
38    mean: f32,
39    std_dev: f32,
40}
41
42impl Display for UncertainTokSec {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "{:.3}±{:.3}", self.mean, self.std_dev)
45    }
46}
47
48async fn run_bench(
49    mistralrs: Arc<MistralRs>,
50    prompt: RequestMessage,
51    n_gen: usize,
52    concurrency: usize,
53    repetitions: usize,
54    test_name: TestName,
55) -> anyhow::Result<BenchResult> {
56    let sampling_params = SamplingParams {
57        temperature: Some(0.1),
58        top_k: Some(32),
59        top_p: Some(0.1),
60        min_p: Some(0.05),
61        top_n_logprobs: 0,
62        frequency_penalty: Some(0.1),
63        presence_penalty: Some(0.1),
64        repetition_penalty: None,
65        max_len: Some(n_gen),
66        stop_toks: None,
67        logits_bias: None,
68        n_choices: 1,
69        dry_params: Some(DrySamplingParams::default()),
70    };
71    let sender = mistralrs.get_sender(None).unwrap();
72    let (tx, mut rx) = channel(10_000);
73
74    let req = Request::Normal(Box::new(NormalRequest {
75        id: mistralrs.next_request_id(),
76        messages: prompt,
77        sampling_params: sampling_params.clone(),
78        response: tx,
79        return_logprobs: false,
80        is_streaming: false,
81        constraint: Constraint::None,
82        suffix: None,
83        tools: None,
84        tool_choice: None,
85        logits_processors: None,
86        return_raw_logits: false,
87        web_search_options: None,
88        model_id: None,
89    }));
90
91    let mut usages = Vec::new();
92
93    for _ in 0..repetitions {
94        for _ in 0..concurrency {
95            if sender.send(req.clone()).await.is_err() {
96                eprintln!("Receiver disconnected");
97            }
98        }
99        for _ in 0..concurrency {
100            match rx.recv().await {
101                Some(r) => match r {
102                    Response::InternalError(e) => {
103                        unreachable!("Got an internal error: {e:?}");
104                    }
105                    Response::ModelError(e, resp) => {
106                        unreachable!("Got a model error: {e:?}, response: {resp:?}");
107                    }
108                    Response::ValidationError(e) => {
109                        unreachable!("Got a validation error: {e:?}");
110                    }
111                    Response::Done(res) => {
112                        usages.push(res.usage);
113                    }
114                    Response::Chunk(_) => unreachable!(),
115                    Response::CompletionModelError(_, _) => unreachable!(),
116                    Response::CompletionDone(res) => {
117                        usages.push(res.usage);
118                    }
119                    Response::CompletionChunk(_) => unreachable!(),
120                    Response::ImageGeneration(_) => unreachable!(),
121                    Response::Speech { .. } => unreachable!(),
122                    Response::Raw { .. } => unreachable!(),
123                },
124                None => unreachable!("Expected a Done response, got None",),
125            }
126        }
127    }
128
129    Ok(BenchResult {
130        usages,
131        concurrency,
132        test_name,
133    })
134}
135
136fn get_tok_s(result: &BenchResult) -> UncertainTokSec {
137    let ts_measurements = match result.test_name {
138        TestName::Prompt(_) => result
139            .usages
140            .iter()
141            .map(|u| u.avg_prompt_tok_per_sec)
142            .collect::<Vec<_>>(),
143        TestName::Gen(_) => result
144            .usages
145            .iter()
146            .map(|u| u.avg_compl_tok_per_sec)
147            .collect::<Vec<_>>(),
148    };
149    // Calculate uncertainty
150    let mean = ts_measurements.iter().sum::<f32>() / ts_measurements.len() as f32;
151    let variance = ts_measurements
152        .iter()
153        .map(|e| (mean - e).powf(2.))
154        .sum::<f32>()
155        / ts_measurements.len() as f32;
156    let std_dev = variance.sqrt();
157    UncertainTokSec { mean, std_dev }
158}
159
160fn get_ms_tok(result: &BenchResult) -> UncertainTokSec {
161    let ms_tok_measurements = match result.test_name {
162        TestName::Prompt(_) => result
163            .usages
164            .iter()
165            .map(|u| 1000. / u.avg_prompt_tok_per_sec)
166            .collect::<Vec<_>>(),
167        TestName::Gen(_) => result
168            .usages
169            .iter()
170            .map(|u| 1000. / u.avg_compl_tok_per_sec)
171            .collect::<Vec<_>>(),
172    };
173    // Calculate uncertainty
174    let mean = ms_tok_measurements.iter().sum::<f32>() / ms_tok_measurements.len() as f32;
175    let variance = ms_tok_measurements
176        .iter()
177        .map(|e| (mean - e).powf(2.))
178        .sum::<f32>()
179        / ms_tok_measurements.len() as f32;
180    let std_dev = variance.sqrt();
181    UncertainTokSec { mean, std_dev }
182}
183
184fn print_usage(model: &str, device: &Device, results: Vec<BenchResult>) {
185    let backend = match device {
186        Device::Cpu => "CPU",
187        Device::Cuda(_) => "CUDA",
188        Device::Metal(_) => "Metal",
189    };
190    let results: Vec<Vec<CellStruct>> = results
191        .into_iter()
192        .map(|r| {
193            vec![
194                model.cell(),
195                backend.cell(),
196                r.test_name.to_string().cell(),
197                get_tok_s(&r).cell().justify(Justify::Right),
198                get_ms_tok(&r).cell().justify(Justify::Right),
199                r.concurrency.cell().justify(Justify::Right),
200                (get_tok_s(&r).mean * r.concurrency as f32)
201                    .cell()
202                    .justify(Justify::Right),
203            ]
204        })
205        .collect();
206
207    let table = results
208        .table()
209        .title(vec![
210            "model".cell().bold(true),
211            // "size".cell().bold(true),
212            // "params".cell().bold(true),
213            "backend".cell().bold(true),
214            // "ngl".cell().bold(true),
215            "test".cell().bold(true),
216            "t/s".cell().bold(true),
217            "ms/t".cell().bold(true),
218            "concurrency".cell().bold(true),
219            "throughput/s".cell().bold(true),
220        ])
221        .bold(true);
222    print_stdout(table).expect("print table");
223}
224
225async fn warmup_run(mistralrs: Arc<MistralRs>) {
226    let sampling_params = SamplingParams {
227        max_len: Some(1),
228        ..SamplingParams::deterministic()
229    };
230    let sender = mistralrs.get_sender(None).unwrap();
231    let (tx, mut rx) = channel(10_000);
232
233    let req = Request::Normal(Box::new(NormalRequest {
234        id: mistralrs.next_request_id(),
235        messages: RequestMessage::Completion {
236            text: "Hello!".to_string(),
237            echo_prompt: false,
238            best_of: None,
239        },
240        sampling_params: sampling_params.clone(),
241        response: tx,
242        return_logprobs: false,
243        is_streaming: false,
244        constraint: Constraint::None,
245        suffix: None,
246        tools: None,
247        tool_choice: None,
248        logits_processors: None,
249        return_raw_logits: false,
250        web_search_options: None,
251        model_id: None,
252    }));
253
254    if sender.send(req.clone()).await.is_err() {
255        eprintln!("Receiver disconnected");
256    }
257
258    let _ = rx.recv().await;
259}
260
261fn parse_cache_type(s: &str) -> Result<PagedCacheType, String> {
262    s.parse()
263}
264
265#[derive(Parser)]
266#[command(version, about, long_about = None)]
267struct Args {
268    /// Model
269    #[clap(subcommand)]
270    model: ModelSelected,
271
272    /// Integer seed to ensure reproducible random number generation.
273    #[arg(short, long)]
274    seed: Option<u64>,
275
276    /// Number of prompt tokens to run.
277    #[arg(long, short = 'p', default_value_t = 512)]
278    n_prompt: usize,
279
280    /// Number of generations tokens to run.
281    #[arg(long, short = 'g', default_value_t = 128)]
282    n_gen: usize,
283
284    /// Number of concurrent requests to run. Default is 1
285    #[clap(short, long, value_parser, value_delimiter = ',')]
286    concurrency: Option<Vec<usize>>,
287
288    /// Number of times to repeat each test.
289    #[arg(long, short, default_value_t = 5)]
290    repetitions: usize,
291
292    /// NOTE: This can be omitted to use automatic device mapping!
293    /// Number of device layers to load and run on GPU(s). All others will be on the CPU.
294    /// If one GPU is used, then this value should be an integer. Otherwise, it follows the following pattern:
295    /// ORD:NUM;... Where ORD is a unique device ordinal and NUM is the number of layers for that device.
296    #[arg(short, long, value_parser, value_delimiter = ';')]
297    num_device_layers: Option<Vec<String>>,
298
299    /// In-situ quantization to apply.
300    #[arg(long = "isq")]
301    in_situ_quant: Option<String>,
302
303    /// GPU memory to allocate for KV cache with PagedAttention in MBs.
304    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
305    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
306    #[arg(long = "pa-gpu-mem")]
307    paged_attn_gpu_mem: Option<usize>,
308
309    /// Percentage of GPU memory to utilize after allocation of KV cache with PagedAttention, from 0 to 1.
310    /// If this is not set and the device is CUDA, it will default to `0.9`.
311    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
312    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
313    #[arg(long = "pa-gpu-mem-usage")]
314    paged_attn_gpu_mem_usage: Option<f32>,
315
316    /// Total context length to allocate the KV cache for (total number of tokens which the KV cache can hold).
317    /// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
318    /// The priority is as follows: `pa-ctxt-len` > `pa-gpu-mem-usage` > `pa-gpu-mem`.
319    /// This is the default setting, and it defaults to the `max-seq-len` specified in after the model type.
320    #[arg(long = "pa-ctxt-len")]
321    paged_ctxt_len: Option<usize>,
322
323    /// PagedAttention KV cache type (auto or f8e4m3).
324    /// Defaults to `auto`.
325    #[arg(long = "pa-cache-type", value_parser = parse_cache_type)]
326    cache_type: Option<PagedCacheType>,
327
328    /// Block size (number of tokens per block) for PagedAttention. If this is not set and the device is CUDA, it will default to 32.
329    /// PagedAttention is only supported on CUDA and is always automatically activated.
330    #[arg(long = "pa-blk-size")]
331    paged_attn_block_size: Option<usize>,
332
333    /// Disable PagedAttention on CUDA. Because PagedAttention is already disabled on Metal, this is only applicable on CUDA.
334    #[arg(long = "no-paged-attn", default_value_t = false)]
335    no_paged_attn: bool,
336
337    /// Enable PagedAttention on Metal. Because PagedAttention is already enabled on CUDA, this is only applicable on Metal.
338    #[arg(long = "paged-attn", default_value_t = false)]
339    paged_attn: bool,
340}
341
342#[tokio::main]
343async fn main() -> anyhow::Result<()> {
344    let mut args = Args::parse();
345    initialize_logging();
346
347    args.concurrency = Some(args.concurrency.unwrap_or(vec![1]));
348
349    let dtype = get_model_dtype(&args.model)?;
350    let auto_device_map_params = get_auto_device_map_params(&args.model)?;
351
352    let max_seq_len = auto_device_map_params.max_seq_len();
353
354    let loader: Box<dyn Loader> = LoaderBuilder::new(args.model).build()?;
355    let model_name = loader.get_id();
356
357    #[cfg(feature = "metal")]
358    let device = Device::new_metal(0)?;
359    #[cfg(not(feature = "metal"))]
360    let device = if mistralrs_core::distributed::use_nccl() {
361        Device::Cpu
362    } else {
363        Device::cuda_if_available(0)?
364    };
365
366    if let Some(seed) = args.seed {
367        device.set_seed(seed)?;
368    }
369
370    let token_source = TokenSource::CacheToken;
371    info!(
372        "avx: {}, neon: {}, simd128: {}, f16c: {}",
373        candle_core::utils::with_avx(),
374        candle_core::utils::with_neon(),
375        candle_core::utils::with_simd128(),
376        candle_core::utils::with_f16c()
377    );
378    info!("Sampling method: penalties -> temperature -> topk -> topp -> minp -> multinomial");
379    info!("Model kind is: {}", loader.get_kind().to_string());
380
381    // Parse device mapper
382    let mapper = if let Some(device_layers) = args.num_device_layers {
383        if device_layers.len() == 1 && device_layers[0].parse::<usize>().is_ok() {
384            let layers = device_layers[0].parse::<usize>().unwrap();
385            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(vec![
386                DeviceLayerMapMetadata { ordinal: 0, layers },
387            ]))
388        } else {
389            let mut mapping = Vec::new();
390            for layer in device_layers {
391                let split = layer.splitn(2, ':').collect::<Vec<_>>();
392                if split.len() < 2 {
393                    panic!("Expected layer to be of format ORD:NUM, got {layer}");
394                }
395                let ord = split[0]
396                    .parse::<usize>()
397                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[0]));
398                let num = split[1]
399                    .parse::<usize>()
400                    .unwrap_or_else(|_| panic!("Failed to parse {} as integer.", split[1]));
401                for DeviceLayerMapMetadata { ordinal, layers: _ } in &mapping {
402                    if *ordinal == ord {
403                        panic!("Duplicate ordinal {ord}");
404                    }
405                }
406                mapping.push(DeviceLayerMapMetadata {
407                    ordinal: ord,
408                    layers: num,
409                });
410            }
411            DeviceMapSetting::Map(DeviceMapMetadata::from_num_device_layers(mapping))
412        }
413    } else {
414        DeviceMapSetting::Auto(auto_device_map_params)
415    };
416
417    let no_paged_attn = if device.is_cuda() || mistralrs_core::distributed::use_nccl() {
418        args.no_paged_attn
419    } else if device.is_metal() {
420        !args.paged_attn
421    } else {
422        true
423    };
424
425    // Allocate 0.5 GB of CPU memory just as a placeholder.
426    // Nothing happens here as we have no `swap_out`, see `_preempt_by_swap`.
427    let cache_config = match (
428        args.paged_attn_block_size,
429        args.paged_attn_gpu_mem,
430        args.paged_attn_gpu_mem_usage,
431        args.paged_ctxt_len,
432        paged_attn_supported(),
433        no_paged_attn,
434    ) {
435        (block_size, None, None, None, true, false) => Some(PagedAttentionConfig::new(
436            block_size,
437            512,
438            MemoryGpuConfig::ContextSize(max_seq_len),
439            args.cache_type.unwrap_or_default(),
440        )?),
441        (block_size, None, None, Some(ctxt), true, false) => Some(PagedAttentionConfig::new(
442            block_size,
443            512,
444            MemoryGpuConfig::ContextSize(ctxt),
445            args.cache_type.unwrap_or_default(),
446        )?),
447        (block_size, None, Some(f), None, true, false) => Some(PagedAttentionConfig::new(
448            block_size,
449            512,
450            MemoryGpuConfig::Utilization(f),
451            args.cache_type.unwrap_or_default(),
452        )?),
453        (block_size, Some(m), None, None, true, false) => Some(PagedAttentionConfig::new(
454            block_size,
455            512,
456            MemoryGpuConfig::MbAmount(m),
457            args.cache_type.unwrap_or_default(),
458        )?),
459        (block_size, Some(_m), Some(f), None, true, false) => {
460            info!("Both memory size, and usage were specified, defaulting to the usage value.");
461            Some(PagedAttentionConfig::new(
462                block_size,
463                512,
464                MemoryGpuConfig::Utilization(f),
465                args.cache_type.unwrap_or_default(),
466            )?)
467        }
468        (block_size, Some(_m), None, Some(ctxt), true, false) => {
469            info!("All memory size and ctxt len, defaulting to the context len value.");
470            Some(PagedAttentionConfig::new(
471                block_size,
472                512,
473                MemoryGpuConfig::ContextSize(ctxt),
474                args.cache_type.unwrap_or_default(),
475            )?)
476        }
477        (block_size, None, Some(f), Some(_ctxt), true, false) => {
478            info!("Both ctxt len and usage were specified, defaulting to the usage value.");
479            Some(PagedAttentionConfig::new(
480                block_size,
481                512,
482                MemoryGpuConfig::Utilization(f),
483                args.cache_type.unwrap_or_default(),
484            )?)
485        }
486        (_, _, _, _, _, _) => None,
487    };
488
489    let isq = args
490        .in_situ_quant
491        .as_ref()
492        .and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
493
494    let pipeline = loader.load_model_from_hf(
495        None,
496        token_source,
497        &dtype,
498        &device,
499        false,
500        mapper,
501        isq,
502        cache_config,
503    )?;
504    info!("Model loaded.");
505
506    let scheduler_config = if cache_config.is_some() {
507        // Handle case where we may have device mapping
508        if let Some(ref cache_config) = pipeline.lock().await.get_metadata().cache_config {
509            SchedulerConfig::PagedAttentionMeta {
510                max_num_seqs: *args.concurrency.as_ref().unwrap().iter().max().unwrap(),
511                config: cache_config.clone(),
512            }
513        } else {
514            SchedulerConfig::DefaultScheduler {
515                method: DefaultSchedulerMethod::Fixed(
516                    (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
517                        .try_into()
518                        .unwrap(),
519                ),
520            }
521        }
522    } else {
523        SchedulerConfig::DefaultScheduler {
524            method: DefaultSchedulerMethod::Fixed(
525                (*args.concurrency.as_ref().unwrap().iter().max().unwrap())
526                    .try_into()
527                    .unwrap(),
528            ),
529        }
530    };
531    let mistralrs = MistralRsBuilder::new(pipeline, scheduler_config, false, None)
532        .with_no_prefix_cache(true)
533        .with_disable_eos_stop(true)
534        .build()
535        .await;
536
537    info!("Starting warmup run.");
538    warmup_run(mistralrs.clone()).await;
539    info!("Finished warmup run.");
540    info!("Starting benchmarks.");
541
542    for concurrency in args.concurrency.as_ref().unwrap() {
543        let mut results = vec![];
544        if args.n_gen > 0 {
545            let r = run_bench(
546                mistralrs.clone(),
547                RequestMessage::Completion {
548                    text: "Rust".to_string(),
549                    echo_prompt: false,
550                    best_of: None,
551                },
552                args.n_gen - 1,
553                *concurrency,
554                args.repetitions,
555                TestName::Gen(args.n_gen),
556            )
557            .await?;
558            results.push(r);
559        }
560
561        if args.n_prompt > 0 {
562            let tks = (1000..1000 + args.n_prompt as u32).collect();
563            let r = run_bench(
564                mistralrs.clone(),
565                RequestMessage::CompletionTokens(tks),
566                1,
567                *concurrency,
568                args.repetitions,
569                TestName::Prompt(args.n_prompt),
570            )
571            .await?;
572
573            results.push(r);
574        }
575
576        print_usage(&model_name, &device, results);
577    }
578
579    Ok(())
580}