Skip to content

Commit 63235e1

Browse files
committed
Refactor + do not accummulate error
1 parent 0f73e43 commit 63235e1

File tree

2 files changed

+27
-41
lines changed

2 files changed

+27
-41
lines changed

cp-algo/math/fft.hpp

+25-40
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,8 @@ namespace cp_algo::math::fft {
1717
using vpoint = complex<vftype>;
1818
static constexpr size_t flen = vftype::size();
1919

20-
21-
template<typename ft>
22-
constexpr ft to_ft(auto x) {
23-
return ft{} + x;
24-
}
25-
template<typename pt>
26-
constexpr pt to_pt(point r) {
27-
using ft = std::conditional_t<std::is_same_v<point, pt>, ftype, vftype>;
28-
return {to_ft<ft>(r.real()), to_ft<ft>(r.imag())};
29-
}
3020
struct cvector {
31-
static constexpr size_t pre_roots = 1 << 17;
21+
static constexpr size_t pre_roots = 1 << 19;
3222
std::vector<vftype> x, y;
3323
cvector(size_t n) {
3424
n = std::max(flen, std::bit_ceil(n));
@@ -67,32 +57,28 @@ namespace cp_algo::math::fft {
6757
}
6858
}
6959
static const cvector roots;
70-
template<class pt = point>
71-
static pt root(size_t n, size_t k) {
72-
if(n < pre_roots) {
60+
template<class pt = point, bool precalc = false>
61+
static pt root(size_t n, size_t k, auto &&arg) {
62+
if(n < pre_roots && !precalc) {
7363
return roots.get<pt>(n + k);
7464
} else {
75-
auto arg = std::numbers::pi / (ftype)n;
76-
if constexpr(std::is_same_v<pt, point>) {
77-
return {cos((ftype)k * arg), sin((ftype)k * arg)};
78-
} else {
79-
return pt{vftype{[&](auto i) {return cos(ftype(k + i) * arg);}},
80-
vftype{[&](auto i) {return sin(ftype(k + i) * arg);}}};
81-
}
65+
return polar<typename pt::value_type>(1., arg);
8266
}
8367
}
84-
template<class pt = point>
68+
template<class pt = point, bool precalc = false>
8569
static void exec_on_roots(size_t n, size_t m, auto &&callback) {
70+
ftype arg = std::numbers::pi / (ftype)n;
8671
size_t step = sizeof(pt) / sizeof(point);
87-
pt cur;
88-
pt arg = to_pt<pt>(root<point>(n, step));
89-
for(size_t i = 0; i < m; i += step) {
90-
if(i % 32 == 0 || n < pre_roots) {
91-
cur = root<pt>(n, i);
72+
using ft = pt::value_type;
73+
auto k = [&]() {
74+
if constexpr(std::is_same_v<pt, point>) {
75+
return ft{};
9276
} else {
93-
cur *= arg;
77+
return ft{[](auto i) {return i;}};
9478
}
95-
callback(i, cur);
79+
}();
80+
for(size_t i = 0; i < m; i += step, k += (ftype)step) {
81+
callback(i, root<pt, precalc>(n, i, arg * k));
9682
}
9783
}
9884

@@ -106,15 +92,15 @@ namespace cp_algo::math::fft {
10692
set(k + i, get<pt>(k) - t);
10793
set(k, get<pt>(k) + t);
10894
};
109-
if(2 * i <= flen) {
95+
if(i < flen) {
11096
exec_on_roots(i, i, butterfly);
11197
} else {
11298
exec_on_roots<vpoint>(i, i, butterfly);
11399
}
114100
}
115101
}
116102
for(size_t k = 0; k < n; k += flen) {
117-
set(k, get<vpoint>(k) /= to_pt<vpoint>((ftype)n));
103+
set(k, get<vpoint>(k) /= (ftype)n);
118104
}
119105
}
120106
void fft() {
@@ -128,7 +114,7 @@ namespace cp_algo::math::fft {
128114
set(k, A);
129115
set(k + i, B * rt);
130116
};
131-
if(2 * i <= flen) {
117+
if(i < flen) {
132118
exec_on_roots(i, i, butterfly);
133119
} else {
134120
exec_on_roots<vpoint>(i, i, butterfly);
@@ -140,14 +126,13 @@ namespace cp_algo::math::fft {
140126
const cvector cvector::roots = []() {
141127
cvector res(pre_roots);
142128
for(size_t n = 1; n < res.size(); n *= 2) {
143-
auto base = polar<ftype>(1., std::numbers::pi / (ftype)n);
144-
point cur = 1;
145-
for(size_t k = 0; k < n; k++) {
146-
if((k & 15) == 0) {
147-
cur = polar<ftype>(1., std::numbers::pi * (ftype)k / (ftype)n);
148-
}
149-
res.set(n + k, cur);
150-
cur *= base;
129+
auto propagate = [&](size_t k, auto rt) {
130+
res.set(n + k, rt);
131+
};
132+
if(n < flen) {
133+
res.exec_on_roots<point, true>(n, n, propagate);
134+
} else {
135+
res.exec_on_roots<vpoint, true>(n, n, propagate);
151136
}
152137
}
153138
return res;

cp-algo/util/complex.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
namespace cp_algo {
55
template<typename T>
66
struct complex {
7+
using value_type = T;
78
T x, y;
89
constexpr complex() {}
910
constexpr complex(T x): x(x), y(0) {}
@@ -26,7 +27,7 @@ namespace cp_algo {
2627
T abs() const {return std::sqrt(norm());}
2728
T real() const {return x;}
2829
T imag() const {return y;}
29-
static complex polar(T r, T theta) {return {r * std::cos(theta), r * std::sin(theta)};}
30+
static complex polar(T r, T theta) {return {r * cos(theta), r * sin(theta)};}
3031
auto operator <=> (complex const& t) const = default;
3132
};
3233
template<typename T>

0 commit comments

Comments
 (0)