Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
Reduction¶
Author: Tianqi Chen
This is an introduction material on how to do reduction in TVM. Associative reduction operators like sum/max/min are typical construction blocks of linear algebra operations.
In this tutorial, we will demonstrate how to do reduction in TVM.
from __future__ import absolute_import, print_function
import tvm
import tvm.testing
from tvm import te
import numpy as np
Describe Sum of Rows¶
Assume we want to compute sum of rows as our example.
In numpy semantics this can be written as B = numpy.sum(A, axis=1)
The following lines describe the row sum operation.
To create a reduction formula, we declare a reduction axis using
te.reduce_axis
. te.reduce_axis
takes in the range of reductions.
te.sum
takes in the expression to be reduced as well as the reduction
axis and compute the sum of value over all k in the declared range.
The equivalent C code is as follows:
Schedule the Reduction¶
There are several ways to schedule a reduction. Before doing anything, let us print out the IR code of default schedule.
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
n = T.var("int32")
m = T.var("int32")
stride = T.var("int32")
stride_1 = T.var("int32")
A_1 = T.match_buffer(A, (n, m), strides=(stride, stride_1), type="auto")
stride_2 = T.var("int32")
B_1 = T.match_buffer(B, (n,), strides=(stride_2,), type="auto")
for i in range(n):
B_2 = T.buffer_decl((stride_2 * n,), data=B_1.data, type="auto")
B_2[i * stride_2] = T.float32(0)
for k in range(m):
A_2 = T.buffer_decl((stride * n,), data=A_1.data, type="auto")
B_2[i * stride_2] = B_2[i * stride_2] + A_2[i * stride + k * stride_1]
You can find that the IR code is quite like the C code. The reduction axis is similar to a normal axis, it can be splitted.
In the following code we split both the row axis of B as well axis by different factors. The result is a nested reduction.
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
n = T.var("int32")
m = T.var("int32")
stride = T.var("int32")
stride_1 = T.var("int32")
A_1 = T.match_buffer(A, (n, m), strides=(stride, stride_1), type="auto")
stride_2 = T.var("int32")
B_1 = T.match_buffer(B, (n,), strides=(stride_2,), type="auto")
for i_outer, i_inner in T.grid((n + 31) // 32, 32):
B_2 = T.buffer_decl((stride_2 * n,), data=B_1.data, type="auto")
if T.likely(i_outer * 32 + i_inner < n):
B_2[(i_outer * 32 + i_inner) * stride_2] = T.float32(0)
if T.likely(i_outer * 32 + i_inner < n):
for k_outer, k_inner in T.grid((m + 15) // 16, 16):
if T.likely(k_outer * 16 + k_inner < m):
A_2 = T.buffer_decl((stride * n,), data=A_1.data, type="auto")
cse_var_1: T.int32 = i_outer * 32 + i_inner
B_2[cse_var_1 * stride_2] = B_2[cse_var_1 * stride_2] + A_2[cse_var_1 * stride + (k_outer * 16 + k_inner) * stride_1]
If we are building a GPU kernel, we can bind the rows of B to GPU threads.
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
n = T.var("int32")
m = T.var("int32")
stride = T.var("int32")
stride_1 = T.var("int32")
A_1 = T.match_buffer(A, (n, m), strides=(stride, stride_1), type="auto")
stride_2 = T.var("int32")
B_1 = T.match_buffer(B, (n,), strides=(stride_2,), type="auto")
blockIdx_x = T.env_thread("blockIdx.x")
T.launch_thread(blockIdx_x, (n + 31) // 32)
threadIdx_x = T.env_thread("threadIdx.x")
T.launch_thread(threadIdx_x, 32)
B_2 = T.buffer_decl((stride_2 * n,), data=B_1.data, type="auto")
if T.likely(blockIdx_x * 32 + threadIdx_x < n):
B_2[(blockIdx_x * 32 + threadIdx_x) * stride_2] = T.float32(0)
for k_outer, k_inner in T.grid((m + 15) // 16, 16):
if T.likely(blockIdx_x * 32 + threadIdx_x < n):
if T.likely(k_outer * 16 + k_inner < m):
A_2 = T.buffer_decl((stride * n,), data=A_1.data, type="auto")
B_2[(blockIdx_x * 32 + threadIdx_x) * stride_2] = B_2[(blockIdx_x * 32 + threadIdx_x) * stride_2] + A_2[(blockIdx_x * 32 + threadIdx_x) * stride + (k_outer * 16 + k_inner) * stride_1]
Reduction Factoring and Parallelization¶
One problem of building a reduction is that we cannot simply parallelize over the reduction axis. We need to divide the computation of the reduction, store the local reduction result in a temporal array before doing a reduction over the temp array.
The rfactor primitive does such rewrite of the computation. In the following schedule, the result of B is written to a temporary result B.rf. The factored dimension becomes the first dimension of B.rf.
@I.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle):
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
n = T.var("int32")
m = T.var("int32")
stride = T.var("int32")
stride_1 = T.var("int32")
A_1 = T.match_buffer(A, (n, m), strides=(stride, stride_1), type="auto")
stride_2 = T.var("int32")
B_1 = T.match_buffer(B, (n,), strides=(stride_2,), type="auto")
B_rf = T.allocate([n * 16], "float32", "global")
B_rf_1 = T.buffer_decl((16 * n,), data=B_rf)
for k_inner, i in T.grid(16, n):
B_rf_1[k_inner * n + i] = T.float32(0)
for k_outer in range((m + 15) // 16):
if T.likely(k_outer * 16 + k_inner < m):
A_2 = T.buffer_decl((stride * n,), data=A_1.data, type="auto")
B_rf_1[k_inner * n + i] = B_rf_1[k_inner * n + i] + A_2[i * stride + (k_outer * 16 + k_inner) * stride_1]
for ax0 in range(n):
B_2 = T.buffer_decl((stride_2 * n,), data=B_1.data, type="auto")
B_2[ax0 * stride_2] = T.float32(0)
for k_inner_v in range(16):
B_2[ax0 * stride_2] = B_2[ax0 * stride_2] + B_rf_1[k_inner_v * n + ax0]
The scheduled operator of B also get rewritten to be sum over the first axis of reduced result of B.f
[T.reduce(T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), source=[B.rf[k_inner_v, ax0]], init=[], axis=[T.iter_var(k_inner_v, T.Range(0, 16), "CommReduce", "")], condition=True, value_index=0)]
Cross Thread Reduction¶
We can now parallelize over the factored axis. Here the reduction axis of B is marked to be a thread. TVM allows reduction axis to be marked as thread if it is the only axis in reduction and cross thread reduction is possible in the device.
This is indeed the case after the factoring. We can directly compute BF at the reduction axis as well. The final generated kernel will divide the rows by blockIdx.x and threadIdx.y columns by threadIdx.x and finally do a cross thread reduction over threadIdx.x
xo, xi = s[B].split(s[B].op.axis[0], factor=32)
s[B].bind(xo, te.thread_axis("blockIdx.x"))
s[B].bind(xi, te.thread_axis("threadIdx.y"))
tx = te.thread_axis("threadIdx.x")
s[B].bind(s[B].op.reduce_axis[0], tx)
s[BF].compute_at(s[B], s[B].op.reduce_axis[0])
s[B].set_store_predicate(tx.var.equal(0))
fcuda = tvm.build(s, [A, B], "cuda")
print(fcuda.imported_modules[0].get_source())
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)
#define __shfl_sync(mask, var, lane, width) \
__shfl((var), (lane), (width))
#define __shfl_down_sync(mask, var, offset, width) \
__shfl_down((var), (offset), (width))
#define __shfl_up_sync(mask, var, offset, width) \
__shfl_up((var), (offset), (width))
#endif
#ifdef _WIN32
using uint = unsigned int;
using uchar = unsigned char;
using ushort = unsigned short;
using int64_t = long long;
using uint64_t = unsigned long long;
#else
#define uint unsigned int
#define uchar unsigned char
#define ushort unsigned short
#define int64_t long long
#define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(512) default_function_kernel0(float* __restrict__ A, float* __restrict__ B, int m, int n, int stride, int stride_1, int stride_2) {
float B_rf[1];
float red_buf0[1];
B_rf[0] = 0.000000e+00f;
for (int k_outer = 0; k_outer < (m >> 4); ++k_outer) {
if (((((int)blockIdx.x) * 32) + ((int)threadIdx.y)) < n) {
B_rf[0] = (B_rf[0] + A[((((((int)blockIdx.x) * 32) + ((int)threadIdx.y)) * stride) + (((k_outer * 16) + ((int)threadIdx.x)) * stride_1))]);
}
}
for (int k_outer_1 = 0; k_outer_1 < (((m & 15) + 15) >> 4); ++k_outer_1) {
if (((((int)blockIdx.x) * 32) + ((int)threadIdx.y)) < n) {
if (((((m >> 4) * 16) + (k_outer_1 * 16)) + ((int)threadIdx.x)) < m) {
B_rf[0] = (B_rf[0] + A[((((((int)blockIdx.x) * 32) + ((int)threadIdx.y)) * stride) + (((((m >> 4) * 16) + (k_outer_1 * 16)) + ((int)threadIdx.x)) * stride_1))]);
}
}
}
uint mask[1];
float t0[1];
red_buf0[0] = B_rf[0];
mask[0] = (__activemask() & ((uint)(65535 << (((int)threadIdx.y) * 16))));
t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 8, 32);
red_buf0[0] = (red_buf0[0] + t0[0]);
t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 4, 32);
red_buf0[0] = (red_buf0[0] + t0[0]);
t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 2, 32);
red_buf0[0] = (red_buf0[0] + t0[0]);
t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 1, 32);
red_buf0[0] = (red_buf0[0] + t0[0]);
red_buf0[0] = __shfl_sync(mask[0], red_buf0[0], (((int)threadIdx.y) * 16), 32);
if (((int)threadIdx.x) == 0) {
B[(((((int)blockIdx.x) * 32) + ((int)threadIdx.y)) * stride_2)] = red_buf0[0];
}
}
Verify the correctness of result kernel by comparing it to numpy.
nn = 128
dev = tvm.cuda(0)
a = tvm.nd.array(np.random.uniform(size=(nn, nn)).astype(A.dtype), dev)
b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), dev)
fcuda(a, b)
tvm.testing.assert_allclose(b.numpy(), np.sum(a.numpy(), axis=1), rtol=1e-4)
Describe Convolution via 2D Reduction¶
In TVM, we can describe convolution via 2D reduction in a simple way. Here is an example for 2D convolution with filter size = [3, 3] and strides = [1, 1].
n = te.var("n")
Input = te.placeholder((n, n), name="Input")
Filter = te.placeholder((3, 3), name="Filter")
di = te.reduce_axis((0, 3), name="di")
dj = te.reduce_axis((0, 3), name="dj")
Output = te.compute(
(n - 2, n - 2),
lambda i, j: te.sum(Input[i + di, j + dj] * Filter[di, dj], axis=[di, dj]),
name="Output",
)
s = te.create_schedule(Output.op)
print(tvm.lower(s, [Input, Filter, Output], simple_mode=True))
@I.ir_module
class Module:
@T.prim_func
def main(Input: T.handle, Filter: T.handle, Output: T.handle):
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
n = T.var("int32")
stride = T.var("int32")
stride_1 = T.var("int32")
Input_1 = T.match_buffer(Input, (n, n), strides=(stride, stride_1), type="auto")
Filter_1 = T.match_buffer(Filter, (3, 3))
Output_1 = T.match_buffer(Output, (n - 2, n - 2))
for i, j in T.grid(n - 2, n - 2):
Output_2 = T.buffer_decl(((n - 2) * (n - 2),), data=Output_1.data)
Output_2[i * (n - 2) + j] = T.float32(0)
for di, dj in T.grid(3, 3):
Input_2 = T.buffer_decl((stride * n,), data=Input_1.data, type="auto")
Filter_2 = T.buffer_decl((9,), data=Filter_1.data)
Output_2[i * (n - 2) + j] = Output_2[i * (n - 2) + j] + Input_2[(i + di) * stride + (j + dj) * stride_1] * Filter_2[di * 3 + dj]
Define General Commutative Reduction Operation¶
Besides the built-in reduction operations like te.sum
,
tvm.te.min
and tvm.te.max
, you can also define your
commutative reduction operation by te.comm_reducer
.
n = te.var("n")
m = te.var("m")
product = te.comm_reducer(lambda x, y: x * y, lambda t: tvm.tir.const(1, dtype=t), name="product")
A = te.placeholder((n, m), name="A")
k = te.reduce_axis((0, m), name="k")
B = te.compute((n,), lambda i: product(A[i, k], axis=k), name="B")
Note
Sometimes we would like to perform reduction that involves multiple
values like argmax
, which can be done by tuple inputs.
See Describe Reduction with Collaborative Inputs for more detail.
Summary¶
This tutorial provides a walk through of reduction schedule.
Describe reduction with reduce_axis.
Use rfactor to factor out axis if we need parallelism.
Define new reduction operation by
te.comm_reducer