sotanishy's competitive programming library

sotanishy's code snippets for competitive programming

View the Project on GitHub sotanishy/cp-library-cpp

:heavy_check_mark: Arbitrary Mod Convolution
(convolution/arbitrary_mod_convolution.hpp)

Description

任意の mod で畳み込みを計算する.

$p_1 = 167772161, p_2 = 469762049, p_3 = 754974721$ の3つの素数を用いた数論変換による畳み込みを計算した後,Garner のアルゴリズムで答えを復元する.

mod を取る前の値が $p_1 \times p_2 \times p_3$ 未満なら正しく復元できるので,mod が 32-bit 整数の場合 $n \leq 2^{22}$ 程度までなら計算できる.

Operations

Depends on

Verified with

Code

#pragma once
#include <vector>

#include "../math/number-theory/garner.hpp"
#include "../math/modint.hpp"
#include "ntt.hpp"

std::vector<int> convolution(const std::vector<int>& a,
                             const std::vector<int>& b, int mod) {
    using mint1 = Modint<167772161>;
    using mint2 = Modint<469762049>;
    using mint3 = Modint<754974721>;

    std::vector<mint1> a1(a.begin(), a.end()), b1(b.begin(), b.end());
    std::vector<mint2> a2(a.begin(), a.end()), b2(b.begin(), b.end());
    std::vector<mint3> a3(a.begin(), a.end()), b3(b.begin(), b.end());

    auto c1 = convolution(a1, b1);
    auto c2 = convolution(a2, b2);
    auto c3 = convolution(a3, b3);

    std::vector<int> c(c1.size());
    std::vector<long long> d(3);
    const std::vector<long long> mods = {167772161, 469762049, 754974721};
    for (int i = 0; i < (int)c1.size(); ++i) {
        d[0] = c1[i].val();
        d[1] = c2[i].val();
        d[2] = c3[i].val();
        c[i] = garner(d, mods, mod);
    }
    return c;
}
#line 2 "convolution/arbitrary_mod_convolution.hpp"
#include <vector>

#line 3 "math/number-theory/garner.hpp"

#line 2 "math/number-theory/extgcd.hpp"
#include <algorithm>
#include <utility>

std::pair<long long, long long> extgcd(long long a, long long b) {
    long long s = a, sx = 1, sy = 0, t = b, tx = 0, ty = 1;
    while (t) {
        long long q = s / t;
        std::swap(s -= t * q, t);
        std::swap(sx -= tx * q, tx);
        std::swap(sy -= ty * q, ty);
    }
    return {sx, sy};
}

long long mod_inv(long long a, long long mod) {
    long long inv = extgcd(a, mod).first;
    return (inv % mod + mod) % mod;
}
#line 5 "math/number-theory/garner.hpp"

long long garner(const std::vector<long long>& b, std::vector<long long> m,
                 long long mod) {
    m.push_back(mod);
    const int n = m.size();
    std::vector<long long> coeffs(n, 1);
    std::vector<long long> consts(n, 0);
    for (int k = 0; k < n - 1; ++k) {
        long long t = (b[k] - consts[k]) * mod_inv(coeffs[k], m[k]) % m[k];
        if (t < 0) t += m[k];
        for (int i = k + 1; i < n; ++i) {
            consts[i] = (consts[i] + t * coeffs[i]) % m[i];
            coeffs[i] = coeffs[i] * m[k] % m[i];
        }
    }
    return consts.back();
}
#line 3 "math/modint.hpp"
#include <iostream>

/**
 * @brief Mod int
 */
template <int m>
class Modint {
    using mint = Modint;
    static_assert(m > 0, "Modulus must be positive");

   public:
    static constexpr int mod() { return m; }

    constexpr Modint(long long y = 0) : x(y >= 0 ? y % m : (y % m + m) % m) {}

    constexpr int val() const { return x; }

    constexpr mint& operator+=(const mint& r) {
        if ((x += r.x) >= m) x -= m;
        return *this;
    }
    constexpr mint& operator-=(const mint& r) {
        if ((x += m - r.x) >= m) x -= m;
        return *this;
    }
    constexpr mint& operator*=(const mint& r) {
        x = static_cast<int>(1LL * x * r.x % m);
        return *this;
    }
    constexpr mint& operator/=(const mint& r) { return *this *= r.inv(); }

    constexpr bool operator==(const mint& r) const { return x == r.x; }

    constexpr mint operator+() const { return *this; }
    constexpr mint operator-() const { return mint(-x); }

