sotanishy's code snippets for competitive programming
View the Project on GitHub sotanishy/cp-library-cpp
#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod" #include "../../math/modint.hpp" #include "../../convolution/ntt.hpp" #include <bits/stdc++.h> using namespace std; using ll = long long; using mint = Modint<998244353>; int main() { int N, M; cin >> N >> M; vector<mint> a(N), b(M); for (int i = 0; i < N; i++) cin >> a[i]; for (int i = 0; i < M; i++) cin >> b[i]; auto c = convolution(a, b); for (int i = 0; i < N + M - 1; i++) cout << c[i] << (i < N + M - 2 ? " " : "\n"); }
#line 1 "test/yosupo/convolution_mod.test.cpp" #define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod" #line 2 "math/modint.hpp" #include <algorithm> #include <iostream> /** * @brief Mod int */ template <int m> class Modint { using mint = Modint; static_assert(m > 0, "Modulus must be positive"); public: static constexpr int mod() { return m; } constexpr Modint(long long y = 0) : x(y >= 0 ? y % m : (y % m + m) % m) {} constexpr int val() const { return x; } constexpr mint& operator+=(const mint& r) { if ((x += r.x) >= m) x -= m; return *this; } constexpr mint& operator-=(const mint& r) { if ((x += m - r.x) >= m) x -= m; return *this; } constexpr mint& operator*=(const mint& r) { x = static_cast<int>(1LL * x * r.x % m); return *this; } constexpr mint& operator/=(const mint& r) { return *this *= r.inv(); } constexpr bool operator==(const mint& r) const { return x == r.x; } constexpr mint operator+() const { return *this; } constexpr mint operator-() const { return mint(-x); } constexpr friend mint operator+(const mint& l, const mint& r) { return mint(l) += r; } constexpr friend mint operator-(const mint& l, const mint& r) { return mint(l) -= r; } constexpr friend mint operator*(const mint& l, const mint& r) { return mint(l) *= r; } constexpr friend mint operator/(const mint& l, const mint& r) { return mint(l) /= r; } constexpr mint inv() const { int a = x, b = m, u = 1, v = 0; while (b > 0) { int t = a / b; std::swap(a -= t * b, b); std::swap(u -= t * v, v); } return mint(u); } constexpr mint pow(long long n) const { mint ret(1), mul(x); while (n > 0) { if (n & 1) ret *= mul; mul *= mul; n >>= 1; } return ret; } friend std::ostream& operator<<(std::ostream& os, const mint& r) { return os << r.x; } friend std::istream& operator>>(std::istream& is, mint& r) { long long t; is >> t; r = mint(t); return is; } private: int x; }; #line 2 "convolution/ntt.hpp" #include <bit> #include <vector> constexpr int get_primitive_root(int mod) { if (mod == 167772161) return 3; if (mod == 469762049) return 3; if (mod == 754974721) return 11; if (mod == 998244353) return 3; if (mod == 1224736769) return 3; } template <typename mint> void ntt(std::vector<mint>& a) { constexpr int mod = mint::mod(); constexpr mint primitive_root = get_primitive_root(mod); const int n = a.size(); for (int m = n; m > 1; m >>= 1) { mint omega = primitive_root.pow((mod - 1) / m); for (int s = 0; s < n / m; ++s) { mint w = 1; for (int i = 0; i < m / 2; ++i) { mint l = a[s * m + i]; mint r = a[s * m + i + m / 2]; a[s * m + i] = l + r; a[s * m + i + m / 2] = (l - r) * w; w *= omega; } } } } template <typename mint> void intt(std::vector<mint>& a) { constexpr int mod = mint::mod(); constexpr mint primitive_root = get_primitive_root(mod); const int n = a.size(); for (int m = 2; m <= n; m <<= 1) { mint omega = primitive_root.pow((mod - 1) / m).inv(); for (int s = 0; s < n / m; ++s) { mint w = 1; for (int i = 0; i < m / 2; ++i) { mint l = a[s * m + i]; mint r = a[s * m + i + m / 2] * w; a[s * m + i] = l + r; a[s * m + i + m / 2] = l - r; w *= omega; } } } } template <typename mint> std::vector<mint> convolution(std::vector<mint> a, std::vector<mint> b) { const int size = a.size() + b.size() - 1; const int n = std::bit_ceil((unsigned int)size); a.resize(n); b.resize(n); ntt(a); ntt(b); for (int i = 0; i < n; ++i) a[i] *= b[i]; intt(a); a.resize(size); mint n_inv = mint(n).inv(); for (int i = 0; i < size; ++i) a[i] *= n_inv; return a; } #line 5 "test/yosupo/convolution_mod.test.cpp" #include <bits/stdc++.h> using namespace std; using ll = long long; using mint = Modint<998244353>; int main() { int N, M; cin >> N >> M; vector<mint> a(N), b(M); for (int i = 0; i < N; i++) cin >> a[i]; for (int i = 0; i < M; i++) cin >> b[i]; auto c = convolution(a, b); for (int i = 0; i < N + M - 1; i++) cout << c[i] << (i < N + M - 2 ? " " : "\n"); }