CF815D

· · 题解

模拟赛遇到的题目。

看各位大佬的做法都不是很懂,于是自己一通乱搞搞出来了。

题意

翻译清楚的不能再清楚了

做法

为了方便叙述,我们将每个给出的三元组表示成 (a_i,b_i,c_i),所选的三元组表示成(x,y,z)

首先我们考虑如果 c_i 相等(设为 c)怎么做:

直接做看起来很麻烦,于是我们考虑正难则反,由算可行解数量改为算不可行解数量,并维护对于每一个 y 可选的 x ,显然对于每一个yx的选择满足单调性。

然后我们分类讨论一下 zc_i 的大小关系

z \le c 时,对于每一个三元组,当 y \in [1,b_i] 时, x>p;当 y \in (b_i,q] 时,x>a_i

z>c 时,对于每一个三元组,只需满足当 y \in [1,b_i]时,x>a_i 即可。

拿个线段树维护一下就可以了。

然后我们将这种做法拓展一下,我们先把原序列按 c_i 从大到小排序,可以轻易发现此时当 z \in (c_i,c_{i+1}] 时,每一个 z 贡献都是一样的

对于每一个区间(c_i,c_{i+1}],我们考虑上面的做法来算贡献,显然对于前 i 个三元组,满足 z>c 的条件,对于后面的三元组,满足 z \le c 的条件。

且显然 z \le c 的条件是比 z > c 完全严格的,所以每次修改的时候把 z > c的条件给当前枚举到的三元组加上去就好了。

然后关于这个线段树,我们需要维护区间最大值,区间求和,这个可以自行百度解决(相信做到这道题的大佬都会)。

#include <bits/stdc++.h>
#define int long long
#define INF 2147483647
#define N 1000005
#define mkp make_pair
#define pii pair<int,int>
#define pb push_back
using namespace std;
int n,A,B,C,ans=0;
struct Y{int a,b,c;}a[N];
int cmp(Y x,Y y){return x.c>y.c;}
struct Ji_seg{
    struct Tr{
        int mn,tag,t,se,sum;
    }tr[N<<2];
    int ls(int nw){return nw*2;}
    int rs(int nw){return nw*2+1;}
    void change(int nw,int l,int r,int k){//对nw节点的区间进行区间取max 
        if(k<=tr[nw].mn)return;
        if(k<tr[nw].se){
            tr[nw].sum+=(k-tr[nw].mn)*tr[nw].t;
            tr[nw].mn=k;tr[nw].t=max(tr[nw].t,1ll);tr[nw].tag=max(tr[nw].tag,k);
            return;
        }
        int mid=(l+r)/2;
        change(ls(nw),l,mid,k);change(rs(nw),mid+1,r,k);push_up(nw);
    }
    void push_down(int nw,int l,int r){
        if(tr[nw].tag){
            int mid=(l+r)/2;
            change(ls(nw),l,mid,tr[nw].tag);
            change(rs(nw),mid+1,r,tr[nw].tag);
            tr[nw].tag=0;
        }
    }
    void push_up(int nw){
        tr[nw].sum=tr[ls(nw)].sum+tr[rs(nw)].sum;
        if(tr[ls(nw)].mn==tr[rs(nw)].mn){
            tr[nw].mn=tr[ls(nw)].mn;
            tr[nw].t=tr[ls(nw)].t+tr[rs(nw)].t;
            tr[nw].se=min(tr[ls(nw)].se,tr[rs(nw)].se);
            return;
        }
        if(tr[ls(nw)].mn<tr[rs(nw)].mn){tr[nw].mn=tr[ls(nw)].mn;tr[nw].t=tr[ls(nw)].t;tr[nw].se=min(tr[ls(nw)].se,tr[rs(nw)].mn);}
            else {tr[nw].mn=tr[rs(nw)].mn;tr[nw].t=tr[rs(nw)].t;tr[nw].se=min(tr[rs(nw)].se,tr[ls(nw)].mn);}
    }
    void build(int nw,int l,int r){
        if(l==r){
            tr[nw].sum=0;tr[nw].mn=0;tr[nw].se=INF;tr[nw].t=1;tr[nw].tag=0;
            return;
        }
        int mid=(l+r)/2;
        build(ls(nw),l,mid);
        build(rs(nw),mid+1,r);
        push_up(nw);
    }
    void update(int nw,int l,int r,int x,int y,int k){
        if(x>y)return;
        if(x<=l&&r<=y){
            change(nw,l,r,k);
            return;
        }
        int mid=(l+r)/2;push_down(nw,l,r);
        if(x<=mid)update(ls(nw),l,mid,x,y,k);
        if(y>mid)update(rs(nw),mid+1,r,x,y,k);
        push_up(nw);
    }
}T;
signed main() {
    scanf("%lld %lld %lld %lld",&n,&A,&B,&C);
    for(int i=1;i<=n;i++)scanf("%lld %lld %lld",&a[i].a,&a[i].b,&a[i].c);
    sort(a+1,a+n+1,cmp);a[0]=Y{0,0,C};a[n+1]=Y{0,0,0};
    T.build(1,1,B);
    for(int i=1;i<=n;i++)T.update(1,1,B,1,a[i].b,a[i].a);
    for(int i=0;i<=n;i++){
        T.update(1,1,B,1,a[i].b,A);T.update(1,1,B,a[i].b+1,B,a[i].a);
        ans+=(a[i].c-a[i+1].c)*(A*B-T.tr[1].sum);
    }
    printf("%lld\n",ans);
    return 0;
}