@@ -17,18 +17,8 @@ namespace cp_algo::math::fft {
17
17
using vpoint = complex<vftype>;
18
18
static constexpr size_t flen = vftype::size();
19
19
20
-
21
- template <typename ft>
22
- constexpr ft to_ft (auto x) {
23
- return ft{} + x;
24
- }
25
- template <typename pt>
26
- constexpr pt to_pt (point r) {
27
- using ft = std::conditional_t <std::is_same_v<point, pt>, ftype, vftype>;
28
- return {to_ft<ft>(r.real ()), to_ft<ft>(r.imag ())};
29
- }
30
20
struct cvector {
31
- static constexpr size_t pre_roots = 1 << 17 ;
21
+ static constexpr size_t pre_roots = 1 << 19 ;
32
22
std::vector<vftype> x, y;
33
23
cvector (size_t n) {
34
24
n = std::max (flen, std::bit_ceil (n));
@@ -67,32 +57,28 @@ namespace cp_algo::math::fft {
67
57
}
68
58
}
69
59
static const cvector roots;
70
- template <class pt = point>
71
- static pt root (size_t n, size_t k) {
72
- if (n < pre_roots) {
60
+ template <class pt = point, bool precalc = false >
61
+ static pt root (size_t n, size_t k, auto &&arg ) {
62
+ if (n < pre_roots && !precalc ) {
73
63
return roots.get <pt>(n + k);
74
64
} else {
75
- auto arg = std::numbers::pi / (ftype)n;
76
- if constexpr (std::is_same_v<pt, point>) {
77
- return {cos ((ftype)k * arg), sin ((ftype)k * arg)};
78
- } else {
79
- return pt{vftype{[&](auto i) {return cos (ftype (k + i) * arg);}},
80
- vftype{[&](auto i) {return sin (ftype (k + i) * arg);}}};
81
- }
65
+ return polar<typename pt::value_type>(1 ., arg);
82
66
}
83
67
}
84
- template <class pt = point>
68
+ template <class pt = point, bool precalc = false >
85
69
static void exec_on_roots (size_t n, size_t m, auto &&callback) {
70
+ ftype arg = std::numbers::pi / (ftype)n;
86
71
size_t step = sizeof (pt) / sizeof (point);
87
- pt cur;
88
- pt arg = to_pt<pt>(root<point>(n, step));
89
- for (size_t i = 0 ; i < m; i += step) {
90
- if (i % 32 == 0 || n < pre_roots) {
91
- cur = root<pt>(n, i);
72
+ using ft = pt::value_type;
73
+ auto k = [&]() {
74
+ if constexpr (std::is_same_v<pt, point>) {
75
+ return ft{};
92
76
} else {
93
- cur *= arg ;
77
+ return ft{[]( auto i) { return i;}} ;
94
78
}
95
- callback (i, cur);
79
+ }();
80
+ for (size_t i = 0 ; i < m; i += step, k += (ftype)step) {
81
+ callback (i, root<pt, precalc>(n, i, arg * k));
96
82
}
97
83
}
98
84
@@ -106,15 +92,15 @@ namespace cp_algo::math::fft {
106
92
set (k + i, get<pt>(k) - t);
107
93
set (k, get<pt>(k) + t);
108
94
};
109
- if (2 * i <= flen) {
95
+ if (i < flen) {
110
96
exec_on_roots (i, i, butterfly);
111
97
} else {
112
98
exec_on_roots<vpoint>(i, i, butterfly);
113
99
}
114
100
}
115
101
}
116
102
for (size_t k = 0 ; k < n; k += flen) {
117
- set (k, get<vpoint>(k) /= to_pt<vpoint>(( ftype)n) );
103
+ set (k, get<vpoint>(k) /= ( ftype)n);
118
104
}
119
105
}
120
106
void fft () {
@@ -128,7 +114,7 @@ namespace cp_algo::math::fft {
128
114
set (k, A);
129
115
set (k + i, B * rt);
130
116
};
131
- if (2 * i <= flen) {
117
+ if (i < flen) {
132
118
exec_on_roots (i, i, butterfly);
133
119
} else {
134
120
exec_on_roots<vpoint>(i, i, butterfly);
@@ -140,14 +126,13 @@ namespace cp_algo::math::fft {
140
126
const cvector cvector::roots = []() {
141
127
cvector res (pre_roots);
142
128
for (size_t n = 1 ; n < res.size (); n *= 2 ) {
143
- auto base = polar<ftype>(1 ., std::numbers::pi / (ftype)n);
144
- point cur = 1 ;
145
- for (size_t k = 0 ; k < n; k++) {
146
- if ((k & 15 ) == 0 ) {
147
- cur = polar<ftype>(1 ., std::numbers::pi * (ftype)k / (ftype)n);
148
- }
149
- res.set (n + k, cur);
150
- cur *= base;
129
+ auto propagate = [&](size_t k, auto rt) {
130
+ res.set (n + k, rt);
131
+ };
132
+ if (n < flen) {
133
+ res.exec_on_roots <point, true >(n, n, propagate);
134
+ } else {
135
+ res.exec_on_roots <vpoint, true >(n, n, propagate);
151
136
}
152
137
}
153
138
return res;
0 commit comments