4

「学习笔记」多项式算法

 2 years ago
source link: https://blunt-axe.github.io/2019/12/11/20191211-Polynomial-Algorithms-Notes/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

在 OI 中,许多涉及到生成函数的计数题都需要使用一些多项式算法,所以掌握多项式算法是必要的。

多项式加减法

加减法只要对应系数加减即可,这是线性的。

多项式乘法

乘法可以使用 FFT 完成。假设要计算 ,那么只需要把 进行 DFT,然后对应点值相乘后再 IDFT 回去就行了,时间复杂度 。

多项式求导与积分

对于多项式 ,我们有:

可以直接线性计算。

泰勒展开是多项式牛顿迭代的前置技能,它的思想就是用多项式来逼近一个函数。定义函数 在点 的泰勒展开为:

其中 表示函数 的 阶导数。 当然 不仅可以以数为参数,也可以以函数为参数。

多项式牛顿迭代

多项式牛顿迭代可以解决多项式求逆和多项式开根,它也是多项式指数函数的前置技能之一。

问题:给定 ,构造 ,满足 。

假设我们已经构造出了 满足 ,现在我们要求出 。考虑 在 上的泰勒展开:

发现 和 的前 项应该是一样的,所以我们就有:

所以就可以递归了,如果一次递归可以做到 那么时间复杂度就是 。

多项式求逆

问题:给定 ,求 。

考虑令 ,并使用牛顿迭代法。那么我们有:

对于 可以直接求逆元,总时间复杂度 。

多项式开方

问题:给定 ,求 满足 。

考虑令 ,并使用牛顿迭代法。那么我们有:

对于 需要用到二次剩余,总复杂度 。

多项式对数函数

已知 ,我们要求 。

考虑 的意义。我们有 ,我们将 展开:

那么有 ,两边同时积分得到:

那么多项式的 的 就是:

注意 的常数项必须为 。怎么求 呢?考虑复合函数求导:

两边同时积分得到:

那么使用多项式求逆计算即可,时间复杂度 。

多项式指数函数

问题:给定 ,求 。

考虑令 ,并使用牛顿迭代法。那么我们有:

时 必须是 ,此时 ,总复杂度 。

