sotanishy's competitive programming library

sotanishy's code snippets for competitive programming

View the Project on GitHub sotanishy/cp-library-cpp

:heavy_check_mark: test/yosupo/directedmst.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/directedmst"

#include <bits/stdc++.h>

#include "../../graph/minimum_spanning_arborescence.hpp"
using namespace std;
using ll = long long;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);

    int N, M, S;
    cin >> N >> M >> S;
    vector<tuple<int, int, ll>> G(M);
    for (int i = 0; i < M; ++i) {
        int s, t, w;
        cin >> s >> t >> w;
        G[i] = {s, t, w};
    }
    auto ans = minimum_spanning_arborescence(G, N, S);
    cout << ans.first << endl;
    for (int i = 0; i < N; ++i)
        cout << ans.second[i] << (i < N - 1 ? " " : "\n");
}
#line 1 "test/yosupo/directedmst.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/directedmst"

#include <bits/stdc++.h>

#line 6 "graph/minimum_spanning_arborescence.hpp"
#include <ranges>
#line 9 "graph/minimum_spanning_arborescence.hpp"

#line 5 "data-structure/leftist_heap.hpp"

template <typename T>
class LeftistHeap {
   public:
    LeftistHeap() = default;

    static LeftistHeap meld(LeftistHeap a, LeftistHeap b) {
        return LeftistHeap(meld(std::move(a.root), std::move(b.root)));
    }

    std::pair<int, T> top() const {
        push(root);
        return {root->id, root->val};
    }

    void pop() {
        push(root);
        root = meld(std::move(root->left), std::move(root->right));
    }

    void push(int id, T x) {
        root = meld(std::move(root), std::make_unique<Node>(id, x));
    }

    bool empty() const { return root == nullptr; }

    void add(T x) { root->lazy += x; }

   private:
    struct Node;
    using node_ptr = std::unique_ptr<Node>;

    struct Node {
        node_ptr left, right;
        int s;
        int id;
        T val, lazy;
        Node(int id, T x) : id(id), val(x), lazy(0) {}
    };

    node_ptr root = nullptr;

    explicit LeftistHeap(node_ptr root) : root(std::move(root)) {}

    static node_ptr meld(node_ptr a, node_ptr b) {
        if (!a) return b;
        if (!b) return a;
        push(a);
        push(b);
        if (a->val > b->val) std::swap(a, b);
        a->right = meld(std::move(a->right), std::move(b));
        if (!a->left || a->left->s < a->right->s) std::swap(a->left, a->right);
        a->s = (a->right ? a->right->s : 0) + 1;
        return a;
    }

    static void push(const node_ptr& t) {
        if (t->left) t->left->lazy += t->lazy;
        if (t->right) t->right->lazy += t->lazy;
        t->val += t->lazy;
        t->lazy = 0;
    }
};
#line 4 "data-structure/unionfind/union_find.hpp"

class UnionFind {
   public:
    UnionFind() = default;
    explicit UnionFind(int n) : data(n, -1) {}

    int find(int x) {
        if (data[x] < 0) return x;
        return data[x] = find(data[x]);
    }

    void unite(int x, int y) {
        x = find(x);
        y = find(y);
        if (x == y) return;
        if (data[x] > data[y]) std::swap(x, y);
        data[x] += data[y];
        data[y] = x;
    }

    bool same(int x, int y) { return find(x) == find(y); }

    int size(int x) { return -data[find(x)]; }

   private:
    std::vector<int> data;
};
#line 12 "graph/minimum_spanning_arborescence.hpp"

/**
 * @brief Minimum Spanning Arborescence
 */
template <typename T>
std::pair<T, std::vector<int>> minimum_spanning_arborescence(
    std::vector<std::tuple<int, int, T>> G, int V, int root) {
    std::vector<LeftistHeap<T>> incoming(V);
    for (int i = 0; i < (int)G.size(); ++i) {
        auto [s, t, w] = G[i];
        incoming[t].push(i, w);
    }
    T weight = 0;
    UnionFind uf(V);
    std::vector<int> from(V), stem(V, -1), prev_edge(G.size()), ord;
    std::vector<T> from_cost(V);
    std::vector<int> status(V);  // 0: not checked, 1: cheking, 2: checked
    status[root] = 2;

    for (int s = 0; s < V; ++s) {
        if (status[s] != 0) continue;
        int cur = s, cyc = 0;
        std::vector<int> seen, processing;
        while (status[cur] != 2) {
            status[cur] = 1;
            processing.push_back(cur);
            if (incoming[cur].empty()) {  // no msa
                return {std::numeric_limits<T>::max(), std::vector<int>()};
            }

            auto [i, c] = incoming[cur].top();
            int v = uf.find(std::get<0>(G[i]));
            incoming[cur].pop();
            if (cur == v) continue;
            from[cur] = v;
            from_cost[cur] = c;
            weight += c;
            ord.push_back(i);
            if (stem[cur] == -1) stem[cur] = i;
            while (cyc) {
                prev_edge[seen.back()] = i;
                seen.pop_back();
                --cyc;
            }
            seen.push_back(i);

            if (status[v] == 1) {
                int p = cur;
                do {
                    if (!incoming[p].empty()) incoming[p].add(-from_cost[p]);
                    if (p != cur) {
                        uf.unite(p, cur);
                        auto newheap = LeftistHeap<T>::meld(
                            std::move(incoming[cur]), std::move(incoming[p]));
                        incoming[cur = uf.find(cur)] = std::move(newheap);
                    }
                    p = uf.find(from[p]);
                    ++cyc;
                } while (p != cur);
            } else {
                cur = v;
            }
        }
        for (int v : processing) status[v] = 2;
    }
    std::vector<bool> used_edge(G.size());
    std::vector<int> par(V);
    std::iota(par.begin(), par.end(), 0);
    for (int i : ord | std::views::reverse) {
        if (used_edge[i]) continue;
        auto [s, t, w] = G[i];
        par[t] = s;
        int x = stem[t];
        while (x != i) {
            used_edge[x] = true;
            x = prev_edge[x];
        }
    }
    return {weight, par};
}
#line 6 "test/yosupo/directedmst.test.cpp"
using namespace std;
using ll = long long;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);

    int N, M, S;
    cin >> N >> M >> S;
    vector<tuple<int, int, ll>> G(M);
    for (int i = 0; i < M; ++i) {
        int s, t, w;
        cin >> s >> t >> w;
        G[i] = {s, t, w};
    }
    auto ans = minimum_spanning_arborescence(G, N, S);
    cout << ans.first << endl;
    for (int i = 0; i < N; ++i)
        cout << ans.second[i] << (i < N - 1 ? " " : "\n");
}
Back to top page