sotanishy's competitive programming library

sotanishy's code snippets for competitive programming

View the Project on GitHub sotanishy/cp-library-cpp

:heavy_check_mark: Number Theoretic Transform
(convolution/ntt.hpp)

Description

数論変換 (NTT) は,有限体 $\mathbb{F}_p$ 上の高速 Fourier 変換である.

Operations

Note

数列の長さ $n$ が $p - 1$ を割り切るとき,1の原子 $n$ 乗根が必ず存在する.$n$ の長さを2冪にするため,素数 $p$ は $a \times 2^k + 1$ の形に表したとき $k$ が十分大きいものを用いる.

以下の p と原子根の組み合わせがよく用いられる:

$k$ の値はそれぞれ,25, 26, 24, 23, 24である.

Required by

Verified with

Code

#pragma once
#include <bit>
#include <vector>

constexpr int get_primitive_root(int mod) {
    if (mod == 167772161) return 3;
    if (mod == 469762049) return 3;
    if (mod == 754974721) return 11;
    if (mod == 998244353) return 3;
    if (mod == 1224736769) return 3;
}

template <typename mint>
void ntt(std::vector<mint>& a) {
    constexpr int mod = mint::mod();
    constexpr mint primitive_root = get_primitive_root(mod);

    const int n = a.size();
    for (int m = n; m > 1; m >>= 1) {
        mint omega = primitive_root.pow((mod - 1) / m);
        for (int s = 0; s < n / m; ++s) {
            mint w = 1;
            for (int i = 0; i < m / 2; ++i) {
                mint l = a[s * m + i];
                mint r = a[s * m + i + m / 2];
                a[s * m + i] = l + r;
                a[s * m + i + m / 2] = (l - r) * w;
                w *= omega;
            }
        }
    }
}

template <typename mint>
void intt(std::vector<mint>& a) {
    constexpr int mod = mint::mod();
    constexpr mint primitive_root = get_primitive_root(mod);

    const int n = a.size();
    for (int m = 2; m <= n; m <<= 1) {
        mint omega = primitive_root.pow((mod - 1) / m).inv();
        for (int s = 0; s < n / m; ++s) {
            mint w = 1;
            for (int i = 0; i < m / 2; ++i) {
                mint l = a[s * m + i];
                mint r = a[s * m + i + m / 2] * w;
                a[s * m + i] = l + r;
                a[s * m + i + m / 2] = l - r;
                w *= omega;
            }
        }
    }
}

template <typename mint>
std::vector<mint> convolution(std::vector<mint> a, std::vector<mint> b) {
    const int size = a.size() + b.size() - 1;
    const int n = std::bit_ceil((unsigned int)size);
    a.resize(n);
    b.resize(n);
    ntt(a);
    ntt(b);
    for (int i = 0; i < n; ++i) a[i] *= b[i];
    intt(a);
    a.resize(size);
    mint n_inv = mint(n).inv();
    for (int i = 0; i < size; ++i) a[i] *= n_inv;
    return a;
}
#line 2 "convolution/ntt.hpp"
#include <bit>
#include <vector>

constexpr int get_primitive_root(int mod) {
    if (mod == 167772161) return 3;
    if (mod == 469762049) return 3;
    if (mod == 754974721) return 11;
    if (mod == 998244353) return 3;
    if (mod == 1224736769) return 3;
}

template <typename mint>
void ntt(std::vector<mint>& a) {
    constexpr int mod = mint::mod();
    constexpr mint primitive_root = get_primitive_root(mod);

    const int n = a.size();
    for (int m = n; m > 1; m >>= 1) {
        mint omega = primitive_root.pow((mod - 1) / m);
        for (int s = 0; s < n / m; ++s) {
            mint w = 1;
            for (int i = 0; i < m / 2; ++i) {
                mint l = a[s * m + i];
                mint r = a[s * m + i + m / 2];
                a[s * m + i] = l + r;
                a[s * m + i + m / 2] = (l - r) * w;
                w *= omega;
            }
        }
    }
}

template <typename mint>
void intt(std::vector<mint>& a) {
    constexpr int mod = mint::mod();
    constexpr mint primitive_root = get_primitive_root(mod);

    const int n = a.size();
    for (int m = 2; m <= n; m <<= 1) {
        mint omega = primitive_root.pow((mod - 1) / m).inv();
        for (int s = 0; s < n / m; ++s) {
            mint w = 1;
            for (int i = 0; i < m / 2; ++i) {
                mint l = a[s * m + i];
                mint r = a[s * m + i + m / 2] * w;
                a[s * m + i] = l + r;
                a[s * m + i + m / 2] = l - r;
                w *= omega;
            }
        }
    }
}

template <typename mint>
std::vector<mint> convolution(std::vector<mint> a, std::vector<mint> b) {
    const int size = a.size() + b.size() - 1;
    const int n = std::bit_ceil((unsigned int)size);
    a.resize(n);
    b.resize(n);
    ntt(a);
    ntt(b);
    for (int i = 0; i < n; ++i) a[i] *= b[i];
    intt(a);
    a.resize(size);
    mint n_inv = mint(n).inv();
    for (int i = 0; i < size; ++i) a[i] *= n_inv;
    return a;
}
Back to top page