sotanishy's competitive programming library

sotanishy's code snippets for competitive programming

View the Project on GitHub sotanishy/cp-library-cpp

:warning: Relaxed Convolution
(convolution/relaxed_convolution.hpp)

Description

Relaxed convolution は,畳み込みをオンラインで処理するアルゴリズムである.すなわち,数列 $F, G$ の項が前から順番に与えられたとき,それらの畳み込み $H=F*G$ の各項を前から順番に返す.

Operations

Reference

Depends on

Code

#pragma once
#include <bit>
#include <vector>

#include "ntt.hpp"

template <typename mint>
class RelaxedConvolution {
   public:
    mint get(mint a, mint b) {
        f.push_back(a);
        g.push_back(b);
        ++n;
        int m = 1 << std::countr_zero((unsigned int)n + 1);
        int s = 0, x = 1;
        while (x <= m) {
            calc(n - x, n, s, s + x);
            if (n + 1 == m && x == m >> 1) break;
            calc(s, s + x, n - x, n);
            s += x;
            x <<= 1;
        }
        return h[n - 1];
    }

   private:
    int n = 0;
    std::vector<mint> f, g, h;

    void calc(int lf, int rf, int lg, int rg) {
        if ((int)h.size() < rf + rg - 1) {
            h.resize(rf + rg - 1);
        }
        auto res =
            convolution(std::vector<mint>(f.begin() + lf, f.begin() + rf),
                        std::vector<mint>(g.begin() + lg, g.begin() + rg));
        for (int i = 0; i < (int)res.size(); ++i) {
            h[lf + lg + i] += res[i];
        }
    }
};
#line 2 "convolution/relaxed_convolution.hpp"
#include <bit>
#include <vector>

#line 4 "convolution/ntt.hpp"

constexpr int get_primitive_root(int mod) {
    if (mod == 167772161) return 3;
    if (mod == 469762049) return 3;
    if (mod == 754974721) return 11;
    if (mod == 998244353) return 3;
    if (mod == 1224736769) return 3;
}

template <typename mint>
void ntt(std::vector<mint>& a) {
    constexpr int mod = mint::mod();
    constexpr mint primitive_root = get_primitive_root(mod);

    const int n = a.size();
    for (int m = n; m > 1; m >>= 1) {
        mint omega = primitive_root.pow((mod - 1) / m);
        for (int s = 0; s < n / m; ++s) {
            mint w = 1;
            for (int i = 0; i < m / 2; ++i) {
                mint l = a[s * m + i];
                mint r = a[s * m + i + m / 2];
                a[s * m + i] = l + r;
                a[s * m + i + m / 2] = (l - r) * w;
                w *= omega;
            }
        }
    }
}

template <typename mint>
void intt(std::vector<mint>& a) {
    constexpr int mod = mint::mod();
    constexpr mint primitive_root = get_primitive_root(mod);

    const int n = a.size();
    for (int m = 2; m <= n; m <<= 1) {
        mint omega = primitive_root.pow((mod - 1) / m).inv();
        for (int s = 0; s < n / m; ++s) {
            mint w = 1;
            for (int i = 0; i < m / 2; ++i) {
                mint l = a[s * m + i];
                mint r = a[s * m + i + m / 2] * w;
                a[s * m + i] = l + r;
                a[s * m + i + m / 2] = l - r;
                w *= omega;
            }
        }
    }
}

template <typename mint>
std::vector<mint> convolution(std::vector<mint> a, std::vector<mint> b) {
    const int size = a.size() + b.size() - 1;
    const int n = std::bit_ceil((unsigned int)size);
    a.resize(n);
    b.resize(n);
    ntt(a);
    ntt(b);
    for (int i = 0; i < n; ++i) a[i] *= b[i];
    intt(a);
    a.resize(size);
    mint n_inv = mint(n).inv();
    for (int i = 0; i < size; ++i) a[i] *= n_inv;
    return a;
}
#line 6 "convolution/relaxed_convolution.hpp"

template <typename mint>
class RelaxedConvolution {
   public:
    mint get(mint a, mint b) {
        f.push_back(a);
        g.push_back(b);
        ++n;
        int m = 1 << std::countr_zero((unsigned int)n + 1);
        int s = 0, x = 1;
        while (x <= m) {
            calc(n - x, n, s, s + x);
            if (n + 1 == m && x == m >> 1) break;
            calc(s, s + x, n - x, n);
            s += x;
            x <<= 1;
        }
        return h[n - 1];
    }

   private:
    int n = 0;
    std::vector<mint> f, g, h;

    void calc(int lf, int rf, int lg, int rg) {
        if ((int)h.size() < rf + rg - 1) {
            h.resize(rf + rg - 1);
        }
        auto res =
            convolution(std::vector<mint>(f.begin() + lf, f.begin() + rf),
                        std::vector<mint>(g.begin() + lg, g.begin() + rg));
        for (int i = 0; i < (int)res.size(); ++i) {
            h[lf + lg + i] += res[i];
        }
    }
};
Back to top page