题解 P1600 【天天爱跑步】

robin12138

2019-07-10 20:49:15

题解

博客传送门:孑行

线段树合并+树上差分

个人认为这道题线段树合并的做法是最简单的

题意:给一棵树和 m条路径,树上每个点有一个值 W_i 。问对于每一个点,询问有多少条路径的第 W_i+1个点是这个点。n,m \leqslant 1e5

假设当前路径为s->t,显然我们可以预处理出lca

我们考虑对于每一个节点建立一个权值线段树,以dep为下标,每个节点保存数值i的出现次数,其实我们只需要叶子结点上的信息

用树上差分把链的信息转化为点

现在对于s->lca的路径,只需要在s处的线段树让dep[s]+1,然后每个点统计答案时查询dep[x]+w[x]出现了几次即可即可

对于lca->t的路径,我们想办法把s关于lca翻上去,在在每个点统计dep=dep[x]-w[x]的点有几个

注意翻上去后d可能变为负的,所以要整体平移一下值域,都加上n即可(其实不平移应该也可以)

#include <cstdio>
#include <cctype>
#include <iostream>
#include <algorithm>
using namespace std;

#define rint register int
#define il inline
const int N=3e5+5;
il int read(int x=0,int f=1,char ch='0')
{
    while(!isdigit(ch=getchar())) if(ch=='-') f=-1;
    while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
    return f*x;
}
int n,m,w[N];

int head[N],ver[N<<1],nxt[N<<1],tot;
il void add(int x,int y){ ver[++tot]=y; nxt[tot]=head[x]; head[x]=tot; }

int dep[N],top[N],son[N],siz[N],fa[N];
void dfs1(int x,int ff)
{
    fa[x]=ff; dep[x]=dep[ff]+1; siz[x]=1;
    for(rint i=head[x];i;i=nxt[i])
    {
        int y=ver[i]; if(y==ff) continue;
        dfs1(y,x); siz[x]+=siz[y];
        if(siz[y]>siz[son[x]]) son[x]=y;
    }
}
void dfs2(int x,int topf)
{
    top[x]=topf;
    if(!son[x]) return ;
    dfs2(son[x],topf);
    for(rint i=head[x];i;i=nxt[i])
    {
        int y=ver[i]; if(y==fa[x]||y==son[x]) continue;
        dfs2(y,y);
    }
}
il int LCA(int x,int y)
{
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    return dep[x]>dep[y] ? y : x;
}

const int M=N*55;
int root[N],lc[M],rc[M],val[M],num;
void update(int &x,int l,int r,int v,int d)
{
    if(!x) x=++num;
    if(l==r) return (void)(val[x]+=d);
    int mid=l+r>>1;
    if(v<=mid) update(lc[x],l,mid,v,d);
    else update(rc[x],mid+1,r,v,d);
}
int query(int x,int l,int r,int p)
{
    if(!x) return 0;
    if(l==r) return val[x];
    int mid=l+r>>1;
    if(p<=mid) return query(lc[x],l,mid,p);
    else return query(rc[x],mid+1,r,p);
}
int merge(int a,int b,int l,int r)
{
    if(!a || !b) return a+b;
    if(l==r)
        val[a]+=val[b];
    else
    {
        int mid=l+r>>1;
        lc[a]=merge(lc[a],lc[b],l,mid);
        rc[a]=merge(rc[a],rc[b],mid+1,r);
    }
    return a;   
}
int ans[N];
void dfs(int x)
{
    for(rint i=head[x];i;i=nxt[i])
    {
        int y=ver[i]; if(y==fa[x]) continue;
        dfs(y); 
        root[x]=merge(root[x],root[y],1,n<<1);
    }
    if(w[x] && n+dep[x]+w[x]<=2*n)//注意要判断没有越界
        ans[x]+=query(root[x],1,n<<1,n+dep[x]+w[x]);
    ans[x]+=query(root[x],1,n<<1,n+dep[x]-w[x]);
}

int main()
{
    n=read(); m=read();
    for(rint i=1,x,y;i<n;++i) x=read(),y=read(),add(x,y),add(y,x);
    dfs1(1,0); dfs2(1,1);
    for(rint i=1;i<=n;++i) w[i]=read();
    for(rint i=1;i<=m;++i)
    {
        int x=read(),y=read(); int lca=LCA(x,y);
        update(root[x],1,n<<1,n+dep[x],1);
        update(root[y],1,n<<1,n+dep[lca]*2-dep[x],1);
        update(root[lca],1,n<<1,n+dep[x],-1);
        update(root[fa[lca]],1,n<<1,n+dep[lca]*2-dep[x],-1);
    }
    dfs(1);
    for(rint i=1;i<=n;++i)
        printf("%d ",ans[i]);
    return 0;
}