2

现代硬件算法[7.4]: 蒙哥马利乘法

 11 months ago
source link: https://no5-aaron-wu.github.io/2023/09/18/HPC-7-4-AMH-MontgomeryMultiplication/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

旭穹の陋室

现代硬件算法[7.4]: 蒙哥马利乘法

发表于2023-09-18|更新于2023-10-10|高性能计算
阅读量:2

蒙哥马利乘法

不出所料,模运算中的大部分计算通常用于取模运算,其速度与一般的整数除法一样慢,通常需要15-20个周期,具体取决于操作数的大小。

解决这个问题的最好方法是完全避免使用取模运算,通过使用分支预测来推迟或代替它,例如在计算模之和时就可以这样做:

const int M = 1e9 + 7;

// input: array of n integers in the [0, M) range
// output: sum modulo M
int slow_sum(int *a, int n) {
int s = 0;
for (int i = 0; i < n; i++)
s = (s + a[i]) % M;
return s;
}

int fast_sum(int *a, int n) {
int s = 0;
for (int i = 0; i < n; i++) {
s += a[i]; // s < 2 * M
s = (s >= M ? s - M : s); // will be replaced with cmov
}
return s;
}

int faster_sum(int *a, int n) {
long long s = 0; // 64-bit integer to handle overflow
for (int i = 0; i < n; i++)
s += a[i]; // will be vectorized
return s % M;
}

然而,有时你需要进行一连串的模乘法(指对两个数进行乘法运算,然后取结果的模)运算,这时没有好的方法可以避免取模运算,除了使用一些整数除法技巧(这要求模值为常数)和一些预计算。

但是,这里有一种专为模运算设计的技术,称为蒙哥马利乘法(Montgomery multiplication)。

蒙哥马利空间

蒙哥马利乘法首先将乘数转换到蒙哥马利空间(Montgomery space),在这个空间中可以以较低的代价执行模乘法,然后当需要实际的值时,再将它们转换回去。与常规的整数除法方法不同,对于只进行一次模运算而言,蒙哥马利乘法并不高效,只有在执行一系列模操作时才值得使用。

这个空间是由模数nnn和一个与nnn互质的正整数r≥nr≥nr≥n定义的。算法涉及到对r的模运算和除法,所以在实践中,rrr通常被选择为2322^{32}232或2642^{64}264,这样这些操作可以分别通过右移和位与操作来完成。

定义:一个数xxx在蒙哥马利空间中的表示x‾\overline{x}x被定义为

x‾=x⋅rmodn\overline{x} = x \cdot r \mod n x=x⋅rmodn

这种转换的计算包括一次乘法和一次取模运算——这是一种我们最初想要优化的代价昂贵的操作——这就是为什么我们只在数字与蒙哥马利空间表示相互转换的额外开销值得时才使用这种方法,而不是用于一般的模乘法。

在蒙哥马利空间中,加法、减法和相等性检查与平常相同:

x⋅r+y⋅r≡(x+y)⋅rmodnx \cdot r + y \cdot r ≡ (x + y) \cdot r \mod n x⋅r+y⋅r≡(x+y)⋅rmodn

但是,乘法却并非如此。我们将蒙哥马利空间中的乘法表示为 ∗*∗ ,将普通的乘法表示 ⋅\cdot⋅ ,我们期望的结果是:

x‾∗y‾=x⋅y‾=(x⋅y)⋅rmodn\overline{x} * \overline{y} = \overline{x\cdot y} = (x \cdot y) \cdot r \mod n x∗y​=x⋅y​=(x⋅y)⋅rmodn

但是对于蒙哥马利空间中的普通乘法有:

x‾⋅y‾=(x⋅y)⋅r⋅rmodn\overline{x} \cdot \overline{y} = (x \cdot y) \cdot r \cdot r \mod n x⋅y​=(x⋅y)⋅r⋅rmodn

因此,蒙哥马利空间中的乘法被定义为:

x‾∗y‾=x‾⋅y‾⋅r−1modn\overline{x} * \overline{y} = \overline{x} \cdot \overline{y} \cdot r^{-1} \mod n x∗y​=x⋅y​⋅r−1modn

这意味着,在蒙哥马利空间中正常乘两个数后,我们需要通过乘以r−1r^{−1}r−1来“减小”结果并取模 —— 并且有一种有效的方式来执行这个特定的“减小”操作。

蒙哥马利模余法(Montgomery reduction)

假设r=232r = 2^{32}r=232,模数nnn是32位的,我们需要“减小”的数xxx是64位的(两个32位数的乘积),我们的目标是计算y=x⋅r−1modny = x \cdot r^{-1} \mod ny=x⋅r−1modn。

由于rrr与nnn互质,我们知道在[0,n)[0,n)[0,n)范围内有两个数字r−1r^{-1}r−1和n′n'n′满足:

r⋅r−1+n⋅n′=1r \cdot r^{-1} + n \cdot n' = 1 r⋅r−1+n⋅n′=1

其中r−1r^{-1}r−1和n′n'n′都是可以计算的,使用拓展欧几里得算法

