0

移动端算法优化

 6 months ago
source link: https://muyuuuu.github.io/2024/03/03/mobile-algorithm-optimize/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

移动端算法优化

2024-03-03

11 8.3k 8 分钟

移动端算法优化是个很庞大的话题。从计算机体系到指令,涉及到非常广而深的东西。本文尝试以常见的算法为例,阐述算法在单线程场景下的加速与优化,多线程是最后的收尾,没啥可说的。而至于具体的场景,如金字塔、滤波、降噪等,优化的思路都是相同的:减少 IO,一次 IO 完成尽可能多的计算。

本文会使用 Neon, OpenCL 来优化算法,如果有可能也会引入 DSP。本文持续更新,整理算法优化相关的经验。额外的,确保打开了 O3 编译选项,打开 release 模式等,否则会影响算法的执行时间。

注:本文不考虑数学角度的优化,如修改计算公式得到相同结果什么的。实现的浮点矩阵计算为:

简单起见,$A$ 的维度为 $512\times 128$,矩阵 $B$ 的维度为 $128 \times 256$。在高通骁龙某芯片上,目前的加速结果如下:

版本 时间
常规矩阵乘法 59.84ms
Neon 加速版本 1 12.90 ms
Neon 加速版本 2 3.85ms
Cache 友好的矩阵乘法 2.52ms
Neon 加速版本 3 2.77ms
Neon 加速版本 4 2.01ms
Neon 加速版本 5 1.09ms

为什么没 OpenCL?因为还没来得及写,仿佛欠着好多博客。

常规矩阵乘法

以线性代数中的矩阵乘法为例,目标矩阵的第 $i, j$ 个元素是矩阵 $A$ 的第 $i$ 行和矩阵 $B$ 的第 $j$ 列逐元素相乘相加的结果。根据这一原理写出最直观的代码,耗时 59.84ms:

void sgemm_c(float *C, float *A, float *B, float *bias, int d0, int d1, int d2)
{
int row, col, m;
for (row = 0; row < d0; row++) {
for (col = 0; col < d2; col++) {
for (m = 0; m < d1; m++) {
C[row * d2 + col] += A[row * d1 + m] * B[m * d2 + col];
}
C[row * d2 + col] += bias[row * d2 + col];
}
}
}

我们知道矩阵在计算机中是行朱序存储的,即访问矩阵 $B[i, j]$ 时,会将 $B[i, j+1], B[i, j+2],…$ 等元素也一同取到内存的 cache 中。当需要 $B[i, j+1]$ 时就从 cache 中读取而不是去内存读取,这样会节省很多时间。

所以上述代码的性能瓶颈在于:

for (m = 0; m < d1; m++) {
C[row * d2 + col] += A[row * d1 + m] * B[m * d2 + col];
}

由于最内层的循环中 m 逐渐增加,矩阵 $B$ 的寻址方式为跳行寻址。在我们看不见的地方,cache 缓存的数据无法使用,每次读取 $B$ 矩阵的元素时还需要刷新 cache,这就导致这份代码很耗时。

Neon 加速版本 1

void sgemm_neon1(float *C, float *A, float *B, float *bias, int d0, int d1, int d2)
{
int row, col, m;
for (row = 0; row < d0; row++) {
for (col = 0; col < d2; col+=4) {
float32x4_t sum4 = vdupq_n_f32(0.0f);

float *pa = A + row * d1;
float *pb = B + col;
float *pc = C + row * d2 + col;
float *pd = bias + row * d2 + col;

for (m = 0; m < d1; m+=4) {
float32x4_t a4 = vld1q_f32(pa);
float32x4_t b0 = vld1q_f32(pb + 0 * d2);
float32x4_t b1 = vld1q_f32(pb + 1 * d2);
float32x4_t b2 = vld1q_f32(pb + 2 * d2);
float32x4_t b3 = vld1q_f32(pb + 3 * d2);

sum4 = vmlaq_lane_f32(sum4, b0, vget_low_f32(a4), 0);
sum4 = vmlaq_lane_f32(sum4, b1, vget_low_f32(a4), 1);
sum4 = vmlaq_lane_f32(sum4, b2, vget_high_f32(a4), 0);
sum4 = vmlaq_lane_f32(sum4, b3, vget_high_f32(a4), 1);

pa += 4;
pb += 4 * d2;
}

float32x4_t d4 = vld1q_f32(pd);
sum4 = vaddq_f32(sum4, d4);
vst1q_f32(pc, sum4);
}
}
}

