CF1681F Unique Occurrences

orz_z

2022-07-29 16:00:28

题解

CF1681F Unique Occurrences

线段树分治练手题,半小时秒了。

考虑对每个颜色 w 计算有多少路径经过恰好一条颜色为 w 的边。

具体地,将所有颜色为 w 的边从原树中删去,对于每一条颜色为 w 的边的两端点连通块大小乘积的和即为这个颜色的答案。

将颜色看作时间线,那么一条颜色为 w 的边建在时间 [1,w-1][w+1,n] 上,线段树优化即可。

计算连通块大小用并查集即可,线段树分治考虑可撤销并查集。

时间复杂度 \mathcal O(n\log^2n)

#include <bits/stdc++.h>

using namespace std;

namespace Fread
{
    const int SIZE = 1 << 23;
    char buf[SIZE], *S, *T;
    inline char getchar()
    {
        if (S == T)
        {
            T = (S = buf) + fread(buf, 1, SIZE, stdin);
            if (S == T)
                return '\n';
        }
        return *S++;
    }
}

namespace Fwrite
{
    const int SIZE = 1 << 23;
    char buf[SIZE], *S = buf, *T = buf + SIZE;
    inline void flush()
    {
        fwrite(buf, 1, S - buf, stdout);
        S = buf;
    }
    inline void putchar(char c)
    {
        *S++ = c;
        if (S == T)
            flush();
    }
    struct NTR
    {
        ~NTR()
        {
            flush();
        }
    } ztr;
}

#ifdef ONLINE_JUDGE
#define getchar Fread::getchar
#define putchar Fwrite::putchar
#endif

//#define int long long
typedef long long ll;
#define ha putchar(' ')
#define he putchar('\n')

inline int read()
{
    int x = 0, f = 1;
    char c = getchar();
    while (c < '0' || c > '9')
    {
        if (c == '-')
            f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9')
        x = x * 10 + c - '0', c = getchar();
    return x * f;
}

inline void write(ll x)
{
    if (x < 0)
    {
        putchar('-');
        x = -x;
    }
    if (x > 9)
        write(x / 10);
    putchar(x % 10 + '0');
}

const int _ = 5e5 + 10;

int n, x, y, z, sz[_], fa[_], top;

ll ans;

vector<pair<int, int>> d[_], e[_ << 2];

int find(int x)
{
    while (x ^ fa[x])
        x = fa[x];
    return x;
}

void upd(int o, int l, int r, int L, int R, pair<int, int> pi)
{
    if (L <= l && r <= R)
        return e[o].push_back(pi), void();
    int mid = (l + r) >> 1;
    if (L <= mid)
        upd(o << 1, l, mid, L, R, pi);
    if (R > mid)
        upd(o << 1 | 1, mid + 1, r, L, R, pi);
}

void solve(int o, int l, int r)
{
    vector<int> dl;
    for (auto v : e[o])
    {
        int p = find(v.first), q = find(v.second);
        if (p == q)
            continue;
        if (sz[p] > sz[q])
            swap(p, q);
        dl.push_back(p), sz[q] += sz[p], fa[p] = q;
    }
    if (l == r)
        for (auto v : d[l])
            ans += 1ll * sz[find(v.first)] * sz[find(v.second)];
    else
    {
        int mid = (l + r) >> 1;
        solve(o << 1, l, mid), solve(o << 1 | 1, mid + 1, r);
    }
    reverse(dl.begin(), dl.end());
    for (auto v : dl)
        sz[fa[v]] -= sz[v], fa[v] = v;
}

signed main()
{
    n = read();
    for (int i = 1; i < n; ++i)
    {
        x = read(), y = read(), z = read();
        d[z].push_back({x, y});
        if (z > 1)
            upd(1, 1, n, 1, z - 1, {x, y});
        if (z < n)
            upd(1, 1, n, z + 1, n, {x, y});
    }
    for (int i = 1; i <= n; ++i)
        sz[fa[i] = i] = 1;
    solve(1, 1, n);
    write(ans), he;
    return 0;
}