2

Tensor DSL总结

 8 months ago
source link: https://zhen8838.github.io/2023/12/20/tensordsl/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

Tensor DSL总结

本文旨在总结一些张量优化的DSL是如何设计的, 尝试从其中发现一些共同点. 接下来我将统一使用Matmul(Transpose(Conv(lhs)),rhs)的例子在不同的框架中进行测试.

1. Jittor

1.1 DSL语法

首先结合论文中的例子讲一下reindex的原理:

def conv(x, p):
N,C,H,W = x.shape
o,i,h,w = p.shape
xx = x.reindex(shape=(N,o,H,W,i,h,w),
indexes=("i0", "i4", "i2-i5", "i3-i6") )
pp = p.broadcast(xx.shape, dims=(0,2,3))
yy = xx*pp
y = yy.sum(dims=(4,5,6))
return y

这里其实是把shape看作为循环层级, 这里的reindex相当于在7层循环的最内层中做类似xx[N,o,H,W,i,h,w] = x[N,i,H-h,W-w]的索引. 然后再把weights也通过boradcast扩展到同样的循环层级pp[N,o,H,W,i,h,w] = p[o,i,h,w], 在7层循环内部执行xx[N,o,H,W,i,h,w]*pp[N,o,H,W,i,h,w]的操作, 等价于执行x[N,i,H-h,W-w] * p[o,i,h,w], 然后对i,h,w三层循环做求和.

可以说通过reindex+broadcast操作, 完成了类似于polyhedral中2d+1表示中的loop dimension align和修改access relation(由indexes指定). Jittor这里并没有考虑让开发者自行调度算子, 后续的优化都交给编译器自动化.

1.2 测试例子

import jittor as jt

def conv(x, p):
N,C,H,W = x.shape
o,i,h,w = p.shape
xx = x.reindex(shape=(N,o,H,W,i,h,w),
indexes=("i0", "i4", "i2-i5", "i3-i6") )
pp = p.broadcast(xx.shape, dims=(0,2,3))
yy = xx*pp
y = yy.sum(dims=(4,5,6))
return y

def matmul(a,b):
bc, c, m,k = a.shape
_, _, _,n = b.shape
shape = [bc, c, m, k, n]
a = a.broadcast(shape, [-1]) # [m,k, ] -> [m,k,n]
b = b.broadcast(shape, [-3]) # [ ,k,n] -> [m,k,n]
return (a*b).sum(-2)

lhs = jt.randn(8,3,32,32)
kernel = jt.randn(16,3,3,3)
rhs = jt.randn(8,32,16,64)
jt.flags.compile_options={"compile_shapes":1}
with jt.profile_scope() as report:
output = matmul(jt.transpose(conv(lhs, kernel), [0,2,3,1]), rhs).fetch_sync()
jt.flags.compile_options={}

编译后得到:

Profile result, sorted by TotalTime
('it/s' represent number of iterations per sec)
Name FileName Count TotalTime %,cum% AvgTime MinTime MaxTime Input Output InOut Compute
Total time: 12.8ms
Total Memory Access: 6.19MB
[opkey0:broadcast_to[Tx:float32][DIM=7][BCAST=d][JIT:1][JIT_cpu:1][index_t:int32]][opkey1:reindex[Tx:float32][XDIM=4][YDIM=7][OVERFLOW:itof(0x0)][INDEX0:i0][INDEX1:i4][INDEX2:i2-i5][INDEX3:i3-i6][OSIZE=0][ESIZE=0][JIT:1][JIT_cpu:1][index_t:int32]][opkey2:binary[Tx:float32][Ty:float32][Tz:float32][OP:multiply][JIT:1][JIT_cpu:1][index_t:int32]][opkey3:reduce[Tx:float32][Ty:float32][Tz:float32][OP:add][DIM=7][REDUCE=70][JIT:1][JIT_cpu:1][index_t:int32]][JIT:1][JIT_cpu:1][graph:040000,062010,010020,000021,020030,][var_info::041704171724][shapes:[10,3,3,3,],[8,10,20,20,3,3,3,],[8,3,20,20,],[8,10,20,20,3,3,3,],[8,10,20,20,3,3,3,],[8,10,20,20,],][choices:compile_shapes:1,]
/root/.cache/jittor/jt1.3.1/g++10.5.0/py3.8.18/Linux-5.4.0-42xae/AMDEPYC7T8364-x8f_debug/default/jit/_opkey0_broadcast_to_Tx_float32__DIM_7__BCAST_d__JIT_1__JIT_cpu_1__index_t_int32___opkey1____hash_a2d65b1fd1c3f3d0_op.cc
1 8.12ms(63.3%,63.3%) 8.12ms 8.12ms 8.12ms 11.7MB/s 61.6MB/s 73.3MB/s 436Mit/s
random[T:float32][R:normal][JIT:1][JIT_cpu:1][index_t:int32]
/root/.cache/jittor/jt1.3.1/g++10.5.0/py3.8.18/Linux-5.4.0-42xae/AMDEPYC7T8364-x8f_debug/default/jit/random_T_float32__R_normal__JIT_1__JIT_cpu_1__index_t_int32__hash_c27874d0aacc5d25_op.cc
3 3.68ms(28.7%,91.9%) 1.23ms 5.58us 3.37ms 0 B/s 298MB/s 298MB/s 78Mit/s
[opkey0:broadcast_to[Tx:float32][DIM=5][BCAST=4][JIT:1][JIT_cpu:1][index_t:int32]][opkey1:broadcast_to[Tx:float32][DIM=5][BCAST=10][JIT:1][JIT_cpu:1][index_t:int32]][opkey2:binary[Tx:float32][Ty:float32][Tz:float32][OP:multiply][JIT:1][JIT_cpu:1][index_t:int32]][opkey3:reduce[Tx:float32][Ty:float32][Tz:float32][OP:add][DIM=5][REDUCE=8][JIT:1][JIT_cpu:1][index_t:int32]][JIT:1][JIT_cpu:1][graph:040000,062010,010020,000021,020030,][var_info::041504151524][shapes:[8,20,10,40,],[8,20,20,10,40,],[8,20,20,10,],[8,20,20,10,40,],[8,20,20,10,40,],[8,20,20,40,],][choices:compile_shapes:1,]
/root/.cache/jittor/jt1.3.1/g++10.5.0/py3.8.18/Linux-5.4.0-42xae/AMDEPYC7T8364-x8f_debug/default/jit/_opkey0_broadcast_to_Tx_float32__DIM_5__BCAST_4__JIT_1__JIT_cpu_1__index_t_int32___opkey1____hash_ef35f9063cf2acdf_op.cc
1 828us(6.45%,98.4%) 828us 828us 828us 1.77GB/s 2.36GB/s 4.13GB/s 10.1Git/s
transpose[Tx:float32][DIM=4][AXES0=0][AXES2=1][AXES3=2][AXES1=3][JIT:1][JIT_cpu:1][index_t:int32]
/root/.cache/jittor/jt1.3.1/g++10.5.0/py3.8.18/Linux-5.4.0-42xae/AMDEPYC7T8364-x8f_debug/default/jit/transpose_Tx_float32__DIM_4__AXES0_0__AXES2_1__AXES3_2__AXES1_3__JIT_1__JIT_cpu_1__index_t_int32__hash_998b34c8052fe15_op.cc
1 208us(1.62%,100%) 208us 208us 208us 2.35GB/s 2.35GB/s 4.71GB/s 632Mit/s

最终我检查他的输出, 发现是分成了三个部分, _opkey0_broadcast_to_Tx_float32__DIM_7__BCAST_d__JIT_1__JIT_cpu_1__index_t_int32___opkey1____hash_a2d65b1fd1c3f3d0_op为卷积实现, _opkey0_broadcast_to_Tx_float32__DIM_5__BCAST_4__JIT_1__JIT_cpu_1__index_t_int32___opkey1____hash_ef35f9063cf2acdf_op为矩阵乘, transpose_Tx_float32__DIM_4__AXES0_0__AXES2_1__AXES3_2__AXES1_3__JIT_1__JIT_cpu_1__index_t_int32__hash_998b34c8052fe15_op为转置.

2. Halide

2.1 DSL语法

import halide as hl
inputLhs = hl.ImageParam(hl.Float(32), 2, "inputLhs")
inputRhs = hl.ImageParam(hl.Float(32), 2, "inputRhs")
output = hl.Func("output")
(m, n) = hl.Var("m"), hl.Var("n")
k = hl.RDom([hl.Range(0, self.K)], "k")
output[n, m] = 0.0
output[n, m] += inputLhs[k.x, m] * inputRhs[n, k.x]

Halide使用Var来表示循环,对于规约的循环需要用RDom来标识(并且如果定义了规约循环,那么还需要为数据设定初值). 循环层级也是由Var来确定, 他这里默认应该都会把规约的循环放到最内层. 使用Var对张量inputLhs[k.x, m]进行索引操作用于建立access relation.

提前声明的循环变量的缺点在于需要开发者手动管理好所有的循环变量, 书写起来较为复杂; 优点在于可以确定上下游操作循环之间的关系, 可以轻易的做到自动fusion上下两层算子.

2.2 测试例子

import halide as hl

input = hl.ImageParam(hl.Float(32), 4, "input")
weight = hl.ImageParam(hl.Float(32), 4, "weight")
act = hl.ImageParam(hl.Float(32), 2, "act")
pad_w_before = 0 # hl.Param(hl.Int(32), "pad_w_before")
pad_h_before = 0 # hl.Param(hl.Int(32), "pad_h_before")
stride_w = 1 # hl.Param(hl.Int(32), "stride_w")
stride_h = 1 # hl.Param(hl.Int(32), "stride_h")


WO, HO, CI, B, CO = hl.Var("WO"), hl.Var("HO"), hl.Var("CI"), hl.Var("B"), hl.Var("CO")
Padding, Paded, Conv, Acted, Clamped, Psumed = hl.Func("Padding"), hl.Func(
"Paded"), hl.Func("Conv"), hl.Func("Acted"), hl.Func("Clamped"), hl.Func("Psumed")

r = hl.RDom([hl.Range(0, weight.width()), hl.Range(0, weight.height()),
hl.Range(0, weight.dim(2).extent())]) # w,h,ic

Padding = hl.BoundaryConditions.constant_exterior(
input, 0, [hl.Range(0, input.width()), hl.Range(0, input.height())])

in_channels = input.dim(2).extent()
out_channels = weight.dim(3).extent()

Paded[WO, HO, CI, B] = Padding[WO - pad_w_before, HO - pad_h_before, CI, B]

Conv[WO, HO, CO, B] = 0.0
Conv[WO, HO, CO, B] += weight[r[0], r[1], r[2], CO] * Paded[WO * stride_w + r[0], HO * stride_h + r[1], r[2], B] # use float to sum

Acted[WO, HO, CO, B] = hl.select(
Conv[WO, HO, CO, B] < act[0, CO],
Conv[WO, HO, CO, B] * act[1, CO] + act[2, CO],
Conv[WO, HO, CO, B] * act[3, CO] + act[4, CO]) # float


Transpose = hl.Func("Transpose")
Transpose[CO, WO, HO, B] = Acted[WO, HO, CO, B]

rhs = hl.ImageParam(hl.Float(32), 4, "rhs") # [x,x,K,N]

N = hl.Var("N")

kdom = hl.RDom([hl.Range(0, rhs.dim(2).extent())], "k")

Matmul = hl.Func("Matmul")
Matmul[N, WO, HO, B] = 0.0
Matmul[N, WO, HO, B] += Transpose[kdom.x, WO, HO, B] * rhs[N, kdom.x, HO, B]

Matmul.print_loop_nest()

得到的循环嵌套如下:

produce Matmul:
for B:
for HO:
for WO:
for N:
Matmul(...) = ...
for B:
for HO:
for WO:
for N:
for k:
produce Conv:
Conv(...) = ...
for r14:
for r14:
for r14:
Conv(...) = ...
consume Conv:
Matmul(...) = ...

矩阵层的初始化他默认放到root层级, 下面是自动把Transpose的操作inline了, 也自动把矩阵乘和卷积进行了fusion.

3. TVM

TVM中脱胎于Halide, 他提供了一套Tensor ExpressionDSL来协助我们定义算子计算逻辑.

3.1 DSL语法

n = te.var("n")
A = te.placeholder((n,), name="A")
B = te.placeholder((n,), name="B")
C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")

也是使用shape来表示完美循环, fcompute的回调函数的参数映射迭代变量, 同时也会在最内层循环执行它. 同Jittor类似, 不过使用回调函数的方式更增加了灵活性. 可以使用reduce_axis, 类似于RDom, 会自动最内层循环加上规约的循环, 他这里默认的初始化会放到规约循环外面.

3.2 测试例子

import tvm
from tvm import te
from tvm import tir

batch_size = 8
in_channel = 3
out_channel = 16
in_height = 32
in_width = 32
kernel_height = 3
kernel_width = 3

N = 64

Input = te.placeholder(
(batch_size, in_channel, in_height, in_width), name='Input')
Kernel = te.placeholder(
(out_channel, in_channel, kernel_height, kernel_width), name='Kernel')

