Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
Use Tensorize to Leverage Hardware Intrinsics¶
Author: Yizhi Liu
This is an introduction material on how to perform tensorization in TVM.
By using schedule primitive tensorize
,
people can replace a unit of computation with the corresponding intrinsics,
making it easy to leverage handcrafted micro-kernels,
as well as extend TVM to support new hardware architectures.
The purpose of this tutorial is to show the functionality and usage of tensorize instead of providing an efficient solution.
from __future__ import absolute_import, print_function
import tvm
from tvm import te
import tvm.testing
import numpy as np
Define Matrix Multiplication¶
Take matrix multiplication as our example.
Matmul first multiply the corresponding elements between two matrix,
then accumulate across a certain axis.
The following lines describe the computation A * B^T
in TVM.
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle):
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
A_1 = T.match_buffer(A, (1024, 64))
B_1 = T.match_buffer(B, (512, 64))
C_1 = T.match_buffer(C, (1024, 512))
for i, j in T.grid(1024, 512):
C_2 = T.buffer_decl((524288,), data=C_1.data)
C_2[i * 512 + j] = T.float32(0)
for k in range(64):
cse_var_1: T.int32 = i * 512 + j
A_2 = T.buffer_decl((65536,), data=A_1.data)
B_2 = T.buffer_decl((32768,), data=B_1.data)
C_2[cse_var_1] = C_2[cse_var_1] + A_2[i * 64 + k] * B_2[j * 64 + k]
Schedule the Matmul¶
Now, suppose we have an accelerator that supports matrix-vector multiplication (GEMV) as a hardware primitive, which can take arbitrary size of reduce axis, but another axis needs to be no larger than 16. Thus we break down the matmul loops to make the innermost loops a (16x64) GEMV.
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle):
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
A_1 = T.match_buffer(A, (1024, 64))
B_1 = T.match_buffer(B, (512, 64))
C_1 = T.match_buffer(C, (1024, 512))
for i, j_outer, j_inner in T.grid(1024, 32, 16):
C_2 = T.buffer_decl((524288,), data=C_1.data)
C_2[i * 512 + j_outer * 16 + j_inner] = T.float32(0)
for k in range(64):
cse_var_1: T.int32 = i * 512 + j_outer * 16 + j_inner
A_2 = T.buffer_decl((65536,), data=A_1.data)
B_2 = T.buffer_decl((32768,), data=B_1.data)
C_2[cse_var_1] = C_2[cse_var_1] + A_2[i * 64 + k] * B_2[j_outer * 1024 + j_inner * 64 + k]
As showed in the IR printed above,
the inner loops j.inner
along with k
together form a computation of GEMV
- within the inner most two loops, the index i
is fixed,
the access to the matrix A
only varies by k
,
which makes the access pattern of A
a “vector”.
In order to leverage our hypothetical hardware’s GEMV instruction,
we can tensorize over j.inner
.
Define GEMV Tensorization Intrinsic¶
Before scheduling the tensorization, we need to first define the intrinsic function for GEMV.
It includes two parts, the first is a compute definition of GEMV.
TVM uses it to match the computing pattern in the original Matmul schedule.
The second is to specify how to execute GEMV on the device,
which is done in intrin_func
below.
def intrin_gemv(m, l):
a = te.placeholder((l,), name="a")
b = te.placeholder((m, l), name="b")
k = te.reduce_axis((0, l), name="k")
c = te.compute((m,), lambda i: te.sum(a[k] * b[i, k], axis=k), name="c")
Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[1])
Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[te.var("s1"), 1])
Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1])
def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()
aa, bb = ins
cc = outs[0]
ib.emit(
tvm.tir.call_extern(
"int32",
"gemv_update",
cc.access_ptr("w"),
aa.access_ptr("r"),
bb.access_ptr("r"),
m,
l,
bb.strides[0],
)
)
return ib.get()
return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})
Here te.decl_tensor_intrin
declares how to execute the computation c.op
.
Our implementation simply takes the inputs and outputs,
converts them to pointers and emit an external function call.
Note that tensorization requires user to specify offset_factor
,
with this information, TVM has knowledge of whether the data is aligned
between the start address of the original data structure
and the offset being passed to tensorize,
so that it has chance to optimize with vectorized loading.
We set the factor to 1 for simplification.
Buffers are also declared for inputs and outputs, though this is not required,
we benefit from the extra information provided by buffers. For example, we pass
bb.strides[0]
as an argument to the external function gemv_update
.
For now bb.strides[0] == l
,
but later we will see how they can differ with more complicated schedules.
Note that we use te.var("s1")
as the first stride dimension for B
.
If the strides can be inferred
- in this case, TVM knows tensor B is compact thus the strides are [L, 1]
-
such placeholder can be put to let TVM automatically bind the inferred value for us.
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle):
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
A_1 = T.match_buffer(A, (1024, 64))
B_1 = T.match_buffer(B, (512, 64))
C_1 = T.match_buffer(C, (1024, 512))
for i, j_outer in T.grid(1024, 32):
T.call_extern("int32", "gemv_update", T.tvm_access_ptr(T.type_annotation("float32"), C_1.data, i * 512 + j_outer * 16, 16, 2), T.tvm_access_ptr(T.type_annotation("float32"), A_1.data, i * 64, 64, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_1.data, j_outer * 1024, 1024, 1), 16, 64, 64)
By tensorizing over yi
, the inner most two loops are
now replaced by the intrinsic function we defined before.
In order to build and run the module, let’s define the external function gemv_update
,
it is a naive implementation of GEMV, just for demonstration.
def gemv_impl():
cc_code = """
extern "C" int gemv_update(float *cc, float *aa, float *bb, int m, int l, int stride) {
for (int i = 0; i < m; ++i) {
for (int j = 0; j < l; ++j) {
cc[i] += aa[j] * bb[i * stride + j];
}
}
return 0;
}
"""
from tvm.contrib import utils, clang
temp = utils.tempdir()
ll_path = temp.relpath("temp.ll")
# Create LLVM ir from c source code
ll_code = clang.create_llvm(cc_code, output=ll_path)
return ll_code
Now we leverage the pragma attribute import_llvm
to import llvm asm inline.
The importing needs to happen before the tensorized GEMV being executed.
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle):
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
A_1 = T.match_buffer(A, (1024, 64))
B_1 = T.match_buffer(B, (512, 64))
C_1 = T.match_buffer(C, (1024, 512))
i = T.var("int32")
T.attr(T.iter_var(i, None, "DataPar", ""), "pragma_import_llvm", "; ModuleID = '/tmp/tmpuyepqa17/input0.cc'\nsource_filename = \"/tmp/tmpuyepqa17/input0.cc\"\ntarget datalayout = \"e-m:e-i64:64-f80:128-n8:16:32:64-S128\"\ntarget triple = \"x86_64-pc-linux-gnu\"\n\n; Function Attrs: noinline nounwind optnone uwtable\ndefine dso_local i32 @gemv_update(float*, float*, float*, i32, i32, i32) #0 {\n %7 = alloca float*, align 8\n %8 = alloca float*, align 8\n %9 = alloca float*, align 8\n %10 = alloca i32, align 4\n %11 = alloca i32, align 4\n %12 = alloca i32, align 4\n %13 = alloca i32, align 4\n %14 = alloca i32, align 4\n store float* %0, float** %7, align 8\n store float* %1, float** %8, align 8\n store float* %2, float** %9, align 8\n store i32 %3, i32* %10, align 4\n store i32 %4, i32* %11, align 4\n store i32 %5, i32* %12, align 4\n store i32 0, i32* %13, align 4\n br label %15\n\n15: ; preds = %50, %6\n %16 = load i32, i32* %13, align 4\n %17 = load i32, i32* %10, align 4\n %18 = icmp slt i32 %16, %17\n br i1 %18, label %19, label %53\n\n19: ; preds = %15\n store i32 0, i32* %14, align 4\n br label %20\n\n20: ; preds = %46, %19\n %21 = load i32, i32* %14, align 4\n %22 = load i32, i32* %11, align 4\n %23 = icmp slt i32 %21, %22\n br i1 %23, label %24, label %49\n\n24: ; preds = %20\n %25 = load float*, float** %8, align 8\n %26 = load i32, i32* %14, align 4\n %27 = sext i32 %26 to i64\n %28 = getelementptr inbounds float, float* %25, i64 %27\n %29 = load float, float* %28, align 4\n %30 = load float*, float** %9, align 8\n %31 = load i32, i32* %13, align 4\n %32 = load i32, i32* %12, align 4\n %33 = mul nsw i32 %31, %32\n %34 = load i32, i32* %14, align 4\n %35 = add nsw i32 %33, %34\n %36 = sext i32 %35 to i64\n %37 = getelementptr inbounds float, float* %30, i64 %36\n %38 = load float, float* %37, align 4\n %39 = fmul float %29, %38\n %40 = load float*, float** %7, align 8\n %41 = load i32, i32* %13, align 4\n %42 = sext i32 %41 to i64\n %43 = getelementptr inbounds float, float* %40, i64 %42\n %44 = load float, float* %43, align 4\n %45 = fadd float %44, %39\n store float %45, float* %43, align 4\n br label %46\n\n46: ; preds = %24\n %47 = load i32, i32* %14, align 4\n %48 = add nsw i32 %47, 1\n store i32 %48, i32* %14, align 4\n br label %20\n\n49: ; preds = %20\n br label %50\n\n50: ; preds = %49\n %51 = load i32, i32* %13, align 4\n %52 = add nsw i32 %51, 1\n store i32 %52, i32* %13, align 4\n br label %15\n\n53: ; preds = %15\n ret i32 0\n}\n\nattributes #0 = { noinline nounwind optnone uwtable \"correctly-rounded-divide-sqrt-fp-math\"=\"false\" \"disable-tail-calls\"=\"false\" \"less-precise-fpmad\"=\"false\" \"min-legal-vector-width\"=\"0\" \"no-frame-pointer-elim\"=\"true\" \"no-frame-pointer-elim-non-leaf\" \"no-infs-fp-math\"=\"false\" \"no-jump-tables\"=\"false\" \"no-nans-fp-math\"=\"false\" \"no-signed-zeros-fp-math\"=\"false\" \"no-trapping-math\"=\"false\" \"stack-protector-buffer-size\"=\"8\" \"target-cpu\"=\"x86-64\" \"target-features\"=\"+cx8,+fxsr,+mmx,+sse,+sse2,+x87\" \"unsafe-fp-math\"=\"false\" \"use-soft-float\"=\"false\" }\n\n!llvm.module.flags = !{!0}\n!llvm.ident = !{!1}\n\n!0 = !{i32 1, !\"wchar_size\", i32 4}\n!1 = !{!\"clang version 9.0.0-2~ubuntu18.04.2 (tags/RELEASE_900/final)\"}\n")
for i, j_outer in T.grid(1024, 32):
T.call_extern("int32", "gemv_update", T.tvm_access_ptr(T.type_annotation("float32"), C_1.data, i * 512 + j_outer * 16, 16, 2), T.tvm_access_ptr(T.type_annotation("float32"), A_1.data, i * 64, 64, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_1.data, j_outer * 1024, 1024, 1), 16, 64, 64)
Finally we compare the tensorize version with that numpy.dot
produces,
ensure our implementation is correct.
func = tvm.build(s, [A, B, C], target="llvm", name="gemv")
from tvm.topi.utils import get_const_tuple
dtype = A.dtype
dev = tvm.device("cpu", 0)
a = np.random.uniform(size=get_const_tuple(A.shape)).astype(dtype)
b = np.random.uniform(size=get_const_tuple(B.shape)).astype(dtype)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=dtype), dev)
func(tvm.nd.array(a, dev), tvm.nd.array(b, dev), c)
tvm.testing.assert_allclose(c.numpy(), np.dot(a, b.T), rtol=1e-3)
Reduce-update for Tensorize¶
So far you have learned the basic idea of tensorize, now let’s move one step forward to a more complicated case.
Assume our accelerator could only multiply a vector by a square matrix, in which the vector size needs to be no larger than 16. Given such hardware constrain, now we need to split the reduce axis as following,
However, since the tensorize intrinsic now only covers a part of the reduce axis,
instead of using one “body” function, TVM requires a reduce_reset
function,
which will be invoked before the reduce for-loop, and a reduce_update
function,
which defines the “update” computing strategy.
def gemv_impl():
cc_code = """
extern "C" int gemv_update(float *cc, float *aa, float *bb, int m, int l, int stride) {
for (int i = 0; i < m; ++i) {
for (int j = 0; j < l; ++j) {
cc[i] += aa[j] * bb[i * stride + j];
}
}
return 0;
}
extern "C" int gemv_reset(float *cc, int m) {
for (int i = 0; i < m; ++i) {
cc[i] = 0.0;
}
return 0;
}
"""
from tvm.contrib import utils, clang
temp = utils.tempdir()
ll_path = temp.relpath("temp.ll")
# Create LLVM ir from c source code
ll_code = clang.create_llvm(cc_code, output=ll_path)
return ll_code
def intrin_gemv(m, l):
a = te.placeholder((l,), name="a")
b = te.placeholder((m, l), name="b")
k = te.reduce_axis((0, l), name="k")
c = te.compute((m,), lambda i: te.sum(a[k] * b[i, k], axis=k), name="c")
Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[1])
Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[te.var("s1"), 1])
Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1])
def intrin_func(ins, outs):
aa, bb = ins
cc = outs[0]
def _body():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
"int32",
"gemv_update",
cc.access_ptr("w"),
aa.access_ptr("r"),
bb.access_ptr("r"),
m,
l,
bb.strides[0],
)
)
return ib.get()
def _reduce_reset():
ib = tvm.tir.ir_builder.create()
ib.emit(tvm.tir.call_extern("int32", "gemv_reset", cc.access_ptr("w"), m))
return ib.get()
def _reduce_update():
return _body()
return _body(), _reduce_reset(), _reduce_update()
return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})
Note that intrin_func
now returns a triplet:
(body, reduce_reset, reduce_update)
.
If tensorization includes all the reduce axes, function body()
will be invoked,
otherwise reduce_reset()
and reduce_update()
together will be used.
In our example body()
and reduce_update()
share the same implementation,
while in other cases, hardware may have different instructions for these two functions.
Moreover, we can see now bb.strides[0]
is different from l
due to the tiling.
Tensorize for squared GEMV, build and check the results,
gemv = intrin_gemv(factor, factor)
s[C].tensorize(yi, gemv)
s[C].pragma(yo, "import_llvm", gemv_impl())
func = tvm.build(s, [A, B, C], target="llvm", name="gemv")
a = np.random.uniform(size=get_const_tuple(A.shape)).astype(dtype)
b = np.random.uniform(size=get_const_tuple(B.shape)).astype(dtype)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=dtype), dev)
func(tvm.nd.array(a, dev), tvm.nd.array(b, dev), c)
tvm.testing.assert_allclose(c.numpy(), np.dot(a, b.T), rtol=1e-3)
Summary¶
This tutorial demonstrates the usage of tensorize intrinsic in TVM. Tensorize provides a way for users to get fully optimized schedule via micro-kernels. For example, INT8 quantization on Intel CPUs uses tensorization to invoke AVX instruction directly. It also enables TVM to compile to ASICs - checkout VTA: Versatile Tensor Accelerator for details. We also demonstrates how to use inline assembly importing, which helps users inject asm easily into the schedule.