P6782 [Ynoi2008] rplexq 题解

· · 题解

题目让求的值即为:

( x 子树内 [l,r] 点个数的平方 - x 每个儿子子树内 [l,r] 点个数的平方) /2

首先根号分治,将每个点按儿子个数大于 \sqrt{n} 和小于 \sqrt{n} 分成两组。

先考虑小于 \sqrt{n} 的,我们需要求出每个点的每个儿子子树内满足条件的点,顺便求出子树内满足条件的点。 即在dfs序与取值上进行二维数点+扫描线,但一共有m个询问,每个点最多有 \sqrt{n} 个儿子,如果全部存需要 O(m\sqrt{n}) ,空间不够,发现同一个点上每个询问的询问区间相同,则可以 O(n) 的存下每个点的每个子树区间进行扫描线,扫描到一点x再将它的询问全部进行二维数点。

大于 \sqrt{n} 首先用与上面一样的二维数点求出点x子树内所有满足条件的点, 然后考虑怎么求每个儿子子树内合法的点,如果直接求复杂度肯定上天,所以考虑再平衡一下复杂度。

我在这里对于 size 小于 \sqrt{n} 的儿子将他们的子树内所有点全部提出成一个序列,然后将询问在该序列上离散化,在上面跑莫队,对 size 大于 \sqrt{n} 的儿子维护出一个块状前缀和数组,上面维护出每个值出现的个数,每次 O(1) 统计答案,对于边角块再维护一个每个值在哪里出现的数组,统一一次 O(\sqrt{n}) 扫完,这里的复杂度总体是 O(m\sqrt{n}) 的。

考虑如何维护这个块状数组,我用的是类似于Dsu on tree的东西,在该点子树大小大于 \sqrt{n} 时开始上传,上传的时候选择最大的子树信息继承,如果最大的子树大小小于 \sqrt{n} 就大力dfs一遍统计,因为前面已经根号分治过了,统计答案的点 size 一定大于 \sqrt{n} ,所以一个点统计完答案后就会维护出来这个数组,轻边不会被莫队重复统计,空间复杂度是 O(n) ,不会分析这个时间复杂度,感觉这里类应该是 O(n\sqrt{n}) ?

代码(变量不够用了所以变量名很丑):

