题解 P5364 【[SNOI2017]礼物】
_rqy
·
·
题解
@Fading 首先我要告诉您您的做法不是最优...
首先,容易发现这个东西可以写成矩阵乘法的形式...(如果你没有发现,请看最早那篇题解)...然后可以惊奇的发现,矩阵是上三角矩阵,并且对角线上有一个 $2$ 和 $k+1$ 个 $1$!
因为矩阵是上三角,所以显然其特征值就是主对角线上的元素,所以根据递推通项公式之类的性质,可以发现答案一定是这种形式:
$$c_k2^n+\sum_{i=0}^ka_{k,i}n^i$$
即一个 $2^n$ 的项加上一个 $k$ 次多项式。如果我们知道了 $c_k$ 的值,那么通过预处理 $[1, k+1]$ 项就可以拉格朗日插值得到 $n$ 的值。(拉格朗日插值可以做到 $O(k)$)
众所周知 $c_k2^n$ 的差分仍然是 $c_k2^n$,但是 $k$ 次多项式的差分是 $k-1$ 阶多项式,所以把前 $k+1$ 项进行 $k$ 阶差分(这个可以直接用组合数 $O(k)$ 得到)就可以得到 $c_k$ 的值。
这样的话,复杂度瓶颈还剩下预处理前 $k+1$ 项。而可以发现只需要求出 $1^k\dots (k+1)^k$,而这可以通过欧拉筛只求其中素数的幂,从而做到 $O(\pi(k)\log k+k)=O(k)$。
于是总复杂度就是 $O(k+\log n)$。($O(\log n)$是计算 $2^n$ ~~和输入n~~)($n$ 更大也不用高精度,只需要对 $10^9+7$ 取模算插值,$10^9+6$ 取模算 $2$ 的幂即可)
代码
```cpp
// luogu-judger-enable-o2
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
typedef long long LL;
const int K = 15;
const int mod = 1000000007;
LL pow_mod(LL a, LL b) {
LL ans = 1;
for (; b; b >>= 1, a = a * a % mod)
if (b & 1) ans = ans * a % mod;
return ans;
}
LL pK[K], f[K], inv[K];
int pr[K], cnt;
void Sieve(int k) {
pK[1] = 1;
for (int i = 2; i <= k + 1; ++i) {
if (!pK[i]) pK[pr[cnt++] = i] = pow_mod(i, k);
for (int j = 0; pr[j] * i <= k + 1; ++j) {
pK[i * pr[j]] = pK[i] * pK[pr[j]] % mod;
if (i % pr[j] == 0) break;
}
}
}
int main() {
int k; LL n;
scanf("%lld%d", &n, &k); --n;
LL s = 0, tn = n % mod;
Sieve(k);
for (int i = 0; i <= k; ++i) {
if ((f[i] = s + pK[i + 1]) >= mod) f[i] -= mod;
if ((s += f[i]) >= mod) s -= mod;
}
if (n <= k) return printf("%lld\n", f[n]) & 0;
LL g = -1; inv[1] = 1;
for (int i = 2; i <= k; ++i) {
inv[i] = -(mod / i) * inv[mod % i] % mod;
g = g * -inv[i] % mod;
}
LL c = 1, p = 0;
for (int i = 0; i <= k; ++i) {
// sum_{i=0}^k (-1)^(k-i) C(k, i) f[i+1]
LL _t = c * f[i] % mod;
c = c * inv[i + 1] % mod * (k - i) % mod;
if ((k - i) & 1) p -= _t;
else p += _t;
}
LL ans = 0, p2 = p, t = 1;
for (int i = 0; i <= k; ++i) {
LL _y = (f[i] - p2) * g % mod;
ans = (ans * (tn - i) + _y * t) % mod;
t = t * (tn - i) % mod;
p2 = p2 * 2 % mod;
g = g * (i - k) % mod * inv[i + 1] % mod;
}
ans = (ans + p * pow_mod(2, n % (mod - 1))) % mod;
printf("%lld\n", (ans + mod) % mod);
}
```