数年之前就听闻莫反,FFT,NTT等数论变换的名称,但是一直未学习相关知识

最近学习后量子密码学,遇到了类似数论变换,辄学习一下

名词区分

1、DFT(Discrete Fourier Transform):离散傅立叶变换 $\rightarrow$ $O(n^2)$计算多项式乘法
2、FFT(Fast Fourier Teansformation):快速傅立叶变换 $\rightarrow$ $O(nlogn)$计算多项式乘法
3、(F)NTT(Number Theoretic Transform):(快速)数论变换 $\rightarrow$ 优化常数和误差,适用于整数域
4、MTT(any Module NTT):NTT的扩展 $\rightarrow$ 任意模数

离散傅里叶变换DFT

这是一个朴素算法,用于将一个多项式在$O(n^2)$时间由系数表示法转化为点值表示法

原理:将一个用系数表示的多项式转化成它的点值表示的算法

对于一个$n-1$次的$n$项多项式$f(x)$可以表示为$f(x)=\sum_{i=0}^{n-1}a_ix^i$

系数表示法:$f(x)={a_0,a_1,…,a_{n-1}}$

点值表示法:$f(x)={(x_0,f(x_0)),(x_1,f(x_1)),…,(x_{n-1},f(x_{n-1}))}$

计算两个多项式相乘 $h(x)=f(x)*g(x)$

对于系数表示法,需要每一项和每一项的系数相乘,时间复杂度显然是$O(n^2)$

对于点值表示法——

$h(x)={(x_0,f(x_0)\cdot g(x_0)),(x_1,f(x_1)\cdot g(x_1)),…,(x_{n-1},f(x_{n-1})\cdot g(x_{n-1}))}$,时间复杂度是$O(n)$

已知 $n$ 个点值,可以唯一确定一个 $n-1$ 阶多项式

证明

已知

可以写成矩阵形式

中间的那列就是一个范德蒙行列式了,秩为1,所以$p_0,…,p_{n-1}$有且仅有一个解,即多项式的系数确定

证毕。

利用单位圆进行转化

将$n$向上填充为2的整数次幂,然后将单位圆平均取$n$个点

在单位元上,我们定义$x_k=w_n^k=(cos\frac{k}{n}2\pi,sin\frac{k}{n}2\pi)$

对于点$x_k$​,$x_k=x_{k-i}*x_i(i\in[0,k])$​成立

此时,$x_1$即为该单位圆上的单位根(即该循环群中的单位元)

单位根的性质

很有用的

性质一(相消引理):$w_{2n}^{2k}=w_n^k$ 这两个说的本质上是一个点

性质二(折半引理):$w_n^{k+\frac{n}{2}}=-w_n^k$ 关于原点对称(向量等大反向)

显而易见的
  1. $w_n^k=cos(2\pi\cdot\frac{k}{n})+isin(2\pi\cdot\frac{k}{n})$
  2. $w_n^0=w_n^n=1$
  3. $w_n^{n-i}=w_n^i$
  4. $w_n^{n+i}=w_n^i$

实现方式

把多项式$A(x)$的离散傅里叶变换结果作为另一个多项式$B(x)$的系数,去单位根的倒数即$w_n^0,w_n^{-1},…,w_n^{-(n-1)}$作为$x$代入$B(x)$,得到的每个数再除以$n$,得到的是$A(x)$的各项系数

实现了傅里叶变换的逆变换——把点值表示转换成多项式系数表示

证明

设$(y_0,y_1,y_2,…,y_{n−1})$为多项式$A(x)=a_0+a_1x+a_2x^2+…+a_{n−1}x^{n−1}$的离散傅里叶变换。

现在我们再设一个多项式$B(x)=y_0+y_1x+y_2x^2+…+y_{n−1}x^{n−1}$,现在我们把上面的$n$个单位根的倒数,即$ω^0_n,ω^{−1}_n,ω^{−2}_n,…,ω^{−(n−1)}_n$作为$x$代入$B(x)$, 得到一个新的离散傅里叶变换$(z_0,z_1,z_2,…,z_{n−1})$。

