Deploying yolort on ONNX Runtime

The ONNX model exported by yolort differs from other pipeline in the following three ways.

  • We embed the pre-processing into the graph (mainly composed of letterbox). and the exported model expects a Tensor[C, H, W], which is in RGB channel and is rescaled to range float32 [0-1].

  • We embed the post-processing into the model graph with torchvision.ops.batched_nms. So the outputs of the exported model are straightforward boxes, labels and scores fields of this image.

  • We adopt the dynamic shape mechanism to export the ONNX models.

Set up environment and function utilities

First you should install ONNX Runtime first to run this tutorial. See the ONNX Runtime installation matrix for recommended instructions for desired combinations of target operating system, hardware, accelerator, and language.

A quick solution is to install via pip on X64:

pip install onnxruntime
[1]:
import os
import torch

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

device = torch.device('cpu')
[2]:
import cv2
import onnx
import onnxruntime

from yolort.models import YOLOv5
from yolort.v5 import attempt_download

from yolort.utils import get_image_from_url, read_image_to_tensor
from yolort.utils.image_utils import to_numpy

Define some parameters used for defining the model, exporting ONNX models and inferencing on ONNX Runtime.

[3]:
img_size = 640
size = (img_size, img_size)  # Used for pre-processing
size_divisible = 64
score_thresh = 0.35
nms_thresh = 0.45
opset_version = 11

Get images for inferenceing.

[4]:
img_src1 = "https://huggingface.co/spaces/zhiqwang/assets/resolve/main/bus.jpg"
img_one = get_image_from_url(img_src1)
img_one = read_image_to_tensor(img_one, is_half=False)
img_one = img_one.to(device)

img_src2 = "https://huggingface.co/spaces/zhiqwang/assets/resolve/main/zidane.jpg"
img_two = get_image_from_url(img_src2)
img_two = read_image_to_tensor(img_two, is_half=False)
img_two = img_two.to(device)

Load the model trained from yolov5

The model used below is officially released by yolov5 and trained on COCO 2017 datasets.

[5]:
# yolov5n6.pt is downloaded from 'https://github.com/ultralytics/yolov5/releases/download/v6.0/yolov5n6.pt'
model_path = "yolov5n6.pt"
onnx_path = "yolov5n6.onnx"
checkpoint_path = attempt_download(model_path)
[6]:
model = YOLOv5.load_from_yolov5(
    model_path,
    size=size,
    size_divisible=size_divisible,
    score_thresh=score_thresh,
    nms_thresh=nms_thresh,
)

model = model.eval()
model = model.to(device)

                 from  n    params  module                                  arguments
  0                -1  1      1760  yolort.v5.models.common.Conv            [3, 16, 6, 2, 2]
  1                -1  1      4672  yolort.v5.models.common.Conv            [16, 32, 3, 2]
  2                -1  1      4800  yolort.v5.models.common.C3              [32, 32, 1]
  3                -1  1     18560  yolort.v5.models.common.Conv            [32, 64, 3, 2]
  4                -1  2     29184  yolort.v5.models.common.C3              [64, 64, 2]
  5                -1  1     73984  yolort.v5.models.common.Conv            [64, 128, 3, 2]
  6                -1  3    156928  yolort.v5.models.common.C3              [128, 128, 3]
  7                -1  1    221568  yolort.v5.models.common.Conv            [128, 192, 3, 2]
  8                -1  1    167040  yolort.v5.models.common.C3              [192, 192, 1]
  9                -1  1    442880  yolort.v5.models.common.Conv            [192, 256, 3, 2]
 10                -1  1    296448  yolort.v5.models.common.C3              [256, 256, 1]
 11                -1  1    164608  yolort.v5.models.common.SPPF            [256, 256, 5]
 12                -1  1     49536  yolort.v5.models.common.Conv            [256, 192, 1, 1]
 13                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 14           [-1, 8]  1         0  yolort.v5.models.common.Concat          [1]
 15                -1  1    203904  yolort.v5.models.common.C3              [384, 192, 1, False]
 16                -1  1     24832  yolort.v5.models.common.Conv            [192, 128, 1, 1]
 17                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 18           [-1, 6]  1         0  yolort.v5.models.common.Concat          [1]
 19                -1  1     90880  yolort.v5.models.common.C3              [256, 128, 1, False]
 20                -1  1      8320  yolort.v5.models.common.Conv            [128, 64, 1, 1]
 21                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 22           [-1, 4]  1         0  yolort.v5.models.common.Concat          [1]
 23                -1  1     22912  yolort.v5.models.common.C3              [128, 64, 1, False]
 24                -1  1     36992  yolort.v5.models.common.Conv            [64, 64, 3, 2]
 25          [-1, 20]  1         0  yolort.v5.models.common.Concat          [1]
 26                -1  1     74496  yolort.v5.models.common.C3              [128, 128, 1, False]
 27                -1  1    147712  yolort.v5.models.common.Conv            [128, 128, 3, 2]
 28          [-1, 16]  1         0  yolort.v5.models.common.Concat          [1]
 29                -1  1    179328  yolort.v5.models.common.C3              [256, 192, 1, False]
 30                -1  1    332160  yolort.v5.models.common.Conv            [192, 192, 3, 2]
 31          [-1, 12]  1         0  yolort.v5.models.common.Concat          [1]
 32                -1  1    329216  yolort.v5.models.common.C3              [384, 256, 1, False]
 33  [23, 26, 29, 32]  1    164220  yolort.v5.models.yolo.Detect            [80, [[19, 27, 44, 40, 38, 94], [96, 68, 86, 152, 180, 137], [140, 301, 303, 264, 238, 542], [436, 615, 739, 380, 925, 792]], [64, 128, 192, 256]]
