FFT可以用来快速计算卷积,有时候出题人会给出像998244353之类的良心模数,那么我们NTT就好了。
但是有些毒瘤偏不,他们不但给了模数,他们还给了不可以被拆成
x
∗
2
k
+
1
x*2^k+1
x∗2k+1形式的模数
这个时候就需要一些黑科技了。
拆系数FFT
嗯,不能边NTT边取模怎么办呢?
那我们就把它直接FFT完了再取模
高兴地写完代码一交,WA到自闭,额,好像溢出了,算一下最大值
n
p
2
=
1
0
23
np^2=10^{23}
np2=1023 显然已经爆long long了。
那么怎么把数弄的稍微小一点呢?这个时候就需要拆系数了
设一个阀值为
m
=
p
m=\sqrt {p}
m=p
可以把第一个多项式里每个系数都拆成
f
i
=
a
i
∗
m
+
b
i
f_i=a_i*m+b_i
fi=ai∗m+bi的形式
第二个多项式拆成
g
i
=
c
i
∗
m
+
d
i
g_i=c_i*m+d_i
gi=ci∗m+di的形式
因为比较懒不想写卷积形式,我们就算一下
f
i
∗
g
i
f_i*g_i
fi∗gi的值吧,两者显然是等价的
f
i
∗
g
i
=
(
a
i
∗
m
+
b
i
)
∗
(
c
i
∗
m
+
d
i
)
=
a
i
∗
c
i
∗
m
2
+
(
a
i
∗
d
i
+
b
i
∗
c
i
)
m
+
b
i
∗
d
i
f_i*g_i=(a_i*m+b_i)*(c_i*m+d_i)=a_i*c_i*m^2+(a_i*d_i+b_i*c_i)m+b_i*d_i
fi∗gi=(ai∗m+bi)∗(ci∗m+di)=ai∗ci∗m2+(ai∗di+bi∗ci)m+bi∗di
那么对ac、ad、bc、bd各跑一遍FFT,显然此时最大值变成了
n
p
=
1
0
14
np=10^{14}
np=1014,然后按照上面公式合并取模肯定不会溢出
但是吧,上面跑了整整8遍FFT,这个常数有点大啊……
神仙毛爷爷在他的集训队论文里给出了神仙的优化方法,听说能够把FFT次数变到4次
学不动了……
三模数NTT
嗯,看到三模数就知道为什么了,还是要解决爆long long的问题,那我们就取3个1e9左右的模数,跑NTT,再用crt合并一下,还原出原数就可以了(学CRT强烈安利z(hou)z(hi)d(ao)的博客)
嗯,还原出原数,又爆long long了
这里可能需要点技巧
我们先合并前两组
a
n
s
≡
A
(
m
o
d
P
)
ans\equiv A(mod P)
ans≡A(modP)
a
n
s
≡
a
3
(
m
o
d
p
3
)
ans\equiv a_3(mod p_3)
ans≡a3(modp3)
可以设
a
n
s
=
t
P
+
A
=
k
p
3
+
a
3
ans=tP+A=kp_3+a_3
ans=tP+A=kp3+a3
t
P
≡
a
3
−
A
(
m
o
d
p
3
)
tP\equiv a_3-A(mod p_3)
tP≡a3−A(modp3)
所以
t
≡
(
a
3
−
A
)
P
−
1
(
m
o
d
p
3
)
t\equiv (a_3-A)P^{-1}(modp_3)
t≡(a3−A)P−1(modp3)
这样假设右边是
x
x
x
t
=
k
p
3
+
x
t=kp_3+x
t=kp3+x
代入
a
n
s
=
t
P
+
A
ans=tP+A
ans=tP+A
a
n
s
=
(
k
p
3
+
x
)
p
1
p
2
+
A
=
k
p
1
p
2
p
3
+
x
P
+
A
ans=(kp_3+x)p_1p_2+A=kp_1p_2p_3+xP+A
ans=(kp3+x)p1p2+A=kp1p2p3+xP+A
因为crt的范围在
[
0
,
p
1
p
2
p
3
)
[0,p_1p_2p_3)
[0,p1p2p3) 所以k=0
即
a
n
s
=
x
P
+
A
ans=xP+A
ans=xP+A 直接模p就可以啦
代入以后可以得到ans,合理使用龟速乘就可以不爆精度
比较一波可以感受到三模数NTT的常数(六次)应该会比拆系数FFT(四次)来得大
然而,我选择三模数NTT……
————upd
算错NTT次数了qwq
被直接卡爆
拆系数FFT不带黑科技的也来一份吧……
代码如下:
#include<bits/stdc++.h>
#define gg 3
#define N 300030
using namespace std;
long long ans[N],f[3][N],g[3][N],mod1[]={998244353,469762049,1004535809};
int r[N],n,m,p,lim;
inline long long mul(long long a,long long b,long long mod)
{
long long res=a*b-(long long)((long double)a*b/mod+0.5)*mod;
return res<0?res+mod:res;
}
long long kasumi(long long a,long long b,long long mod)
{
long long ans=1;
while(b)
{
if(b&1) ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans;
}
void NTT(long long *a,int kd,int mod)
{
for(int i=0;i<lim;i++)
{
if(i<r[i]) swap(a[i],a[r[i]]);
}
for(int mid=1;mid<lim;mid<<=1)
{
long long wn=kasumi(gg,(mod-1)/(mid<<1),mod);
if(kd) wn=kasumi(wn,mod-2,mod);
for(int i=0;i<lim;i+=mid<<1)
{
long long w=1;
for(int j=0;j<mid;j++,w=wn*w%mod)
{
long long x=a[i+j];
long long y=a[i+j+mid]*w%mod;
a[i+j]=(x+y)%mod;
a[i+j+mid]=(x-y+mod)%mod;
}
}
}
if(kd)
{
int inv=kasumi(lim,mod-2,mod);
for(int i=0;i<lim;i++) a[i]=a[i]*inv%mod;
}
}
int main()
{
// freopen("ha.in","r",stdin);
// freopen("ha.out","w",stdout);
lim=1;
scanf("%d%d%d",&n,&m,&p);
for(int i=0;i<=n;i++)
{
scanf("%lld",&f[0][i]);
f[0][i]=f[1][i]=f[2][i]=f[0][i]%p;
}
for(int i=0;i<=m;i++)
{
scanf("%lld",&g[0][i]);
g[0][i]=g[1][i]=g[2][i]=g[0][i]%p;
}
int cnt=0;
while(lim<=(n+m)) lim<<=1,cnt++;
for(int i=0;i<lim;i++)
{
r[i]=(r[i>>1]>>1)|((i&1)<<(cnt-1));
}
for(int i=0;i<=2;i++)
{
NTT(f[i],0,mod1[i]);NTT(g[i],0,mod1[i]);
for(int j=0;j<lim;j++)
{
f[i][j]=f[i][j]*g[i][j]%mod1[i];
}
NTT(f[i],1,mod1[i]);
}
long long inv1=kasumi(mod1[0],mod1[1]-2,mod1[1]);
long long inv2=kasumi(mod1[1],mod1[0]-2,mod1[0]);
long long mul1=mod1[0]*mod1[1];
for(int i=0;i<lim;i++)
{
ans[i]+=mul(f[0][i]*inv2%mul1,mod1[1],mul1);
ans[i]+=mul(f[1][i]*inv1%mul1,mod1[0],mul1);
ans[i]%=mul1;
}
long long inv3=kasumi(mul1%mod1[2],mod1[2]-2,mod1[2]);
for(int i=0;i<lim;i++)
{
ans[i]=((f[2][i]-ans[i]%mod1[2]+mod1[2])%mod1[2]*inv3%mod1[2]*(mul1%p)%p+ans[i]%p)%p;
}
for(int i=0;i<=n+m;i++) printf("%lld ",ans[i]%p);
}
#include<cstdio>
#include<string>
#include<cmath>
#include<algorithm>
#define sz 32768
#define N 600030
using std::swap;
long long ans[N];
int r[N],n,m,p,lim,cnt;
long long ff[N],gg[N];
const long double pi=std::acos(-1);
struct comp
{
long double r,i;
comp(){}
comp(long double a,long double b):r(a),i(b){}
}f[2][N],g[2][N],t1[N],t2[N],t3[N];
inline comp operator +(const comp a,const comp b) {return comp(a.r+b.r,a.i+b.i);}
inline comp operator -(const comp a,const comp b) {return comp(a.r-b.r,a.i-b.i);}
inline comp operator *(const comp a,const comp b) {return comp(a.r*b.r-a.i*b.i,a.r*b.i+b.r*a.i);}
void FFT(comp *a,int kd,int lim)
{
for(int i=0;i<lim;i++) if(i<r[i]) swap(a[i],a[r[i]]);
for(int mid=1;mid<lim;mid<<=1)
{
comp wn=comp(std::cos(pi/mid),kd*std::sin(pi/mid));
for(int i=0;i<lim;i+=(mid<<1))
{
comp w=comp(1.0,0.0);
for(int j=0;j<mid;j++,w=wn*w)
{
comp x=a[i+j];
comp y=a[i+j+mid]*w;
a[i+j]=x+y;
a[i+j+mid]=x-y;
}
}
}
if(kd==-1)
{
for(int i=0;i<lim;i++)
{
a[i].r/=lim;
}
}
}
void mul1(long long *a,long long *b,int cnt)
{
int lim=1<<cnt;
for(int i=0;i<lim;i++)
{
f[0][i].r=a[i]/sz;
f[1][i].r=a[i]%sz;
g[0][i].r=b[i]/sz;
g[1][i].r=b[i]%sz;
ans[i]=0;
}
for(int i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(cnt-1));
FFT(f[0],1,lim);FFT(f[1],1,lim);
FFT(g[0],1,lim);FFT(g[1],1,lim);
for(int i=0;i<lim;i++)
{
t1[i]=f[0][i]*g[0][i];
t2[i]=f[0][i]*g[1][i]+g[0][i]*f[1][i];
t3[i]=f[1][i]*g[1][i];
}
FFT(t1,-1,lim);FFT(t2,-1,lim);FFT(t3,-1,lim);
for(int i=0;i<lim;i++)
{
ans[i]=(((long long)(t1[i].r+0.5))%p*sz%p*sz%p+(((long long)(t2[i].r+0.5))%p*sz%p)+(long long)(t3[i].r+0.5)%p)%p;
}
}
int main()
{
int lim=1,cnt=0;
scanf("%d%d%d",&n,&m,&p);
for(int i=0;i<=n;i++) scanf("%lld",&ff[i]);
for(int i=0;i<=m;i++) scanf("%lld",&gg[i]);
while(lim<(n+m)) lim<<=1,cnt++;
mul1(ff,gg,cnt);
for(int i=0;i<=n+m;i++) printf("%lld ",ans[i]);
}