题解 P5210 【[ZJOI2017]线段树】
这题能黑了是因为洛谷仅有的题解都是些滥用数据结构的解法么……(
考虑一棵广义线段树的形态,和定位区间所得的结点的性质。
对于定位区间
注意到定位出来的区间都是
那么问题便是算
不难想到转化为深度和减去
显然地,可以离线下来,暴力地使用数据结构
实际上没必要,因为这些结点具有比较优美的性质,所以可以通过分类讨论解决。
若
若
注意特判
若
线性预处理出每个结点到根链上在链以外的左 / 右儿子个数和深度和即可。
代码:
#include <cstdio>
#include <algorithm>
using namespace std;
const int N = 4e5 + 3;
const int LG = 19;
int n,m,tot;
int pos[N + 5],rt;
int ch[N + 5][2],dep[N + 5],fa[N + 5],sz[N + 5],id[N + 5];
int f[LG + 5][N + 5];
int cnt[N + 5][2];
long long sum[N + 5][2];
int build(int l,int r)
{
int p = ++tot;
if(l == r)
return pos[l] = p;
int mid;
scanf("%d",&mid);
ch[p][0] = build(l,mid),
ch[p][1] = build(mid + 1,r);
return p;
}
void dfs1(int p)
{
static int tot = 0;
id[p] = ++tot,sz[p] = 1;
for(register int i = 1;i <= LG;++i)
f[i][p] = f[i - 1][f[i - 1][p]];
for(register int i = 0;i <= 1;++i)
if(ch[p][i])
fa[ch[p][i]] = f[0][ch[p][i]] = p,
dep[ch[p][i]] = dep[p] + 1,
dfs1(ch[p][i]),
sz[p] += sz[ch[p][i]];
}
void dfs2(int p)
{
for(register int i = 0;i <= 1;++i)
if(ch[p][i])
{
for(register int j = 0;j <= 1;++j)
cnt[ch[p][i]][j] = cnt[p][j],
sum[ch[p][i]][j] = sum[p][j];
if(ch[p][i ^ 1])
++cnt[ch[p][i]][i ^ 1],
sum[ch[p][i]][i ^ 1] += dep[ch[p][i ^ 1]];
dfs2(ch[p][i]);
}
}
inline int getlca(int x,int y)
{
if(dep[x] < dep[y])
swap(x,y);
for(register int i = LG;~i;--i)
if(dep[f[i][x]] >= dep[y])
x = f[i][x];
if(x == y)
return x;
for(register int i = LG;~i;--i)
if(f[i][x] ^ f[i][y])
x = f[i][x],y = f[i][y];
return fa[x];
}
long long ans;
int main()
{
scanf("%d",&n),build(1,n);
rt = ++tot,
ch[rt][0] = pos[0] = ++tot,ch[rt][1] = 1;
rt = ++tot,
ch[rt][0] = rt - 2,ch[rt][1] = pos[n + 1] = ++tot;
dep[rt] = 1,dfs1(rt),dfs2(rt);
scanf("%d",&m);
for(int u,l,r,lca,x,ls,rs;m;--m)
{
scanf("%d%d%d",&u,&l,&r),
lca = getlca(l = pos[l - 1],r = pos[r + 1]),ls = ch[lca][0],rs = ch[lca][1],
ans = (sum[l][1] - sum[ls][1] + sum[r][0] - sum[rs][0]) + (long long)dep[u] * (cnt[l][1] - cnt[ls][1] + cnt[r][0] - cnt[rs][0]);
if(id[u] <= id[lca] || id[u] >= id[lca] + sz[lca])
x = getlca(u,lca),
ans -= 2LL * dep[x] * (cnt[l][1] - cnt[ls][1] + cnt[r][0] - cnt[rs][0]);
else if(id[u] >= id[ls] && id[u] < id[ls] + sz[ls])
x = getlca(u,l),
ans -= 2LL * dep[lca] * (cnt[r][0] - cnt[rs][0]),
ans -= 2LL * dep[x] * (cnt[l][1] - cnt[x][1]),
ans -= 2LL * ((sum[x][1] - sum[ls][1]) - (cnt[x][1] - cnt[ls][1])),
id[u] >= id[ch[x][1]] && id[u] < id[ch[x][1]] + sz[ch[x][1]] && (ans -= 2);
else
x = getlca(u,r),
ans -= 2LL * dep[lca] * (cnt[l][1] - cnt[ls][1]),
ans -= 2LL * dep[x] * (cnt[r][0] - cnt[x][0]),
ans -= 2LL * ((sum[x][0] - sum[rs][0]) - (cnt[x][0] - cnt[rs][0])),
id[u] >= id[ch[x][0]] && id[u] < id[ch[x][0]] + sz[ch[x][0]] && (ans -= 2);
printf("%lld\n",ans);
}
}