暴力美学——浅谈根号分治

RedLycoris

2022-09-25 16:12:57

算法·理论

根号分治,是暴力美学的集大成体现。与其说是一种算法,我们不如称它为一个常用的trick。

首先,我们引入一道入门题目 CF1207F Remainder Problem:

给你一个长度为 5\times10^5 的序列,初值为 0 ,你要完成 q 次操作,操作有如下两种:

  1. 1 x y: 将下标为 x 的位置的值加上 y
  2. 2 x y: 询问所有下标模 x 的结果为 y 的位置的值之和。

考虑这题的暴力是什么。

首先有一种暴力就是按照题目所说的去做,开一个 5\times10^5 大小的数组 a 去存,1 操作就对 a_x 加上 y2 操作就枚举所有下标模 x 的结果为 y 的位置,统计他们的和。

对于这种暴力,1 操作的时间复杂度为 O(1)2 操作的时间复杂度为 O(n),所以在最坏情况下总时间复杂度可达 O(nq)

经过思考,我们可以发现另外一种暴力:新开一个大小为 n\times n 的二维数组 bb_{i,j} 当前所有下标模 i 的结果为 j 的数的和是什么。对于每个 1 操作,动态的去维护这个 b 数组,在每次询问的时候直接输出答案即可。

对于这种暴力,1 操作的时间复杂度是枚举模数的 O(n)2 操作的时间复杂度为 O(1),总的时间复杂度为 O(nq)

现在我们发现,这两种暴力对应了两种极端:一个是 1 操作的时间复杂度为 O(1)2 操作的时间复杂度为 O(n);另一个是 1 操作的时间复杂度是枚举模数的 O(n)2 操作的时间复杂度为 O(1)。那么,有没有办法让这两种暴力融合一下,均摊时间复杂度,达到一个平衡呢?

其实是有的。我们设定一个阈值 b

对于所有 \le b 的数,我们动态的维护暴力 2b 数组。每次 1 操作只需要枚举 b 个模数即可,故单次操作 1 的时间复杂度降为 O(b)

对于所有 >b 的数,我们就不在操作 1 中维护 b,直接再询问答案时暴力枚举下标即可。显然,这 n 个下标中最多有 \lceil \frac{n}{b}\rceil 个下标对 x 取模余 y 找到第一个 y 后每次跳 x,即可做到单次操作 2 时间复杂度为 O(\frac{n}{b})

所以,总时间复杂度就成为了 O(q\times(b+\frac{n}{b}))。由基本不等式可得,b+\frac{n}{b} \geq 2\sqrt{b\times\frac{n}{b}}=2\sqrt{n},当 b=\sqrt{n} 时取等。所以我们只需要让 b=\sqrt{n},就可以做到时间和空间复杂度均为 O(q\sqrt{n}) 的优秀算法了,可以通过此题。

代码:

#include<bits/stdc++.h>
#define ll long long
#define mp make_pair
using namespace std;
ll sum[755][755],a[500005];  //这里我的阈值取了700
int main(){
    ios_base::sync_with_stdio(false);cin.tie(0),cout.tie(0);
    int q;cin>>q;
    for(;q--;){
        int tp,x,y;cin>>tp>>x>>y;
        if(tp==1){  
            for(int i=1;i<700;++i)sum[i][x%i]+=y;   //枚举模数
            a[x]+=y;
        }else{
            if(x<700){
                cout<<sum[x][y]<<endl;
            }else{
                ll rt=0;
                for(int i=y;i<=500000;i+=x)rt+=a[i];  //暴力统计
                cout<<rt<<endl;
            }
        }
    }
}

通过一道例题,我们可以发现,根号分治不就是两个暴力揉到一起嘛!对于小于一个阈值的数据我们采取一种暴力形式,对于大于阈值的数据我们采用另外一种形式。通常,这两种暴力,一种是直接模拟,另外一种是动态的维护什么东西,或者预处理。暴力与暴力的完美结合,就能过题,有时候甚至可以爆踩标算。

根号分治的用处很广,不仅可以用于数据结构,也可以用在其他方面。现在由笔者继续介绍几道例题。