/opt/conda/lib/python3.8/site-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2157.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Model Summary: 355 layers, 3246940 parameters, 3246940 gradients, 4.6 GFLOPs

Inference on PyTorch backend

[7]:
images = [img_one]
[8]:
with torch.no_grad():
    model_out = model(images)
[9]:
%%timeit
with torch.no_grad():
    model_out = model(images)
The slowest run took 5.09 times longer than the fastest. This could mean that an intermediate result is being cached.
115 ms ± 71 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[10]:
model_out[0]['boxes']
[10]:
tensor([[ 32.27846, 225.15266, 811.47729, 740.91071],
        [ 50.42178, 387.48898, 241.54399, 897.61041],
        [219.03331, 386.14346, 345.77689, 869.02582],
        [678.05023, 374.65326, 809.80334, 874.80621]])
[11]:
model_out[0]['scores']
[11]:
tensor([0.88238, 0.84486, 0.72629, 0.70077])
[12]:
model_out[0]['labels']
[12]:
tensor([5, 0, 0, 0])

Export the model to ONNX

[13]:
from yolort.runtime.ort_helper import export_onnx
[14]:
print(f'We are using opset version: {opset_version}')
We are using opset version: 11
[15]:
export_onnx(model=model, onnx_path=onnx_path, opset_version=opset_version)
/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py:3701: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  (torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float()))
/coding/yolort/yolort/models/transform.py:282: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  img_h, img_w = _get_shape_onnx(img)
/coding/yolort/yolort/models/anchor_utils.py:45: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  anchors = torch.as_tensor(self.anchor_grids, dtype=torch.float32, device=device).to(dtype=dtype)
/coding/yolort/yolort/models/anchor_utils.py:46: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  strides = torch.as_tensor(self.strides, dtype=torch.float32, device=device).to(dtype=dtype)
/coding/yolort/yolort/models/box_head.py:402: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  strides = torch.as_tensor(self.strides, dtype=torch.float32, device=device).to(dtype=dtype)
/coding/yolort/yolort/models/box_head.py:333: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  for head_output, grid, shift, stride in zip(head_outputs, grids, shifts, strides):
/opt/conda/lib/python3.8/site-packages/torch/onnx/symbolic_opset9.py:2815: UserWarning: Exporting aten::index operator of advanced indexing in opset 11 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.
  warnings.warn("Exporting aten::index operator of advanced indexing in opset " +

Check the exported ONNX model is well formed

[16]:
# Load the ONNX model
onnx_model = onnx.load(onnx_path)

# Check that the model is well formed
onnx.checker.check_model(onnx_model)

# Print a human readable representation of the graph
# print(onnx.helper.printable_graph(model.graph))

Inference on ONNX Runtime backend

Check the version of ONNX Runtime first.

[17]:
print(f'Starting with onnx {onnx.__version__}, onnxruntime {onnxruntime.__version__}...')
Starting with onnx 1.10.2, onnxruntime 1.10.0...

Prepare the inputs for ONNX Runtime.

[18]:
inputs, _ = torch.jit._flatten(images)
outputs, _ = torch.jit._flatten(model_out)
[19]:
inputs = list(map(to_numpy, inputs))
outputs = list(map(to_numpy, outputs))

We provide a pipeline for deploying yolort with ONNX Runtime.

[20]:
from yolort.runtime import PredictorORT
[21]:
y_runtime = PredictorORT(onnx_path, device="cpu")
Providers was initialized.
Set inference device to CPU
[22]:
ort_outs1 = y_runtime.predict(inputs)

Let’s measure the inferencing speed of ONNX Runtime.

[23]:
%%timeit
y_runtime.predict(inputs)
47.7 ms ± 614 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Verify whether the inference results are consistent with PyTorch’s

[24]:
for i in range(0, len(outputs)):
    torch.testing.assert_allclose(outputs[i], ort_outs1[i], rtol=1e-04, atol=1e-07)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")
Exported model has been tested with ONNXRuntime, and the result looks good!

Verify another image

When using dynamic shape inference in trace mode, the shape inference mechanism for some operators may not work, so we verify it once for another image with a different shape as well.

[25]:
images = [img_two]
[26]:
with torch.no_grad():
    out_pytorch = model(images)
[27]:
inputs, _ = torch.jit._flatten(images)
outputs, _ = torch.jit._flatten(out_pytorch)
[28]:
inputs = list(map(to_numpy, inputs))
outputs = list(map(to_numpy, outputs))

Compute onnxruntime output prediction.

[29]:
ort_outs2 = y_runtime.predict(inputs)

Let’s measure the inferencing speed of ONNX Runtime.

[30]:
%%timeit
y_runtime.predict(inputs)
37.5 ms ± 767 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Verify whether the inference results are consistent with PyTorch’s.

[31]:
for i in range(0, len(outputs)):
    torch.testing.assert_allclose(outputs[i], ort_outs2[i], rtol=1e-04, atol=1e-07)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")
Exported model has been tested with ONNXRuntime, and the result looks good!

View this document as a notebook: https://github.com/zhiqwang/yolort/blob/main/notebooks/export-onnx-inference-onnxruntime.ipynb