Deploy the Pretrained Model on Adreno

Author: Daniil Barinov

This article is a step-by-step tutorial to deploy pretrained Pytorch ResNet-18 model on Adreno (on different precisions).

For us to begin with, PyTorch must be installed. TorchVision is also required since we will be using it as our model zoo.

A quick solution is to install it via pip:

pip install torch
pip install torchvision

Besides that, you should have TVM builded for Android. See the following instructions on how to build it.

Deploy to Adreno GPU

After the build section there should be two files in build directory «libtvm_runtime.so» and «tvm_rpc». Let’s push them to the device and run TVM RPC Server.

TVM RPC Server

To get the hash of the device use:

adb devices

Then to upload these two files to the device you should use:

adb -s <device_hash> push {libtvm_runtime.so,tvm_rpc} /data/local/tmp

At this moment you will have «libtvm_runtime.so» and «tvm_rpc» on path /data/local/tmp on your device. Sometimes cmake can’t find «libc++_shared.so». Use:

find ${ANDROID_NDK_HOME} -name libc++_shared.so

to find it and also push it with adb on the desired device:

adb -s <device_hash> push libc++_shared.so /data/local/tmp

We are now ready to run the TVM RPC Server. Launch rpc_tracker with following line in 1st console:

python3 -m tvm.exec.rpc_tracker --port 9190

Then we need to run tvm_rpc server from under the desired device in 2nd console:

adb -s <device_hash> reverse tcp:9190 tcp:9190
adb -s <device_hash> forward tcp:9090 tcp:9090
adb -s <device_hash> forward tcp:9091 tcp:9091
adb -s <device_hash> forward tcp:9092 tcp:9092
adb -s <device_hash> forward tcp:9093 tcp:9093
adb -s <device_hash> shell LD_LIBRARY_PATH=/data/local/tmp /data/local/tmp/tvm_rpc server --host=0.0.0.0 --port=9090 --tracker=127.0.0.1:9190 --key=android --port-end=9190

Before proceeding to compile and infer model, specify TVM_TRACKER_HOST and TVM_TRACKER_PORT

export TVM_TRACKER_HOST=0.0.0.0
export TVM_TRACKER_PORT=9190

check that the tracker is running and the device is available

python -m tvm.exec.query_rpc_tracker --port 9190

For example, if we have 1 Android device, the output can be:

Queue Status
----------------------------------
key          total  free  pending
----------------------------------
android      1      1     0
----------------------------------

Load a test image

As an example we would use classical cat image from ImageNet

from PIL import Image
from tvm.contrib.download import download_testdata
from matplotlib import pyplot as plt
import numpy as np

img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))
plt.imshow(img)
plt.show()

# Preprocess the image and convert to tensor
from torchvision import transforms

my_preprocess = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
img = my_preprocess(img)
img = np.expand_dims(img, 0)
deploy model on adreno

Load pretrained Pytorch model

Create a Relay graph from a Pytorch ResNet-18 model

import os
import torch
import torchvision
import tvm
from tvm import te
from tvm import relay, rpc
from tvm.contrib import utils, ndk
from tvm.contrib import graph_executor

model_name = "resnet18"
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()

# We grab the TorchScripted model via tracing
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()

