Skip to content

Simple test script to verify multi-model functionality in mistral.rs

Simple test script to verify multi-model functionality in mistral.rs

This script tests the core multi-model operations:

  • Listing models
  • Getting/setting default model
  • Sending requests to specific models
  • Model unloading/reloading
  • Model removal (commented out for safety)
"""
Simple test script to verify multi-model functionality in mistral.rs
This script tests the core multi-model operations:
- Listing models
- Getting/setting default model
- Sending requests to specific models
- Model unloading/reloading
- Model removal (commented out for safety)
"""
from mistralrs import (
Runner,
Which,
ChatCompletionRequest,
Architecture,
)
import sys
def test_multi_model_operations():
"""Test basic multi-model operations."""
print("Testing Multi-Model Operations\n" + "=" * 50)
try:
# Create a simple runner
print("1. Creating runner with GPT-2 model...")
runner = Runner(
which=Which.Plain(
model_id="gpt2",
arch=Architecture.Gpt2,
)
)
print(" ✓ Runner created successfully")
# Test listing models
print("\n2. Testing list_models()...")
models = runner.list_models()
print(f" ✓ Models found: {models}")
assert isinstance(models, list), "list_models should return a list"
assert len(models) > 0, "Should have at least one model"
# Test getting default model
print("\n3. Testing get_default_model_id()...")
default_model = runner.get_default_model_id()
print(f" ✓ Default model: {default_model}")
# Test setting default model (if we have multiple models)
if len(models) > 1:
print("\n4. Testing set_default_model_id()...")
new_default = models[1] if models[0] == default_model else models[0]
runner.set_default_model_id(new_default)
updated_default = runner.get_default_model_id()
print(f" ✓ Changed default from '{default_model}' to '{updated_default}'")
assert updated_default == new_default, "Default model should have changed"
else:
print("\n4. Skipping set_default_model_id() test (only one model loaded)")
# Test sending request with model_id
print("\n5. Testing send_chat_completion_request with model_id...")
messages = [
{"role": "user", "content": "Say 'test successful' and nothing else."}
]
request = ChatCompletionRequest(
messages=messages, model="default", max_tokens=10
)
if models:
response = runner.send_chat_completion_request(
request=request, model_id=models[0]
)
print(f" ✓ Response received: {response.choices[0].message.content}")
# Test list_models_with_status
print("\n6. Testing list_models_with_status()...")
models_with_status = runner.list_models_with_status()
print(f" ✓ Models with status: {models_with_status}")
assert isinstance(models_with_status, list), "Should return a list"
# Test is_model_loaded
print("\n7. Testing is_model_loaded()...")
if models:
is_loaded = runner.is_model_loaded(models[0])
print(f" ✓ Model '{models[0]}' loaded: {is_loaded}")
assert is_loaded, "Model should be loaded initially"
print("\n✅ All tests passed!")
except Exception as e:
print(f"\n❌ Test failed with error: {e}")
import traceback
traceback.print_exc()
return False
return True
def test_model_id_in_requests():
"""Test that model_id is properly passed in requests."""
print("\n\nTesting Model ID in Requests\n" + "=" * 50)
try:
runner = Runner(
which=Which.Plain(
model_id="gpt2",
arch=Architecture.Gpt2,
)
)
models = runner.list_models()
if not models:
print("No models available to test")
return False
model_id = models[0]
print(f"Using model: {model_id}")
# Test different request types with model_id
messages = [{"role": "user", "content": "Hi"}]
# Chat completion
print("\n1. Testing chat completion with model_id...")
request = ChatCompletionRequest(
messages=messages, model="default", max_tokens=5
)
response = runner.send_chat_completion_request(request, model_id=model_id)
print(f" ✓ Chat response: {response.choices[0].message.content}")
# Without model_id (should use default)
print("\n2. Testing chat completion without model_id...")
response = runner.send_chat_completion_request(request)
print(f" ✓ Chat response (default): {response.choices[0].message.content}")
print("\n✅ Model ID request tests passed!")
return True
except Exception as e:
print(f"\n❌ Test failed with error: {e}")
import traceback
traceback.print_exc()
return False
def test_unload_reload():
"""Test model unloading and reloading."""
print("\n\nTesting Model Unload/Reload\n" + "=" * 50)
try:
runner = Runner(
which=Which.Plain(
model_id="gpt2",
arch=Architecture.Gpt2,
)
)
models = runner.list_models()
if not models:
print("No models available to test")
return False
model_id = models[0]
# Check initial status
print("1. Checking initial status...")
assert runner.is_model_loaded(model_id), "Model should be loaded initially"
print(f" ✓ Model '{model_id}' is loaded")
# Unload the model
print("\n2. Unloading model...")
runner.unload_model(model_id)
is_loaded = runner.is_model_loaded(model_id)
print(f" ✓ Model unloaded. is_model_loaded: {is_loaded}")
# Check status after unload
print("\n3. Checking status after unload...")
status = runner.list_models_with_status()
print(f" ✓ Status: {status}")
# Reload the model
print("\n4. Reloading model...")
runner.reload_model(model_id)
is_loaded = runner.is_model_loaded(model_id)
print(f" ✓ Model reloaded. is_model_loaded: {is_loaded}")
assert is_loaded, "Model should be loaded after reload"
print("\n✅ Unload/reload tests passed!")
return True
except Exception as e:
print(f"\n❌ Test failed with error: {e}")
import traceback
traceback.print_exc()
return False
def test_error_handling():
"""Test error handling for multi-model operations."""
print("\n\nTesting Error Handling\n" + "=" * 50)
try:
runner = Runner(
which=Which.Plain(
model_id="gpt2",
arch=Architecture.Gpt2,
)
)
# Test with non-existent model
print("1. Testing request to non-existent model...")
messages = [{"role": "user", "content": "Hi"}]
request = ChatCompletionRequest(messages=messages, model="default")
try:
runner.send_chat_completion_request(request, model_id="non-existent-model")
print(" ❌ Should have raised an error for non-existent model")
except Exception as e:
print(f" ✓ Correctly raised error: {type(e).__name__}")
# Test setting non-existent model as default
print("\n2. Testing set_default_model_id with non-existent model...")
try:
runner.set_default_model_id("non-existent-model")
print(" ❌ Should have raised an error")
except Exception as e:
print(f" ✓ Correctly raised error: {type(e).__name__}")
print("\n✅ Error handling tests passed!")
return True
except Exception as e:
print(f"\n❌ Test setup failed with error: {e}")
return False
if __name__ == "__main__":
print("mistral.rs Multi-Model Test Suite")
print("=" * 60)
all_passed = True
# Run tests
all_passed &= test_multi_model_operations()
all_passed &= test_model_id_in_requests()
all_passed &= test_unload_reload()
all_passed &= test_error_handling()
print("\n" + "=" * 60)
if all_passed:
print("✅ All tests passed!")
sys.exit(0)
else:
print("❌ Some tests failed!")
sys.exit(1)

Source: examples/python/test_multi_model.py