多项式初步学习

数年之前就听闻莫反,FFT,NTT等数论变换的名称,但是一直未学习相关知识

最近学习后量子密码学,遇到了类似数论变换,辄学习一下

名词区分

1、DFT(Discrete Fourier Transform):离散傅立叶变换 计算多项式乘法 2、FFT(Fast Fourier Teansformation):快速傅立叶变换 计算多项式乘法 3、(F)NTT(Number Theoretic Transform):(快速)数论变换 优化常数和误差,适用于整数域 4、MTT(any Module NTT):NTT的扩展 任意模数

离散傅里叶变换DFT

这是一个朴素算法,用于将一个多项式在时间由系数表示法转化为点值表示法

原理:将一个用系数表示的多项式转化成它的点值表示的算法

对于一个次的项多项式可以表示为

系数表示法:

点值表示法: 计算两个多项式相乘

对于系数表示法,需要每一项和每一项的系数相乘,时间复杂度显然是

对于点值表示法——

,时间复杂度是

已知 个点值,可以唯一确定一个 阶多项式

证明

已知 可以写成矩阵形式 中间的那列就是一个范德蒙行列式了,秩为1,所以有且仅有一个解,即多项式的系数确定

证毕。

利用单位圆进行转化

向上填充为2的整数次幂,然后将单位圆平均取个点

在单位元上,我们定义

对于点​,​成立

此时,即为该单位圆上的单位根(即该循环群中的单位元)

单位根的性质

很有用的

性质一(相消引理): 这两个说的本质上是一个点

性质二(折半引理): 关于原点对称(向量等大反向)

显而易见的

实现方式

把多项式的离散傅里叶变换结果作为另一个多项式的系数,去单位根的倒数即作为代入,得到的每个数再除以,得到的是的各项系数

实现了傅里叶变换的逆变换——把点值表示转换成多项式系数表示

证明

为多项式的离散傅里叶变换。

现在我们再设一个多项式,现在我们把上面的个单位根的倒数,即作为代入, 得到一个新的离散傅里叶变换 这个是可求的:当 时,原式=;否则,通过等比数列求和可以得知

证毕

具体代码不做赘述,因为DFT和朴素算法的时间复杂度相同,在这里仅用于为FFT打基础

快速傅里叶变化FFT

用途:1.高精度乘法 2.分离正弦波

原理

FFT和DFT的不同之处在于,傅里叶的时代并没有计算机,所以没有优化时间复杂度的需求;因而虽然DFT的计算是基于单位圆的,但是(求值和差值的)时间复杂度仍旧是;而FFT则采用了分治的思想,将求值和差值的时间复杂度降为

推导

按照下标奇偶性划分为两部分

(注意这里的次方数)

已知,不妨设 ,代入得到 对于,有: 故只需要知道 分别在 的点值表示,就可以 计算 的点值表示

依此,就可以递归实现

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#这里的代码是ai给的,感觉不太靠谱的亚子
import numpy as np