这个$\sum_{i=0}^{n-1}(w_n^{j-k})^i$是可求的:当 $j-k=0$ 时,原式=$n$;否则,通过等比数列求和可以得知

故 $z_k=n\cdot a_k$

即 $a_i=\frac{z_i}{n}$

证毕

具体代码不做赘述,因为DFT和朴素算法的时间复杂度相同,在这里仅用于为FFT打基础

快速傅里叶变化FFT

用途:1.高精度乘法$O(n^2)\rightarrow O(nlogn)$ 2.分离正弦波

原理

FFT和DFT的不同之处在于,傅里叶的时代并没有计算机,所以没有优化时间复杂度的需求;因而虽然DFT的计算是基于单位圆的,但是(求值和差值的)时间复杂度仍旧是$O(n^2)$;而FFT则采用了分治的思想,将求值和差值的时间复杂度降为$O(nlogn)$

推导

设 $A(x)=a_0+a_1x+a_2x^2+…+a_{n-1}x^{n-1}$,

按照下标奇偶性划分为两部分

$A(x)=(a_0+a_2x^2+…+a_{n-2}x^{n-2})+(a_1x+a_3x^3+…+a_{n-1}x^{n-1})$

设 $A_1(x)=a_0+a_2x+…+a_{n-2}x^{\frac{n}{2}-1}\ A_2(x)=a_1x+a_3x+…+a_{n-1}x^{\frac{n}{2}-1}$(注意这里的次方数)

则 $A(x)=A_1(x^2)+xA_2(x^2)$

已知$x=w_n^k$,不妨设 $k<\frac{n}{2}$,代入得到

对于$A(w_n^{k+\frac{n}{2}})$,有:

故只需要知道 $A_1(x)$ 和 $A_2(x)$ 分别在 $(w_\frac{n}{2}^0,w_\frac{n}{2}^1,…,w_\frac{n}{2}^{\frac{n}{2}-1})$ 的点值表示,就可以 $O(n)$ 计算 $A(x)$ 在 $(w_n^0,w_n^1,…,w_n^{n-1})$ 的点值表示

依此,就可以递归实现

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#这里的代码是ai给的,感觉不太靠谱的亚子
import numpy as np

def fft(a):
n = len(a)
if n <= 1:
return a
even = fft(a[0::2])
odd = fft(a[1::2])
t = [np.exp(-2j * np.pi * k / n) * odd[k] for k in range(n // 2)]
return [even[k] + t[k] for k in range(n // 2)] + [even[k] - t[k] for k in range(n // 2)]

def ifft(a):
n = len(a)
a_conj = [np.conjugate(x) for x in a]
y = fft(a_conj)
return [np.conjugate(x) / n for x in y]

def polynomial_multiply(p, q):
n = len(p) + len(q) - 1
m = 1 << (n - 1).bit_length() # Next power of two
p += [0] * (m - len(p))
q += [0] * (m - len(q))

fft_p = fft(p)
fft_q = fft(q)
fft_result = [fft_p[i] * fft_q[i] for i in range(m)]
result = ifft(fft_result)

return [round(x.real) for x in result]

# 示例
p = [1, 2, 3] # 1 + 2x + 3x^2
q = [4, 5] # 4 + 5x
result = polynomial_multiply(p, q)
print(result) # 输出: [4, 13, 22, 15]

后续还有优化版FFT,插个眼,以后更新

(快速)数论变换 (F)NTT

好的,假设我学会了fft,可以开始学ntt了(大雾

不同点 FFT NTT
定义域 主要在复数域中进行,利用复数的旋转性质 使用复数的单位根,通常是复数的n次方根
根的选择 在有限域(通常是素数模数)中进行,适合用于整数运算 使用模p的原根,这些根在有限域中是整数
应用领域 信号处理、图像处理、数字滤波等 密码学、计算机代数和一些整数计算问题,如大数乘法等
计算方式 浮点数运算 完全整数

对于质数$p=qn+1,(n=2^m)$,原根$g$满足$g^{qn}\equiv 1(\mod p)$

将$g_n\equiv g^q(\mod p)$看做$w_n$的等价,其满足相应的性质,如$g_n^n\equiv 1(\mod p),g_n^\frac{n}{2}\equiv -1(\mod p)$等

快速数论变化(FNTT),是数论变换(NTT)增加分治操作之后的快速算法,与快速傅里叶变换使用的分治办法完全一致

FFT中用到复数,需要使用$double$类型来计算,导致精度降低,所以需要使用原根来替代单位根

就像998244353,469762049,1004535809,它们的原根都是3

任意模数NTT MTT

如果模数不是以上几种,我们需要自己取模数

取一些模数$p_1,p_2,…,p_k$使得答案多项式的系数在取模之前不会超过 $\prod_{i=1}^{k}p_i$

一般而言取三个质数即可(998244353,469762049,1004535809)

先计算答案对每个 $p_i$ 取模的结果,利用中国剩余定理就可以求得答案对 $\prod_{i=1}^{k}p_i$ 取模的结果,这个结果就是答案,最后将这个答案对题目中的模数取一次模即可

例题:洛谷 P4245

一道MTT模板题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#include <algorithm>
#include <cstdio>
#define int long long
int mod;

// 这种写法我也第一次见,边写边学吧
// 直接不开在全局std里了,单独一个Math
namespace Math{
inline int qpow( int base , int p , const int mod ){
static int res;
for( res = 1 ; p ; p >>= 1 , base = base * base % mod )
if( p & 1 )
res = res * base % mod;
return res;
}
inline int inv( int x , const int mod ){
return qpow( x , mod-2 , mod );//费马小定理求逆元
}
}

const int mod1 = 998244353, mod2 = 1004535809, mod3 = 469762049, G = 3;
const int mod_1_2 = mod1 * mod2;
const int inv_1 = Math::inv(mod1, mod2), inv_2 = Math::inv(mod_1_2 % mod3, mod3);

struct Int{
int A , B , C;
// 空的默认构造函数
Int(){}
// 这个构造函数允许使用一个整数来初始化 A, B, C,它们都将被初始化为同一个值 __num
Int( int __num ): A(__num) , B(__num) , C(__num) {}
// 允许通过三个整数分别初始化 A, B, C
Int( int __A , int __B , int __C ): A(__A) , B(__B) , C(__C) {}
// 好神奇的操作,研究半天也没明白原理
// 只知道它可以做减法,出现负数就加上一个模数,只需要传入一个指针
static Int reduce( const Int & x ){
return Int( x.A + (x.A >> 31 & mod1) , x.B + (x.B >> 31 & mod2) , x.C + (x.C >> 31 & mod3) );
}
// 加减乘除的重载运算符,很精妙的写法
// 不太懂lhs和rhs是什么,只知道大概是两个input量,不像是数据结构里树的左孩子和右孩子
friend Int operator + ( const Int &lhs , const Int & rhs ){
return reduce(Int(lhs.A + rhs.A - mod1, lhs.B + rhs.B - mod2, lhs.C + rhs.C - mod3));
}
friend Int operator - ( const Int &lhs , const Int & rhs ){
return reduce(Int(lhs.A - rhs.A, lhs.B - rhs.B, lhs.C - rhs.C));
}
friend Int operator * ( const Int &lhs , const Int & rhs ){
return Int( lhs.A * rhs.A % mod1 , lhs.B * rhs.B % mod2 , lhs.C * rhs.C % mod3 );
}
int get(){
int x = (B - A + mod2) % mod2 * inv_1 % mod2 * mod1 + A;
return ((C - x % mod3 + mod3) % mod3 * inv_2 % mod3 * (mod_1_2 % mod) % mod + x) % mod;
}
};

#define maxn 200010 //maxn表示处理的最大元素数量
namespace Poly{
#define N (maxn << 1) //N 定义为 maxn 的两倍,表示用于 NTT 的数组大小。
/*
lim:表示当前处理的长度,是最小的2的幂大于等于输入大小的值
s:记录 lim 的二进制位数
rev:用于存储每个索引的反转(bit-reversal)值
Wn:预计算的旋转因子(根单位元),用于 NTT 计算
*/
int lim , s , rev[N];
Int Wn[N|1];
/*
初始化 NTT 相关参数
计算 lim 为不小于 n 的最小的2的幂
生成反转索引 rev,用于在 NTT 中重排数据
计算旋转因子 t,用于每个模数(mod1, mod2, mod3),通过幂函数 Math::pw 计算
初始化 Wn 数组,预计算旋转因子
*/
void init( int n ){
s = -1 , lim = 1;
while( lim < n ) lim <<= 1 , s ++;//填充
for(int i = 1;i < lim;i ++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
const Int t(Math::qpow(G, (mod1 - 1) / lim, mod1), Math::qpow(G, (mod2 - 1) / lim, mod2), Math::qpow(G, (mod3 - 1) / lim, mod3));
*Wn = Int(1);
for (Int *i = Wn; i != Wn + lim; ++i) *(i + 1) = *i * t;
}
/*
执行 NTT 转换;首先进行反转操作,将数组 A 中的元素按 rev 数组重排
进行蝶形操作(butterfly operation),对数组进行逐层计算,利用预计算的旋转因子 Wn
*/
inline void NTT(Int *A, const int op = 1) {
for (int i = 1; i < lim; ++i) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
for (int mid = 1; mid < lim; mid <<= 1) {
const int t = lim / mid >> 1;
for (int i = 0; i < lim; i += mid << 1) {
for (int j = 0; j < mid; ++j) {
const Int W = op ? Wn[t * j] : Wn[lim - t * j];
const Int X = A[i + j], Y = A[i + j + mid] * W;
A[i + j] = X + Y, A[i + j + mid] = X - Y;
}
}
}
// 如果 op 为 0,表示是反向变换,则需要进行归一化,将结果除以 lim
if (!op) {
const Int ilim(Math::inv(lim, mod1), Math::inv(lim, mod2), Math::inv(lim, mod3));
for (Int *i = A; i != A + lim; ++i) *i = (*i) * ilim;
}
}
#undef N
}

int n , m , x;
Int A[maxn << 1], B[maxn << 1], C[maxn << 1];
signed main() {
scanf("%lld%lld%lld", &n, &m, &mod); ++n, ++m;//因为要考虑常数项,所以+1
for(int i = 0;i < n;i ++) scanf("%lld", &x), A[i] = Int(x % mod);
for(int i = 0;i < m;i ++) scanf("%lld", &x), B[i] = Int(x % mod);
Poly::init(n + m);
Poly::NTT(A), Poly::NTT(B);//系数转化为点值
for(int i = 0;i < Poly::lim;i ++) C[i] = A[i] * B[i];//点值逐项相乘
Poly::NTT(C,0);//反向转化(即op=0),点值转回系数
for(int i = 0;i < n+m-1;i ++) printf("%lld ",C[i].get());
printf("\n");return 0;
}

参考文档:

FFT的学习主要看的这篇https://www.cnblogs.com/RabbitHu/p/FFT.html

写的很*,但是参考文献很好https://blog.csdn.net/Ciellee/article/details/108336914

这篇是后续补充用的(fft)https://www.cnblogs.com/pam-sh/p/15976275.html

同上,不过这里是ntt了https://www.cnblogs.com/windymoon/p/17124857.html

https://www.cnblogs.com/xxeray/p/fast-fourier-transform.html