1.CF710D Two Arithmetic Progressions

题目大意:

现在有两个等差数列,形如 a_1k+b_1a_2k+b_2,其中 k 要满足是自然数。现在再给你两个正整数 l,r,求出 [l,r] 间有多少个数同时出现在两个等差数列中。数据满足 0<a_1,b_1\le2\times10^9,-2\times10*9\le b_1,b_2,l,r\le 2\times10^9,l\le r

题解:

正解要用到 exgcd 等数论知识,且细节较多比较麻烦。现在我们考虑如何用根号分治解决该数论问题。

现在钦定 a_1\geq a_2,再令 t=\sqrt{2\times 10^9}

  1. a_1>t。那么有 \frac{2\times10^9}{a_1}\le t。也就是说,在 [l,r] 这段区间内,属于等差数列 1 的数不会超过根号个。我们只需要枚举这个根号个数,依次判断其是否在等差数列 2 中即可。

代码:

#include<bits/stdc++.h>
#define ll long long
using namespace std;
ll a,b,c,d,l,r;
int main(){
    ios_base::sync_with_stdio(false);
    cin>>a>>b>>c>>d>>l>>r;
    ll m=max(a,c);
    if(m<=50000){     //存在循环节,找到第一个就行了
        for(int i=-m*2;i<=m*2;++i){
            ll p=i*a+b;
            if((abs(d-p)%c==0)){
                ll fg=__gcd(a,c);
                fg=a*c/fg;
                ll bg=max(max(b,d),l);
                if(p>bg){
                    p=bg+(p-bg)%fg;
                }else{
                    p+=((bg-p)/fg+1)*fg;
                    p=bg+(p-bg)%fg;
                }
                if(r<p)continue;
                cout<<(r-p)/fg+1<<'\n';
                return 0;
            }
        }
        cout<<0<<'\n';
        return 0;
    }else{
        if(a<c)swap(a,c),swap(b,d);
        ll cnt=0;
        for(ll i=b;i<=r;i+=a){     //直接暴力去找
            if(i>=l and i>=d){
                ll del=i-d;
                if(del%c==0)++cnt;
            }
        }
        cout<<cnt<<endl;;
    }
}

2.洛谷P8250 交友问题

题目大意:

洛谷上有 n 位用户,这些用户组成了一个双向的网络。

洛谷的图片分享机制如下:如果第 i 个用户向他的好友 j 分享了一张照片,那么,j 的所有好友 k 就能看到这张照片。j 也可以看到这张照片。

现在,用户 u_i 想分享一张照片,但是TA不想让用户 v_i 看到这张照片。在不发送给自己的情况下,TA想知道,他最多可以发送给多少位好友?

对于 100\% 的数据,满足 1 \le n,q \le 2\times10^51\le m \le 7\times 10^5

题解:

这题看上去就很不可做的样子,没有什么优秀的 \log 算法可以解决。

我们注意到这题其实还有一个限制:所有点的度数之和为 2m

这句话看起来是废话,但 某个值的总数量一定 是根号分治的一个明显的提示条件。

所以我们可以这样分治:

暴力 1

利用 map 存储所有的出边,每次询问时暴力枚举所有邻居并在 map 中判断是否出现过。

时间复杂度: O(n^2\log n)

暴力 2

对于每一个节点都开一个 bitset 存储邻接点。(设为 b_i)

查询用 b_u \ xor \ (b_u \ and \ b_v) 即可。xor 为按位异或,and 为按位与。(事实上就是在 b_u 的所有邻居中除掉了所有既在 b_u 中又在 b_v 中的点)

时间&空间复杂度: O(\frac{n^2}{\omega})

正解:

考虑将 暴力 1 和 暴力 2 结合起来。

我们设定一个阈值 limit,并将所有度数 \geq limit 的点称为重点,所有度数 <limit 的点称为轻点。

显然,由于总度数是 O(m) 的,所以重点的个数是 O(\frac{m}{limit}) 级别的。

查询的时候分类讨论(注意 uv 不是等效的,所以要分 4 类而不是 3 类):

  1. 重点对重点。可以完全利用暴力 2 ,单次询问时间复杂度为 O(\frac{n}{\omega})
  2. 重点对轻点。枚举轻点的所有邻居,并让初始答案等于 u 的邻居数。如果该邻居是 u 的邻居则不合法,就让答案减一。单次询问时间复杂度 O(limit)
  3. 轻点对重点。枚举轻点的所有邻居,如果其不在 v 的 bitset 中则让答案加一。单次询问时间复杂度 O(limit)
  4. 轻点对轻点。枚举 u 的所有邻居,在 v 的 map 中判断是否出现。单次询问时间复杂度 O(limit\times \log n)

通过对 limit 的调整进行复杂度的均摊,可以做到时间&空间复杂度为 O(n\times\sqrt{\frac{n\log n}{\omega}})

Code:

#include<bits/stdc++.h>
#define ll long long
#define reg register
#define mp make_pair
#define ri register int
#define ld long double
using namespace std;
const int mxn=2e5+5;
vector<int>g[mxn];
int n,m,q,lim;
vector<int>heavy;
int id[mxn];  //存是否为重点,并存是第几个重点
vector<bitset<mxn> >T;
map<int,int>have[mxn];
int deg[mxn];
inline void solve(){
    scanf("%d%d%d",&n,&m,&q);
    for(ri i=1,u,v;i<=m;++i){
        scanf("%d%d",&u,&v);
        g[u].push_back(v);
        g[v].push_back(u);
        ++deg[u];
        ++deg[v];
    }
    sort(deg+1,deg+n+1);reverse(deg+1,deg+n+1);
    lim=deg[min(n,(int)(sqrt(n)*6))];//根号处分治
    memset(id,-1,sizeof(id));
    for(ri i=1;i<=n;++i){
        if(g[i].size()>lim){
            heavy.push_back(i);
            id[i]=heavy.size()-1;bitset<mxn>newb;
            newb&=0;
            for(ri j=0;j<g[i].size();++j)newb[g[i][j]]=1;
            T.push_back(newb);
        }else{
            for(ri j=0;j<g[i].size();++j)have[i][g[i][j]]=1;
        }
    }
    for(ri i=1;i<=q;++i){
        ri u,v;scanf("%d%d",&u,&v);
        if(~id[u] and ~id[v]){  //对两个点是否是重点分类讨论
            bitset<mxn>tmp=T[id[u]]^(T[id[u]]&T[id[v]]);
            tmp[u]=0;tmp[v]=0;
            printf("%d\n",tmp.count());
        }else if(~id[u]){
            ri ans=0;
            for(ri j=0;j<g[v].size();++j)if(g[v][j]==u or g[v][j]==v or T[id[u]][g[v][j]])++ans;
            printf("%d\n",T[id[u]].count()-ans);
        }else if(~id[v]){
            ri ans=0;
            for(ri j=0;j<g[u].size();++j)if(g[u][j]!=u and g[u][j]!=v and !T[id[v]][g[u][j]])++ans;
            printf("%d\n",ans);
        }else{
            ri ans=0;
            for(ri j=0;j<g[u].size();++j)if(g[u][j]!=u and g[u][j]!=v and !have[v][g[u][j]])++ans;
            printf("%d\n",ans); 
        }
    }
}
int main(){
    solve();
    return 0;
}

3.洛谷P5901 [IOI2009] regions

题目大意:

联合国区域发展委员会(The United Nations Regional Development Agency,UNRDA)有一个良好的组织结构。

它任用了 N 名委员,每名委员都属于几个地区中的一个。

委员们按照其资历被编号为 1N1 号委员是主席,资历最高。委员所属地区被编号为 1R

除了主席之外所有委员都有一个直接导师。任何直接导师的资历都比他所指导的委员的资历要高。

我们说委员 A 是委员 B 的导师当且仅当 AB 的直接导师或者 AB 的直接导师的导师。显然,主席是所有其他委员的导师,没有任何两名委员互为导师。

现在,联合国区域发展委员会想要建立一个计算机系统:在给定委员之间的直接导师关系的情况下,该系统可以自动地回答下述形式的问题:给定两个地区 r_1r_2,要求系统回答委员会中有多少对委员 e_1e_2,满足 e_1 属于 r_1,而 e_2 属于 r_2,并且 e_1e_2 的导师。