using namespace std;
typedef long long ll;
typedef unsigned long long ull;
#define pb push_back
int read(){
    char c=getchar();ll x=1,s=0;
    while(c<'0'||c>'9'){if(c=='-')x=-1;c=getchar();}
    while(c>='0'&&c<='9'){s=s*10+c-'0';c=getchar();}
    return s*x;
}
const int N=205;
const int B=2;
int n,m,rt,x,y,z,dfn[N],pos[N],ed[N],id[N],bh[N],siz[N],fu[N],zs[N],top,tot;
ll ans[N],ls[N];
vi v[N],cd[N],jian[N],jia[N];
vector<pair<int,int>>q[N];
int cmp(int x,int y){
    if(id[dfn[x]]!=id[dfn[y]]){
        return id[dfn[x]]<id[dfn[y]];
    }
    else{
        if(id[dfn[x]]&1)return ed[x]<ed[y];
        else return ed[x]>ed[y];
    }
}
void dfs(int u,int fa){
    dfn[u]=++tot;pos[tot]=u;fu[u]=fa;siz[u]=1;
    for(int i=0;i<v[u].size();i++){
        if(v[u][i]==fa)continue;
        dfs(v[u][i],u);siz[u]+=siz[v[u][i]];
    }
    ed[u]=tot;
}
int kuai[N/B+5],ss[N],you[N/B+5];
void add(int u,int x){
    for(int i=u;i<=you[id[u]];i++)ss[i]+=x;
    for(int i=id[u]+1;i<=id[n];i++)kuai[i]+=x;
}
int query(int x){
    return ss[x]+kuai[id[x]];
}
int vis[N],son[N/B+5][N/B+5],mix[N/B+5][N/B+5],wz[N],nx[N],fz[N],jsq,tou,head,bj[N/B+10];
struct Query{
    int l,r,id;
}p[N];
struct node{
    int col,id;
    friend bool operator < (node x,node y){
        return x.id<y.id;
    }
}w[N];
int cmpp(Query x,Query y){
    if(id[x.l]!=id[y.l])return id[x.l]<id[y.l];
    else{
        if(id[x.l]&1)return x.r<y.r;
        else return x.r>y.r;
    }
}
int tong[N];
ll summ,f[N/B+10],dt[N];
void add(int u){
    summ-=1ll*tong[w[u].col]*tong[w[u].col];
    tong[w[u].col]++;
    summ+=1ll*tong[w[u].col]*tong[w[u].col];
}
void del(int u){
    summ-=1ll*tong[w[u].col]*tong[w[u].col];
    tong[w[u].col]--;
    summ+=1ll*tong[w[u].col]*tong[w[u].col];
}
vector<int>ms[N];
void solve(int u,int fa){
    int maxx=0,jl;
    for(int i=0;i<v[u].size();i++){
        if(v[u][i]==fa)continue;
        solve(v[u][i],u);
        if(siz[v[u][i]]>B)ms[u].pb(v[u][i]);
        if(siz[v[u][i]]>maxx)maxx=siz[v[u][i]],jl=v[u][i];
    }
    if(v[u].size()>B){
        if(q[u].size()){
            tou=0;head=q[u].size();
            memset(kuai,0,sizeof(kuai));
            memset(ss,0,sizeof(ss));
            memset(tong,0,sizeof(tong));
            memset(f,0,sizeof(f));
            for(int i=0;i<v[u].size();i++){
                if(v[u][i]==fa)continue;
                if(siz[v[u][i]]>B)continue;
                for(int k=dfn[v[u][i]];k<=ed[v[u][i]];k++){
                    w[++tou].col=v[u][i];
                    w[tou].id=pos[k];
                }
            }
            sort(w+1,w+tou+1);
            for(int i=1;i<=tou;i++)fz[i]=w[i].id;
            head=0;
            int he1=0,he2=0;
            for(int i=0;i<q[u].size();i++){
                int ll=lower_bound(fz+1,fz+tou+1,q[u][i].fi)-fz;
                int rr=upper_bound(fz+1,fz+tou+1,q[u][i].se)-fz -1;
                if(ll and ll<=rr){
                    p[++head].l=ll;
                    p[head].r=rr;
                    p[head].id=i;
                }
            }
            sort(p+1,p+head+1,cmpp);
            int l=1,r=0;summ=0;
            for(int i=1;i<=head;i++){
                while(r<p[i].r)add(++r);
                while(l>p[i].l)add(--l);
                while(r>p[i].r)del(r--);
                while(l<p[i].l)del(l++);
                ans[cd[u][p[i].id]]-=summ;
            }
            for(int i=0;i<q[u].size();i++){
                memset(f,0,sizeof(f));
                int stt=id[q[u][i].fi];
                int edd=id[q[u][i].se];
                for(int k=q[u][i].fi;k<=min(you[stt],q[u][i].se);k++)f[vis[k]]++;
                if(stt!=edd)for(int k=you[edd-1]+1;k<=q[u][i].se;k++)f[vis[k]]++;
                for(int k=0;k<ms[u].size();k++){
                    if(siz[ms[u][k]]>B){
                        int mz=nx[ms[u][k]];
                        if(edd!=stt)f[mz]+=son[mz][edd-1]-son[mz][stt];
                        ans[cd[u][i]]-=1ll*f[mz]*f[mz];
                    }
                }
            }
        }
    }
    if(siz[u]>B){
        if(maxx<=B){
            nx[u]=++jsq;
            for(int k=dfn[u];k<=ed[u];k++){
                mix[nx[u]][id[pos[k]]]++;
                vis[pos[k]]=nx[u];
            }
            son[nx[u]][1]=mix[nx[u]][1];
            for(int i=2;i<=id[n];i++)son[nx[u]][i]=son[nx[u]][i-1]+mix[nx[u]][i];
        }
        else{
            nx[u]=nx[jl];
            for(int k=dfn[u];k<dfn[jl];k++){
                mix[nx[u]][id[pos[k]]]++;
                vis[pos[k]]=nx[u];
            }
            for(int k=ed[jl]+1;k<=ed[u];k++){
                mix[nx[u]][id[pos[k]]]++;
                vis[pos[k]]=nx[u];
            }
            son[nx[u]][1]=mix[nx[u]][1];
            for(int i=2;i<=id[n];i++)son[nx[u]][i]=son[nx[u]][i-1]+mix[nx[u]][i];
        }
    }
}
int pre[N];
int main(){
    n=read();m=read();rt=read();
    for(int i=1;i<=n;i++)id[i]=((i-1)/B+1),you[id[i]]=i,bh[i]=i;
    for(int i=1;i<n;i++){
        x=read();y=read();
        v[x].pb(y);
        v[y].pb(x);
    }
    dfs(rt,0);
    for(int i=1;i<=m;i++){
        x=read();y=read();z=read();
        q[z].pb(make_pair(x,y));cd[z].pb(i);
    }
    for(int i=1;i<=n;i++){
        if(v[i].size()<=B){
            for(int k=0;k<v[i].size();k++){
                if(v[i][k]==fu[i])continue;
                jian[dfn[v[i][k]]-1].pb(i);
                jia[ed[v[i][k]]].pb(i);
            }
        }
        else{
            jian[dfn[i]].pb(i);
            jia[ed[i]].pb(i);
        }
    }
    for(int i=1;i<=n;i++){
        int u=pos[i];
        add(u,1);
        for(int k=0;k<jia[i].size();k++){
            int vv=jia[i][k];
            if(v[vv].size()>B){
                for(int j=0;j<q[vv].size();j++){
                    int x=query(q[vv][j].se)-query(q[vv][j].fi-1);
                    ls[cd[vv][j]]+=x;
                }
                continue;
            }
            for(int j=0;j<q[vv].size();j++){
                int x=query(q[vv][j].se)-query(q[vv][j].fi-1);
                int mb=cd[vv][j];
                ans[mb]-=1ll*(x-pre[mb])*(x-pre[mb]);
                ls[mb]+=x-pre[mb];
            }
        }
        for(int k=0;k<jian[i].size();k++){
            int vv=jian[i][k];
            if(v[vv].size()>B){
                for(int j=0;j<q[vv].size();j++){
                    int x=query(q[vv][j].se)-query(q[vv][j].fi-1);
                    ls[cd[vv][j]]-=x;
                }
                continue;
            }
            for(int j=0;j<q[vv].size();j++){
                int x=query(q[vv][j].se)-query(q[vv][j].fi-1);
                int mb=cd[vv][j];
                pre[mb]=x;
            }
        }
    }
    for(int i=1;i<=n;i++){
        if(v[i].size()<=B){
            for(int k=0;k<q[i].size();k++){
                if(i>=q[i][k].fi and i<=q[i][k].se)zs[cd[i][k]]+=ls[cd[i][k]];
            }
        }
        else{
            for(int k=0;k<q[i].size();k++){
                if(i>=q[i][k].fi and i<=q[i][k].se)zs[cd[i][k]]+=ls[cd[i][k]];
            }
        }
    }
    solve(rt,0);
    for(int i=1;i<=m;i++)printf("%lld\n",(ans[i]+ls[i]*ls[i])/2ll+1ll*zs[i]);
    return 0;
}