rc = te.reduce_axis((0, in_channel), name='rc')
ry = te.reduce_axis((0, kernel_height), name='ry')
rx = te.reduce_axis((0, kernel_width), name='rx')

Conv = te.compute(
(batch_size, out_channel, in_height -
kernel_height + 1, in_width - kernel_width + 1),
lambda n, f, y, x: te.sum(
Input[n, rc, y + ry, x + rx] * Kernel[f, rc, ry, rx],
axis=[rc, ry, rx]
),
name='Conv'
) # (b,oc,oh,ow) -> (b,oh,ow,oc)

oh, ow = 30, 30
rhs = te.placeholder((batch_size, oh, out_channel, N), name='rhs')

Trans = te.compute(
(batch_size, oh, ow, out_channel),
lambda i0, i1, i2, i3: Conv[i0, i3, i1, i2])


rk = te.reduce_axis((0, out_channel), name='rk')
MatMul = te.compute(
(batch_size, oh, ow, N),
lambda i0, i1, i2, i3: te.sum(
Trans[i0, i1, i2, rk] * rhs[i0, i1, rk, i3], axis=[rk]),
name='MatMul'
)

s: te.Schedule = te.create_schedule([Conv.op, MatMul.op])
ir = tvm.lower(s, [Input, Kernel, rhs])
ir.show()
@I.ir_module
class Module:
@T.prim_func
def main(Input: T.Buffer((8, 3, 32, 32), "float32"), Kernel: T.Buffer((16, 3, 3, 3), "float32"), rhs: T.Buffer((8, 30, 16, 64), "float32")):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
Conv = T.allocate([460800], "float32", "global")
compute = T.allocate([115200], "float32", "global")
Conv_1 = T.Buffer((115200,), data=Conv)
for n, f, y, x in T.grid(8, 16, 30, 30):
Conv_1[n * 14400 + f * 900 + y * 30 + x] = T.float32(0)
for rc, ry, rx in T.grid(3, 3, 3):
cse_var_1: T.int32 = n * 14400 + f * 900 + y * 30 + x
Input_1 = T.Buffer((24576,), data=Input.data)
Kernel_1 = T.Buffer((432,), data=Kernel.data)
Conv_1[cse_var_1] = Conv_1[cse_var_1] + Input_1[n * 3072 + rc * 1024 + y * 32 + ry * 32 + x + rx] * Kernel_1[f * 27 + rc * 9 + ry * 3 + rx]
compute_1 = T.Buffer((115200,), data=compute)
for i0, i1, i2, i3 in T.grid(8, 30, 30, 16):
cse_var_2: T.int32 = i0 * 14400
compute_1[cse_var_2 + i1 * 480 + i2 * 16 + i3] = Conv_1[cse_var_2 + i3 * 900 + i1 * 30 + i2]
for i0, i1, i2, i3 in T.grid(8, 30, 30, 64):
Conv_2 = T.Buffer((460800,), data=Conv)
Conv_2[i0 * 57600 + i1 * 1920 + i2 * 64 + i3] = T.float32(0)
for rk in range(16):
cse_var_3: T.int32 = i0 * 57600 + i1 * 1920 + i2 * 64 + i3
rhs_1 = T.Buffer((245760,), data=rhs.data)
Conv_2[cse_var_3] = Conv_2[cse_var_3] + compute_1[i0 * 14400 + i1 * 480 + i2 * 16 + rk] * rhs_1[i0 * 30720 + i1 * 1024 + rk * 64 + i3]

他这里不像Halide一样需要提前定义好循环变量, 但可以从输出中获取axis然后使用类似Halide的调度, 也可以在lowertir之后使用基于tensor ir的调度. 这里进行lower依据默认的优化流程后, 并无法自动fusion.

4. Mlir

Mlir基于linalgdialect中的linalg.genericop提供了一套OpDSL.

4.1 DSL语法

