跳转至

Shape and MDSpan

概述

在 Choreo 中,形状(shape)是一等公民。本节将介绍如何在 Choreo 代码中对形状进行编程。

一等公民:形状

Choreo 的主要职能是管理数据移动,这对于高效组织与处理大规模数据集至关重要,尤其在机器学习与高性能计算场景中。然而,大多数 C++ 编程环境对数据的处理较为随意——要么视为扁平结构(指针),要么视为层次结构(数组),而缺乏与形状相关联的原生表示。

与之相对,Choreo 通过要求凡声明或使用的数据均须关联形状,来强化代码安全性并简化带形状数据的编程,从而使形状成为一等公民。

使用 mdspan 定义形状

在 Choreo 中,形状可通过关键字 mdspan 定义,其含义为 Multi-Dimensional Span(多维跨度)。该关键字表示用于计算的多维数据。

可按如下方式显式定义 mdspan 变量:

mdspan s0 : [7, 8]; // Defines a 2D shape with dimensions [7, 8]
mdspan<1> s1 : [3]; // Defines a 1D shape with dimensions [3]

本例中,前置关键字 mdspan 表示声明 mdspan 变量,其后为用户指定的变量名。每个 mdspan 变量均由初始化表达式初始化,该表达式由置于 [] 内的、以逗号分隔的整数值构成。紧随变量名之后的符号 : 引出初始化表达式。 因此,s0 定义为二维形状,具有 7 行与 8 列;s1 为一维形状,维度为 3

可在 mdspan 关键字后使用 <> 可选地指定 rank(秩)。若显式给出 rank,则告知 Choreo 编译器进行秩一致性检查;若 rank 与对应的初始化表达式不一致,将在编译期报错。示例如下:

mdspan<3> s2 : [64, 32]; // error: the rank of mdspan is inconsistent

除显式 mdspan 声明外,Choreo 编译器亦可从初始化表达式推断 mdspan 的类型,从而无需显式书写 mdspan 关键字:

s3 : [7, 8, 9]

在该代码中,s3 为秩为 3mdspan,维度为 7, 8, 9。由于 Choreo 要求 mdspan 必须在声明处完成初始化,在实践中更倾向使用类型推断写法。

推导 mdspan

在实际编码中,常需由已有形状推导新形状。例如,进行数据分块(tiling)阻塞(blocking)时需对维度进行划分;或对特定维度进行填充(pad),即在形状维度上增加增量。

在 Choreo 中,此类形状推导可通过对 mdspan 进行算术运算而简便完成。下列代码展示一例:

shape : [128, 64]; // initial shape 
new-shape0 : shape [(0) / 2, (1) / 4, 1];  // tile and reshape: [1, 64, 16]
new-shape1 : shape [(1) + 2, (0) / 16];    // pad and reshape: [66, 8]

本例中,new-shape0shape 推导得到:第 0 维除以 2,第 1 维除以 4,对应高层语义中的分块操作。此外,代码向 new-shape 增加了一个新维度;在高层语义中,该操作常称为重塑(reshape)

在 Choreo 中,new-shape0 的定义等价于:

new-shape0: [shape(0) / 2, shape(1) / 4, 1];

此处初始 shape 以逐元素形式显式列于 [] 内,而非在 [] 外单独指定。但与前一写法类似,元素访问操作以 () 标注,作用于已有形状以取得各维数值。显然,该写法代码量更大但结果相同,故前一写法可视为完整初始化表达式语法糖

在代码示例中,new-shape1 亦由 shape 推导:对第 1 维填充 2,并在推导形状中交换了维度顺序。

此外,亦可整体使用 mdspan 进行推导:

shape : [32, 72]
new-shape0 : shape;  // [32, 72]
new-shape1 : shape + 1;  // [33, 73]
new-shape2 : shape / 4;  // [8, 18]
new-shape3 : [shape, 6]; // [32, 72, 6]

注意,对 mdspan 变量的算术运算是按维应用的。因此语句

new-shape1 : shape + 1;

等价于

new-shape1 : shape [(0) + 1, (1) + 1];

new-shape3 的推导定义表明:在 mdspan 初始化表达式中使用 mdspan 变量会产生拼接行为。因此声明

new-shape3 : [shape, 6];

等价于

new-shape3 : shape [(0), (1), 6];

前一写法同样可视为完整定义的语法糖

mdspan 的求值

截至目前,我们见到的 mdspan 均为常量取值,这在诸多场景下已足够,因为高性能设备核函数常需依据固定的输入数据形状进行精细调优。此时 mdspan编译期求值,其取值不产生额外的运行时间或存储开销。

然而,某些场景需要运行时形状(部分维度在执行时确定)以构建核函数。Choreo 通过为 mdspan 提供符号维度(Symbolic Dimensions)支持该需求,相关内容将于后文介绍。此类情形下,维度取值可能需要在运行时求值。所幸 Choreo 在进入 choreo 函数后、于 Choreo 生成的主机代码中立即完成该求值,启动开销可忽略不计。因此,在通常情况下,程序员可忽略与 mdspan 相关的开销