对于 100\% 的数据,1 \le N \le 2 \times 10^51 \le R \le 2.5 \times 10^41 \le Q \le 2 \times 10^5

题解:

看到每个点都有颜色,然后询问和颜色的种类有关,时限开了很【】的 8s,就可以往根号分治方面想了。

按照常理,我们钦定一个 B,所有颜色为 i 的点的集合为 col_i

如果 |col_x| \le B|col_y|<B,则他们都是小块,可以直接建立出虚树暴力跑,单次询问时间复杂度 O(\sqrt{n})

如果 |col_x| > B|col_y| \le B,那么我们可以进行预处理。对于所有的大块(设其颜色为 c),我们处理处其 up 数组,up_i 表示在第 i 个节点到根结点的路上,有多少个颜色为 c 的节点,然后在询问的时候直接计算 \sum\limits_{v \in col_y}up_v 即可。预处理时间复杂度 O(n\sqrt{n}),单次询问时间复杂度 O(\sqrt{n})

同理,如果 |col_x| \le B|col_y| > B,那么我们可以预处理 dw_i 表示第 i 个节点的子树中有多少个节点的颜色为 c,询问时计算 \sum\limits_{v \in col_x}dw_v。时间复杂度同上。

最后,如果 |col_x| > B|col_y| > B,那么和小块对小块一样,可以预处理,对于所有可能的大块对建出虚树算一遍答案,询问时直接调用即可。时间复杂度 O(n\sqrt{n})

以上是会根号分治的人都可以想到的做法,但如果实现不精细的话,时间复杂度会写成 O(n\sqrt{n\log n}) 且常数巨大,导致你可能只有 60pts 的好成绩。

这就是很多根号分治要面对的问题:卡 常

针对这题可以如此实现:

  1. 普通的倍增LCA的询问是 O(\log n) 的,我们要用 ST 表做到 O(n\log n) 预处理,O(1) 求 LCA。
  2. 在建立虚树的时候,有一步要对所有点按照 dfs 序排序。如果不想写基数排序怎么办?可以在询问前先对每个 col_i 中的元素先按照字典序排序,在询问的时候用类似归并排序的方法 merge 两个有序的 col_xcol_y,就可以把这个 \log n 砍掉了。

现在的时间复杂度已经降到了正好的 O(n\sqrt{n}),但是巨大的常数仍让你在 94pts 和 97pts 之间徘徊。

为了减小常数:

  1. 加入快读快输
  2. 小对小和大对大中暴力统计答案的 dfs 看起来很累赘,那么短。如果我们能够想办法把这一步放入建立虚树的过程中,那么就可以避免存虚树的边,从而大幅减小常数(感谢 w23c3c3 的指导)

代码:

const int mxn=2e5+5;
vector<int>g[mxn];
int n,R,q,r[mxn];
int dep[mxn*2];
int ord[mxn*2],cco,cord[mxn*2];
int mi[mxn*2][22];
int lg[mxn*2],co;
int st[mxn*2],ed[mxn*2];
inline void dfs(int x,int par=0,int deep=1){
    cord[++cco]=x;
    st[x]=cco;ed[x]=cco;
    dep[x]=deep;ord[x]=++co;
    for(int y:g[x]){
        if(par==y)continue;
        dfs(y,x,deep+1);
        cord[++cco]=x;
        ed[x]=cco;
    }
}
inline void init(){     //预处理ST表
    lg[1]=0;
    for(int i=2;i<mxn*2;++i)lg[i]=lg[i>>1]+1;
    for(int i=1;i<=cco;++i)mi[i][0]=cord[i];
    for(int k=1;k<22;++k){
        for(int i=1;i<=cco-(1<<k)+1;++i){
            int x=mi[i][k-1],y=mi[i+(1<<(k-1))][k-1];
            if(dep[x]<dep[y])mi[i][k]=x;
            else mi[i][k]=y;
        }
    }
}
inline int lca(int x,int y){
    int fx=st[x],fy=ed[y];
    if(fx>fy)swap(fx,fy);
    int ax=mi[fx][lg[fy-fx]],ay=mi[fy-(1<<lg[fy-fx])+1][lg[fy-fx]];
    if(dep[ax]<dep[ay])return ax;
    else return ay;
} 
inline bool cmp(int x,int y){return ord[x]<ord[y];}
int sta[mxn],top=0;
int sz[mxn];
int up[404][mxn],dw[404][mxn];
inline vector<int>mer(vector<int>x,vector<int>y){  //类似归并的操作
    vector<int>rt;rt.clear();
    int i=0,j=0;
    for(;i<x.size() and j<y.size();)
        if(cmp(x[i],y[j]))rt.push_back(x[i]),++i;
        else rt.push_back(y[j]),++j;
    for(;i<x.size();++i)rt.push_back(x[i]);
    for(;j<y.size();++j)rt.push_back(y[j]);
    return rt;
}
inline void gen(vector<int>v,int y){  //建立虚树
    top=0;
    sta[++top]=1;
    sz[1]=r[1]==y;
    for(int i:v){
        if(i!=1){
            int l=lca(sta[top],i);
            if(l!=sta[top]){
                for(;ord[l]<ord[sta[top-1]];)sz[sta[top-1]]+=sz[sta[top]],--top;
                if(ord[l]>ord[sta[top-1]]){
                    sz[l]=r[l]==y;
                    sz[l]+=sz[sta[top]];
                    sta[top]=l;
                }else sz[l]+=sz[sta[top]],top--;
            }
            sz[i]=r[i]==y;
            sta[++top]=i;   
        }
    }
    for(int i=top-1;i>=1;--i)sz[sta[i]]+=sz[sta[i+1]];
}
vector<int>col[mxn];
vector<int>hea;
int ans[404][404];
int id[mxn];
//inline void getans(int x,int fa,int cl){//on ng[]    //原来的暴力统计小对小和大对大的答案
//  if(r[x]==cl)sz[x]=1;
//  else sz[x]=0;
//  for(int y:ng[x])if(y!=fa)getans(y,x,cl),sz[x]+=sz[y];
//}
inline void prep(int x,int fa,int cl,int flg=0){   //预处理小对大和大对小的up和dw数组
    if(r[x]==hea[cl])dw[cl][x]=1,++flg;
    up[cl][x]=flg;
    for(int y:g[x])if(y!=fa){
        prep(y,x,cl,flg);
        dw[cl][x]+=dw[cl][y];
    }
}
inline void solve(){
    memset(id,-1,sizeof(id));
    read(n),read(R),read(q);
    read(r[1]);col[r[1]].push_back(1);
    for(int i=2;i<=n;++i){
        int x;read(x),read(r[i]);
        g[x].push_back(i);
        g[i].push_back(x);
        col[r[i]].push_back(i);
    }
    dfs(1);
    init();
    int bound=500;
    for(int i=1;i<=R;++i)if(col[i].size()>=bound)hea.push_back(i);  //这里我的B定为了500
    for(int i=0;i<hea.size();++i)id[hea[i]]=i;
    for(int i=1;i<=n;++i)sort(col[i].begin(),col[i].end(),cmp);
    for(int i=0;i<hea.size();++i)for(int j=0;j<hea.size();++j)if(i!=j){
        gen(mer(col[hea[i]],col[hea[j]]),hea[j]);
//      getans(1,0,hea[j]);
        ll res=0;
        for(int f:col[hea[i]])res+=sz[f];
        ans[i][j]=res;
    }
    for(int i=0;i<hea.size();++i)prep(1,0,i);
    for(;q--;){
        int x,y;read(x),read(y);
        if(~id[x] and ~id[y])print(ans[id[x]][id[y]]),putchar('\n');
        else if(~id[x] and id[y]==-1){
            ll res=0;
            for(int i:col[y])res+=up[id[x]][i];
            print(res),putchar('\n');
        }else if(id[x]==-1 and ~id[y]){
            ll res=0;
            for(int i:col[x])res+=dw[id[y]][i];
            print(res),putchar('\n');
        }else{
            gen(mer(col[x],col[y]),y);
//          getans(1,0,y);
            ll res=0;
            for(int f:col[x])res+=sz[f];
            print(res),putchar('\n');
        }
    }
}
int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    int T=1;
    //cin>>T;
    for(;T--;)solve();
    return 0;
}

4.[ARC052D] 9

题目大意:

给定两个正整数 K,M (1\le K,M \le 10^{10}),你需要求出有多少个正整数 N 满足 1 \le N \le MN \equiv S_N (\mod K) ,其中 S_NN 的各位数字之和。

题解:

这个 10^{10} 的数据范围并不常见,但是可以发现大概是根号的复杂度。

显然无法分块,考虑怎么做到根号分治。我们先对 K 设定一个阈值 T,其中 T\sqrt{M} 级别。

1 \le N \le 10^{10} 时,最大的 S_N 不过 9\times 10=90,所以我们可以去先枚举数字和 S,然后就可以发现,\mod K= SN 的个数不会超过 \lfloor\frac{M}{K}\rfloor +1 个。直接枚举这些数就可以了。复杂度 O(90\times \frac{M}{K})

我们可以考虑把 K 做为一维压到数位 dp 里了。令 dp_{i,j,sm,0/1} 表示考虑到从高到低第 i 位,此时的数 \mod K = j,数字和为 sm,是否已经小于 m 的数的个数。这样就可以 dp 了,复杂度 O(10\times90\times K\times 10)=O(9000K)

均衡一下取 K=10000 最优。

Code:

#include<bits/stdc++.h>
#define ll long long
#define int ll
using namespace std;
ll k,m,n;
int ee[12];
ll dp[12][10000][91][2];
inline void add(ll&x,ll y){x+=y;}
signed main(){
    cin>>k>>m;
    ll ans=0;ll tm=m;
    for(int i=1;;++i){
        ee[i]=tm%10;
        tm/=10;
        if(tm==0){
            n=i;
            break;
        }
    }
    reverse(ee+1,ee+n+1);
    if(k>=10000){
        for(int i=0;i<=90;++i){
            for(ll ee=i;ee<=m;ee+=k){
                ll t=ee,cnt=0;
                while(t){
                    cnt+=t%10;
                    t/=10;
                }
                if(cnt%k==i)++ans;
            }
        }
        cout<<ans-1<<'\n';
        return 0;
    }
    dp[1][0][0][0]=1;
    for(int i=1;i<=n;++i){
        for(int j=0;j<k;++j){
            for(int sm=0;sm<=90;++sm){
                {//f=0
                    for(int t=0;t<ee[i];++t){
                        if(sm+t<=90){
                            add(dp[i+1][(j*10+t)%k][sm+t][1],dp[i][j][sm][0]);
                        }
                    }
                    if(sm+ee[i]<=90)add(dp[i+1][(j*10+ee[i])%k][sm+ee[i]][0],dp[i][j][sm][0]);
                }
                {//f=1
                    for(int t=0;t<10;++t)if(sm+t<=90)add(dp[i+1][(j*10+t)%k][sm+t][1],dp[i][j][sm][1]);
                }
            }
        }
    }
    for(int ee=0;ee<=90;++ee)add(ans,dp[n+1][ee%k][ee][0]),add(ans,dp[n+1][ee%k][ee][1]);
    cout<<ans-1<<'\n';
}

小结:

什么时候会用到根号分治呢?

  1. 题目中明说,或者暗说了某个值的总数量一定。比如图论中所有点的度数之和为 O(m) 的,或者字符串题中询问会输入字符串,且所有字符串的长度之和为 O(n) 的(参考 CF914F Substrings in a String)。
  2. 很明显有两种暴力,一种是 O(b) 的,另一种是 O(\frac{n}{b}) 的,可以利用基本不等式化为 O(n\sqrt{n})
  3. 数据范围长得比较奇怪,比如例4的10^{10},不像是筛法,也不像是什么数位 dp

最后,补充几道例题:

CF914F Substrings in a String

P8349 [SDOI/SXOI2022] 整数序列

CF1039D You Are Given a Tree

CF587F Duff is Mad

CF1666A Admissible Map

ABC259Ex - Yet Another Path Counting

完结撒花。