# Input name can be arbitrary
input_name = "input0"
shape_list = [(input_name, img.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/torchvision/models/_utils.py:209: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
/workspace/python/tvm/relay/frontend/pytorch_utils.py:47: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  return LooseVersion(torch_ver) > ver
/venv/apache-tvm-py3.7/lib/python3.7/site-packages/setuptools/_distutils/version.py:346: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  other = LooseVersion(other)

Precisions

Since TVM support Mixed Precision, we need to register mixed_precision_conversion:

from tvm.relay.op import register_mixed_precision_conversion

conv2d_acc = "float32"


@register_mixed_precision_conversion("nn.conv2d", level=11)
def conv2d_mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str):
    global conv2d_acc
    return [
        relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
        conv2d_acc,
        mixed_precision_type,
    ]


@register_mixed_precision_conversion("nn.dense", level=11)
def conv2d_mixed_precision_rule(call_node: "relay.Call", mixed_precision_type: str):
    global conv2d_acc
    return [
        relay.transform.mixed_precision.MIXED_PRECISION_ALWAYS,
        conv2d_acc,
        mixed_precision_type,
    ]

and also define the conversion function itself

def convert_to_dtype(mod, dtype):
    # downcast to float16
    if dtype == "float16" or dtype == "float16_acc32":
        global conv2d_acc
        conv2d_acc = "float16" if dtype == "float16" else "float32"
        from tvm.ir import IRModule

        mod = IRModule.from_expr(mod)
        seq = tvm.transform.Sequential(
            [relay.transform.InferType(), relay.transform.ToMixedPrecision()]
        )
        with tvm.transform.PassContext(opt_level=3):
            mod = seq(mod)
    return mod

Let’s choose “float16_acc32” for example.

dtype = "float16_acc32"
mod = convert_to_dtype(mod["main"], dtype)
dtype = "float32" if dtype == "float32" else "float16"

print(mod)
def @main(%input0: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] */, %conv1.weight: Tensor[(64, 3, 7, 7), float32] /* ty=Tensor[(64, 3, 7, 7), float32] */, %bn1.weight: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %bn1.bias: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %bn1.running_mean: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %bn1.running_var: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer1.0.conv1.weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %layer1.0.bn1.weight: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer1.0.bn1.bias: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer1.0.bn1.running_mean: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer1.0.bn1.running_var: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer1.0.conv2.weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %layer1.0.bn2.weight: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer1.0.bn2.bias: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer1.0.bn2.running_mean: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer1.0.bn2.running_var: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer1.1.conv1.weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %layer1.1.bn1.weight: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer1.1.bn1.bias: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer1.1.bn1.running_mean: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer1.1.bn1.running_var: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer1.1.conv2.weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %layer1.1.bn2.weight: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer1.1.bn2.bias: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer1.1.bn2.running_mean: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer1.1.bn2.running_var: Tensor[(64), float32] /* ty=Tensor[(64), float32] */, %layer2.0.conv1.weight: Tensor[(128, 64, 3, 3), float32] /* ty=Tensor[(128, 64, 3, 3), float32] */, %layer2.0.bn1.weight: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.0.bn1.bias: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.0.bn1.running_mean: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.0.bn1.running_var: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.0.conv2.weight: Tensor[(128, 128, 3, 3), float32] /* ty=Tensor[(128, 128, 3, 3), float32] */, %layer2.0.bn2.weight: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.0.bn2.bias: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.0.bn2.running_mean: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.0.bn2.running_var: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.0.downsample.0.weight: Tensor[(128, 64, 1, 1), float32] /* ty=Tensor[(128, 64, 1, 1), float32] */, %layer2.0.downsample.1.weight: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.0.downsample.1.bias: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.0.downsample.1.running_mean: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.0.downsample.1.running_var: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.1.conv1.weight: Tensor[(128, 128, 3, 3), float32] /* ty=Tensor[(128, 128, 3, 3), float32] */, %layer2.1.bn1.weight: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.1.bn1.bias: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.1.bn1.running_mean: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.1.bn1.running_var: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.1.conv2.weight: Tensor[(128, 128, 3, 3), float32] /* ty=Tensor[(128, 128, 3, 3), float32] */, %layer2.1.bn2.weight: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.1.bn2.bias: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.1.bn2.running_mean: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer2.1.bn2.running_var: Tensor[(128), float32] /* ty=Tensor[(128), float32] */, %layer3.0.conv1.weight: Tensor[(256, 128, 3, 3), float32] /* ty=Tensor[(256, 128, 3, 3), float32] */, %layer3.0.bn1.weight: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.0.bn1.bias: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.0.bn1.running_mean: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.0.bn1.running_var: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.0.conv2.weight: Tensor[(256, 256, 3, 3), float32] /* ty=Tensor[(256, 256, 3, 3), float32] */, %layer3.0.bn2.weight: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.0.bn2.bias: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.0.bn2.running_mean: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.0.bn2.running_var: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.0.downsample.0.weight: Tensor[(256, 128, 1, 1), float32] /* ty=Tensor[(256, 128, 1, 1), float32] */, %layer3.0.downsample.1.weight: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.0.downsample.1.bias: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.0.downsample.1.running_mean: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.0.downsample.1.running_var: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.1.conv1.weight: Tensor[(256, 256, 3, 3), float32] /* ty=Tensor[(256, 256, 3, 3), float32] */, %layer3.1.bn1.weight: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.1.bn1.bias: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.1.bn1.running_mean: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.1.bn1.running_var: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.1.conv2.weight: Tensor[(256, 256, 3, 3), float32] /* ty=Tensor[(256, 256, 3, 3), float32] */, %layer3.1.bn2.weight: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.1.bn2.bias: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.1.bn2.running_mean: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer3.1.bn2.running_var: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %layer4.0.conv1.weight: Tensor[(512, 256, 3, 3), float32] /* ty=Tensor[(512, 256, 3, 3), float32] */, %layer4.0.bn1.weight: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.0.bn1.bias: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.0.bn1.running_mean: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.0.bn1.running_var: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.0.conv2.weight: Tensor[(512, 512, 3, 3), float32] /* ty=Tensor[(512, 512, 3, 3), float32] */, %layer4.0.bn2.weight: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.0.bn2.bias: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.0.bn2.running_mean: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.0.bn2.running_var: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.0.downsample.0.weight: Tensor[(512, 256, 1, 1), float32] /* ty=Tensor[(512, 256, 1, 1), float32] */, %layer4.0.downsample.1.weight: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.0.downsample.1.bias: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.0.downsample.1.running_mean: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.0.downsample.1.running_var: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.1.conv1.weight: Tensor[(512, 512, 3, 3), float32] /* ty=Tensor[(512, 512, 3, 3), float32] */, %layer4.1.bn1.weight: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.1.bn1.bias: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.1.bn1.running_mean: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.1.bn1.running_var: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.1.conv2.weight: Tensor[(512, 512, 3, 3), float32] /* ty=Tensor[(512, 512, 3, 3), float32] */, %layer4.1.bn2.weight: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.1.bn2.bias: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.1.bn2.running_mean: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %layer4.1.bn2.running_var: Tensor[(512), float32] /* ty=Tensor[(512), float32] */, %fc.weight: Tensor[(1000, 512), float32] /* ty=Tensor[(1000, 512), float32] */, %fc.bias: Tensor[(1000), float32] /* ty=Tensor[(1000), float32] */) -> Tensor[(1, 1000), float16] {
  %0 = cast(%input0, dtype="float16") /* ty=Tensor[(1, 3, 224, 224), float16] */;
  %1 = cast(%conv1.weight, dtype="float16") /* ty=Tensor[(64, 3, 7, 7), float16] */;
  %2 = nn.conv2d(%0, %1, strides=[2, 2], padding=[3, 3, 3, 3], channels=64, kernel_size=[7, 7], out_dtype="float32") /* ty=Tensor[(1, 64, 112, 112), float32] */;
  %3 = cast(%2, dtype="float16") /* ty=Tensor[(1, 64, 112, 112), float16] */;
  %4 = cast(%bn1.weight, dtype="float16") /* ty=Tensor[(64), float16] */;
  %5 = cast(%bn1.bias, dtype="float16") /* ty=Tensor[(64), float16] */;
  %6 = cast(%bn1.running_mean, dtype="float16") /* ty=Tensor[(64), float16] */;
  %7 = cast(%bn1.running_var, dtype="float16") /* ty=Tensor[(64), float16] */;
  %8 = nn.batch_norm(%3, %4, %5, %6, %7) /* ty=(Tensor[(1, 64, 112, 112), float16], Tensor[(64), float16], Tensor[(64), float16]) */;
  %9 = %8.0 /* ty=Tensor[(1, 64, 112, 112), float16] */;
  %10 = nn.relu(%9) /* ty=Tensor[(1, 64, 112, 112), float16] */;
  %11 = nn.max_pool2d(%10, pool_size=[3, 3], strides=[2, 2], padding=[1, 1, 1, 1]) /* ty=Tensor[(1, 64, 56, 56), float16] */;
  %12 = cast(%layer1.0.conv1.weight, dtype="float16") /* ty=Tensor[(64, 64, 3, 3), float16] */;
  %13 = nn.conv2d(%11, %12, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %14 = cast(%13, dtype="float16") /* ty=Tensor[(1, 64, 56, 56), float16] */;
  %15 = cast(%layer1.0.bn1.weight, dtype="float16") /* ty=Tensor[(64), float16] */;
  %16 = cast(%layer1.0.bn1.bias, dtype="float16") /* ty=Tensor[(64), float16] */;
  %17 = cast(%layer1.0.bn1.running_mean, dtype="float16") /* ty=Tensor[(64), float16] */;
  %18 = cast(%layer1.0.bn1.running_var, dtype="float16") /* ty=Tensor[(64), float16] */;
  %19 = nn.batch_norm(%14, %15, %16, %17, %18) /* ty=(Tensor[(1, 64, 56, 56), float16], Tensor[(64), float16], Tensor[(64), float16]) */;
  %20 = %19.0 /* ty=Tensor[(1, 64, 56, 56), float16] */;
  %21 = nn.relu(%20) /* ty=Tensor[(1, 64, 56, 56), float16] */;
  %22 = cast(%layer1.0.conv2.weight, dtype="float16") /* ty=Tensor[(64, 64, 3, 3), float16] */;
  %23 = nn.conv2d(%21, %22, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %24 = cast(%23, dtype="float16") /* ty=Tensor[(1, 64, 56, 56), float16] */;
  %25 = cast(%layer1.0.bn2.weight, dtype="float16") /* ty=Tensor[(64), float16] */;
  %26 = cast(%layer1.0.bn2.bias, dtype="float16") /* ty=Tensor[(64), float16] */;
  %27 = cast(%layer1.0.bn2.running_mean, dtype="float16") /* ty=Tensor[(64), float16] */;
  %28 = cast(%layer1.0.bn2.running_var, dtype="float16") /* ty=Tensor[(64), float16] */;
  %29 = nn.batch_norm(%24, %25, %26, %27, %28) /* ty=(Tensor[(1, 64, 56, 56), float16], Tensor[(64), float16], Tensor[(64), float16]) */;
  %30 = %29.0 /* ty=Tensor[(1, 64, 56, 56), float16] */;
  %31 = add(%30, %11) /* ty=Tensor[(1, 64, 56, 56), float16] */;
  %32 = nn.relu(%31) /* ty=Tensor[(1, 64, 56, 56), float16] */;
  %33 = cast(%layer1.1.conv1.weight, dtype="float16") /* ty=Tensor[(64, 64, 3, 3), float16] */;
  %34 = nn.conv2d(%32, %33, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %35 = cast(%34, dtype="float16") /* ty=Tensor[(1, 64, 56, 56), float16] */;
  %36 = cast(%layer1.1.bn1.weight, dtype="float16") /* ty=Tensor[(64), float16] */;
  %37 = cast(%layer1.1.bn1.bias, dtype="float16") /* ty=Tensor[(64), float16] */;
  %38 = cast(%layer1.1.bn1.running_mean, dtype="float16") /* ty=Tensor[(64), float16] */;
  %39 = cast(%layer1.1.bn1.running_var, dtype="float16") /* ty=Tensor[(64), float16] */;
  %40 = nn.batch_norm(%35, %36, %37, %38, %39) /* ty=(Tensor[(1, 64, 56, 56), float16], Tensor[(64), float16], Tensor[(64), float16]) */;
  %41 = %40.0 /* ty=Tensor[(1, 64, 56, 56), float16] */;
  %42 = nn.relu(%41) /* ty=Tensor[(1, 64, 56, 56), float16] */;
  %43 = cast(%layer1.1.conv2.weight, dtype="float16") /* ty=Tensor[(64, 64, 3, 3), float16] */;
  %44 = nn.conv2d(%42, %43, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 64, 56, 56), float32] */;
  %45 = cast(%44, dtype="float16") /* ty=Tensor[(1, 64, 56, 56), float16] */;
  %46 = cast(%layer1.1.bn2.weight, dtype="float16") /* ty=Tensor[(64), float16] */;
  %47 = cast(%layer1.1.bn2.bias, dtype="float16") /* ty=Tensor[(64), float16] */;
  %48 = cast(%layer1.1.bn2.running_mean, dtype="float16") /* ty=Tensor[(64), float16] */;
  %49 = cast(%layer1.1.bn2.running_var, dtype="float16") /* ty=Tensor[(64), float16] */;
  %50 = nn.batch_norm(%45, %46, %47, %48, %49) /* ty=(Tensor[(1, 64, 56, 56), float16], Tensor[(64), float16], Tensor[(64), float16]) */;
  %51 = %50.0 /* ty=Tensor[(1, 64, 56, 56), float16] */;
  %52 = add(%51, %32) /* ty=Tensor[(1, 64, 56, 56), float16] */;
  %53 = nn.relu(%52) /* ty=Tensor[(1, 64, 56, 56), float16] */;
  %54 = cast(%layer2.0.conv1.weight, dtype="float16") /* ty=Tensor[(128, 64, 3, 3), float16] */;
  %55 = nn.conv2d(%53, %54, strides=[2, 2], padding=[1, 1, 1, 1], channels=128, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 128, 28, 28), float32] */;
  %56 = cast(%55, dtype="float16") /* ty=Tensor[(1, 128, 28, 28), float16] */;
  %57 = cast(%layer2.0.bn1.weight, dtype="float16") /* ty=Tensor[(128), float16] */;
  %58 = cast(%layer2.0.bn1.bias, dtype="float16") /* ty=Tensor[(128), float16] */;
  %59 = cast(%layer2.0.bn1.running_mean, dtype="float16") /* ty=Tensor[(128), float16] */;
  %60 = cast(%layer2.0.bn1.running_var, dtype="float16") /* ty=Tensor[(128), float16] */;
  %61 = nn.batch_norm(%56, %57, %58, %59, %60) /* ty=(Tensor[(1, 128, 28, 28), float16], Tensor[(128), float16], Tensor[(128), float16]) */;
  %62 = %61.0 /* ty=Tensor[(1, 128, 28, 28), float16] */;
  %63 = nn.relu(%62) /* ty=Tensor[(1, 128, 28, 28), float16] */;
  %64 = cast(%layer2.0.conv2.weight, dtype="float16") /* ty=Tensor[(128, 128, 3, 3), float16] */;
  %65 = nn.conv2d(%63, %64, padding=[1, 1, 1, 1], channels=128, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 128, 28, 28), float32] */;
  %66 = cast(%65, dtype="float16") /* ty=Tensor[(1, 128, 28, 28), float16] */;
  %67 = cast(%layer2.0.bn2.weight, dtype="float16") /* ty=Tensor[(128), float16] */;
  %68 = cast(%layer2.0.bn2.bias, dtype="float16") /* ty=Tensor[(128), float16] */;
  %69 = cast(%layer2.0.bn2.running_mean, dtype="float16") /* ty=Tensor[(128), float16] */;
  %70 = cast(%layer2.0.bn2.running_var, dtype="float16") /* ty=Tensor[(128), float16] */;
  %71 = nn.batch_norm(%66, %67, %68, %69, %70) /* ty=(Tensor[(1, 128, 28, 28), float16], Tensor[(128), float16], Tensor[(128), float16]) */;
  %72 = cast(%layer2.0.downsample.0.weight, dtype="float16") /* ty=Tensor[(128, 64, 1, 1), float16] */;
  %73 = nn.conv2d(%53, %72, strides=[2, 2], padding=[0, 0, 0, 0], channels=128, kernel_size=[1, 1], out_dtype="float32") /* ty=Tensor[(1, 128, 28, 28), float32] */;
  %74 = cast(%73, dtype="float16") /* ty=Tensor[(1, 128, 28, 28), float16] */;
  %75 = cast(%layer2.0.downsample.1.weight, dtype="float16") /* ty=Tensor[(128), float16] */;
  %76 = cast(%layer2.0.downsample.1.bias, dtype="float16") /* ty=Tensor[(128), float16] */;
  %77 = cast(%layer2.0.downsample.1.running_mean, dtype="float16") /* ty=Tensor[(128), float16] */;
  %78 = cast(%layer2.0.downsample.1.running_var, dtype="float16") /* ty=Tensor[(128), float16] */;
  %79 = nn.batch_norm(%74, %75, %76, %77, %78) /* ty=(Tensor[(1, 128, 28, 28), float16], Tensor[(128), float16], Tensor[(128), float16]) */;
  %80 = %71.0 /* ty=Tensor[(1, 128, 28, 28), float16] */;
  %81 = %79.0 /* ty=Tensor[(1, 128, 28, 28), float16] */;
  %82 = add(%80, %81) /* ty=Tensor[(1, 128, 28, 28), float16] */;
  %83 = nn.relu(%82) /* ty=Tensor[(1, 128, 28, 28), float16] */;
  %84 = cast(%layer2.1.conv1.weight, dtype="float16") /* ty=Tensor[(128, 128, 3, 3), float16] */;
  %85 = nn.conv2d(%83, %84, padding=[1, 1, 1, 1], channels=128, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 128, 28, 28), float32] */;
  %86 = cast(%85, dtype="float16") /* ty=Tensor[(1, 128, 28, 28), float16] */;
  %87 = cast(%layer2.1.bn1.weight, dtype="float16") /* ty=Tensor[(128), float16] */;
  %88 = cast(%layer2.1.bn1.bias, dtype="float16") /* ty=Tensor[(128), float16] */;
  %89 = cast(%layer2.1.bn1.running_mean, dtype="float16") /* ty=Tensor[(128), float16] */;
  %90 = cast(%layer2.1.bn1.running_var, dtype="float16") /* ty=Tensor[(128), float16] */;
  %91 = nn.batch_norm(%86, %87, %88, %89, %90) /* ty=(Tensor[(1, 128, 28, 28), float16], Tensor[(128), float16], Tensor[(128), float16]) */;
  %92 = %91.0 /* ty=Tensor[(1, 128, 28, 28), float16] */;
  %93 = nn.relu(%92) /* ty=Tensor[(1, 128, 28, 28), float16] */;
  %94 = cast(%layer2.1.conv2.weight, dtype="float16") /* ty=Tensor[(128, 128, 3, 3), float16] */;
  %95 = nn.conv2d(%93, %94, padding=[1, 1, 1, 1], channels=128, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 128, 28, 28), float32] */;
  %96 = cast(%95, dtype="float16") /* ty=Tensor[(1, 128, 28, 28), float16] */;
  %97 = cast(%layer2.1.bn2.weight, dtype="float16") /* ty=Tensor[(128), float16] */;
  %98 = cast(%layer2.1.bn2.bias, dtype="float16") /* ty=Tensor[(128), float16] */;
  %99 = cast(%layer2.1.bn2.running_mean, dtype="float16") /* ty=Tensor[(128), float16] */;
  %100 = cast(%layer2.1.bn2.running_var, dtype="float16") /* ty=Tensor[(128), float16] */;
  %101 = nn.batch_norm(%96, %97, %98, %99, %100) /* ty=(Tensor[(1, 128, 28, 28), float16], Tensor[(128), float16], Tensor[(128), float16]) */;
  %102 = %101.0 /* ty=Tensor[(1, 128, 28, 28), float16] */;
  %103 = add(%102, %83) /* ty=Tensor[(1, 128, 28, 28), float16] */;
  %104 = nn.relu(%103) /* ty=Tensor[(1, 128, 28, 28), float16] */;
  %105 = cast(%layer3.0.conv1.weight, dtype="float16") /* ty=Tensor[(256, 128, 3, 3), float16] */;
  %106 = nn.conv2d(%104, %105, strides=[2, 2], padding=[1, 1, 1, 1], channels=256, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 256, 14, 14), float32] */;
  %107 = cast(%106, dtype="float16") /* ty=Tensor[(1, 256, 14, 14), float16] */;
  %108 = cast(%layer3.0.bn1.weight, dtype="float16") /* ty=Tensor[(256), float16] */;
  %109 = cast(%layer3.0.bn1.bias, dtype="float16") /* ty=Tensor[(256), float16] */;
  %110 = cast(%layer3.0.bn1.running_mean, dtype="float16") /* ty=Tensor[(256), float16] */;
  %111 = cast(%layer3.0.bn1.running_var, dtype="float16") /* ty=Tensor[(256), float16] */;
  %112 = nn.batch_norm(%107, %108, %109, %110, %111) /* ty=(Tensor[(1, 256, 14, 14), float16], Tensor[(256), float16], Tensor[(256), float16]) */;
  %113 = %112.0 /* ty=Tensor[(1, 256, 14, 14), float16] */;
  %114 = nn.relu(%113) /* ty=Tensor[(1, 256, 14, 14), float16] */;
  %115 = cast(%layer3.0.conv2.weight, dtype="float16") /* ty=Tensor[(256, 256, 3, 3), float16] */;
  %116 = nn.conv2d(%114, %115, padding=[1, 1, 1, 1], channels=256, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 256, 14, 14), float32] */;
  %117 = cast(%116, dtype="float16") /* ty=Tensor[(1, 256, 14, 14), float16] */;
  %118 = cast(%layer3.0.bn2.weight, dtype="float16") /* ty=Tensor[(256), float16] */;
  %119 = cast(%layer3.0.bn2.bias, dtype="float16") /* ty=Tensor[(256), float16] */;
  %120 = cast(%layer3.0.bn2.running_mean, dtype="float16") /* ty=Tensor[(256), float16] */;
  %121 = cast(%layer3.0.bn2.running_var, dtype="float16") /* ty=Tensor[(256), float16] */;
  %122 = nn.batch_norm(%117, %118, %119, %120, %121) /* ty=(Tensor[(1, 256, 14, 14), float16], Tensor[(256), float16], Tensor[(256), float16]) */;
  %123 = cast(%layer3.0.downsample.0.weight, dtype="float16") /* ty=Tensor[(256, 128, 1, 1), float16] */;
  %124 = nn.conv2d(%104, %123, strides=[2, 2], padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1], out_dtype="float32") /* ty=Tensor[(1, 256, 14, 14), float32] */;
  %125 = cast(%124, dtype="float16") /* ty=Tensor[(1, 256, 14, 14), float16] */;
  %126 = cast(%layer3.0.downsample.1.weight, dtype="float16") /* ty=Tensor[(256), float16] */;
  %127 = cast(%layer3.0.downsample.1.bias, dtype="float16") /* ty=Tensor[(256), float16] */;
  %128 = cast(%layer3.0.downsample.1.running_mean, dtype="float16") /* ty=Tensor[(256), float16] */;
  %129 = cast(%layer3.0.downsample.1.running_var, dtype="float16") /* ty=Tensor[(256), float16] */;
  %130 = nn.batch_norm(%125, %126, %127, %128, %129) /* ty=(Tensor[(1, 256, 14, 14), float16], Tensor[(256), float16], Tensor[(256), float16]) */;
  %131 = %122.0 /* ty=Tensor[(1, 256, 14, 14), float16] */;
  %132 = %130.0 /* ty=Tensor[(1, 256, 14, 14), float16] */;
  %133 = add(%131, %132) /* ty=Tensor[(1, 256, 14, 14), float16] */;
  %134 = nn.relu(%133) /* ty=Tensor[(1, 256, 14, 14), float16] */;
  %135 = cast(%layer3.1.conv1.weight, dtype="float16") /* ty=Tensor[(256, 256, 3, 3), float16] */;
  %136 = nn.conv2d(%134, %135, padding=[1, 1, 1, 1], channels=256, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 256, 14, 14), float32] */;
  %137 = cast(%136, dtype="float16") /* ty=Tensor[(1, 256, 14, 14), float16] */;
  %138 = cast(%layer3.1.bn1.weight, dtype="float16") /* ty=Tensor[(256), float16] */;
  %139 = cast(%layer3.1.bn1.bias, dtype="float16") /* ty=Tensor[(256), float16] */;
  %140 = cast(%layer3.1.bn1.running_mean, dtype="float16") /* ty=Tensor[(256), float16] */;
  %141 = cast(%layer3.1.bn1.running_var, dtype="float16") /* ty=Tensor[(256), float16] */;
  %142 = nn.batch_norm(%137, %138, %139, %140, %141) /* ty=(Tensor[(1, 256, 14, 14), float16], Tensor[(256), float16], Tensor[(256), float16]) */;
  %143 = %142.0 /* ty=Tensor[(1, 256, 14, 14), float16] */;
  %144 = nn.relu(%143) /* ty=Tensor[(1, 256, 14, 14), float16] */;
  %145 = cast(%layer3.1.conv2.weight, dtype="float16") /* ty=Tensor[(256, 256, 3, 3), float16] */;
  %146 = nn.conv2d(%144, %145, padding=[1, 1, 1, 1], channels=256, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 256, 14, 14), float32] */;
  %147 = cast(%146, dtype="float16") /* ty=Tensor[(1, 256, 14, 14), float16] */;
  %148 = cast(%layer3.1.bn2.weight, dtype="float16") /* ty=Tensor[(256), float16] */;
  %149 = cast(%layer3.1.bn2.bias, dtype="float16") /* ty=Tensor[(256), float16] */;
  %150 = cast(%layer3.1.bn2.running_mean, dtype="float16") /* ty=Tensor[(256), float16] */;
  %151 = cast(%layer3.1.bn2.running_var, dtype="float16") /* ty=Tensor[(256), float16] */;
  %152 = nn.batch_norm(%147, %148, %149, %150, %151) /* ty=(Tensor[(1, 256, 14, 14), float16], Tensor[(256), float16], Tensor[(256), float16]) */;
  %153 = %152.0 /* ty=Tensor[(1, 256, 14, 14), float16] */;
  %154 = add(%153, %134) /* ty=Tensor[(1, 256, 14, 14), float16] */;
  %155 = nn.relu(%154) /* ty=Tensor[(1, 256, 14, 14), float16] */;
  %156 = cast(%layer4.0.conv1.weight, dtype="float16") /* ty=Tensor[(512, 256, 3, 3), float16] */;
  %157 = nn.conv2d(%155, %156, strides=[2, 2], padding=[1, 1, 1, 1], channels=512, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 512, 7, 7), float32] */;
  %158 = cast(%157, dtype="float16") /* ty=Tensor[(1, 512, 7, 7), float16] */;
  %159 = cast(%layer4.0.bn1.weight, dtype="float16") /* ty=Tensor[(512), float16] */;
  %160 = cast(%layer4.0.bn1.bias, dtype="float16") /* ty=Tensor[(512), float16] */;
  %161 = cast(%layer4.0.bn1.running_mean, dtype="float16") /* ty=Tensor[(512), float16] */;
  %162 = cast(%layer4.0.bn1.running_var, dtype="float16") /* ty=Tensor[(512), float16] */;
  %163 = nn.batch_norm(%158, %159, %160, %161, %162) /* ty=(Tensor[(1, 512, 7, 7), float16], Tensor[(512), float16], Tensor[(512), float16]) */;
  %164 = %163.0 /* ty=Tensor[(1, 512, 7, 7), float16] */;
  %165 = nn.relu(%164) /* ty=Tensor[(1, 512, 7, 7), float16] */;
  %166 = cast(%layer4.0.conv2.weight, dtype="float16") /* ty=Tensor[(512, 512, 3, 3), float16] */;
  %167 = nn.conv2d(%165, %166, padding=[1, 1, 1, 1], channels=512, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 512, 7, 7), float32] */;
  %168 = cast(%167, dtype="float16") /* ty=Tensor[(1, 512, 7, 7), float16] */;
  %169 = cast(%layer4.0.bn2.weight, dtype="float16") /* ty=Tensor[(512), float16] */;
  %170 = cast(%layer4.0.bn2.bias, dtype="float16") /* ty=Tensor[(512), float16] */;
  %171 = cast(%layer4.0.bn2.running_mean, dtype="float16") /* ty=Tensor[(512), float16] */;
  %172 = cast(%layer4.0.bn2.running_var, dtype="float16") /* ty=Tensor[(512), float16] */;
  %173 = nn.batch_norm(%168, %169, %170, %171, %172) /* ty=(Tensor[(1, 512, 7, 7), float16], Tensor[(512), float16], Tensor[(512), float16]) */;
  %174 = cast(%layer4.0.downsample.0.weight, dtype="float16") /* ty=Tensor[(512, 256, 1, 1), float16] */;
  %175 = nn.conv2d(%155, %174, strides=[2, 2], padding=[0, 0, 0, 0], channels=512, kernel_size=[1, 1], out_dtype="float32") /* ty=Tensor[(1, 512, 7, 7), float32] */;
  %176 = cast(%175, dtype="float16") /* ty=Tensor[(1, 512, 7, 7), float16] */;
  %177 = cast(%layer4.0.downsample.1.weight, dtype="float16") /* ty=Tensor[(512), float16] */;
  %178 = cast(%layer4.0.downsample.1.bias, dtype="float16") /* ty=Tensor[(512), float16] */;
  %179 = cast(%layer4.0.downsample.1.running_mean, dtype="float16") /* ty=Tensor[(512), float16] */;
  %180 = cast(%layer4.0.downsample.1.running_var, dtype="float16") /* ty=Tensor[(512), float16] */;
  %181 = nn.batch_norm(%176, %177, %178, %179, %180) /* ty=(Tensor[(1, 512, 7, 7), float16], Tensor[(512), float16], Tensor[(512), float16]) */;
  %182 = %173.0 /* ty=Tensor[(1, 512, 7, 7), float16] */;
  %183 = %181.0 /* ty=Tensor[(1, 512, 7, 7), float16] */;
  %184 = add(%182, %183) /* ty=Tensor[(1, 512, 7, 7), float16] */;
  %185 = nn.relu(%184) /* ty=Tensor[(1, 512, 7, 7), float16] */;
  %186 = cast(%layer4.1.conv1.weight, dtype="float16") /* ty=Tensor[(512, 512, 3, 3), float16] */;
  %187 = nn.conv2d(%185, %186, padding=[1, 1, 1, 1], channels=512, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 512, 7, 7), float32] */;
  %188 = cast(%187, dtype="float16") /* ty=Tensor[(1, 512, 7, 7), float16] */;
  %189 = cast(%layer4.1.bn1.weight, dtype="float16") /* ty=Tensor[(512), float16] */;
  %190 = cast(%layer4.1.bn1.bias, dtype="float16") /* ty=Tensor[(512), float16] */;
  %191 = cast(%layer4.1.bn1.running_mean, dtype="float16") /* ty=Tensor[(512), float16] */;
  %192 = cast(%layer4.1.bn1.running_var, dtype="float16") /* ty=Tensor[(512), float16] */;
  %193 = nn.batch_norm(%188, %189, %190, %191, %192) /* ty=(Tensor[(1, 512, 7, 7), float16], Tensor[(512), float16], Tensor[(512), float16]) */;
  %194 = %193.0 /* ty=Tensor[(1, 512, 7, 7), float16] */;
  %195 = nn.relu(%194) /* ty=Tensor[(1, 512, 7, 7), float16] */;
  %196 = cast(%layer4.1.conv2.weight, dtype="float16") /* ty=Tensor[(512, 512, 3, 3), float16] */;
  %197 = nn.conv2d(%195, %196, padding=[1, 1, 1, 1], channels=512, kernel_size=[3, 3], out_dtype="float32") /* ty=Tensor[(1, 512, 7, 7), float32] */;
  %198 = cast(%197, dtype="float16") /* ty=Tensor[(1, 512, 7, 7), float16] */;
  %199 = cast(%layer4.1.bn2.weight, dtype="float16") /* ty=Tensor[(512), float16] */;
  %200 = cast(%layer4.1.bn2.bias, dtype="float16") /* ty=Tensor[(512), float16] */;
  %201 = cast(%layer4.1.bn2.running_mean, dtype="float16") /* ty=Tensor[(512), float16] */;
  %202 = cast(%layer4.1.bn2.running_var, dtype="float16") /* ty=Tensor[(512), float16] */;
  %203 = nn.batch_norm(%198, %199, %200, %201, %202) /* ty=(Tensor[(1, 512, 7, 7), float16], Tensor[(512), float16], Tensor[(512), float16]) */;
  %204 = %203.0 /* ty=Tensor[(1, 512, 7, 7), float16] */;
  %205 = add(%204, %185) /* ty=Tensor[(1, 512, 7, 7), float16] */;
  %206 = nn.relu(%205) /* ty=Tensor[(1, 512, 7, 7), float16] */;
  %207 = cast(%206, dtype="float32") /* ty=Tensor[(1, 512, 7, 7), float32] */;
  %208 = nn.adaptive_avg_pool2d(%207, output_size=[1, 1]) /* ty=Tensor[(1, 512, 1, 1), float32] */;
  %209 = reshape(%208, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 512, 1, 1), float32] */;
  %210 = squeeze(%209, axis=[2, 3]) /* ty=Tensor[(1, 512), float32] */;
  %211 = cast(%210, dtype="float16") /* ty=Tensor[(1, 512), float16] */;
  %212 = cast(%fc.weight, dtype="float16") /* ty=Tensor[(1000, 512), float16] */;
  %213 = nn.dense(%211, %212, units=None, out_dtype="float32") /* ty=Tensor[(1, 1000), float32] */;
  %214 = cast(%213, dtype="float16") /* ty=Tensor[(1, 1000), float16] */;
  %215 = cast(%fc.bias, dtype="float16") /* ty=Tensor[(1000), float16] */;
  nn.bias_add(%214, %215, axis=-1) /* ty=Tensor[(1, 1000), float16] */
}