利用这一特性,我们可以将r⋅r−1r \cdot r^{-1}r⋅r−1表示为(1−n⋅n′)(1 - n \cdot n')(1−n⋅n′),并将x⋅r−1x \cdot r^{-1}x⋅r−1写为

x⋅r−1=x⋅r⋅r−1/r=x⋅(1−n⋅n′)/r=(x−x⋅n⋅n′)/r≡(x−x⋅n⋅n′+k⋅r⋅n)/r(modn)(foranyintegerk)≡(x−(x⋅n′−k⋅r)⋅n)/r(modn)\begin{align} x \cdot r^{-1} &= x \cdot r \cdot r^{-1} / r \\ &= x \cdot (1 - n \cdot n') / r \\ &= (x - x \cdot n \cdot n') / r \\ &≡ (x - x \cdot n \cdot n' + k \cdot r \cdot n) / r \quad (\mod n) \quad (for \space any \space integer \space k) \\ &≡ (x - (x \cdot n' - k \cdot r)\cdot n) /r \quad (\mod n ) \end{align} x⋅r−1​=x⋅r⋅r−1/r=x⋅(1−n⋅n′)/r=(x−x⋅n⋅n′)/r≡(x−x⋅n⋅n′+k⋅r⋅n)/r(modn)(for any integer k)≡(x−(x⋅n′−k⋅r)⋅n)/r(modn)​​

现在,如果我们令kkk为⌊x⋅n′/r⌋⌊x⋅n′/r⌋⌊x⋅n′/r⌋(乘积x⋅n′x \cdot n'x⋅n′的高32位(原文中是64位)),就能进行约简,k⋅r−x⋅n′k \cdot r - x \cdot n'k⋅r−x⋅n′就等于x⋅n′modrx \cdot n' \mod rx⋅n′modr(x⋅n′x \cdot n'x⋅n′的低32位),有:

x⋅r−1≡(x−x⋅n′modr⋅n)/rx \cdot r^{-1} ≡ (x - x \cdot n' \mod r \cdot n)/r x⋅r−1≡(x−x⋅n′modr⋅n)/r

算法本身就是在计算这个公式,执行两次乘法来计算$q=x⋅n’ \mod r $和 m=q⋅nm=q⋅nm=q⋅n,然后从xxx中减去结果,然后通过右移执行除以rrr操作。

唯一需要注意的是,结果可能不在[0,n)[0,n)[0,n)范围内,但是由于

x<n⋅n<r⋅n⟹x/r<nx < n \cdot n < r \cdot n \Longrightarrow x/r < n x<n⋅n<r⋅n⟹x/r<n
m=q⋅n<r∗n⟹m/r<nm = q \cdot n < r * n \Longrightarrow m /r < n m=q⋅n<r∗n⟹m/r<n

这就能保证

−n<(x−m)/r<n-n < (x-m)/r < n −n<(x−m)/r<n

因此,我们可以简单地检查结果是否为负,若为负,则加上n,有以下算法:

typedef __uint32_t u32;
typedef __uint64_t u64;

const u32 n = 1e9 + 7, nr = inverse(n, 1ull << 32);

u32 reduce(u64 x) {
u32 q = u32(x) * nr; // q = x * n' mod r
u64 m = (u64) q * n; // m = q * n
u32 y = (x - m) >> 32; // y = (x - m) / r
return x < m ? y + n : y; // if y < 0, add n to make it be in the [0, n) range
}

最后一次的检查相对便宜,但仍然在关键路径上。如果我们可以接受结果在[0,2n−2][0,2n-2][0,2n−2]范围内,而不是[0,n)[0,n)[0,n)范围内,我们可以移除这个检查,并无条件地将nnn添加到结果中。

u32 reduce(u64 x) {
u32 q = u32(x) * nr;
u64 m = (u64) q * n;
u32 y = (x - m) >> 32;
return y + n
}

我们也可以将>>32操作在计算图中提前一步,计算⌊x/r⌋−⌊m/r⌋⌊x/r⌋−⌊m/r⌋⌊x/r⌋−⌊m/r⌋,而不是计算(x−m)/r(x-m)/r(x−m)/r。这样做是没问题的,因为xxx和mmm的低32位在任何情况下都是相等的,因为有

m=x⋅n′⋅n≡x(modr)m = x \cdot n' \cdot n ≡ x (\mod r) m=x⋅n′⋅n≡x(modr)

为什么我们会主动选择进行两次右移,而不是只进行一次呢?这样做是有利的,因为对于((u64) q * n) >> 32,我们需要执行一个32位乘以32位的乘法,并取结果的上32位(x86的mul指令已经将其写入到一个单独的寄存器 ,所以这不会有任何额外的代价),而另一个右移 x >> 32 不在关键路径上。

u32 reduce(u64 x) {
u32 q = u32(x) * nr;
u32 m = ((u64) q * n) >> 32;
return (x >> 32) + n - m;
}

蒙哥马利乘法相比其他模余方法的主要优势之一是它不需要非常大的数据类型:它只需要一个r×rr × rr×r乘法,该乘法提取结果的低rrr位和高rrr位,这在大多数硬件上都有特定的支持,也使得它容易推广到SIMD和更大的数据类型。

typedef __uint128_t u128;

u64 reduce(u128 x) const {
u64 q = u64(x) * nr;
u64 m = ((u128) q * n) >> 64;
return (x >> 64) + n - m;
}

请注意,一般的整数除法技巧无法实现128位对64位的模除运算:编译器会退化成调用一个慢速的长整数运算库函数来支持它。

更快的逆元变换

Montgomery乘法本身很快,但它需要一些预计算:

  • 对nnn模rrr求逆元以计算n′n'n′,
  • 将一个数转换到Montgomery空间,
  • 将一个数从Montgomery空间转换出来。

上面实现的reduce方法已经可以有效地执行最后一步操作,但是前两步还可以稍微优化一下。

计算n′=n−1modrn'=n^{-1} \mod rn′=n−1modr,有比使用扩展欧几里得算法更快的方法,这是因为rrr是222的幂,可以利用以下的特性:

a⋅x≡1mod2k⟹a⋅x⋅(2−a⋅x)≡1mod22ka \cdot x ≡ 1 \mod 2^k \Longrightarrow a \cdot x \cdot (2 - a \cdot x) ≡ 1 \mod 2^{2k} a⋅x≡1mod2k⟹a⋅x⋅(2−a⋅x)≡1mod22k

证明如下:

a⋅x⋅(2−a⋅x)=2⋅a⋅x−(a⋅x)2=2⋅(1+m⋅2k)−(1+m⋅2k)2=2+2⋅m⋅2k−1−2⋅m⋅2k−m2⋅22k=1−m2⋅22k≡1mod22k\begin{align} a \cdot x \cdot (2 - a \cdot x) &= 2 \cdot a \cdot x - (a \cdot x)^2 \\ &=2 \cdot (1 + m \cdot 2^k)-(1 + m \cdot 2^k)^2 \\ &=2 + 2\cdot m \cdot 2^k - 1 - 2 \cdot m \cdot 2^k - m^2 \cdot 2^{2k} \\ &=1 - m^2 \cdot 2^{2k} \\ &≡ 1 \mod 2^{2k} \end{align} a⋅x⋅(2−a⋅x)​=2⋅a⋅x−(a⋅x)2=2⋅(1+m⋅2k)−(1+m⋅2k)2=2+2⋅m⋅2k−1−2⋅m⋅2k−m2⋅22k=1−m2⋅22k≡1mod22k​​

我们一开始可以用x=1x=1x=1作为amod21a \mod 2^1amod21的逆(因为a−1=a21−2=1a^{-1} = a^{2^1-2} = 1a−1=a21−2=1),然后应用上面这个等式log2rlog_2rlog2​r次,每次都会将逆中的bit数翻倍 - 这有点类似于牛顿法

将一个数转换到蒙哥马利空间

可以通过将其乘以rrr并进行取模运算来实现,但我们也可以利用下面这个等式:

x‾=x⋅rmodn=x∗r2\overline{x} = x \cdot r \mod n = x * r^2 x=x⋅rmodn=x∗r2

将一个数字转换到蒙哥马利空间只需要乘以r2r^2r2。因此,我们可以预先计算r2modnr^2 \mod nr2modn,然后执行乘法和模余操作,这样做速度可能会更快,也可能不会更快,因为将一个数字乘以r=2kr=2^kr=2k可以用左移位实现,而将一个数乘以r2modnr^2 \mod nr2modn则无法利用左移位。

将所有内容封装到单个constexpr结构体:

struct Montgomery {
u32 n, nr;

constexpr Montgomery(u32 n) : n(n), nr(1) {
// log(2^32) = 5
for (int i = 0; i < 5; i++)
nr *= 2 - n * nr;
}

u32 reduce(u64 x) const {
u32 q = u32(x) * nr;
u32 m = ((u64) q * n) >> 32;
return (x >> 32) + n - m;
// returns a number in the [0, 2 * n - 2] range
// (add a "x < n ? x : x - n" type of check if you need a proper modulo)
}

u32 multiply(u32 x, u32 y) const {
return reduce((u64) x * y);
}

u32 transform(u32 x) const {
return (u64(x) << 32) % n;
// can also be implemented as multiply(x, r^2 mod n)
}
};

为了测试其性能,我们可以将蒙哥马利乘法插入到二进制幂运算中:

constexpr Montgomery space(M);

int inverse(int _a) {
u64 a = space.transform(_a);
u64 r = space.transform(1);

#pragma GCC unroll(30)
for (int l = 0; l < 30; l++) {
if ( (M - 2) >> l & 1 )
r = space.multiply(r, a);
a = space.multiply(a, a);
}

return space.reduce(r);
}

编译器生成的普通二进制幂运算,即使是使用快速模运算技巧,每次inverse也需要大约170纳秒,而这个实现只需要大约166纳秒,如果我们忽略transformreduce(一个合理的用例是用inverse作为更大的模运算中的子过程),这个时间可以降低到大约158纳秒。这是一个小的改进,但对于SIMD应用程序和更大的数据类型,蒙哥马利乘法变得更有优势。

练习题:实现高效的模矩阵乘法


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK