sotanishy's competitive programming library

sotanishy's code snippets for competitive programming

View the Project on GitHub sotanishy/cp-library-cpp

:warning: Sum of Top K Values
(misc/sum_top_k.hpp)

Code

#pragma once
#include <set>
#include <vector>

/**
 * @brief Sum of Top K Values
 */

template <typename T>
class SumTopK {
   public:
    SumTopK() = default;
    explicit SumTopK(int k) : k(k) {}

    void insert(T x) {
        small.insert(x);
        sumsmall += x;
        fix();
    }

    void erase(T x) {
        if (small.contains(x)) {
            small.erase(small.find(x));
            sumsmall -= x;
        } else {
            assert(big.contains(x));
            big.erase(big.find(x));
            sumbig -= x;
        }
        fix();
    }

    T sum_big() const { return sumbig; }

    T sum_small() const { return sumsmall; }

   private:
    std::multiset<T> small, big;
    T sumsmall = 0, sumbig = 0;
    int k;

    void fix() {
        while (!small.empty() && !big.empty() &&
               *small.rbegin() > *big.begin()) {
            auto st = --small.end();
            auto bt = big.begin();
            auto sv = *st, bv = *bt;
            sumsmall -= sv, sumsmall += bv;
            sumbig -= bv, sumbig += sv;
            small.erase(st), small.insert(bv);
            big.erase(bt), big.insert(sv);
        }
        while (!small.empty() && big.size() < k) {
            auto st = --small.end();
            auto sv = *st;
            sumsmall -= sv;
            sumbig += sv;
            small.erase(st);
            big.insert(sv);
        }
        while (big.size() > k) {
            auto bt = big.begin();
            auto bv = *bt;
            sumsmall += bv;
            sumbig -= bv;
            big.erase(bt);
            small.insert(bv);
        }
    }
};
#line 2 "misc/sum_top_k.hpp"
#include <set>
#include <vector>

/**
 * @brief Sum of Top K Values
 */

template <typename T>
class SumTopK {
   public:
    SumTopK() = default;
    explicit SumTopK(int k) : k(k) {}

    void insert(T x) {
        small.insert(x);
        sumsmall += x;
        fix();
    }

    void erase(T x) {
        if (small.contains(x)) {
            small.erase(small.find(x));
            sumsmall -= x;
        } else {
            assert(big.contains(x));
            big.erase(big.find(x));
            sumbig -= x;
        }
        fix();
    }

    T sum_big() const { return sumbig; }

    T sum_small() const { return sumsmall; }

   private:
    std::multiset<T> small, big;
    T sumsmall = 0, sumbig = 0;
    int k;

    void fix() {
        while (!small.empty() && !big.empty() &&
               *small.rbegin() > *big.begin()) {
            auto st = --small.end();
            auto bt = big.begin();
            auto sv = *st, bv = *bt;
            sumsmall -= sv, sumsmall += bv;
            sumbig -= bv, sumbig += sv;
            small.erase(st), small.insert(bv);
            big.erase(bt), big.insert(sv);
        }
        while (!small.empty() && big.size() < k) {
            auto st = --small.end();
            auto sv = *st;
            sumsmall -= sv;
            sumbig += sv;
            small.erase(st);
            big.insert(sv);
        }
        while (big.size() > k) {
            auto bt = big.begin();
            auto bv = *bt;
            sumsmall += bv;
            sumbig -= bv;
            big.erase(bt);
            small.insert(bv);
        }
    }
};
Back to top page