    constexpr friend mint operator+(const mint& l, const mint& r) {
        return mint(l) += r;
    }
    constexpr friend mint operator-(const mint& l, const mint& r) {
        return mint(l) -= r;
    }
    constexpr friend mint operator*(const mint& l, const mint& r) {
        return mint(l) *= r;
    }
    constexpr friend mint operator/(const mint& l, const mint& r) {
        return mint(l) /= r;
    }

    constexpr mint inv() const {
        int a = x, b = m, u = 1, v = 0;
        while (b > 0) {
            int t = a / b;
            std::swap(a -= t * b, b);
            std::swap(u -= t * v, v);
        }
        return mint(u);
    }

    constexpr mint pow(long long n) const {
        mint ret(1), mul(x);
        while (n > 0) {
            if (n & 1) ret *= mul;
            mul *= mul;
            n >>= 1;
        }
        return ret;
    }

    friend std::ostream& operator<<(std::ostream& os, const mint& r) {
        return os << r.x;
    }

    friend std::istream& operator>>(std::istream& is, mint& r) {
        long long t;
        is >> t;
        r = mint(t);
        return is;
    }

   private:
    int x;
};
#line 2 "convolution/ntt.hpp"
#include <bit>
#line 4 "convolution/ntt.hpp"

constexpr int get_primitive_root(int mod) {
    if (mod == 167772161) return 3;
    if (mod == 469762049) return 3;
    if (mod == 754974721) return 11;
    if (mod == 998244353) return 3;
    if (mod == 1224736769) return 3;
}

template <typename mint>
void ntt(std::vector<mint>& a) {
    constexpr int mod = mint::mod();
    constexpr mint primitive_root = get_primitive_root(mod);

    const int n = a.size();
    for (int m = n; m > 1; m >>= 1) {
        mint omega = primitive_root.pow((mod - 1) / m);
        for (int s = 0; s < n / m; ++s) {
            mint w = 1;
            for (int i = 0; i < m / 2; ++i) {
                mint l = a[s * m + i];
                mint r = a[s * m + i + m / 2];
                a[s * m + i] = l + r;
                a[s * m + i + m / 2] = (l - r) * w;
                w *= omega;
            }
        }
    }
}

template <typename mint>
void intt(std::vector<mint>& a) {
    constexpr int mod = mint::mod();
    constexpr mint primitive_root = get_primitive_root(mod);

    const int n = a.size();
    for (int m = 2; m <= n; m <<= 1) {
        mint omega = primitive_root.pow((mod - 1) / m).inv();
        for (int s = 0; s < n / m; ++s) {
            mint w = 1;
            for (int i = 0; i < m / 2; ++i) {
                mint l = a[s * m + i];
                mint r = a[s * m + i + m / 2] * w;
                a[s * m + i] = l + r;
                a[s * m + i + m / 2] = l - r;
                w *= omega;
            }
        }
    }
}

template <typename mint>
std::vector<mint> convolution(std::vector<mint> a, std::vector<mint> b) {
    const int size = a.size() + b.size() - 1;
    const int n = std::bit_ceil((unsigned int)size);
    a.resize(n);
    b.resize(n);
    ntt(a);
    ntt(b);
    for (int i = 0; i < n; ++i) a[i] *= b[i];
    intt(a);
    a.resize(size);
    mint n_inv = mint(n).inv();
    for (int i = 0; i < size; ++i) a[i] *= n_inv;
    return a;
}
#line 7 "convolution/arbitrary_mod_convolution.hpp"

std::vector<int> convolution(const std::vector<int>& a,
                             const std::vector<int>& b, int mod) {
    using mint1 = Modint<167772161>;
    using mint2 = Modint<469762049>;
    using mint3 = Modint<754974721>;

    std::vector<mint1> a1(a.begin(), a.end()), b1(b.begin(), b.end());
    std::vector<mint2> a2(a.begin(), a.end()), b2(b.begin(), b.end());
    std::vector<mint3> a3(a.begin(), a.end()), b3(b.begin(), b.end());

    auto c1 = convolution(a1, b1);
    auto c2 = convolution(a2, b2);
    auto c3 = convolution(a3, b3);

    std::vector<int> c(c1.size());
    std::vector<long long> d(3);
    const std::vector<long long> mods = {167772161, 469762049, 754974721};
    for (int i = 0; i < (int)c1.size(); ++i) {
        d[0] = c1[i].val();
        d[1] = c2[i].val();
        d[2] = c3[i].val();
        c[i] = garner(d, mods, mod);
    }
    return c;
}
Back to top page