robin12138
2019-07-10 20:49:15
博客传送门:孑行
个人认为这道题线段树合并的做法是最简单的
题意:给一棵树和 m条路径,树上每个点有一个值
假设当前路径为
我们考虑对于每一个节点建立一个权值线段树,以
用树上差分把链的信息转化为点
现在对于
对于
注意翻上去后d可能变为负的,所以要整体平移一下值域,都加上n即可(其实不平移应该也可以)
#include <cstdio>
#include <cctype>
#include <iostream>
#include <algorithm>
using namespace std;
#define rint register int
#define il inline
const int N=3e5+5;
il int read(int x=0,int f=1,char ch='0')
{
while(!isdigit(ch=getchar())) if(ch=='-') f=-1;
while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
return f*x;
}
int n,m,w[N];
int head[N],ver[N<<1],nxt[N<<1],tot;
il void add(int x,int y){ ver[++tot]=y; nxt[tot]=head[x]; head[x]=tot; }
int dep[N],top[N],son[N],siz[N],fa[N];
void dfs1(int x,int ff)
{
fa[x]=ff; dep[x]=dep[ff]+1; siz[x]=1;
for(rint i=head[x];i;i=nxt[i])
{
int y=ver[i]; if(y==ff) continue;
dfs1(y,x); siz[x]+=siz[y];
if(siz[y]>siz[son[x]]) son[x]=y;
}
}
void dfs2(int x,int topf)
{
top[x]=topf;
if(!son[x]) return ;
dfs2(son[x],topf);
for(rint i=head[x];i;i=nxt[i])
{
int y=ver[i]; if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
il int LCA(int x,int y)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
return dep[x]>dep[y] ? y : x;
}
const int M=N*55;
int root[N],lc[M],rc[M],val[M],num;
void update(int &x,int l,int r,int v,int d)
{
if(!x) x=++num;
if(l==r) return (void)(val[x]+=d);
int mid=l+r>>1;
if(v<=mid) update(lc[x],l,mid,v,d);
else update(rc[x],mid+1,r,v,d);
}
int query(int x,int l,int r,int p)
{
if(!x) return 0;
if(l==r) return val[x];
int mid=l+r>>1;
if(p<=mid) return query(lc[x],l,mid,p);
else return query(rc[x],mid+1,r,p);
}
int merge(int a,int b,int l,int r)
{
if(!a || !b) return a+b;
if(l==r)
val[a]+=val[b];
else
{
int mid=l+r>>1;
lc[a]=merge(lc[a],lc[b],l,mid);
rc[a]=merge(rc[a],rc[b],mid+1,r);
}
return a;
}
int ans[N];
void dfs(int x)
{
for(rint i=head[x];i;i=nxt[i])
{
int y=ver[i]; if(y==fa[x]) continue;
dfs(y);
root[x]=merge(root[x],root[y],1,n<<1);
}
if(w[x] && n+dep[x]+w[x]<=2*n)//注意要判断没有越界
ans[x]+=query(root[x],1,n<<1,n+dep[x]+w[x]);
ans[x]+=query(root[x],1,n<<1,n+dep[x]-w[x]);
}
int main()
{
n=read(); m=read();
for(rint i=1,x,y;i<n;++i) x=read(),y=read(),add(x,y),add(y,x);
dfs1(1,0); dfs2(1,1);
for(rint i=1;i<=n;++i) w[i]=read();
for(rint i=1;i<=m;++i)
{
int x=read(),y=read(); int lca=LCA(x,y);
update(root[x],1,n<<1,n+dep[x],1);
update(root[y],1,n<<1,n+dep[lca]*2-dep[x],1);
update(root[lca],1,n<<1,n+dep[x],-1);
update(root[fa[lca]],1,n<<1,n+dep[lca]*2-dep[x],-1);
}
dfs(1);
for(rint i=1;i<=n;++i)
printf("%d ",ans[i]);
return 0;
}