题解:CF1957F2 Frequency Mismatch (Hard Version)

· · 题解

题目大意

给你一棵 n 个节点的树,每个点有点权,q 次询问 u_1,v_1,u_2,v_2,k,要求输出 k 个权值 c,使得 cu_1v_1 以及 u_2v_2 的出现次数不同。如果 c 的个数小于 k,则全部输出。

题目分析

我们先给每个权值一个哈希值。

在树上建可持久化线段树,对于树上的每个点 x,维护根到点 x 的每个权值的出现次数的哈希值。

然后对于每次查询,求出 \text{lca} 后很容易就能维护出两条路径的哈希值的线段树。于是我们从两个根开始,比较当前节点的哈希值,如果一样说明这一段区间内所有权值的出现次数相同,直接 return;不一样则继续往左右子树走。当访问到叶子节点时还没 return,说明两条路经这个权值的出现次数不同,直接加入答案中即可。

当找到的答案数量等于 k 时,也直接 return 并输出。

时间复杂度 O(nk \log n)

需要注意的是单哈希容易被 \text{hack},建议多写几个哈希,或者也可以用一些奇奇怪怪的模数

代码

const int N = 1e5+10;
const ull P = 13131;
ull pw1[N],pw2[N],pw3[N];
int n,a[N],tot;
vector<int> v[N];
int root[N];
struct Segment{
    int lc,rc;
    ull val1,val2,val3;
}t[N<<5];
#define ls t[rt].lc
#define rs t[rt].rc
inline void push_up(int rt){
    t[rt].val1 = t[ls].val1+t[rs].val1;
    t[rt].val2 = t[ls].val2+t[rs].val2;
    t[rt].val3 = t[ls].val3+t[rs].val3;
}
int update(int rt,int l,int r,int pos){
    int u = ++tot;
    t[u] = t[rt];
    if(l == r){
        t[u].val1 += pw1[l];
        t[u].val2 += pw2[l];
        t[u].val3 += pw3[l];
        return u;
    }
    int mid = (l+r)>>1;
    if(pos <= mid) t[u].lc = update(t[u].lc,l,mid,pos);
    else           t[u].rc = update(t[u].rc,mid+1,r,pos);
    push_up(u);
    return u;
}
int fa[N][22],Lg,dep[N];
void dfs(int x,int fr){
    fa[x][0] = fr; dep[x] = dep[fr]+1;
    root[x] = update(root[fr],1,100000,a[x]);
    for(int i=1;i<=Lg;i++) fa[x][i] = fa[fa[x][i-1]][i-1];
    for(int y : v[x]){
        if(y == fr) continue;
        dfs(y,x);
    }
}
int lca(int x,int y){
    if(dep[x] < dep[y]) swap(x,y);
    for(int i=Lg;i>=0;i--)
        if(dep[fa[x][i]] >= dep[y]) x = fa[x][i];
    if(x == y) return x;
    for(int i=Lg;~i;i--)
        if(fa[x][i] != fa[y][i])
            x = fa[x][i],y = fa[y][i];
    return fa[x][0];
}
int k,ans[114514],cnt;
struct node{
    int x,y;
    int u,v;
    node lson(){return node{t[x].lc,t[y].lc,t[u].lc,t[v].lc};}
    node rson(){return node{t[x].rc,t[y].rc,t[u].rc,t[v].rc};}
    ull val1(){return t[x].val1+t[y].val1-t[u].val1-t[v].val1;}
    ull val2(){return t[x].val2+t[y].val2-t[u].val2-t[v].val2;}
    ull val3(){return t[x].val3+t[y].val3-t[u].val3-t[v].val3;}
};
void qry(node x,node y,int l,int r){
    if(cnt == k || (x.val1() == y.val1() && x.val2() == y.val2() && x.val3() == y.val3())) return ;
    if(l == r){ans[++cnt] = l; return ;}
    int mid = (l+r)>>1;
    qry(x.lson(),y.lson(),l,mid);
    qry(x.rson(),y.rson(),mid+1,r);
}
int main(){
    n = read();
    pw1[0] = pw2[0] = pw3[0] = 1; Lg = log(n)/log(2)+1;
    srand(time(0));
    for(int i=1;i<=100000;i++) pw1[i] = pw1[i-1]*P;
    for(int i=1;i<=100000;i++) pw2[i] = pw2[i-1]*131;
    for(int i=1;i<=100000;i++) pw3[i] = rand()*rand();
    for(int i=1;i<=n;i++) a[i] = read();
    for(int i=1;i<n;i++){
        int x = read(),y = read();
        v[x].push_back(y);
        v[y].push_back(x);
    }
    dfs(1,0);
    int Q = read();
    while(Q--){
        int x = read(),y = read(),xx = read(),yy = read();
        k = read(); cnt = 0;
        int d1 = lca(x,y),d2 = lca(xx,yy);
        node A = node{root[x],root[y],root[d1],root[fa[d1][0]]},
             B = node{root[xx],root[yy],root[d2],root[fa[d2][0]]};
        qry(A,B,1,100000);
        printf("%d ",cnt);
        for(int i=1;i<=cnt;i++) printf("%d ",ans[i]);
        printf("\n");
    }
    return 0; 
}