As you can see in the IR, the architecture now contains cast operations, which are needed to convert to FP16 precision. You can also use “float16” or “float32” precisions as other dtype options.

Compile the model with relay

Specify Adreno target before compiling to generate texture leveraging kernels and get all the benefits of textures Note: This generated example running on our x86 server for demonstration. If running it on the Android device, we need to specify its instruction set. Set local_demo to False if you want to run this tutorial with a real device.

local_demo = True

# by default on CPU target will execute.
# select 'cpu', 'opencl' and 'vulkan'
test_target = "cpu"

# Change target configuration.
# Run `adb shell cat /proc/cpuinfo` to find the arch.
arch = "arm64"
target = tvm.target.Target("llvm -mtriple=%s-linux-android" % arch)

if local_demo:
    target = tvm.target.Target("llvm")
elif test_target == "opencl":
    target = tvm.target.Target("opencl", host=target)
elif test_target == "vulkan":
    target = tvm.target.Target("vulkan", host=target)

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

Deploy the Model Remotely by RPC

Using RPC you can deploy the model from host machine to the remote Adreno device

rpc_tracker_host = os.environ.get("TVM_TRACKER_HOST", "127.0.0.1")
rpc_tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190))
key = "android"

if local_demo:
    remote = rpc.LocalSession()
else:
    tracker = rpc.connect_tracker(rpc_tracker_host, rpc_tracker_port)
    # When running a heavy model, we should increase the `session_timeout`
    remote = tracker.request(key, priority=0, session_timeout=60)

if local_demo:
    dev = remote.cpu(0)
elif test_target == "opencl":
    dev = remote.cl(0)
elif test_target == "vulkan":
    dev = remote.vulkan(0)
else:
    dev = remote.cpu(0)

temp = utils.tempdir()
dso_binary = "dev_lib_cl.so"
dso_binary_path = temp.relpath(dso_binary)
fcompile = ndk.create_shared if not local_demo else None
lib.export_library(dso_binary_path, fcompile)
remote_path = "/data/local/tmp/" + dso_binary
remote.upload(dso_binary_path)
rlib = remote.load_module(dso_binary)
m = graph_executor.GraphModule(rlib["default"](dev))

Run inference

We now can set inputs, infer our model and get predictions as output

m.set_input(input_name, tvm.nd.array(img.astype("float32")))
m.run()
tvm_output = m.get_output(0)

Get predictions and performance statistic

This piece of code displays the top-1 and top-5 predictions, as well as provides information about the model’s performance

from os.path import join, isfile
from matplotlib import pyplot as plt
from tvm.contrib import download

# Download ImageNet categories
categ_url = "https://github.com/uwsampl/web-data/raw/main/vta/models/"
categ_fn = "synset.txt"
download.download(join(categ_url, categ_fn), categ_fn)
synset = eval(open(categ_fn).read())

top_categories = np.argsort(tvm_output.asnumpy()[0])
top5 = np.flip(top_categories, axis=0)[:5]

# Report top-1 classification result
print("Top-1 id: {}, class name: {}".format(top5[1 - 1], synset[top5[1 - 1]]))

# Report top-5 classification results
print("\nTop5 predictions: \n")
print("\t#1:", synset[top5[1 - 1]])
print("\t#2:", synset[top5[2 - 1]])
print("\t#3:", synset[top5[3 - 1]])
print("\t#4:", synset[top5[4 - 1]])
print("\t#5:", synset[top5[5 - 1]])
print("\t", top5)
ImageNetClassifier = False
for k in top_categories[-5:]:
    if "cat" in synset[k]:
        ImageNetClassifier = True
assert ImageNetClassifier, "Failed ImageNet classifier validation check"

print("Evaluate inference time cost...")
print(m.benchmark(dev, number=1, repeat=10))
/workspace/python/tvm/runtime/ndarray.py:200: DeprecationWarning: NDArray.asnumpy() will be deprecated in TVM v0.8 release. Please use NDArray.numpy() instead.
  DeprecationWarning,
Top-1 id: 281, class name: tabby, tabby cat

Top5 predictions:

        #1: tabby, tabby cat
        #2: tiger cat
        #3: lynx, catamount
        #4: red fox, Vulpes vulpes
        #5: Egyptian cat
         [281 282 287 277 285]
Evaluate inference time cost...
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)
 2545.4545    2545.3663    2547.9816    2543.9309      1.2446

Gallery generated by Sphinx-Gallery