题解 P5210 【[ZJOI2017]线段树】

· · 题解

这题能黑了是因为洛谷仅有的题解都是些滥用数据结构的解法么……(

考虑一棵广义线段树的形态,和定位区间所得的结点的性质。

对于定位区间 [l,r] 的过程,可以考虑 [l-1,l-1][r+1,r+1] 对应的结点和它们的 LCA(分别设为 L,R,U)。
注意到定位出来的区间都是 LU 左儿子链上所有结点的在链以外的右儿子和 RU 右儿子链上所有结点的在链以外的左儿子。
那么问题便是算 u 与这些结点的距离和。
不难想到转化为深度和减去 2 倍的 LCA 的深度和。

显然地,可以离线下来,暴力地使用数据结构 O((n+m) \log^2 n)O((n+m)\log n) 维护(
实际上没必要,因为这些结点具有比较优美的性质,所以可以通过分类讨论解决。

u 不在 U 的子树内或恰为 U,则显然这些结点与 u 的 LCA 即为 uU 的 LCA。
uU 的左子树内,设 uL 的 LCA 为 x,则 U 右子树内与 u 的 LCA 均为 U;左子树中 x 以下的结点与 u 的 LCA 显然为 x,而 xU 的左儿子段都是 u 的祖先。
注意特判 ux 的右子树内的情况,此时 LCA 应为 x 的右儿子。
uU 的右子树内,与前一情况类似。

线性预处理出每个结点到根链上在链以外的左 / 右儿子个数和深度和即可。

代码:

#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 4e5 + 3;
const int LG = 19;
int n,m,tot;
int pos[N + 5],rt;
int ch[N + 5][2],dep[N + 5],fa[N + 5],sz[N + 5],id[N + 5];
int f[LG + 5][N + 5];
int cnt[N + 5][2];
long long sum[N + 5][2];
int build(int l,int r)
{
    int p = ++tot;
    if(l == r)
        return pos[l] = p;
    int mid;
    scanf("%d",&mid);
    ch[p][0] = build(l,mid),
    ch[p][1] = build(mid + 1,r);
    return p;
}
void dfs1(int p)
{
    static int tot = 0;
    id[p] = ++tot,sz[p] = 1;
    for(register int i = 1;i <= LG;++i)
        f[i][p] = f[i - 1][f[i - 1][p]];
    for(register int i = 0;i <= 1;++i)
        if(ch[p][i])
            fa[ch[p][i]] = f[0][ch[p][i]] = p,
            dep[ch[p][i]] = dep[p] + 1,
            dfs1(ch[p][i]),
            sz[p] += sz[ch[p][i]];
}
void dfs2(int p)
{
    for(register int i = 0;i <= 1;++i)
        if(ch[p][i])
        {
            for(register int j = 0;j <= 1;++j)
                cnt[ch[p][i]][j] = cnt[p][j],
                sum[ch[p][i]][j] = sum[p][j];
            if(ch[p][i ^ 1])
                ++cnt[ch[p][i]][i ^ 1],
                sum[ch[p][i]][i ^ 1] += dep[ch[p][i ^ 1]];
            dfs2(ch[p][i]);
        }
}
inline int getlca(int x,int y)
{
    if(dep[x] < dep[y])
        swap(x,y);
    for(register int i = LG;~i;--i)
        if(dep[f[i][x]] >= dep[y])
            x = f[i][x];
    if(x == y)
        return x;
    for(register int i = LG;~i;--i)
        if(f[i][x] ^ f[i][y])
            x = f[i][x],y = f[i][y];
    return fa[x];
}
long long ans;
int main()
{
    scanf("%d",&n),build(1,n);
    rt = ++tot,
    ch[rt][0] = pos[0] = ++tot,ch[rt][1] = 1;
    rt = ++tot,
    ch[rt][0] = rt - 2,ch[rt][1] = pos[n + 1] = ++tot;
    dep[rt] = 1,dfs1(rt),dfs2(rt);
    scanf("%d",&m);
    for(int u,l,r,lca,x,ls,rs;m;--m)
    {
        scanf("%d%d%d",&u,&l,&r),
        lca = getlca(l = pos[l - 1],r = pos[r + 1]),ls = ch[lca][0],rs = ch[lca][1],
        ans = (sum[l][1] - sum[ls][1] + sum[r][0] - sum[rs][0]) + (long long)dep[u] * (cnt[l][1] - cnt[ls][1] + cnt[r][0] - cnt[rs][0]);
        if(id[u] <= id[lca] || id[u] >= id[lca] + sz[lca])
            x = getlca(u,lca),
            ans -= 2LL * dep[x] * (cnt[l][1] - cnt[ls][1] + cnt[r][0] - cnt[rs][0]);
        else if(id[u] >= id[ls] && id[u] < id[ls] + sz[ls])
            x = getlca(u,l),
            ans -= 2LL * dep[lca] * (cnt[r][0] - cnt[rs][0]),
            ans -= 2LL * dep[x] * (cnt[l][1] - cnt[x][1]),
            ans -= 2LL * ((sum[x][1] - sum[ls][1]) - (cnt[x][1] - cnt[ls][1])),
            id[u] >= id[ch[x][1]] && id[u] < id[ch[x][1]] + sz[ch[x][1]] && (ans -= 2);
        else
            x = getlca(u,r),
            ans -= 2LL * dep[lca] * (cnt[l][1] - cnt[ls][1]),
            ans -= 2LL * dep[x] * (cnt[r][0] - cnt[x][0]),
            ans -= 2LL * ((sum[x][0] - sum[rs][0]) - (cnt[x][0] - cnt[rs][0])),
            id[u] >= id[ch[x][0]] && id[u] < id[ch[x][0]] + sz[ch[x][0]] && (ans -= 2);
        printf("%lld\n",ans);
    }
}