def fft(a):
n = len(a)
if n <= 1:
return a
even = fft(a[0::2])
odd = fft(a[1::2])
t = [np.exp(-2j * np.pi * k / n) * odd[k] for k in range(n // 2)]
return [even[k] + t[k] for k in range(n // 2)] + [even[k] - t[k] for k in range(n // 2)]

def ifft(a):
n = len(a)
a_conj = [np.conjugate(x) for x in a]
y = fft(a_conj)
return [np.conjugate(x) / n for x in y]

def polynomial_multiply(p, q):
n = len(p) + len(q) - 1
m = 1 << (n - 1).bit_length() # Next power of two
p += [0] * (m - len(p))
q += [0] * (m - len(q))

fft_p = fft(p)
fft_q = fft(q)
fft_result = [fft_p[i] * fft_q[i] for i in range(m)]
result = ifft(fft_result)

return [round(x.real) for x in result]

# 示例
p = [1, 2, 3] # 1 + 2x + 3x^2
q = [4, 5] # 4 + 5x
result = polynomial_multiply(p, q)
print(result) # 输出: [4, 13, 22, 15]

后续还有优化版FFT,插个眼,以后更新

(快速)数论变换 (F)NTT

好的,假设我学会了fft,可以开始学ntt了(大雾

不同点 FFT NTT
定义域 主要在复数域中进行,利用复数的旋转性质 使用复数的单位根,通常是复数的n次方根
根的选择 在有限域(通常是素数模数)中进行,适合用于整数运算 使用模p的原根,这些根在有限域中是整数
应用领域 信号处理、图像处理、数字滤波等 密码学、计算机代数和一些整数计算问题,如大数乘法等
计算方式 浮点数运算 完全整数

对于质数,原根满足

看做的等价,其满足相应的性质,如

快速数论变化(FNTT),是数论变换(NTT)增加分治操作之后的快速算法,与快速傅里叶变换使用的分治办法完全一致

FFT中用到复数,需要使用类型来计算,导致精度降低,所以需要使用原根来替代单位根 就像998244353,469762049,1004535809,它们的原根都是3

任意模数NTT MTT

如果模数不是以上几种,我们需要自己取模数

取一些模数使得答案多项式的系数在取模之前不会超过

一般而言取三个质数即可(998244353,469762049,1004535809)

先计算答案对每个 取模的结果,利用中国剩余定理就可以求得答案对 取模的结果,这个结果就是答案,最后将这个答案对题目中的模数取一次模即可

例题:洛谷 P4245

一道MTT模板题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#include <algorithm>
#include <cstdio>
#define int long long
int mod;

// 这种写法我也第一次见,边写边学吧
// 直接不开在全局std里了,单独一个Math
namespace Math{
inline int qpow( int base , int p , const int mod ){
static int res;
for( res = 1 ; p ; p >>= 1 , base = base * base % mod )
if( p & 1 )
res = res * base % mod;
return res;
}
inline int inv( int x , const int mod ){
return qpow( x , mod-2 , mod );//费马小定理求逆元
}
}

const int mod1 = 998244353, mod2 = 1004535809, mod3 = 469762049, G = 3;
const int mod_1_2 = mod1 * mod2;
const int inv_1 = Math::inv(mod1, mod2), inv_2 = Math::inv(mod_1_2 % mod3, mod3);

struct Int{
int A , B , C;
// 空的默认构造函数
Int(){}
// 这个构造函数允许使用一个整数来初始化 A, B, C,它们都将被初始化为同一个值 __num
Int( int __num ): A(__num) , B(__num) , C(__num) {}
// 允许通过三个整数分别初始化 A, B, C
Int( int __A , int __B , int __C ): A(__A) , B(__B) , C(__C) {}
// 好神奇的操作,研究半天也没明白原理
// 只知道它可以做减法,出现负数就加上一个模数,只需要传入一个指针
static Int reduce( const Int & x ){
return Int( x.A + (x.A >> 31 & mod1) , x.B + (x.B >> 31 & mod2) , x.C + (x.C >> 31 & mod3) );
}
// 加减乘除的重载运算符,很精妙的写法
// 不太懂lhs和rhs是什么,只知道大概是两个input量,不像是数据结构里树的左孩子和右孩子
friend Int operator + ( const Int &lhs , const Int & rhs ){
return reduce(Int(lhs.A + rhs.A - mod1, lhs.B + rhs.B - mod2, lhs.C + rhs.C - mod3));
}
friend Int operator - ( const Int &lhs , const Int & rhs ){
return reduce(Int(lhs.A - rhs.A, lhs.B - rhs.B, lhs.C - rhs.C));
}
friend Int operator * ( const Int &lhs , const Int & rhs ){
return Int( lhs.A * rhs.A % mod1 , lhs.B * rhs.B % mod2 , lhs.C * rhs.C % mod3 );
}
int get(){
int x = (B - A + mod2) % mod2 * inv_1 % mod2 * mod1 + A;
return ((C - x % mod3 + mod3) % mod3 * inv_2 % mod3 * (mod_1_2 % mod) % mod + x) % mod;
}
};

#define maxn 200010 //maxn表示处理的最大元素数量
namespace Poly{
#define N (maxn << 1) //N 定义为 maxn 的两倍,表示用于 NTT 的数组大小。
/*
lim:表示当前处理的长度,是最小的2的幂大于等于输入大小的值
s:记录 lim 的二进制位数
rev:用于存储每个索引的反转(bit-reversal)值
Wn:预计算的旋转因子(根单位元),用于 NTT 计算
*/
int lim , s , rev[N];
Int Wn[N|1];
/*
初始化 NTT 相关参数
计算 lim 为不小于 n 的最小的2的幂
生成反转索引 rev,用于在 NTT 中重排数据
计算旋转因子 t,用于每个模数(mod1, mod2, mod3),通过幂函数 Math::pw 计算
初始化 Wn 数组,预计算旋转因子
*/
void init( int n ){
s = -1 , lim = 1;
while( lim < n ) lim <<= 1 , s ++;//填充
for(int i = 1;i < lim;i ++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
const Int t(Math::qpow(G, (mod1 - 1) / lim, mod1), Math::qpow(G, (mod2 - 1) / lim, mod2), Math::qpow(G, (mod3 - 1) / lim, mod3));
*Wn = Int(1);
for (Int *i = Wn; i != Wn + lim; ++i) *(i + 1) = *i * t;
}
/*
执行 NTT 转换;首先进行反转操作,将数组 A 中的元素按 rev 数组重排
进行蝶形操作(butterfly operation),对数组进行逐层计算,利用预计算的旋转因子 Wn
*/
inline void NTT(Int *A, const int op = 1) {
for (int i = 1; i < lim; ++i) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
for (int mid = 1; mid < lim; mid <<= 1) {
const int t = lim / mid >> 1;
for (int i = 0; i < lim; i += mid << 1) {
for (int j = 0; j < mid; ++j) {
const Int W = op ? Wn[t * j] : Wn[lim - t * j];
const Int X = A[i + j], Y = A[i + j + mid] * W;
A[i + j] = X + Y, A[i + j + mid] = X - Y;
}
}
}
// 如果 op 为 0,表示是反向变换,则需要进行归一化,将结果除以 lim
if (!op) {
const Int ilim(Math::inv(lim, mod1), Math::inv(lim, mod2), Math::inv(lim, mod3));
for (Int *i = A; i != A + lim; ++i) *i = (*i) * ilim;
}
}
#undef N
}

int n , m , x;
Int A[maxn << 1], B[maxn << 1], C[maxn << 1];
signed main() {
scanf("%lld%lld%lld", &n, &m, &mod); ++n, ++m;//因为要考虑常数项,所以+1
for(int i = 0;i < n;i ++) scanf("%lld", &x), A[i] = Int(x % mod);
for(int i = 0;i < m;i ++) scanf("%lld", &x), B[i] = Int(x % mod);
Poly::init(n + m);
Poly::NTT(A), Poly::NTT(B);//系数转化为点值
for(int i = 0;i < Poly::lim;i ++) C[i] = A[i] * B[i];//点值逐项相乘
Poly::NTT(C,0);//反向转化(即op=0),点值转回系数
for(int i = 0;i < n+m-1;i ++) printf("%lld ",C[i].get());
printf("\n");return 0;
}

参考文档:

FFT的学习主要看的这篇https://www.cnblogs.com/RabbitHu/p/FFT.html

写的很*,但是参考文献很好https://blog.csdn.net/Ciellee/article/details/108336914

这篇是后续补充用的(fft)https://www.cnblogs.com/pam-sh/p/15976275.html

同上,不过这里是ntt了https://www.cnblogs.com/windymoon/p/17124857.html

https://www.cnblogs.com/xxeray/p/fast-fourier-transform.html