@linalg_structured_op
def conv_2d_nhwc_hwcf(
I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
):
"""Performs 2-D convolution.

Layout:
* Input: NHWC.
* Kernel: HWCF.

Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f])

Mlir这里不使用shape, 使用和polyhedral更加贴近的称呼domain来表示嵌套循环. 我觉得这里更加激进的一点就是完全抛弃循环中的初始化, 也就是忠实的翻译这个OpDSL所描述的内容.

4.2 测试例子

from mlir.dialects import arith, builtin, func, linalg, tensor, memref
from mlir.dialects.linalg.opdsl.lang import *
from mlir.ir import *


@linalg_structured_op
def transpose_nchw_nhwc(
I=TensorDef(TV.T1, S.d0, S.d1, S.d2, S.d3),
O=TensorDef(TV.T1, S.d0, S.d2, S.d3, S.d1, output=True)
):
domain(D.d0, D.d1, D.d2, D.d3)
implements(ContractionOpInterface)
O[D.d0, D.d2, D.d3, D.d1] = I[D.d0, D.d1, D.d2, D.d3]

@linalg_structured_op
def matmul_4d(
A=TensorDef(TV.T1, S.d0, S.d1, S.M, S.K),
B=TensorDef(TV.T1, S.d0, S.d1, S.K, S.N),
C=TensorDef(TV.T1, S.d0, S.d2, S.M, S.N, output=True)
):
domain(D.d0, D.d1, D.m, D.n, D.k)
implements(ContractionOpInterface)
C[D.d0, D.d1, D.m, D.n] += A[D.d0, D.d1, D.m, D.k] * B[D.d0, D.d1, D.k, D.n]

def testOpResultFromOtherOp():
with Context(), Location.unknown():
module = Module.create()
f32 = F32Type.get()
index_type = IndexType.get()
with InsertionPoint(module.body):
batch_size = 8
in_channel = 3
out_channel = 16
in_height = 32
out_height = 30
in_width = 32
out_width = 30
kernel_height = 3
kernel_width = 3
N = 64

@func.FuncOp.from_py_func(
MemRefType.get(
(batch_size, in_channel, in_height, in_width), f32),
MemRefType.get(
(out_channel, in_channel, kernel_height, kernel_width), f32),
MemRefType.get((batch_size, out_height, out_channel, N), f32),
)
def main(lhs, weight, rhs):
# conv = tensor.EmptyOp([batch_size, out_channel, out_height, out_width], f32)
zero = arith.ConstantOp(F32Type.get(), 0.0)
# CHECK: %[[LHS:.*]] = linalg.fill
conv = memref.AllocOp(MemRefType.get(
[batch_size, out_channel, out_height, out_width], f32), [], [])
linalg.fill(zero, outs=[conv])
linalg.conv_2d_nchw_fchw(lhs, weight, outs=[conv])
trans = memref.AllocOp(MemRefType.get(
[batch_size, out_height, out_width, out_channel], f32), [], [])
transpose_nchw_nhwc(conv, outs=[trans])
matmul = memref.AllocOp(MemRefType.get(
[batch_size, out_height, out_width, N], f32), [], [])
matmul_4d(trans, rhs, outs=[matmul])
return matmul

print(module)


testOpResultFromOtherOp()

得到convmatmul.mlir:

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4, d3)>
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
module {
func.func @main(%arg0: memref<8x3x32x32xf32>, %arg1: memref<16x3x3x3xf32>, %arg2: memref<8x30x16x64xf32>) -> memref<8x30x30x64xf32> {
%cst = arith.constant 0.000000e+00 : f32
%alloc = memref.alloc() : memref<8x16x30x30xf32>
linalg.fill ins(%cst : f32) outs(%alloc : memref<8x16x30x30xf32>)
linalg.conv_2d_nchw_fchw ins(%arg0, %arg1 : memref<8x3x32x32xf32>, memref<16x3x3x3xf32>) outs(%alloc : memref<8x16x30x30xf32>)
%alloc_0 = memref.alloc() : memref<8x30x30x16xf32>
linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%alloc : memref<8x16x30x30xf32>) outs(%alloc_0 : memref<8x30x30x16xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
}
%alloc_1 = memref.alloc() : memref<8x30x30x64xf32>
linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%alloc_0, %arg2 : memref<8x30x30x16xf32>, memref<8x30x16x64xf32>) outs(%alloc_1 : memref<8x30x30x64xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%0 = arith.mulf %in, %in_2 : f32
%1 = arith.addf %out, %0 : f32
linalg.yield %1 : f32
}
return %alloc_1 : memref<8x30x30x64xf32>
}
}

使用mlir-opt进行fusion:

mlir-opt -allow-unregistered-dialect convmatmul.mlir --convert-linalg-to-affine-loops -o convmatmul1.mlir
mlir-opt -allow-unregistered-dialect convmatmul1.mlir -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -o convmatmul2.mlir
#map = affine_map<(d0, d1) -> (d0 + d1)>
module {
func.func @main(%arg0: memref<8x3x32x32xf32>, %arg1: memref<16x3x3x3xf32>, %arg2: memref<8x30x16x64xf32>) -> memref<8x30x30x64xf32> {
%alloc = memref.alloc() : memref<1x1x1x16xf32>
%alloc_0 = memref.alloc() : memref<1x1x1x1xf32>
%cst = arith.constant 0.000000e+00 : f32
%alloc_1 = memref.alloc() : memref<8x30x30x64xf32>
affine.for %arg3 = 0 to 8 {
affine.for %arg4 = 0 to 30 {
affine.for %arg5 = 0 to 30 {
affine.for %arg6 = 0 to 16 {
affine.store %cst, %alloc_0[0, 0, 0, 0] : memref<1x1x1x1xf32>
affine.for %arg7 = 0 to 3 {
affine.for %arg8 = 0 to 3 {
affine.for %arg9 = 0 to 3 {
%1 = affine.apply #map(%arg4, %arg8)
%2 = affine.apply #map(%arg5, %arg9)
%3 = affine.load %arg0[%arg3, %arg7, %1, %2] : memref<8x3x32x32xf32>
%4 = affine.load %arg1[%arg6, %arg7, %arg8, %arg9] : memref<16x3x3x3xf32>
%5 = affine.load %alloc_0[0, 0, 0, 0] : memref<1x1x1x1xf32>
%6 = arith.mulf %3, %4 : f32
%7 = arith.addf %5, %6 : f32
affine.store %7, %alloc_0[0, 0, 0, 0] : memref<1x1x1x1xf32>
}
}
}
%0 = affine.load %alloc_0[0, 0, 0, 0] : memref<1x1x1x1xf32>
affine.store %0, %alloc[0, 0, 0, %arg6] : memref<1x1x1x16xf32>
}
affine.for %arg6 = 0 to 64 {
affine.for %arg7 = 0 to 16 {
%0 = affine.load %alloc[0, 0, 0, %arg7] : memref<1x1x1x16xf32>
%1 = affine.load %arg2[%arg3, %arg4, %arg7, %arg6] : memref<8x30x16x64xf32>
%2 = affine.load %alloc_1[%arg3, %arg4, %arg5, %arg6] : memref<8x30x30x64xf32>
%3 = arith.mulf %0, %1 : f32
%4 = arith.addf %2, %3 : f32
affine.store %4, %alloc_1[%arg3, %arg4, %arg5, %arg6] : memref<8x30x30x64xf32>
}
}
}
}
}
return %alloc_1 : memref<8x30x30x64xf32>
}
}

经过affine-loop-fusion之后的ir基本符合我的预期.

5. Tiramisu

5.1 DSL语法

import tiramisu as tm
import shutil
import os

tm.init("matmul")

M = 64
K = 256
N = 128

# Level I: specifies "what" should be computed

A = tm.input("A", ['m', 'k'], [M, K], tm.primitive_t.p_float32)
B = tm.input("B", ['k', 'n'], [K, N], tm.primitive_t.p_float32)

m, k, n = tm.var("m", 0, M), tm.var("k", 0, K), tm.var("n", 0, N)
C_init = tm.computation("C_init", [m, n], tm.expr(0.0))
C = tm.computation("C", [m, n, k], tm.primitive_t.p_float32)
C.set_expression(C[m, n, k - 1] + A[m, k] * B[k, n])

# Level II: level specifies "when" and "where"

# schedule the computation oerder

# Level III: level specifies "stored"

bufA = tm.buffer("bufA", [M, K], tm.primitive_t.p_float32, tm.argument_t.a_input)
bufB = tm.buffer("bufB", [K, N], tm.primitive_t.p_float32, tm.argument_t.a_input)
bufC = tm.buffer("bufC", [M, N], tm.primitive_t.p_float32, tm.argument_t.a_output)
A.store_in(bufA)
B.store_in(bufB)
C_init.store_in(bufC, [m, n])
C.store_in(bufC, [m, n])


f = tm.get_implicit_function()
f.codegen([bufA, bufB, bufC], "matmul.o", 0, False)
f.dump_halide_stmt()

tiramisu其实是使用更加贴近于polyhedral的思想, computation可以等价于statement. 类似halide一样使用var表示循环, 用于指定当前computation的循环位置. 不过他这里内部都是基于polyhedral, 定义computation后直接得到了iteration domain, 经过loop dimension align可以得到schedule.

5.2 测试例子

import tiramisu as tm
import shutil
import os

# f = tm.function("matmul")
tm.init("convmatmul")

B = 8
IC = 3
OC = 16
IH, OH = 32, 30
IW, OW = 32, 30
KH = 3
KW = 3
N = 64

# Level I: specifies "what" should be computed

lhs = tm.input("lhs", ['B', 'IC', 'IH', 'IW'], [
B, IC, IH, IW], tm.primitive_t.p_float64)
rhs = tm.input("rhs", ['B', 'OH', 'OC', 'N'], [
B, OH, OC, N], tm.primitive_t.p_float64)
kernel = tm.input("kernel", ['OC', 'IC', 'KH', 'KW'], [
OC, IC, KH, KW], tm.primitive_t.p_float64)

b, oc, oh, ow, ic, kh, kw = tm.var('b', 0, B), tm.var('oc', 0, OC), tm.var('oh', 0, OH), tm.var(
'ow', 0, OW), tm.var('ic', 0, IC), tm.var('kh', 0, KH), tm.var('kw', 0, KW)
ConvInit = tm.computation("conv_init", [b, oc, oh, ow], tm.expr(0.0))
Conv = tm.computation(
"Conv", [b, oc, oh, ow, ic, kh, kw], tm.primitive_t.p_float64)
Conv.set_expression(Conv[b, oc, oh, ow, ic, kh, kw] +
lhs[b, oc, oh + kh, ow + kw] * kernel[oc, ic, kh, kw])
n = tm.var('n', 0, N)
Transpose = tm.computation("transpose", [b, oh, ow, oc], ConvInit[b, oc, oh, ow])

MatmulInit = tm.computation("matmul_init", [b, oh, ow, n], tm.expr(0.0))
Matmul = tm.computation("Matmul", [b, oh, ow, n], tm.primitive_t.p_float64)
Matmul.set_expression(
MatmulInit[b, oh, ow, n] + Transpose[b, oh, ow, oc] * rhs[b, oh, oc, n])

# Level II: level specifies "when" and "where"

# Level III: level specifies "stored"

buflhs = tm.buffer("buflhs", [B, IC, IH, IW], tm.primitive_t.p_float64, tm.argument_t.a_input)
bufrhs = tm.buffer("bufrhs", [B, OH, OC, N], tm.primitive_t.p_float64, tm.argument_t.a_input)
bufkernel = tm.buffer("bufkernel", [OC, IC, KH, KW], tm.primitive_t.p_float64, tm.argument_t.a_input)
bufconv = tm.buffer("bufconv", [B, OC, OH, OW], tm.primitive_t.p_float64, tm.argument_t.a_temporary)
bufmatmul = tm.buffer("bufmatmul", [B, OH, OW, N], tm.primitive_t.p_float64, tm.argument_t.a_output)

lhs.store_in(buflhs)
rhs.store_in(bufrhs)
kernel.store_in(bufkernel)

ConvInit.store_in(bufconv, [b, oc, oh, ow])
Conv.store_in(bufconv, [b, oc, oh, ow])
Transpose.store_in(bufconv, [b, oh, ow, oc])
MatmulInit.store_in(bufmatmul, [b, oh, ow, n])
Matmul.store_in(bufmatmul, [b, oh, ow, n])

f = tm.get_implicit_function()
f.codegen([buflhs, bufrhs, bufkernel, bufconv, bufmatmul], "matmul.o", 0, False)
f.dump_halide_stmt()
produce  {
allocate _transpose_b5[float64 * (16 - 0) * (30 - 0) * (30 - 0) * (8 - 0)] in Heap
allocate _rhs_b1[float64 * (64 - 0) * (16 - 0) * (30 - 0) * (8 - 0)] in Heap
allocate _matmul_init_b6[float64 * (64 - 0) * (30 - 0) * (30 - 0) * (8 - 0)] in Heap
allocate _lhs_b0[float64 * (32 - 0) * (32 - 0) * (3 - 0) * (8 - 0)] in Heap
allocate _kernel_b2[float64 * (3 - 0) * (3 - 0) * (3 - 0) * (16 - 0)] in Heap
allocate _conv_init_b3[float64 * (30 - 0) * (30 - 0) * (16 - 0) * (8 - 0)] in Heap
allocate _Matmul_b7[float64 * (64 - 0) * (30 - 0) * (30 - 0) * (8 - 0)] in Heap
allocate _Conv_b4[float64 * (3 - 0) * (3 - 0) * (3 - 0) * (30 - 0) * (30 - 0) * (16 - 0) * (8 - 0)] in Heap
for (c1, 0, 8 - 0) {
for (c3, 0, 30 - 0) {
for (c5, 0, 30 - 0) {
for (c7, 0, 64 - 0) {
if (c3 >= 16) {
bufmatmul[(((0 + (c7*1)) + (c5*64)) + (c3*1920)) + (c1*57600)] = 0.000000
bufmatmul[(((0 + (c7*1)) + (c5*64)) + (c3*1920)) + (c1*57600)] = (float64)bufmatmul[(((0 + (c7*1)) + (c5*64)) + (c3*1920)) + (c1*57600)] + ((float64)bufconv[(((0 + (oc*1)) + (c5*30)) + (c3*900)) + (c1*14400)]*(float64)bufrhs[(((0 + (c7*1)) + (oc*64)) + (c3*1024)) + (c1*30720)])
if (c7 <= 15) {
bufconv[(((0 + (c7*1)) + (c5*30)) + (c3*900)) + (c1*14400)] = (float64)bufconv[(((0 + (c5*1)) + (c3*30)) + (c7*900)) + (c1*14400)]
}
} else if (c7 >= 30) {
bufmatmul[(((0 + (c7*1)) + (c5*64)) + (c3*1920)) + (c1*57600)] = 0.000000
bufmatmul[(((0 + (c7*1)) + (c5*64)) + (c3*1920)) + (c1*57600)] = (float64)bufmatmul[(((0 + (c7*1)) + (c5*64)) + (c3*1920)) + (c1*57600)] + ((float64)bufconv[(((0 + (oc*1)) + (c5*30)) + (c3*900)) + (c1*14400)]*(float64)bufrhs[(((0 + (c7*1)) + (oc*64)) + (c3*1024)) + (c1*30720)])
} else {
for (c9, 0, 3 - 0) {
for (c11, 0, 3 - 0) {
for (c13, 0, 3 - 0) {
if (((c9 == 0) && (c11 == 0)) && (c13 == 0)) {
bufmatmul[(((0 + (c7*1)) + (c5*64)) + (c3*1920)) + (c1*57600)] = 0.000000
}
bufconv[(((0 + (c7*1)) + (c5*30)) + (c3*900)) + (c1*14400)] = (float64)bufconv[(((0 + (c7*1)) + (c5*30)) + (c3*900)) + (c1*14400)] + ((float64)buflhs[(((0 + ((c7 + c13)*1)) + ((c5 + c11)*32)) + (c3*1024)) + (c1*3072)]*(float64)bufkernel[(((0 + (c13*1)) + (c11*3)) + (c9*9)) + (c3*27)])
if (((c9 == 0) && (c11 == 0)) && (c13 == 0)) {
bufmatmul[(((0 + (c7*1)) + (c5*64)) + (c3*1920)) + (c1*57600)] = (float64)bufmatmul[(((0 + (c7*1)) + (c5*64)) + (c3*1920)) + (c1*57600)] + ((float64)bufconv[(((0 + (oc*1)) + (c5*30)) + (c3*900)) + (c1*14400)]*(float64)bufrhs[(((0 + (c7*1)) + (oc*64)) + (c3*1024)) + (c1*30720)])
if (c7 <= 15) {
bufconv[(((0 + (c7*1)) + (c5*30)) + (c3*900)) + (c1*14400)] = (float64)bufconv[(((0 + (c5*1)) + (c3*30)) + (c7*900)) + (c1*14400)]
}
bufconv[(((0 + (c7*1)) + (c5*30)) + (c3*900)) + (c1*14400)] = 0.000000
}
}
}
}
}
}
}
}
}
}

因为tiramisu是完全依赖手动调度, 所以这里的fusion使用的buffer需要提前手动指定, 对于手工优化算子应该会省事情, 但是集成到编译器并不是很合适.

我认为一个算子的实现取决于compute order, tiling, buffer binding三部分, 作为编译器或者开发者需要利用尽可能多的信息在这三个design space进行选择从而优化计算性能. 我们期待Tensor DSL能带来足够多的信息支持做好这件事.

目前这些DSL都提供了基本的iteration domainaccess relation的信息. 但是只有halide直接记录了循环共享的信息, mlir应该是在fusion的优化中通过affine map来推导出这个信息的. 在linalg的基本原理文档中, mlir总结了不同编译器实现的经验, 需要提供有灵活调度的能力, 又不能像Polyhedral那样复杂, 并且还要保持SSA的形式. 从fusion pass的优化结果上来看, mlir所提炼的方案还是比较实用的.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK