题解 P3940 【分组】

· · 题解

前言:看了各位巨佬的题解,思路其实都一样,但是在下觉得自己的写法更适合像我一样的蒟蒻们理解和实现一些,没有需要特判的情况,可写性较强。

题目分析

本题明显要分 K=1 和 K=2 两种情况讨论。整体算法思想是贪心:尽可能使得能放在同一组的两只兔子分在同一组,便将使得总分组最少。这个结论是显然的,证明略。

对于 K=1 的情况,外循环枚举每一只兔子的颜色值 a_i,内循环设 x 从1枚举到512(=\sqrt{131072×2}),看看有没有访问过 (x^2-a_i) 这个值。如果访问过,则第i只兔子与当前组内其它兔子有冲突,需要清空访问标记并且单独新建一个分组;如果没访问过,则令第i只兔子加入当前分组。最后再打上访问标记即可。

对于 K=2 的情况,整体思路和K=1差不多,但是此时需要使用到并查集。与K=1的情况的不同点在于,如果访问过 (x^2-a_i) ,我们仍然可以试图将a_i加入当前分组,因为此时我们的分组里允许有两个小团体。我们此时可以使用并查集来维护敌对关系:fa[] 数组开2倍空间, find(1~n) 代表每只兔子所在小团体, find(n+1~2n) 代表每只的兔子所在小团体的敌对小团体。

在上述情况下当访问过 (x^2-a_i) 时,我们先判断下颜色值为 (x^2-a_i) 的兔子(们)(注意!此时可能不止一只!以上几篇题解都是用特判的方法处理的重复问题,本蒟蒻不会,所以推荐开动态数组 vector 存下每一只当前分组内相同颜色的兔子编号。) 和第 i 只兔子是否已经被规划进入同一小团体中。如果已经被规划入同一小团体,那么便不得不清空访问并新建分组。否则,将第 i 只兔子与颜色值为 (x^2-a_i) 的兔子(们)分别互相加入对方的敌对集合即可。

PS:由于需要字典序最小,我们需要倒序扫描 a 数组。

代码

#include<cstdio>
#include<cctype>
#include<stack>
#include<vector>
using namespace std;
inline long long input(){
    long long a=0,f=1;
    char tmp=getchar();
    while(!isdigit(tmp)){if(tmp=='-')f=-1;tmp=getchar();}
    while(isdigit(tmp)){a=(a<<3)+(a<<1)+(tmp^48);tmp=getchar();}
    return a*f;
}
int n,K,a[131073],fa[300001],times[300001],last;
std::vector<int>vis[300001];
std::stack<int>ans;
int find(int x){
    if(fa[x]&&fa[x]^x)return fa[x]=find(fa[x]);
    return fa[x]=x;
}
void read(){
    n=input(),K=input();
    for(int i=1;i<=n;i++)
        a[i]=input();
}
void clear(int &last,int i){//一定不要直接把vis[1~300000]全部清空,这样会大大增加时间复杂度导致TLE。
    for(int j=last-1;j>i;j--)
        vis[a[j]].clear();
    last=i+1;
    ans.push(last-1);
}
void solve1(){
    for(int i=n;i;i--){
        bool ok=1;
        for(int x=1;x<=512;x++)
            if(x*x-a[i]>=0&&vis[x*x-a[i]].size()){
                ok=0;
                break;
            }
        if(!ok)clear(last,i);
        vis[a[i]].push_back(1);
    }
}
void solve2(){
    for(int i=n;i;i--){
        for(int x=1;x<=512;x++)
            if(x*x-a[i]>=0&&vis[x*x-a[i]].size())
                for(int j=0;j<vis[x*x-a[i]].size();j++){
                    int f=vis[x*x-a[i]][j];
                    if(find(i)==find(f))clear(last,i);
                    else{
                        fa[find(i+n)]=find(f);
                        fa[find(f+n)]=find(i);
                    }
                }
        vis[a[i]].push_back(i);
    }
}
void print(){
    printf("%d\n",ans.size()+1);
    while(ans.size())printf("%d ",ans.top()),ans.pop();
    putchar('\n');
}
int main(){
    read();
    last=n+1;
    if(K==1)solve1();
    else solve2();
    print();
    return 0;
}