[Algo Beat Contest 001 G] Great Sub-Expressions 官方题解

· · 题解

题解作者是 orchardist,以下叙述中的第一人称均指作者。

膜拜 Eric998 的矩阵做法,此人(奶龙)锐评为 1900 分。但本人太菜,只会朴素的 dp 和线段树做法。

首先,不难想到一个对于每次修改 O(n) 的 dp 解法:设 dp_i 表示以第 i 位结尾的所有表达式计算结果之和。则对于 2 \le i \le n,有:

初始状态为 dp_1=a_1,答案即为 \sum_{i = 1}^{n} dp_i

考虑优化这个过程。为了便于表达,以下把上面五种转移方程分别记作 tp_i=1,2,3,4,5。不难发现,每次只修改一个 i 的转移方程。

每次修改后,从询问所给的第 x 位开始,若当前对应的 tp 值为 45,则以后位置的 dp 值不受影响。找到第一个 x 后的 i 使 tp_i=45,可以用 set 维护。

对于其他位置,如果从起始位加 1 到当前位出现 tp_i=3 奇数次,则 dp_i 增加起始位的增加量,偶数次则减少。这里需用线段树维护带修改的区间 dp 值的和,区间从第 1 位开始出现 tp_i=3 奇数次的个数与偶数次的个数。另外,计算起始位的增加量,需要单点查询。

修改时不需要全部分类讨论,只要先按 dp_x 的增量修改,然后判断先后是否有 tp_x=3,4,5,逐一更新即可。单次修改 O(\log N),总复杂度 O((N+Q)\log N)

最后就是一些细节问题,详见代码。

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e6+5;
int n,q,dp[N],tp[N],sum[N];
int mktp(char x,char y){
    if((x=='|'||x=='^')&&y=='0') return 1;
    if(x=='&'&&y=='1') return 2;
    if(x=='^'&&y=='1') return 3;
    if(x=='&'&&y=='0') return 4;
    return 5;
}
set<int> S;
string s;
struct node{
    ll ans;
    int cnt0,cnt1,add;
    bool rev;
    void tgrev(){
        swap(cnt0,cnt1);
        rev^=1;
        add=-add;
    }
    void tgadd(int k){
        ans+=1ll*(cnt0-cnt1)*k;
        add+=k;
    }
}t[N<<2],lst,emp;
node operator+(node a,node b){
    return {a.ans+b.ans,a.cnt0+b.cnt0,a.cnt1+b.cnt1,0,0};   
}
void build(int p,int l,int r){
    if(l==r){
        t[p].ans=dp[l];
        if(sum[l]) t[p].cnt1=1;
        else t[p].cnt0=1;
        return;
    }
    int mid=(l+r)>>1;
    build(p*2,l,mid);
    build(p*2+1,mid+1,r);
    t[p]=t[p*2]+t[p*2+1]; 
}
void down(int p){
    if(t[p].rev){
        t[p*2].tgrev();
        t[p*2+1].tgrev();
        t[p].rev=0;
    }
    if(t[p].add){
        t[p*2].tgadd(t[p].add);
        t[p*2+1].tgadd(t[p].add);
        t[p].add=0;
    }
}
void updrev(int p,int l,int r,int ql,int qr){
    if(l>qr||r<ql) return;
    if(ql<=l&&r<=qr){
        t[p].tgrev();
        return;
    }
    down(p);
    int mid=(l+r)>>1;
    updrev(p*2,l,mid,ql,qr);
    updrev(p*2+1,mid+1,r,ql,qr);
    t[p]=t[p*2]+t[p*2+1];
}
void updadd(int p,int l,int r,int ql,int qr,int k){
    if(l>qr||r<ql) return;
    if(ql<=l&&r<=qr){
        t[p].tgadd(k);
        return;
    }
    down(p);
    int mid=(l+r)>>1;
    updadd(p*2,l,mid,ql,qr,k);
    updadd(p*2+1,mid+1,r,ql,qr,k);
    t[p]=t[p*2]+t[p*2+1];
}
node query(int p,int l,int r,int x){
    if(l==r) return t[p];
    down(p);
    int mid=(l+r)>>1;
    if(x<=mid) return query(p*2,l,mid,x);
    return query(p*2+1,mid+1,r,x);
}
int trans(int ldp,int ntp,int i){
    if(ntp==1) return ldp;
    if(ntp==2) return ldp+1;
    if(ntp==3) return i-ldp;
    if(ntp==4) return 0;
    return i;
}
int main(){
    cin>>n>>q>>s;s=" "+s;
    if(s[1]=='0') s[0]='|';
    else s[0]='&';
    for(int i=1;i<=n;i++){
        tp[i]=mktp(s[i*2-2],s[i*2-1]);
        if(tp[i]>=4) S.insert(i);
    }
    S.insert(n+1);
    dp[1]=s[1]-'0';
    for(int i=2;i<=n;i++)
        dp[i]=trans(dp[i-1],tp[i],i);
    for(int i=1;i<=n;i++)
        if(tp[i]<=3) sum[i]=sum[i-1]^(tp[i]==3);
    build(1,1,n);
    while(q--){
        int x;char y,z;
        scanf("%d %c %c",&x,&y,&z);
        if(x==1){
            if(z=='0') y='|';
            else y='&';
        }
        int ntp=mktp(y,z);
        if(tp[x]==ntp){
            printf("%lld\n",t[1].ans);
            continue;
        }
        lst=emp;
        lst.cnt0=1;
        if(x>1) lst=query(1,1,n,x-1);
        int ldp=trans(lst.ans,tp[x],x);
        int ndp=trans(lst.ans,ntp,x);
        int k=ldp-ndp;
        bool lstst=((lst.cnt0&&tp[x]!=3)||(lst.cnt1&&tp[x]>=3));
        bool nowst=((lst.cnt0&&ntp!=3)||(lst.cnt1&&ntp>=3));
        if(lstst) k=-k;
        auto it=S.upper_bound(x);
        int r=(*it)-1;
        updadd(1,1,n,x,r,k);
        if(lstst^nowst) updrev(1,1,n,x,r);
        if(tp[x]>=4&&ntp<4) S.erase(x);
        else if(tp[x]<4&&ntp>=4) S.insert(x);
        printf("%lld\n",t[1].ans);
        tp[x]=ntp;
    }
    return 0; 
}