Neon 加速版本 2

void sgemm_neon2(float *C, float *A, float *B, float *bias, int d0, int d1, int d2)
{
int row, col, m;
for (row = 0; row < d0; row+=4) {
for (col = 0; col < d2; col+=4) {

float *pa = A + row * d1;
float *pb = B + col;
float *pc = C + row * d2 + col;
float *pd = bias + row * d2 + col;

float32x4_t sum0 = vld1q_f32(pd + 0 * d2);
float32x4_t sum1 = vld1q_f32(pd + 1 * d2);
float32x4_t sum2 = vld1q_f32(pd + 2 * d2);
float32x4_t sum3 = vld1q_f32(pd + 3 * d2);

for (m = 0; m < d1; m+=4) {
float32x4_t b0 = vld1q_f32(pb + 0 * d2);
float32x4_t b1 = vld1q_f32(pb + 1 * d2);
float32x4_t b2 = vld1q_f32(pb + 2 * d2);
float32x4_t b3 = vld1q_f32(pb + 3 * d2);

float32x4_t a0 = vld1q_f32(pa + 0 * d1);
float32x4_t a1 = vld1q_f32(pa + 1 * d1);
float32x4_t a2 = vld1q_f32(pa + 2 * d1);
float32x4_t a3 = vld1q_f32(pa + 3 * d1);

sum0 = vmlaq_lane_f32(sum0, b0, vget_low_f32(a0), 0);
sum0 = vmlaq_lane_f32(sum0, b1, vget_low_f32(a0), 1);
sum0 = vmlaq_lane_f32(sum0, b2, vget_high_f32(a0), 0);
sum0 = vmlaq_lane_f32(sum0, b3, vget_high_f32(a0), 1);

sum1 = vmlaq_lane_f32(sum1, b0, vget_low_f32(a1), 0);
sum1 = vmlaq_lane_f32(sum1, b1, vget_low_f32(a1), 1);
sum1 = vmlaq_lane_f32(sum1, b2, vget_high_f32(a1), 0);
sum1 = vmlaq_lane_f32(sum1, b3, vget_high_f32(a1), 1);

sum2 = vmlaq_lane_f32(sum2, b0, vget_low_f32(a2), 0);
sum2 = vmlaq_lane_f32(sum2, b1, vget_low_f32(a2), 1);
sum2 = vmlaq_lane_f32(sum2, b2, vget_high_f32(a2), 0);
sum2 = vmlaq_lane_f32(sum2, b3, vget_high_f32(a2), 1);

sum3 = vmlaq_lane_f32(sum3, b0, vget_low_f32(a3), 0);
sum3 = vmlaq_lane_f32(sum3, b1, vget_low_f32(a3), 1);
sum3 = vmlaq_lane_f32(sum3, b2, vget_high_f32(a3), 0);
sum3 = vmlaq_lane_f32(sum3, b3, vget_high_f32(a3), 1);

pa += 4;
pb += 4 * d2;
}

vst1q_f32(pc + 0 * d2, sum0);
vst1q_f32(pc + 1 * d2, sum1);
vst1q_f32(pc + 2 * d2, sum2);
vst1q_f32(pc + 3 * d2, sum3);
}
}
}

Cache 友好的矩阵乘法

void rsgemm_c(float *C, float *A, float *B, float *bias, int d0, int d1, int d2)
{
int row, col, m;
for(row = 0; row < d0; row++) {
for(m = 0; m < d1; m++) {
for(col = 0; col < d2; col++) {
C[row * d2 + col] += A[row * d1 + m] * B[m * d2 + col];
if (0 == m) {
C[row * d2 + col] += bias[row * d2 + col];
}
}
}
}
}

