agc033_e

myee

2023-04-25 13:44:56

题解

前言

怎么题解里全是 dp。

这里有一个比较套路的 GF 做法。

思路

不妨总是令第一项为 \tt R。则,说明弧上不存在相邻的 \tt B

如果串 S 仅由 \tt R 构成,则方案数即为不存在相邻 \tt B 的方案数。这个答案显然就是 \sum_k\binom{n-k}{k}+\sum_k\binom{n-k-2}{k}

否则,环上既有 \tt B 又有 \tt R,不妨假设有 k\tt B,而 \tt R 构成的联通块大小依次为 a_1,a_2,\dots,a_k,使得 a_i\ge1\land\sum a=n-k,且该序列若合法,则对答案的贡献为恰 \frac{n}{k} 份。

考虑怎样的序列是合法的。

假设 S 串中有 t\tt B,其把 \tt R 序列分成了 c_0\sim c_t 各段,其中 c\ge0

首先,显然有 a\le c_0+1\land2\nmid a,而 c_t 对答案没有影响。

然后,如果有 2\nmid c_j(1\le j\le t-1),则说明该步必然走到了下一个(相对于上一个)\tt R 联通块,且当前 a\le c;如果 2|c_j(1\le j\le t-1),则该步必然是往回折返的,而没有其余限制。

由于对于每个节点 a,第一步向左 / 向右走都是要做到可行的,所以每种路线带来的限制都是要被计入的;且容易发现没有多余的限制。

所以其实总的限制就是

2\nmid a 1\le a\le\min\{c_0+[2|c_0],\min_{0<j<t,2\nmid c_j}c_j\} \sum a=n-k

这样的方案有几种呢?设

a=1+2b

0\le b\le\frac{\min\{c_0+[2|c_0],\min_{0<j<t,2\nmid c_j}c_j\}-1}{2} \sum b=\frac{n}{2}-k

d=\frac{\min\{c_0+[2|c_0],\min_{0<j<t,2\nmid c_j}c_j\}-1}{2},则方案数即

[z^{n/2-k}] (\frac{1-z^{d+1}}{1-z})^k

从而答案即为

n[z^{n/2}]\sum_k\frac{(z\frac{1-z^{d+1}}{1-z})^k}{k} \\=n[z^{n/2}]\ln\frac{1}{1-z\frac{1-z^{d+1}}{1-z}}

使用 MTT 即可做到 O(n\log n)

MTT 太混乱邪恶了,考虑阳间一点的推导。

n[z^{n/2}]\ln\frac{1}{1-z\frac{1-z^{d+1}}{1-z}} \\=2[z^{n/2-1}] (\ln\frac{1}{1-z\frac{1-z^{d+1}}{1-z}})' \\=2[z^{n/2-1}]\frac{1-(d+2)z^{d+1}+(d+1)z^{d+2}}{1-3z+2z^2+z^{d+2}-z^{d+3}}

由于分母的非 0 项很少,直接递推即可。

总复杂度 O(n+m)

Code

代码很好写。

const ullt Mod=1e9+7;
typedef ConstMod::mod_ullt<Mod>modint;
typedef std::vector<modint>modvec;
int main()
{
#ifdef MYEE
    freopen("QAQ.in","r",stdin);
    freopen("QAQ.out","w",stdout);
#endif
    std::vector<uint>V{0};
    uint n;scanf("%u",&n);
    {
        chr c;
        do c=getchar();while(c!='R'&&c!='B');
        chr o=c;
        do
        {
            if(o==c)V.back()++;else V.push_back(0);
            c=getchar();
        }
        while(c=='R'||c=='B');
        V.pop_back();
    }
    if(V.empty())
    {
        modint a(2),b(1);
        while(n--)std::swap(a+=b,b);
        a.println();return 0;
    }
    if(n&1)return puts("0"),0;
    uint d=V[0]/2;for(auto s:V)if(s&1)_min(d,s/2);
    static modint P[100005];
    P[0]=1,P[d+1]=-modint(d+2),P[d+2]=d+1;
    for(uint i=1;i<n/2;i++)
    {
        P[i]+=3*P[i-1];
        if(i>=2)P[i]-=2*P[i-2];
        if(i>=d+2)P[i]-=P[i-d-2];
        if(i>=d+3)P[i]+=P[i-d-3];
    }
    (P[n/2-1]*2).println();
    return 0;
}