题解 P5319 【[BJOI2019]奥术神杖】

· · 题解

由题意可知,若最后的字符串中包含的集合为 S=\{ \{s_i,v_i\}\},则 ans=\sqrt[|S|]{\prod \limits_{i=1}^{|S|}v_i}

看过题解后发现,我们可以两边取 $\ln$,则 $\ln ans=\dfrac1{|S|}\sum\limits_{i=1}^{|S|}\ln v_i

这是一个经典的 0/1 分数规划问题,我们二分一个答案 mid

\dfrac1{|S|}\sum\limits_{i=1}^{|S|}\ln v_i>mid \sum\limits_{i=1}^{|S|}\ln v_i>|S|\cdot mid \sum\limits_{i=1}^{|S|}(\ln v_i-mid)>0

然后在 AC 自动机上跑 dp 即可,注意记录每个状态是从哪个状态转移来的、从哪个字母转移来的,来输出方案。

const int N = 2000;
int n,m;char ch[N],ans[N];
double f[N][N];int g[N][N][2];
struct AC
{
    struct tree
    {
        int ch[10],fail;
        int& operator [](int x) { return ch[x]; }
    }t[N];int cnt;
    int sum[N];double val[N];
    void insert(char *s,double v)
    {
        int now = 0;
        for(int i = 1;s[i];i++)
        {
            if(!t[now][s[i] - '0']) t[now][s[i] - '0'] = ++cnt;
            now = t[now][s[i] - '0'];
        }sum[now]++,val[now] += v;
    }
    int q[N],head,tail;
    void get_fail()
    {
        head = 1,tail = 0;
        for(int i = 0;i < 10;i++) if(t[0][i]) q[++tail] = t[0][i];
        while(head <= tail)
        {
            int u = q[head++];
            sum[u] += sum[t[u].fail],val[u] += val[t[u].fail];
            for(int i = 0;i < 10;i++)
                if(t[u][i]) t[t[u][i]].fail = t[t[u].fail][i],q[++tail] = t[u][i];
                else t[u][i] = t[t[u].fail][i];
        }
    }
    double DP(double v)
    {
        for(int i = 0;i <= cnt;i++) val[i] -= sum[i] * v;
        for(int i = 0;i <= n;i++) for(int j = 0;j <= cnt;j++) f[i][j] = -1e100;
        f[0][0] = 0;
        for(int i = 0;i < n;i++) for(int j = 0;j <= cnt;j++) if(f[i][j] > -1e99)
            for(int k = 0;k < 10;k++) if(ch[i] == '.' || ch[i] == k + '0')
            {
                int c = t[j][k];
                if(f[i + 1][c] < f[i][j] + val[c])
                {
                    f[i + 1][c] = f[i][j] + val[c];
                    g[i + 1][c][0] = j,g[i + 1][c][1] = k;
                }
            }

        for(int i = 0;i <= cnt;i++) val[i] += sum[i] * v;
        int pos = 0;
        for(int i = 1;i <= cnt;i++) if(f[n][i] > f[n][pos]) pos = i;
        for(int i = n,now = pos;i;i--)
            ans[i] = g[i][now][1] + '0',now = g[i][now][0];
        return f[n][pos];
    }
}T;
int main()
{
    n = read(),m = read();scanf("%s",ch);
    for(int i = 1;i <= m;i++)
    {
        static char s[N];
        scanf("%s",s + 1);
        T.insert(s,log(read()));
    }T.get_fail();
    double l = 0,r = log(INF),res = 0;
    while(r - l > eps)
    {
        double mid = (l + r) / 2.0;
        if(T.DP(mid) > 0) res = mid,l = mid;
        else r = mid;
    }T.DP(res);
    printf("%s",ans + 1);
}