Neon 加速版本 3

void rsgemm_neon1(float *C, float *A, float *B, float *bias, int d0, int d1, int d2)
{
int row, col, m;
for (row = 0; row < d0; row++) {
for (m = 0; m < d1; m++) {

float32x4_t a4 = vdupq_n_f32(A[row * d1 + m]);
float *pb = B + m * d2;
float *pc = C + row * d2;
float *pd = bias + row * d2;

for (col = 0; col < d2; col+=4) {
float32x4_t b4 = vld1q_f32(pb);
float32x4_t c4 = vld1q_f32(pc);
float32x4_t val = vmulq_f32(a4, b4);
val = vaddq_f32(c4, val);

if (0 == m) {
val = vaddq_f32(vld1q_f32(pd), val);
}

vst1q_f32(pc, val);

pb += 4;
pc += 4;
pd += 4;
}
}
}
}

Neon 加速版本 4

void rsgemm_neon2(float *C, float *A, float *B, float *bias, int d0, int d1, int d2)
{
int row, col, m;
for (row = 0; row < d0; row++) {
for (m = 0; m < d1; m+=4) {

float *pb0 = B + (m + 0) * d2;
float *pb1 = B + (m + 1) * d2;
float *pb2 = B + (m + 2) * d2;
float *pb3 = B + (m + 3) * d2;

float *pc = C + row * d2;
float *pd = bias + row * d2;

float32x4_t a4 = vld1q_f32(A + row * d1 + m);
float32x4_t a0 = vdupq_n_f32(vgetq_lane_f32(a4, 0));
float32x4_t a1 = vdupq_n_f32(vgetq_lane_f32(a4, 1));
float32x4_t a2 = vdupq_n_f32(vgetq_lane_f32(a4, 2));
float32x4_t a3 = vdupq_n_f32(vgetq_lane_f32(a4, 3));

for (col = 0; col < d2; col+=4) {
float32x4_t c4 = vld1q_f32(pc);

c4 = vaddq_f32(c4, vmulq_f32(a0, vld1q_f32(pb0)));
c4 = vaddq_f32(c4, vmulq_f32(a1, vld1q_f32(pb1)));
c4 = vaddq_f32(c4, vmulq_f32(a2, vld1q_f32(pb2)));
c4 = vaddq_f32(c4, vmulq_f32(a3, vld1q_f32(pb3)));

if (0 == m) {
c4 = vaddq_f32(vld1q_f32(pd), c4);
}

vst1q_f32(pc, c4);

pb0 += 4;
pb1 += 4;
pb2 += 4;
pb3 += 4;

pc += 4;
pd += 4;
}
}
}
}

Neon 加速版本 5

