Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
Intrinsics and Math Functions¶
Author: Tianqi Chen
While TVM supports basic arithmetic operations. In many cases
usually we will need more complicated builtin functions.
For example exp
to take the exponential of the function.
These functions are target system dependent and may have different names of different target platforms. In this tutorial, we will learn how we can invoke these target specific functions, and how we can unify the interface via TVM’s intrinsic API.
from __future__ import absolute_import, print_function
import numpy as np
import tvm
from tvm import te
from tvm.ir import register_op_attr, register_intrin_lowering
Direct Declare Extern Math Call¶
The most straight-forward way to call target specific function is via
extern function call construct in tvm.
In the following example, we use tvm.tir.call_pure_extern
to call
__expf
function, which is only available under CUDA.
n = te.var("n")
A = te.placeholder((n,), name="A")
B = te.compute(A.shape, lambda i: tvm.tir.call_pure_extern("float32", "__expf", A[i]), name="B")
s = te.create_schedule(B.op)
num_thread = 64
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))
f = tvm.build(s, [A, B], "cuda", name="myexp")
print(f.imported_modules[0].get_source())
#ifdef _WIN32
using uint = unsigned int;
using uchar = unsigned char;
using ushort = unsigned short;
using int64_t = long long;
using uint64_t = unsigned long long;
#else
#define uint unsigned int
#define uchar unsigned char
#define ushort unsigned short
#define int64_t long long
#define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(64) myexp_kernel0(float* __restrict__ B, float* __restrict__ A, int n, int stride, int stride_1) {
if (((int)blockIdx.x) < (n >> 6)) {
B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = __expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]);
} else {
if (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) < n) {
B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = __expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]);
}
}
}
Unified Intrinsic Call¶
The above code verifies that direct external call can be used to call into device specific functions. However, the above way only works for CUDA target with float type. Ideally, we want to write same code for any device and any data type.
TVM intrinsic provides the user a mechanism to achieve this, and this
is the recommended way to solve the problem.
The following code use te.exp instead, which create an intrinsic call
:py:tvm.te.exp()
to do the exponential.
n = te.var("n")
A = te.placeholder((n,), name="A")
B = te.compute(A.shape, lambda i: te.exp(A[i]), name="B")
s = te.create_schedule(B.op)
num_thread = 64
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))
fcuda = tvm.build(s, [A, B], "cuda", name="myexp")
print(fcuda.imported_modules[0].get_source())
#ifdef _WIN32
using uint = unsigned int;
using uchar = unsigned char;
using ushort = unsigned short;
using int64_t = long long;
using uint64_t = unsigned long long;
#else
#define uint unsigned int
#define uchar unsigned char
#define ushort unsigned short
#define int64_t long long
#define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(64) myexp_kernel0(float* __restrict__ B, float* __restrict__ A, int n, int stride, int stride_1) {
if (((int)blockIdx.x) < (n >> 6)) {
B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = __expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]);
} else {
if (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) < n) {
B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = __expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]);
}
}
}
We can find that the code works for both CUDA and opencl. The same te.exp can also be used for float64 data types.
fopencl = tvm.build(s, [A, B], "opencl", name="myexp")
print(fopencl.imported_modules[0].get_source())
// Function: myexp_kernel0
__kernel void myexp_kernel0(__global float* restrict B, __global float* restrict A, int n, int stride, int stride_1) {
if (((int)get_group_id(0)) < (n >> 6)) {
B[(((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride)] = exp(A[(((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride_1)]);
} else {
if (((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) < n) {
B[(((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride)] = exp(A[(((((int)get_group_id(0)) * 64) + ((int)get_local_id(0))) * stride_1)]);
}
}
}
Intrinsic Lowering Rule¶
When tvm.te.exp()
is called, TVM creates an intrinsic Call Expr.
TVM uses transformation rules to transform the intrinsic
call to device specific extern calls.
TVM also allows user to customize the rules during runtime.
The following example customizes CUDA lowering rule for exp
.
def my_cuda_math_rule(op):
"""Customized CUDA intrinsic lowering rule"""
assert isinstance(op, tvm.tir.Call)
name = op.op.name
assert name.startswith("tir.")
dispatch_name = name[4:]
if op.dtype == "float32":
# call float function
return tvm.tir.call_pure_extern("float32", "%sf" % dispatch_name, op.args[0])
elif op.dtype == "float64":
# call double function
return tvm.tir.call_pure_extern("float32", dispatch_name, op.args[0])
else:
# cannot do translation, return self.
return op
register_intrin_lowering("tir.exp", target="cuda", f=my_cuda_math_rule, level=99)
<function my_cuda_math_rule at 0x7f65bd33a9e0>
Register the rule to TVM with override option to override existing rule.
Notice the difference between the printed code from previous one:
our new rule uses math function expf
instead of
fast math version __expf
.
fcuda = tvm.build(s, [A, B], "cuda", name="myexp")
print(fcuda.imported_modules[0].get_source())
#ifdef _WIN32
using uint = unsigned int;
using uchar = unsigned char;
using ushort = unsigned short;
using int64_t = long long;
using uint64_t = unsigned long long;
#else
#define uint unsigned int
#define uchar unsigned char
#define ushort unsigned short
#define int64_t long long
#define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(64) myexp_kernel0(float* __restrict__ B, float* __restrict__ A, int n, int stride, int stride_1) {
if (((int)blockIdx.x) < (n >> 6)) {
B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]);
} else {
if (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) < n) {
B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = expf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]);
}
}
}
Add Your Own Intrinsic¶
If there is an intrinsic that is not provided by TVM.
User can easily add new intrinsic by using the intrinsic rule system.
The following example add an intrinsic mylog
to the system.
def mylog(x):
"""customized log intrinsic function"""
return tvm.tir.call_intrin(x.dtype, "tir.mylog", x)
def my_cuda_mylog_rule(op):
"""CUDA lowering rule for log"""
if op.dtype == "float32":
return tvm.tir.call_pure_extern("float32", "logf", op.args[0])
elif op.dtype == "float64":
return tvm.tir.call_pure_extern("float64", "log", op.args[0])
else:
return op
# new op registration is triggered by registering an attribute of the op
register_op_attr("tir.mylog", "TCallEffectKind", tvm.tir.CallEffectKind.Pure)
register_intrin_lowering("tir.mylog", target="cuda", f=my_cuda_mylog_rule, level=99)
n = te.var("n")
A = te.placeholder((n,), name="A")
B = te.compute(A.shape, lambda i: mylog(A[i]), name="B")
s = te.create_schedule(B.op)
num_thread = 64
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))
fcuda = tvm.build(s, [A, B], "cuda", name="mylog")
print(fcuda.imported_modules[0].get_source())
#ifdef _WIN32
using uint = unsigned int;
using uchar = unsigned char;
using ushort = unsigned short;
using int64_t = long long;
using uint64_t = unsigned long long;
#else
#define uint unsigned int
#define uchar unsigned char
#define ushort unsigned short
#define int64_t long long
#define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(64) mylog_kernel0(float* __restrict__ B, float* __restrict__ A, int n, int stride, int stride_1) {
if (((int)blockIdx.x) < (n >> 6)) {
B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = logf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]);
} else {
if (((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) < n) {
B[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride)] = logf(A[(((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) * stride_1)]);
}
}
}