Optimize for Speed and Savings: High-Performance Model Serving Strategies
Achieve Up to 9x Faster & 13x Smaller Model Serving Compared to Naive Setups
Training a machine learning model is just the beginning when it comes to solving a business problem. The next steps involve deploying it effectively in production and ensuring the serving strategy can scale to meet demand.
In this article, we'll delve into different model serving strategies and explore technologies that can enhance their efficiency. We'll walk through building three lightweight model services from scratch and compare their performance in a benchmark test. The implementation will focus on performing inference using CPUs, though the same concept can be extended to GPUs, as technologies proposed here (ONNX Runtime) support various hardware platforms, including GPUs and NPUs.
All source code can be found under the model-serving repository. For people who are not interested in technical details - I suggest jumping to “Benchmark Results” and “Conclusions” directly.
Technical Background
Before diving into implementation examples, let's first cover a few technical concepts: Open Neural Network Exchange (ONNX) and ONNX Runtime.
Open Neural Network Exchange
ONNX is a specification1 (standard format) designed to represent machine learning models as computational graphs, providing a common language across different frameworks. It defines necessary operations (operators), data types, and serialization methods (using Protocol Buffers) to enable interoperability and ease of deployment in various environments. The ONNX specification supports extensibility through custom operators and functions and includes tools for model visualization, metadata storage, etc.
ONNX Runtime
ONNX Runtime is a high-performance inference engine designed to execute machine learning models in the ONNX format efficiently across various hardware platforms. It serves as a cross-platform accelerator that enables developers to deploy models trained in different frameworks—such as PyTorch, TensorFlow, Keras, and scikit-learn—into production environments with minimal overhead.
One of the key benefits of ONNX Runtime is its flexible architecture that supports both kernel-based and runtime-based Execution Providers. Kernel-based Execution Providers implement specific ONNX operations optimized for particular hardware (e.g., CPUs with CPUExecutionProvider, GPUs with CUDAExecutionProvider), while runtime-based Execution Providers can execute entire or partial computational graphs using specialized accelerators like TensorRT or nGraph.
Lastly, ONNX Runtime performs various levels of graph optimizations—such as constant folding, node fusions, and redundant node eliminations—that modify the computational graph for faster execution. These optimizations can be applied both online and offline, further reducing computational overhead and improving inference speed. By integrating these capabilities, ONNX Runtime allows for efficient, scalable, and flexible deployment of machine learning models in production environments.
Problem Context
With the technical background in place, let's move on to a real-world application: serving a machine learning model in a production environment. We'll build upon the previous article on "ML Training Pipelines," where we developed a model to predict weather conditions. To recap briefly, we fine-tuned the MobileNet V3-small model (~1.53 million parameters) that identifies 11 distinct weather patterns.
Now that the model is “trained”, the next step is to serve it efficiently in a production environment. To streamline integration with the serving application, we can make a few improvements to the training pipeline itself.
Adding Input Transformations to the Model Graph
In the previous article, we saved both PyTorch and ONNX models as Kubeflow Pipelines artifacts for downstream use or direct production deployment. A useful adjustment to this approach is embedding image transformations directly within the model’s computation graph. This provides two key advantages:
Modularity and Simplification: By incorporating input transformations into the model graph, we separate input logic from serving logic, making the setup more modular and easier to integrate. This also minimizes third-party dependencies on the serving side, resulting in leaner Docker images and faster startup times.
Optimized Processing Speed: With input transformations embedded, ONNX Runtime can optimize them as well, further enhancing overall request processing speed.
To implement this improvement, we need to investigate what transformations the MobileNet_V3_Small_Weights.DEFAULT.transforms()
use:
ImageClassification(
crop_size=[224]
resize_size=[256]
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
interpolation=InterpolationMode.BILINEAR
)
The next step is to implement this transformer in a way that it can be correctly exported into ONNX format. This typically involves using native PyTorch operations and tensors throughout. Additionally, we need to create a new model that incorporates the transformer as part of its computation graph. Below is an example implementation2:
...
class ModelWithTransforms(Module): # type: ignore[misc]
def __init__(self, model: MobileNetV3) -> None:
super(ModelWithTransforms, self).__init__()
self.model = model
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
self.register_buffer("targ_h", torch.tensor(224))
self.register_buffer("targ_w", torch.tensor(224))
def transform(self, img: torch.Tensor) -> torch.Tensor:
# Add batch dimension if needed.
if img.dim() == 3:
img = img.unsqueeze(0)
resized = F.interpolate(img, size=256, mode="bilinear", align_corners=False)
_, _, curr_h, curr_w = resized.shape
pad_h = torch.clamp(self.targ_h - curr_h, min=0)
pad_w = torch.clamp(self.targ_w - curr_w, min=0)
padding = [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
padded = pad(resized, padding)
start_h = torch.clamp((curr_h + pad_h - self.targ_h) // 2, min=0)
start_w = torch.clamp((curr_w + pad_w - self.targ_w) // 2, min=0)
cropped = padded[..., start_h : start_h + self.targ_h, start_w : start_w + self.targ_w]
normalized = (cropped - self.mean.to(cropped.device)) / self.std.to(cropped.device)
return normalized
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.transform(x)
return self.model(x)
The model
refers to the original trained model. To save the new model with the integrated transformations:
...
model_with_transform = ModelWithTransforms(model)
model_with_transform.to(device)
torch.onnx.export(
model_with_transform,
model_input,
f"{onnx_with_transform_model.path}.onnx",
opset_version=opset_version,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size", 2: "height", 3: "width"}, # Dynamic batch size, height, and width
"output": {0: "batch_size"}, # Dynamic batch size for output
},
)
Performing Offline ONNX Graph Optimizations
As the official ONNX runtime documentation states:
All optimizations can be performed either online or offline. In online mode, when initializing an inference session, we also apply all enabled graph optimizations before performing model inference. Applying all optimizations each time we initiate a session can add overhead to the model startup time (especially for complex models), which can be critical in production scenarios. This is where the offline mode can bring a lot of benefit. In offline mode, after performing graph optimizations, ONNX Runtime serializes the resulting model to disk. Subsequently, we can reduce startup time by using the already optimized model and disabling all optimizations.
Depending on the model size, this optimization can significantly reduce instance start times, improving instance scaling speed in production systems under high loads. The implementation is straightforward—all we need to do is add a small component to the original pipeline that takes the ONNX model with transformations as input. Here's an example of how the implementation might look:
from kfp.dsl import Input, Metrics, Model, Output
def onnx_optimize(
onnx_with_transform_model: Input[Model],
optimization_metrics: Output[Metrics],
optimized_onnx_with_transform_model: Output[Model]
) -> None:
import time
import onnxruntime as rt
start_time = time.time()
sess_options = rt.SessionOptions()
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.optimized_model_filepath = optimized_onnx_with_transform_model.path
rt.InferenceSession(f"{onnx_with_transform_model.path}.onnx", sess_options)
optimized_onnx_with_transform_model.framework = (
f"onnxruntime-{rt.__version__}, graphOptimizationLevel-{str(sess_options.graph_optimization_level)}"
)
optimization_metrics.log_metric("timeTakenSeconds", round(time.time() - start_time, 2))
After this step, the optimized ONNX model will be ready for deployment in production. As highlighted in the official documentation, a critical consideration when selecting the offline optimization approach is:
When running in offline mode, make sure to use the exact same options (e.g., execution providers, optimization level) and hardware as the target machine that the model inference will run on (e.g., you cannot run a model pre-optimized for a GPU execution provider on a machine that is equipped only with CPU).
When layout optimizations are enabled, the offline mode can only be used on compatible hardware to the environment when the offline model is saved. For example, if model has layout optimized for AVX2, the offline model would require CPUs that support AVX2.
Model Serving Strategies
With the optimized model ready, we can start building model-serving applications. In this article, we’ll benchmark three different serving strategies to compare their performance:
Naive Model Serving with PyTorch and FastAPI (Python): This setup uses PyTorch with
model.eval()
andtorch.inference_mode()
enabled. No ONNX or ONNX Runtime optimizations are applied; instead, we serve the model directly from its savedstate_dict
after training. Although this approach is less optimized, it remains common in practice, with Flask or Django being possible alternatives to FastAPI, making it a valuable baseline for our benchmarks.Optimized Model Serving with ONNX Runtime and FastAPI (Python): In this approach, we leverage ONNX Runtime for serving. Input transformation logic is embedded directly into the model’s computation graph, and graph optimizations are applied offline, providing a more efficient alternative to the naive approach.
Optimized Model Serving with ONNX Runtime and Actix-Web (Rust): Here, we use a Rust-based setup with ONNX Runtime (built from source and utilizing the pykeio/ort wrapper) and Actix-Web for serving. Like the previous setup, input transformation logic is embedded in the model graph, and offline graph optimizations are applied, aiming for maximum performance.
Benchmark Setup
When interpreting benchmark results, avoid treating them as universally applicable values, as absolute performance can vary significantly with different hardware, operating systems (OS), and C standard library implementations (e.g., glibc or musl), which affect the Application Binary Interface (ABI).
Furthermore, performance metrics can differ based on the sizes of the input images; therefore, in a production environment, it would be important to understand the distribution of image sizes. For the purposes of this exercise, the focus should be on the relative performance differences between different serving strategies.
The most reliable way to assess model service performance on a specific host machine is to conduct direct testing in that environment.
Host System
Hardware: Apple M2 Max
OS: macOS 15.0.1
Docker:
Engine v27.2.0
Desktop 4.34.3
Containers
CPU Allocation: Each container was allocated 4 CPU cores.
Memory Allocation: Memory was allocated dynamically, providing each container with as much memory as it required.
Worker and Thread Configuration: To fully utilize each container's CPU allocation—reaching up to 400% usage corresponding to 4 CPU cores—CPU oversubscription was closely monitored and prevented. The following configurations were implemented to achieve optimal performance:
onnx_serving
:Uvicorn Workers: 4
ONNX Runtime Session Threads:
torch_serving
:Uvicorn Workers: 4
rust_onnx_serving
:Actix Web Workers: 4
ONNX Runtime Session Threads:
Intra-Op Threads: 3
Inter-Op Threads: 1
Benchmark Configuration
Benchmarking tool: apache benchmark.
ab -n 40000 -c 50 -p images/rime_5868.json -T 'application/json' -s 3600 "http://localhost:$port/predict/"
-n 40000
: a total of 40000 requests.-c 50
: concurrency of 50.Payload image:
images/rime_5868.jpg
:
Original size: 393 KB.
Payload size after PIL compression and base64 encoding (~33% increase): 304 KB.
Implementations
Due to the volume of code involved in model serving, I’ll provide links to the corresponding GitHub repository directories. This approach keeps the Substack article clearer/compact while allowing readers to view the code with GitHub’s syntax highlighting and clear project structure.
Naive Model Serving Using PyTorch/FastAPI
Model Serving Using ONNX-Runtime/FastAPI (Python)
Model Serving Using ONNX-Runtime/Actix Web (Rust)
Benchmark Results
Performance Metrics
Deployment Metrics
Conclusions
ONNX Runtime Significantly Improves Performance: Converting models to ONNX and serving them with ONNX Runtime greatly enhances throughput and reduces latency compared to serving with PyTorch. Specifically:
onnx-serving
(Python) handles approximately 7.18 times more requests per second thantorch-serving
(255.53 vs. 35.62 requests/sec).rust-onnx-serving
(Rust) achieves about 9.23 times higher throughput thantorch-serving
(328.94 vs. 35.62 requests/sec).
Rust Implementation Delivers Highest Performance: Despite higher memory usage than Python ONNX serving, the Rust implementation offers higher performance and advantages in deployment size and startup time:
Throughput:
rust-onnx-serving
is about 1.29 times faster thanonnx-serving
(328.94 vs. 255.53 requests/sec).Startup Time: Rust application starts in 0.348 seconds, which is over 12 times faster than
torch-serving
(4.342 seconds) and about 4 times faster thanonnx-serving
(1.396 seconds).Docker Image Size: Rust image size is 48.3 MB, which is approximately 13 times smaller than
torch-serving
(650 MB) and about 6 times smaller thanonnx-serving
(296 MB).
Memory Usage Difference: The higher memory usage in Rust compared to Python ONNX serving stems from differences in implementations and libraries used:
Image Processing Differences: The Rust implementation uses less optimized image processing compared to Python's PIL and NumPy libraries, leading to higher memory consumption.
Library Efficiency: The Rust
ort
crate is an unofficial wrapper and might manage memory differently compared to the official ONNX Runtime SDK for Python, which is mature and highly optimized.Threading Configuration: The Rust implementation uses more intra-threads, which contributes to some additional memory consumption. However, this accounts for only a smaller portion of the overall difference observed.
The last memory point is just a consequence of a more important factor: Python’s mature and extensive ecosystem for machine learning. Rewriting these serving strategies in Rust can introduce challenges, such as increased development effort, potential performance trade-offs where optimized crates are unavailable (or one has to write them), and added complexity. However, Rust's benefits may sometimes justify the effort, depending on specific business needs.
Using inference-optimized solutions like ONNX Runtime can significantly enhance model serving performance, especially for larger models. While this article uses a small model (MobileNet V3-small, ~1.53 million parameters), the benefits of ONNX Runtime become more pronounced with more complex architectures. Its ability to optimize computation graphs and streamline resource usage leads to higher throughput and reduced latency, making it invaluable for scaling model-serving solutions.
Different kinds of literature might refer to ONNX as an intermediate representation (IR) of models.
When implementing transformations like this, always ensure they perform identically to the original transformation you're replicating. Comparing the before-and-after images for both transformations can be a helpful validation step.