//「Luogu 5245」Polynomial Power
2
#include <bits/stdc++.h>
3
using namespace std;
4
5
const int maxn = 1 << 18, g = 3, mod = 998244353;
6
int n, m, k, a[maxn + 3], b[maxn + 3];
7
char s[maxn + 3];
8
9
namespace poly {
	int lim, bit, rev[maxn + 3];
	int A[maxn + 3], B[maxn + 3], C[maxn + 3], D[maxn + 3];
	int E[maxn + 3], F[maxn + 3], G[maxn + 3], H[maxn + 3];
	int qpow(int a, int b) {
		if (b < 0) b += mod - 1;
		int c = 1;
		for (; b; b >>= 1, a = 1ll * a * a % mod) {
			if (b & 1) c = 1ll * a * c % mod;
		}
20
		return c;
21
	}
22
23
	int func(int x) {
24
		return x < mod ? x : x - mod;
25
	}
26
27
	void dft(int a[], int n, int type) {
28
		for (int i = 0; i < n; i++) {
29
			if (i < rev[i]) swap(a[i], a[rev[i]]);
30
		}
31
		for (int k = 1; k < n; k <<= 1) {
32
			int x = qpow(g, (mod - 1) / (k << 1) * type);
33
			for (int i = 0; i < n; i += k << 1) {
34
				int y = 1;
35
				for (int j = i; j < i + k; j++, y = 1ll * x * y % mod) {
36
					int p = a[j], q = 1ll * a[j + k] * y % mod;
37
					a[j] = func(p + q), a[j + k] = func(p - q + mod);
38
				}
39
			}
40
		}
41
		if (type == -1) {
42
			int x = qpow(n, mod - 2);
43
			for (int i = 0; i < n; i++) {
44
				a[i] = 1ll * a[i] * x % mod;
45
			}
46
		}
47
	}
48
49
	void mult(int a[], int b[], int c[], int n) {
50
		// a = b = c is ok
51
		for (lim = 1, bit = 0; lim <= n * 2; lim <<= 1) bit++;
52
		for (int i = 1; i < lim; i++) {
53
			rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
54
		}
55
		copy(a, a + lim, A);
56
		copy(b, b + lim, B);
57
		fill(A + n + 1, A + lim, 0);
58
		fill(B + n + 1, B + lim, 0);
59
		dft(A, lim, 1), dft(B, lim, 1);
60
		for (int i = 0; i < lim; i++) {
61
			A[i] = 1ll * A[i] * B[i] % mod;
62
		}
63
		dft(A, lim, -1);
64
		copy(A, A + n * 2 + 1, c);
65
	}
66
67
	void inv(int a[], int b[], int n) {
68
		// a != b
69
		if (n == 0) {
70
			b[0] = qpow(a[0], mod - 2);
71
			return;
72
		}
73
		int m = n >> 1;
74
		inv(a, b, m);
75
		copy(a, a + n + 1, C);
76
		copy(b, b + m + 1, D);
77
		fill(D + m + 1, D + n + 1, 0);
78
		mult(D, D, D, m);
79
		mult(C, D, C, n);
80
		for (int i = 0; i <= n; i++) {
81
			b[i] = func(b[i] * 2 % mod - C[i] + mod);
82
		}
83
	}
84
85
	void sqrt(int a[], int b[], int n) {
86
		// a != b
87
		if (n == 0) {
88
			assert(a[0] == 1);
89
			b[0] = 1;
90
			return;
91
		}
92
		int m = n >> 1;
93
		sqrt(a, b, m);
94
		copy(b, b + m + 1, E);
95
		fill(E + m + 1, E + n + 1, 0);
96
		fill(F, F + n + 1, 0);
97
		mult(E, E, E, m);
98
		for (int i = 0; i <= n; i++) {
99
			E[i] = func(E[i] + a[i]);
		}
		inv(b, F, n);
		mult(E, F, E, n);
		for (int i = 0; i <= n; i++) {
			b[i] = 1ll * (mod + 1) / 2 * E[i] % mod;
		}
	}
	void ln(int a[], int b[], int n) {
		// a = b is ok
		for (int i = 1; i <= n; i++) {
			E[i - 1] = 1ll * a[i] * i % mod;
		}
		E[n] = 0;
		fill(F, F + n + 1, 0);
		inv(a, F, n);
		mult(E, F, E, n);
		b[0] = 0;
		for (int i = 1; i <= n; i++) {
			b[i] = 1ll * E[i - 1] * qpow(i, mod - 2) % mod;
		}
	}
	void exp(int a[], int b[], int n) {
		// a != b
		if (n == 0) {
			assert(a[0] == 0);
			b[0] = 1;
			return;
		}
		int m = n >> 1;
		exp(a, b, m);
		copy(b, b + m + 1, G);
		fill(G + m + 1, G + n + 1, 0);
		copy(G, G + n + 1, H);
		ln(H, H, n);
		for (int i = 0; i <= n; i++) {
			H[i] = func(a[i] - H[i] + mod);
		}
		H[0] = func(H[0] + 1);
		mult(G, H, G, n);
		copy(G, G + n + 1, b);
	}
	void pow(int a[], int b[], int n, int k) {
		// a != b
		ln(a, a, n);
		for (int i = 0; i <= n; i++) {
			a[i] = 1ll * a[i] * k % mod;
		}
		exp(a, b, n);
	}
}
int main() {
	scanf("%d %s", &n, s + 1);
	n--, m = strlen(s + 1);
	for (int i = 1; i <= m; i++) {
		k = ((10ll * k) + (s[i] ^ '0')) % mod;
	}
	for (int i = 0; i <= n; i++) {
		scanf("%d", &a[i]);
	}
	poly::pow(a, b, n, k);
	for (int i = 0; i <= n; i++) {
		printf("%d%c", b[i], " \n"[i == n]);
	}
	return 0;
}

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK