sotanishy's competitive programming library

sotanishy's code snippets for competitive programming

View the Project on GitHub sotanishy/cp-library-cpp

:heavy_check_mark: test/yosupo/discrete_logarithm_mod.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/discrete_logarithm_mod"

#include "../../math/number-theory/mod_arithmetic.hpp"

#include <bits/stdc++.h>
using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);

    int T;
    cin >> T;
    for (int i = 0; i < T; ++i) {
        int X, Y, M;
        cin >> X >> Y >> M;
        cout << mod_log(X, Y, M) << "\n";
    }
}
#line 1 "test/yosupo/discrete_logarithm_mod.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/discrete_logarithm_mod"

#line 2 "math/number-theory/mod_arithmetic.hpp"
#include <vector>
#include <cmath>
#include <numeric>
#include <unordered_map>

#line 4 "math/number-theory/euler_totient.hpp"

long long euler_totient(long long n) {
    long long ret = n;
    if (n % 2 == 0) {
        ret -= ret / 2;
        while (n % 2 == 0) n /= 2;
    }
    for (long long i = 3; i * i <= n; i += 2) {
        if (n % i == 0) {
            ret -= ret / i;
            while (n % i == 0) n /= i;
        }
    }
    if (n != 1) ret -= ret / n;
    return ret;
}

std::vector<int> euler_totient_table(int n) {
    std::vector<int> ret(n + 1);
    std::iota(ret.begin(), ret.end(), 0);
    for (int i = 2; i <= n; ++i) {
        if (ret[i] == i) {
            for (int j = i; j <= n; j += i) {
                ret[j] = ret[j] / i * (i - 1);
            }
        }
    }
    return ret;
}

template <typename mint>
std::pair<std::vector<mint>, std::vector<mint>> totient_summatory_table(
    long long n) {
    if (n == 0) return {{0}, {0}};
    const int b = std::min(n, (long long)1e4);
    std::vector<mint> small(n / b + 1), large(b + 1);

    std::vector<int> totient(n / b + 1);
    std::iota(totient.begin(), totient.end(), 0);
    for (int i = 2; i <= n / b; ++i) {
        if (totient[i] != i) continue;
        for (int j = i; j <= n / b; j += i) {
            totient[j] = totient[j] / i * (i - 1);
        }
    }
    for (int i = 0; i < n / b; ++i) small[i + 1] = small[i] + totient[i + 1];

    for (int i = 1; i <= b; ++i) {
        mint k = n / i;
        large[i] = k * (k + 1) / 2;
    }
    for (long long i = b; i >= 1; --i) {
        for (long long l = 2; l <= n / i;) {
            long long q = n / (i * l), r = n / (i * q) + 1;
            large[i] -=
                (i * l <= b ? large[i * l] : small[n / (i * l)]) * (r - l);
            l = r;
        }
    }
    return {small, large};
}
#line 8 "math/number-theory/mod_arithmetic.hpp"

/*
 * Modular Exponentiation
 */
long long mod_pow(long long a, long long e, int mod) {
    long long ret = 1;
    while (e > 0) {
        if (e & 1) ret = ret * a % mod;
        a = a * a % mod;
        e >>= 1;
    }
    return ret;
}

long long mod_inv(long long a, int mod) { return mod_pow(a, mod - 2, mod); }

/*
 * Discrete Logarithm
 */
int mod_log(long long a, long long b, int mod) {
    // make a and mod coprime
    a %= mod;
    b %= mod;
    long long k = 1, add = 0, g;
    while ((g = std::gcd(a, mod)) > 1) {
        if (b == k) return add;
        if (b % g) return -1;
        b /= g;
        mod /= g;
        ++add;
        k = k * a / g % mod;
    }

    // baby-step
    const int m = std::sqrt(mod) + 1;
    std::unordered_map<long long, int> baby_index;
    long long baby = b;
    for (int i = 0; i <= m; ++i) {
        baby_index[baby] = i;
        baby = baby * a % mod;
    }

    // giant-step
    long long am = 1;
    for (int i = 0; i < m; ++i) am = am * a % mod;
    long long giant = k;
    for (int i = 1; i <= m; ++i) {
        giant = giant * am % mod;
        if (baby_index.contains(giant)) {
            return i * m - baby_index[giant] + add;
        }
    }
    return -1;
}

/*
 * Quadratic Residue
 */
long long mod_sqrt(long long n, int mod) {
    if (n == 0) return 0;
    if (mod == 2) return 1;
    if (std::gcd(n, mod) != 1) return -1;
    if (mod_pow(n, (mod - 1) / 2, mod) == mod - 1) return -1;

    int Q = mod - 1, S = 0;
    while (!(Q & 1)) Q >>= 1, ++S;
    long long z = 2;
    while (true) {
        if (mod_pow(z, (mod - 1) / 2, mod) == mod - 1) break;
        ++z;
    }
    int M = S;
    long long c = mod_pow(z, Q, mod);
    long long t = mod_pow(n, Q, mod);
    long long R = mod_pow(n, (Q + 1) / 2, mod);
    while (t != 1) {
        int i = 0;
        long long s = t;
        while (s != 1) {
            s = s * s % mod;
            ++i;
        }
        long long b = mod_pow(c, 1 << (M - i - 1), mod);
        M = i;
        c = b * b % mod;
        t = t * c % mod;
        R = R * b % mod;
    }
    return R;
}

/**
 * Modular Tetration
 */
long long mod_tetration(long long a, long long b, int mod) {
    if (mod == 1) return 0;
    if (a == 0) return 1 - (b % 2);
    if (a == 1 || b == 0) return 1;

    auto pow = [&](long long a, long long e, int mod) {
        if (a >= mod) a = a % mod + mod;
        long long ret = 1;
        while (e > 0) {
            if (e & 1) {
                ret = ret * a;
                if (ret >= mod) ret = ret % mod + mod;
            }
            a = a * a;
            if (a >= mod) a = a % mod + mod;
            e >>= 1;
        }
        return ret;
    };

    auto rec = [&](auto& rec, long long b, int mod) -> long long {
        if (b == 1) return a;
        if (mod == 1) return 1;
        return pow(a, rec(rec, b - 1, euler_totient(mod)), mod);
    };

    return rec(rec, b, mod) % mod;
}

/**
 * Table of Modular Inverses
 */
std::vector<int> mod_inv_table(int n, int mod) {
    std::vector<int> inv(n + 1, 1);
    for (int i = 2; i <= n; ++i) {
        inv[i] = mod - 1LL * inv[mod % i] * (mod / i) % mod;
    }
    return inv;
}
#line 4 "test/yosupo/discrete_logarithm_mod.test.cpp"

#include <bits/stdc++.h>
using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);

    int T;
    cin >> T;
    for (int i = 0; i < T; ++i) {
        int X, Y, M;
        cin >> X >> Y >> M;
        cout << mod_log(X, Y, M) << "\n";
    }
}
Back to top page