sotanishy's code snippets for competitive programming
#include "data-structure/fenwick_tree.hpp"
Fenwick tree,または binary indexed tree (BIT) は,可換モノイド $(T, \cdot, e)$ の列に対する一点更新と接頭辞和の取得を提供するデータ構造である.
セグメント木より制約が強く,操作が限られているが,実装が簡潔で定数倍速い.浮動小数点の演算をするときは,セグメント木よりも誤差が大きくなる傾向があるので注意が必要である.
空間計算量: $O(n)$
FenwickTree(int n)
n
で要素がすべて単位元 $e$ の Fenwick tree を構築するT prefix_fold(int i)
void update(int i, T x)
int lower_bound(T x)
int lower_bound(T x, Compare cmp)
cmp(prefix_fold(i), x) == false
となる最初の $i$ を返す.そのような $i$ が存在しない場合は $n$ を返す.cmp
を指定しない場合は <
で比較される.列の単調性を仮定する.#pragma once
#include <functional>
#include <vector>
template <typename M>
class FenwickTree {
using T = M::T;
public:
FenwickTree() = default;
explicit FenwickTree(int n) : n(n), data(n + 1, M::id()) {}
T prefix_fold(int i) const {
T ret = M::id();
for (; i > 0; i -= i & -i) ret = M::op(ret, data[i]);
return ret;
}
void update(int i, const T& x) {
for (++i; i <= n; i += i & -i) data[i] = M::op(data[i], x);
}
int lower_bound(const T& x) const { return lower_bound(x, std::less<>()); }
template <typename Compare>
int lower_bound(const T& x, Compare cmp) const {
if (!cmp(M::id(), x)) return 0;
int k = 1;
while (k * 2 <= n) k <<= 1;
int i = 0;
T v = M::id();
for (; k > 0; k >>= 1) {
if (i + k > n) continue;
T nv = M::op(v, data[i + k]);
if (cmp(nv, x)) {
v = nv;
i += k;
}
}
return i + 1;
}
private:
int n;
std::vector<T> data;
};
#line 2 "data-structure/fenwick_tree.hpp"
#include <functional>
#include <vector>
template <typename M>
class FenwickTree {
using T = M::T;
public:
FenwickTree() = default;
explicit FenwickTree(int n) : n(n), data(n + 1, M::id()) {}
T prefix_fold(int i) const {
T ret = M::id();
for (; i > 0; i -= i & -i) ret = M::op(ret, data[i]);
return ret;
}
void update(int i, const T& x) {
for (++i; i <= n; i += i & -i) data[i] = M::op(data[i], x);
}
int lower_bound(const T& x) const { return lower_bound(x, std::less<>()); }
template <typename Compare>
int lower_bound(const T& x, Compare cmp) const {
if (!cmp(M::id(), x)) return 0;
int k = 1;
while (k * 2 <= n) k <<= 1;
int i = 0;
T v = M::id();
for (; k > 0; k >>= 1) {
if (i + k > n) continue;
T nv = M::op(v, data[i + k]);
if (cmp(nv, x)) {
v = nv;
i += k;
}
}
return i + 1;
}
private:
int n;
std::vector<T> data;
};