sotanishy's code snippets for competitive programming
#include "misc/sum_top_k.hpp"
#pragma once
#include <set>
#include <vector>
/**
* @brief Sum of Top K Values
*/
template <typename T>
class SumTopK {
public:
SumTopK() = default;
explicit SumTopK(int k) : k(k) {}
void insert(T x) {
small.insert(x);
sumsmall += x;
fix();
}
void erase(T x) {
if (small.contains(x)) {
small.erase(small.find(x));
sumsmall -= x;
} else {
assert(big.contains(x));
big.erase(big.find(x));
sumbig -= x;
}
fix();
}
T sum_big() const { return sumbig; }
T sum_small() const { return sumsmall; }
private:
std::multiset<T> small, big;
T sumsmall = 0, sumbig = 0;
int k;
void fix() {
while (!small.empty() && !big.empty() &&
*small.rbegin() > *big.begin()) {
auto st = --small.end();
auto bt = big.begin();
auto sv = *st, bv = *bt;
sumsmall -= sv, sumsmall += bv;
sumbig -= bv, sumbig += sv;
small.erase(st), small.insert(bv);
big.erase(bt), big.insert(sv);
}
while (!small.empty() && big.size() < k) {
auto st = --small.end();
auto sv = *st;
sumsmall -= sv;
sumbig += sv;
small.erase(st);
big.insert(sv);
}
while (big.size() > k) {
auto bt = big.begin();
auto bv = *bt;
sumsmall += bv;
sumbig -= bv;
big.erase(bt);
small.insert(bv);
}
}
};
#line 2 "misc/sum_top_k.hpp"
#include <set>
#include <vector>
/**
* @brief Sum of Top K Values
*/
template <typename T>
class SumTopK {
public:
SumTopK() = default;
explicit SumTopK(int k) : k(k) {}
void insert(T x) {
small.insert(x);
sumsmall += x;
fix();
}
void erase(T x) {
if (small.contains(x)) {
small.erase(small.find(x));
sumsmall -= x;
} else {
assert(big.contains(x));
big.erase(big.find(x));
sumbig -= x;
}
fix();
}
T sum_big() const { return sumbig; }
T sum_small() const { return sumsmall; }
private:
std::multiset<T> small, big;
T sumsmall = 0, sumbig = 0;
int k;
void fix() {
while (!small.empty() && !big.empty() &&
*small.rbegin() > *big.begin()) {
auto st = --small.end();
auto bt = big.begin();
auto sv = *st, bv = *bt;
sumsmall -= sv, sumsmall += bv;
sumbig -= bv, sumbig += sv;
small.erase(st), small.insert(bv);
big.erase(bt), big.insert(sv);
}
while (!small.empty() && big.size() < k) {
auto st = --small.end();
auto sv = *st;
sumsmall -= sv;
sumbig += sv;
small.erase(st);
big.insert(sv);
}
while (big.size() > k) {
auto bt = big.begin();
auto bv = *bt;
sumsmall += bv;
sumbig -= bv;
big.erase(bt);
small.insert(bv);
}
}
};