sotanishy's code snippets for competitive programming
View the Project on GitHub sotanishy/cp-library-cpp
#include "convolution/arbitrary_mod_convolution.hpp"
任意の mod で畳み込みを計算する.
$p_1 = 167772161, p_2 = 469762049, p_3 = 754974721$ の3つの素数を用いた数論変換による畳み込みを計算した後,Garner のアルゴリズムで答えを復元する.
mod を取る前の値が $p_1 \times p_2 \times p_3$ 未満なら正しく復元できるので,mod が 32-bit 整数の場合 $n \leq 2^{22}$ 程度までなら計算できる.
vector<int> convolution(vector<int> a, vector<int>, int mod)
#pragma once #include <vector> #include "../math/number-theory/garner.hpp" #include "../math/modint.hpp" #include "ntt.hpp" std::vector<int> convolution(const std::vector<int>& a, const std::vector<int>& b, int mod) { using mint1 = Modint<167772161>; using mint2 = Modint<469762049>; using mint3 = Modint<754974721>; std::vector<mint1> a1(a.begin(), a.end()), b1(b.begin(), b.end()); std::vector<mint2> a2(a.begin(), a.end()), b2(b.begin(), b.end()); std::vector<mint3> a3(a.begin(), a.end()), b3(b.begin(), b.end()); auto c1 = convolution(a1, b1); auto c2 = convolution(a2, b2); auto c3 = convolution(a3, b3); std::vector<int> c(c1.size()); std::vector<long long> d(3); const std::vector<long long> mods = {167772161, 469762049, 754974721}; for (int i = 0; i < (int)c1.size(); ++i) { d[0] = c1[i].val(); d[1] = c2[i].val(); d[2] = c3[i].val(); c[i] = garner(d, mods, mod); } return c; }
#line 2 "convolution/arbitrary_mod_convolution.hpp" #include <vector> #line 3 "math/number-theory/garner.hpp" #line 2 "math/number-theory/extgcd.hpp" #include <algorithm> #include <utility> std::pair<long long, long long> extgcd(long long a, long long b) { long long s = a, sx = 1, sy = 0, t = b, tx = 0, ty = 1; while (t) { long long q = s / t; std::swap(s -= t * q, t); std::swap(sx -= tx * q, tx); std::swap(sy -= ty * q, ty); } return {sx, sy}; } long long mod_inv(long long a, long long mod) { long long inv = extgcd(a, mod).first; return (inv % mod + mod) % mod; } #line 5 "math/number-theory/garner.hpp" long long garner(const std::vector<long long>& b, std::vector<long long> m, long long mod) { m.push_back(mod); const int n = m.size(); std::vector<long long> coeffs(n, 1); std::vector<long long> consts(n, 0); for (int k = 0; k < n - 1; ++k) { long long t = (b[k] - consts[k]) * mod_inv(coeffs[k], m[k]) % m[k]; if (t < 0) t += m[k]; for (int i = k + 1; i < n; ++i) { consts[i] = (consts[i] + t * coeffs[i]) % m[i]; coeffs[i] = coeffs[i] * m[k] % m[i]; } } return consts.back(); } #line 3 "math/modint.hpp" #include <iostream> /** * @brief Mod int */ template <int m> class Modint { using mint = Modint; static_assert(m > 0, "Modulus must be positive"); public: static constexpr int mod() { return m; } constexpr Modint(long long y = 0) : x(y >= 0 ? y % m : (y % m + m) % m) {} constexpr int val() const { return x; } constexpr mint& operator+=(const mint& r) { if ((x += r.x) >= m) x -= m; return *this; } constexpr mint& operator-=(const mint& r) { if ((x += m - r.x) >= m) x -= m; return *this; } constexpr mint& operator*=(const mint& r) { x = static_cast<int>(1LL * x * r.x % m); return *this; } constexpr mint& operator/=(const mint& r) { return *this *= r.inv(); } constexpr bool operator==(const mint& r) const { return x == r.x; } constexpr mint operator+() const { return *this; } constexpr mint operator-() const { return mint(-x); } constexpr friend mint operator+(const mint& l, const mint& r) { return mint(l) += r; } constexpr friend mint operator-(const mint& l, const mint& r) { return mint(l) -= r; } constexpr friend mint operator*(const mint& l, const mint& r) { return mint(l) *= r; } constexpr friend mint operator/(const mint& l, const mint& r) { return mint(l) /= r; } constexpr mint inv() const { int a = x, b = m, u = 1, v = 0; while (b > 0) { int t = a / b; std::swap(a -= t * b, b); std::swap(u -= t * v, v); } return mint(u); } constexpr mint pow(long long n) const { mint ret(1), mul(x); while (n > 0) { if (n & 1) ret *= mul; mul *= mul; n >>= 1; } return ret; } friend std::ostream& operator<<(std::ostream& os, const mint& r) { return os << r.x; } friend std::istream& operator>>(std::istream& is, mint& r) { long long t; is >> t; r = mint(t); return is; } private: int x; }; #line 2 "convolution/ntt.hpp" #include <bit> #line 4 "convolution/ntt.hpp" constexpr int get_primitive_root(int mod) { if (mod == 167772161) return 3; if (mod == 469762049) return 3; if (mod == 754974721) return 11; if (mod == 998244353) return 3; if (mod == 1224736769) return 3; } template <typename mint> void ntt(std::vector<mint>& a) { constexpr int mod = mint::mod(); constexpr mint primitive_root = get_primitive_root(mod); const int n = a.size(); for (int m = n; m > 1; m >>= 1) { mint omega = primitive_root.pow((mod - 1) / m); for (int s = 0; s < n / m; ++s) { mint w = 1; for (int i = 0; i < m / 2; ++i) { mint l = a[s * m + i]; mint r = a[s * m + i + m / 2]; a[s * m + i] = l + r; a[s * m + i + m / 2] = (l - r) * w; w *= omega; } } } } template <typename mint> void intt(std::vector<mint>& a) { constexpr int mod = mint::mod(); constexpr mint primitive_root = get_primitive_root(mod); const int n = a.size(); for (int m = 2; m <= n; m <<= 1) { mint omega = primitive_root.pow((mod - 1) / m).inv(); for (int s = 0; s < n / m; ++s) { mint w = 1; for (int i = 0; i < m / 2; ++i) { mint l = a[s * m + i]; mint r = a[s * m + i + m / 2] * w; a[s * m + i] = l + r; a[s * m + i + m / 2] = l - r; w *= omega; } } } } template <typename mint> std::vector<mint> convolution(std::vector<mint> a, std::vector<mint> b) { const int size = a.size() + b.size() - 1; const int n = std::bit_ceil((unsigned int)size); a.resize(n); b.resize(n); ntt(a); ntt(b); for (int i = 0; i < n; ++i) a[i] *= b[i]; intt(a); a.resize(size); mint n_inv = mint(n).inv(); for (int i = 0; i < size; ++i) a[i] *= n_inv; return a; } #line 7 "convolution/arbitrary_mod_convolution.hpp" std::vector<int> convolution(const std::vector<int>& a, const std::vector<int>& b, int mod) { using mint1 = Modint<167772161>; using mint2 = Modint<469762049>; using mint3 = Modint<754974721>; std::vector<mint1> a1(a.begin(), a.end()), b1(b.begin(), b.end()); std::vector<mint2> a2(a.begin(), a.end()), b2(b.begin(), b.end()); std::vector<mint3> a3(a.begin(), a.end()), b3(b.begin(), b.end()); auto c1 = convolution(a1, b1); auto c2 = convolution(a2, b2); auto c3 = convolution(a3, b3); std::vector<int> c(c1.size()); std::vector<long long> d(3); const std::vector<long long> mods = {167772161, 469762049, 754974721}; for (int i = 0; i < (int)c1.size(); ++i) { d[0] = c1[i].val(); d[1] = c2[i].val(); d[2] = c3[i].val(); c[i] = garner(d, mods, mod); } return c; }