sotanishy's competitive programming library

sotanishy's code snippets for competitive programming

View the Project on GitHub sotanishy/cp-library-cpp

:warning: Combination (Arbitrary mod)
(math/combination_arbitrary_mod.hpp)

Description

二項係数を任意の mod で計算する.

Reference

Depends on

Code

#pragma once
#include <vector>

#include "number-theory/prime.hpp"

template <typename mint>
std::vector<mint> combination_arbitrary_modint(int n) {
    if (n == 0) return {1};
    const int m = mint::mod();

    // preprocess prime
    auto is_prime = prime_table(n);
    std::vector<int> prime;
    for (int i = 2; i <= n; ++i) {
        if (is_prime[i]) {
            prime.push_back(i);
        }
    }

    int l = 0;
    std::vector<int> maxfact(n + 1, -1);
    std::vector<int> ps;
    for (int p : prime) {
        if (m % p == 0) {
            ps.push_back(p);
            for (int i = p; i <= n; i += p) {
                maxfact[i] = l;
            }
            ++l;
        }
    }

    std::vector<std::vector<mint>> pow(l);
    for (int i = 0; i < l; ++i) {
        pow[i].resize(n / (ps[i] - 1) + 1);
        pow[i][0] = 1;
        for (int j = 1; j < (int)pow[i].size(); ++j) {
            pow[i][j] = pow[i][j - 1] * ps[i];
        }
    }

    // calculate comb
    std::vector<mint> comb(n + 1);
    comb[0] = 1;
    mint s = 1;
    std::vector<int> t(l);
    for (int k = 1; k <= n; ++k) {
        int den = k;
        while (maxfact[den] != -1) {
            --t[maxfact[den]];
            den /= ps[maxfact[den]];
        }
        int num = n - k + 1;
        while (maxfact[num] != -1) {
            ++t[maxfact[num]];
            num /= ps[maxfact[num]];
        }
        s = s / den * num;
        comb[k] = s;
        for (int i = 0; i < l; ++i) {
            comb[k] *= pow[i][t[i]];
        }
    }

    return comb;
}
#line 2 "math/combination_arbitrary_mod.hpp"
#include <vector>

#line 2 "math/number-theory/prime.hpp"
#include <map>
#include <numeric>
#line 5 "math/number-theory/prime.hpp"

/*
 * Primality Test
 */
bool is_prime(long long n) {
    if (n <= 1) return false;
    if (n <= 3) return true;
    if (n % 2 == 0 || n % 3 == 0) return false;
    if (n < 9) return true;
    for (long long i = 5; i * i <= n; i += 6) {
        if (n % i == 0 || n % (i + 2) == 0) return false;
    }
    return true;
}

/*
 * Prime Table
 */
std::vector<bool> prime_table(int n) {
    std::vector<bool> prime(n + 1, true);
    prime[0] = prime[1] = false;
    for (int j = 4; j <= n; j += 2) prime[j] = false;
    for (int i = 3; i * i <= n; i += 2) {
        if (!prime[i]) continue;
        for (int j = i * i; j <= n; j += 2 * i) prime[j] = false;
    }
    return prime;
}

/*
 * Table of Minimum Prime Factors
 */
std::vector<int> min_factor_table(int n) {
    std::vector<int> factor(n + 1);
    std::iota(factor.begin(), factor.end(), 0);
    for (int i = 2; i * i <= n; ++i) {
        if (factor[i] != i) continue;
        for (int j = i * i; j <= n; j += i) {
            if (factor[j] == j) factor[j] = i;
        }
    }
    return factor;
}

/*
 * Prime Factorization
 */
std::map<long long, int> prime_factor(long long n) {
    std::map<long long, int> ret;
    if (n % 2 == 0) {
        int cnt = 0;
        while (n % 2 == 0) {
            ++cnt;
            n /= 2;
        }
        ret[2] = cnt;
    }
    for (long long i = 3; i * i <= n; i += 2) {
        if (n % i == 0) {
            int cnt = 0;
            while (n % i == 0) {
                ++cnt;
                n /= i;
            }
            ret[i] = cnt;
        }
    }
    if (n != 1) ret[n] = 1;
    return ret;
}
#line 5 "math/combination_arbitrary_mod.hpp"

template <typename mint>
std::vector<mint> combination_arbitrary_modint(int n) {
    if (n == 0) return {1};
    const int m = mint::mod();

    // preprocess prime
    auto is_prime = prime_table(n);
    std::vector<int> prime;
    for (int i = 2; i <= n; ++i) {
        if (is_prime[i]) {
            prime.push_back(i);
        }
    }

    int l = 0;
    std::vector<int> maxfact(n + 1, -1);
    std::vector<int> ps;
    for (int p : prime) {
        if (m % p == 0) {
            ps.push_back(p);
            for (int i = p; i <= n; i += p) {
                maxfact[i] = l;
            }
            ++l;
        }
    }

    std::vector<std::vector<mint>> pow(l);
    for (int i = 0; i < l; ++i) {
        pow[i].resize(n / (ps[i] - 1) + 1);
        pow[i][0] = 1;
        for (int j = 1; j < (int)pow[i].size(); ++j) {
            pow[i][j] = pow[i][j - 1] * ps[i];
        }
    }

    // calculate comb
    std::vector<mint> comb(n + 1);
    comb[0] = 1;
    mint s = 1;
    std::vector<int> t(l);
    for (int k = 1; k <= n; ++k) {
        int den = k;
        while (maxfact[den] != -1) {
            --t[maxfact[den]];
            den /= ps[maxfact[den]];
        }
        int num = n - k + 1;
        while (maxfact[num] != -1) {
            ++t[maxfact[num]];
            num /= ps[maxfact[num]];
        }
        s = s / den * num;
        comb[k] = s;
        for (int i = 0; i < l; ++i) {
            comb[k] *= pow[i][t[i]];
        }
    }

    return comb;
}
Back to top page