void rsgemm_neon3(float *C, float *A, float *B, float *bias, int d0, int d1, int d2)
{
int row, col, m;
for (row = 0; row < d0; row+=4) {
for (m = 0; m < d1; m+=4) {

float *pb0 = B + (m + 0) * d2;
float *pb1 = B + (m + 1) * d2;
float *pb2 = B + (m + 2) * d2;
float *pb3 = B + (m + 3) * d2;

float *pc0 = C + (0 + row) * d2;
float *pc1 = C + (1 + row) * d2;
float *pc2 = C + (2 + row) * d2;
float *pc3 = C + (3 + row) * d2;

float *pd0 = bias + (0 + row) * d2;
float *pd1 = bias + (1 + row) * d2;
float *pd2 = bias + (2 + row) * d2;
float *pd3 = bias + (3 + row) * d2;

float32x4_t a0 = vld1q_f32(A + (row + 0) * d1 + m);
float32x4_t a1 = vld1q_f32(A + (row + 1) * d1 + m);
float32x4_t a2 = vld1q_f32(A + (row + 2) * d1 + m);
float32x4_t a3 = vld1q_f32(A + (row + 3) * d1 + m);

float32x4_t a00 = vdupq_n_f32(vgetq_lane_f32(a0, 0));
float32x4_t a01 = vdupq_n_f32(vgetq_lane_f32(a0, 1));
float32x4_t a02 = vdupq_n_f32(vgetq_lane_f32(a0, 2));
float32x4_t a03 = vdupq_n_f32(vgetq_lane_f32(a0, 3));

float32x4_t a10 = vdupq_n_f32(vgetq_lane_f32(a1, 0));
float32x4_t a11 = vdupq_n_f32(vgetq_lane_f32(a1, 1));
float32x4_t a12 = vdupq_n_f32(vgetq_lane_f32(a1, 2));
float32x4_t a13 = vdupq_n_f32(vgetq_lane_f32(a1, 3));

float32x4_t a20 = vdupq_n_f32(vgetq_lane_f32(a2, 0));
float32x4_t a21 = vdupq_n_f32(vgetq_lane_f32(a2, 1));
float32x4_t a22 = vdupq_n_f32(vgetq_lane_f32(a2, 2));
float32x4_t a23 = vdupq_n_f32(vgetq_lane_f32(a2, 3));

float32x4_t a30 = vdupq_n_f32(vgetq_lane_f32(a3, 0));
float32x4_t a31 = vdupq_n_f32(vgetq_lane_f32(a3, 1));
float32x4_t a32 = vdupq_n_f32(vgetq_lane_f32(a3, 2));
float32x4_t a33 = vdupq_n_f32(vgetq_lane_f32(a3, 3));

for (col = 0; col < d2; col+=4) {
float32x4_t c04 = vld1q_f32(pc0);
float32x4_t c14 = vld1q_f32(pc1);
float32x4_t c24 = vld1q_f32(pc2);
float32x4_t c34 = vld1q_f32(pc3);

float32x4_t b0 = vld1q_f32(pb0);
float32x4_t b1 = vld1q_f32(pb1);
float32x4_t b2 = vld1q_f32(pb2);
float32x4_t b3 = vld1q_f32(pb3);

c04 = vaddq_f32(c04, vmulq_f32(a00, b0));
c04 = vaddq_f32(c04, vmulq_f32(a01, b1));
c04 = vaddq_f32(c04, vmulq_f32(a02, b2));
c04 = vaddq_f32(c04, vmulq_f32(a03, b3));

c14 = vaddq_f32(c14, vmulq_f32(a10, b0));
c14 = vaddq_f32(c14, vmulq_f32(a11, b1));
c14 = vaddq_f32(c14, vmulq_f32(a12, b2));
c14 = vaddq_f32(c14, vmulq_f32(a13, b3));

c24 = vaddq_f32(c24, vmulq_f32(a20, b0));
c24 = vaddq_f32(c24, vmulq_f32(a21, b1));
c24 = vaddq_f32(c24, vmulq_f32(a22, b2));
c24 = vaddq_f32(c24, vmulq_f32(a23, b3));

c34 = vaddq_f32(c34, vmulq_f32(a30, b0));
c34 = vaddq_f32(c34, vmulq_f32(a31, b1));
c34 = vaddq_f32(c34, vmulq_f32(a32, b2));
c34 = vaddq_f32(c34, vmulq_f32(a33, b3));

if (0 == m) {
c04 = vaddq_f32(vld1q_f32(pd0), c04);
c14 = vaddq_f32(vld1q_f32(pd1), c14);
c24 = vaddq_f32(vld1q_f32(pd2), c24);
c34 = vaddq_f32(vld1q_f32(pd3), c34);
}

vst1q_f32(pc0, c04);
vst1q_f32(pc1, c14);
vst1q_f32(pc2, c24);
vst1q_f32(pc3, c34);

pb0 += 4;
pb1 += 4;
pb2 += 4;
pb3 += 4;

pc0 += 4;
pc1 += 4;
pc2 += 4;
pc3 += 4;

pd0 += 4;
pd1 += 4;
pd2 += 4;
pd3 += 4;
}
}
}
}

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK