Skip to content

Commit

Permalink
Merge pull request #77 from luckylat/fft/implement
Browse files Browse the repository at this point in the history
feat(fft): implement
  • Loading branch information
luckylat authored Jan 29, 2024
2 parents 972a7f6 + 5f6763c commit f196723
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 0 deletions.
1 change: 1 addition & 0 deletions cpp/data-structure/mod-int/arbitrary-mod-int.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#pragma once

long long modadd(long long A, long long B){
return (A+B)%MOD;
Expand Down
6 changes: 6 additions & 0 deletions cpp/data-structure/mod-int/mod-int.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma once

template <int mod>
struct ModInt{
int n;
Expand Down Expand Up @@ -55,6 +57,10 @@ struct ModInt{
return ret;
}

int getMod() const {
return mod;
}

friend ostream &operator<<(ostream &os, const ModInt &p){
return os << p.n;
}
Expand Down
2 changes: 2 additions & 0 deletions cpp/math/binary-power-method.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma once

template <typename T>
T uPow(T z,T n, T mod){
T ans = 1;
Expand Down
92 changes: 92 additions & 0 deletions cpp/math/convolution.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Ref: https://qiita.com/AngrySadEight/items/0dfde26060daaf6a2fda

#include "../template/template.cpp"
#include "./binary-power-method.cpp"
#include "../data-structure/mod-int/mod-int.cpp"
using namespace std;

template<typename MINT>
vector<MINT> ntt(vector<MINT> X, int depth, vector<MINT> root) {
long long n = X.size();
if(n == 1){
return X;
}else{
vector<MINT> val(0);
vector<MINT> even(0);
vector<MINT> odd(0);
for(int i = 0; n > i; i++){
if(i % 2 == 0)even.push_back(X[i]);
else odd.push_back(X[i]);
}

auto ntt_even = ntt(even, depth-1, root);
auto ntt_odd = ntt(odd, depth-1, root);

mint r = root[depth];

MINT now = 1;
for(int i = 0; n > i; i++){
val.push_back(ntt_even[i%(n/2)] + (now * ntt_odd[i%(n/2)]));
now *= r;
}
return val;
}
}

template<typename MINT> // 998244353 mod
vector<MINT> make_root(long long p){
vector<MINT> val(0);
mint r = uPow(3LL, 119LL, p);
for(int i = 0; 23 > i; i++){
val.push_back(r);
r *= r;
}
reverse(val.begin(), val.end());
return val;
}

template<typename MINT>
vector<MINT> make_invroot(vector<MINT> root){
vector<MINT> val(0);
for(int i = 0; 23 > i; i++){
val.push_back(root[i].inverse());
}
return val;
}

template<typename MINT>
vector<MINT> convolution(vector<MINT> A, vector<MINT> B){
long long p = A[0].getMod(); // each mod must be same

vector<MINT> root = make_root<MINT>(p);
vector<MINT> invroot = make_invroot<MINT>(root);

size_t size = (A.size()+B.size()-1);
int n = 1;
int log2_n = 0;
while(n < size){
n *= 2;
log2_n++;
}

while(A.size() < n)A.push_back(0);
while(B.size() < n)B.push_back(0);

// AとBのNTTを求める
auto nttA = ntt(A, log2_n-1, root);
auto nttB = ntt(B, log2_n-1, root);

vector<MINT> nttC(n);
for(int i = 0; n > i; i++){
nttC[i] = nttA[i]*nttB[i];
}

auto nC = ntt(nttC, log2_n-1, invroot);
vector<MINT> C(size);
for(int i = 0; size > i; i++){
C[i] = nC[i]/(mint)n;
}

return C;
}

16 changes: 16 additions & 0 deletions cpp/z_test/yosupo-convolution_mod.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod"

#include "../../cpp/template/template.cpp"

#include "../../cpp/math/convolution.cpp"

int main(){
int na,nb;cin>>na>>nb;
vector<mint> A(na), B(nb);
for(int i = 0; na > i; i++)cin>>A[i];
for(int i = 0; nb > i; i++)cin>>B[i];
auto C = convolution(A, B);
for(int i = 0; C.size() > i; i++){
cout << C[i] << " \n"[i+1 == C.size()];
}
}

0 comments on commit f196723

Please sign in to comment.