5

tvm dynamic shape 学习

 8 months ago
source link: https://zhen8838.github.io/2023/11/15/dynamic-shape/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

tvm dynamic shape 学习

探究tvm dynamic shape的实现.

tvm ir design

ir.png

将relax ir的语法dump出来可以知道, 这里与relay那种数据流的ir不同, dataflow中的每个操作使用一个var binding来存储.

@R.function
def fn1(a: R.Tensor(("n", 10), 'float32'), b: R.Tensor((1,), 'float32')):
with R.dataflow():
n = T.int64()
c: R.Tensor((n, 10)) = a + b
R.output(c)
return c
Function(
params=[
Var(
name_hint="a",
struct_info=TensorStructInfo(
dtype=float32,
shape=ShapeExpr(
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
],
struct_info=ShapeStructInfo(
ndim=2,
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
]
)
)
)
),
Var(
name_hint="b",
struct_info=TensorStructInfo(
dtype=float32,
shape=ShapeExpr(
values=[PrimExpr(value=`T.int64(1)`)],
struct_info=ShapeStructInfo(
ndim=1,
values=[PrimExpr(value=`T.int64(1)`)]
)
)
)
)
],
body=SeqExpr(
blocks=[
BindingBlock(
bindings=[
VarBinding(
var=Var(
name_hint="c",
struct_info=TensorStructInfo(
dtype=,
shape=ShapeExpr(
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
],
struct_info=ShapeStructInfo(
ndim=2,
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
]
)
)
)
),
value=Call(
op=Op(name="relax.add"),
args=[
Var(
name_hint="a",
struct_info=TensorStructInfo(
dtype=float32,
shape=ShapeExpr(
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
],
struct_info=ShapeStructInfo(
ndim=2,
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
]
)
)
)
),
Var(
name_hint="b",
struct_info=TensorStructInfo(
dtype=float32,
shape=ShapeExpr(
values=[PrimExpr(value=`T.int64(1)`)],
struct_info=ShapeStructInfo(
ndim=1,
values=[PrimExpr(value=`T.int64(1)`)]
)
)
)
)
],
struct_info=TensorStructInfo(
dtype=,
shape=ShapeExpr(
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
],
struct_info=ShapeStructInfo(
ndim=2,
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
]
)
)
)
)
)
]
)
],
body=Var(
name_hint="c",
struct_info=TensorStructInfo(
dtype=,
shape=ShapeExpr(
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
],
struct_info=ShapeStructInfo(
ndim=2,
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
]
)
)
)
),
struct_info=TensorStructInfo(
dtype=,
shape=ShapeExpr(
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
],
struct_info=ShapeStructInfo(
ndim=2,
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
]
)
)
)
),
ret_struct_info=TensorStructInfo(
dtype=,
shape=ShapeExpr(
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
],
struct_info=ShapeStructInfo(
ndim=2,
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
]
)
)
),
is_pure=1,
attrs={"global_symbol": "fn1"},
struct_info=FuncStructInfo(
params=[
TensorStructInfo(
dtype=float32,
shape=ShapeExpr(
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
],
struct_info=ShapeStructInfo(
ndim=2,
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
]
)
)
),
TensorStructInfo(
dtype=float32,
shape=ShapeExpr(
values=[PrimExpr(value=`T.int64(1)`)],
struct_info=ShapeStructInfo(
ndim=1,
values=[PrimExpr(value=`T.int64(1)`)]
)
)
)
],
ret=TensorStructInfo(
dtype=,
shape=ShapeExpr(
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
],
struct_info=ShapeStructInfo(
ndim=2,
values=[
PrimExpr(value=`n`),
PrimExpr(value=`T.int64(10)`)
]
)
)
),
purity=True
)
)

而根据relax shape设计文档下面这种情况应该是无法支持的:

@R.function
def fn2(a: R.Tensor(("n", 10), 'float32'), b: R.Tensor((1,), 'float32')):
with R.dataflow():
n = T.int64()
c = a + b
cshape: R.Shape() = R.shape_of(c)
d = R.reshape(c, [1, cshape[0], cshape[1], 1])
R.output(d)
return d

我在思考是不是应该有一种直接基于数据流的方式来添加symbolic shape的信息?


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK