P5296 [北京省选集训2019] 生成树计数 题解
分析
题目就是要求出所有:
我们看到生成树计数相关问题就可以考虑使用矩阵树定理,但是矩阵树定理只能求出若边权为
并且
那么我们对于这道题可以每条边设置一个长度为
我们把
假如说我们知道了当前边集和的
这就变成了一个类似于卷积的形式,众所周知卷积是乘法运算,于是我们把上面的向量拿来做一个卷积就可以处理乘法问题了。
最后的答案就是第
加法减法我们可以对应位置暴力相减,除法我们可以倒推求出逆元,然后乘上逆元即可,注意特别判断第
每次操作时间复杂度为
代码
代码如下,其中多项式相关技巧均为暴力实现:
#include<bits/stdc++.h>
#define ll long long
#define mod 998244353
#define N 35
#define M 200005
using namespace std;
struct node{ll x,y,z;};
struct vect{ll p[N];}a[N][N];
ll n,K,i,j,k,m,x,y,z,ans,res,vis[N],f[M],fath[N],C[N][N],maps[N][N];
inline ll qmi(ll a,ll b,ll p){
ll res = 1%p,t=a;
while(b){
if(b&1) res=res*t%p;
t=t*t%p;
b>>=1;
}
return res;
}
vect operator+(vect a,vect b){
vect c;
for(ll i=0;i<=K;i++) c.p[i]=(a.p[i]+b.p[i])%mod;
return c;
}
vect operator-(vect a,vect b){
vect c;
for(ll i=0;i<=K;i++) c.p[i]=(a.p[i]-b.p[i]+mod)%mod;
return c;
}
vect operator*(vect a,vect b){
vect c;
for(ll i=0;i<=K;i++){
c.p[i] = 0;
for(ll j=0;j<=i;j++) c.p[i]=(c.p[i]+a.p[j]*b.p[i-j]%mod*C[i][j])%mod;
}
return c;
}
vect inv(vect a){
vect c;
c.p[0] = qmi(a.p[0],mod-2,mod);
for(ll i=1;i<=K;i++){
c.p[i] = 0;
for(ll j=0;j<i;j++) c.p[i] = (c.p[i]-c.p[j]*a.p[i-j]%mod*C[i][j]%mod+mod)%mod;
c.p[i] = c.p[i]*c.p[0]%mod;
}
return c;
}
bool operator==(vect a,vect b){
for(ll i=0;i<=K;i++) if(a.p[i]!=b.p[i]) return 0;
return 1;
}
bool operator!(vect a){
for(ll i=0;i<=K;i++) if(a.p[i]) return 0;
return 1;
}
inline ll solve(){
ll i,j,k,has=0;
for(i=1;i<n;i++) for(j=1;j<n;j++) for(k=0;k<=K;k++) a[i][j].p[k]=(a[i][j].p[k]%mod+mod)%mod;
for(i=1;i<n;i++){
for(j=i+1;j<n;j++){
if(!a[i][i]) swap(a[i],a[j]),has^=1;
vect res = a[j][i]*inv(a[i][i]);
assert(res*a[i][i]==a[j][i]);
for(k=i;k<n;k++) a[j][k]=a[j][k]-a[i][k]*res;
}
}
vect ans;
for(i=1;i<=K;i++) ans.p[i]=0;
ans.p[0]=1;
for(i=1;i<n;i++) ans=ans*a[i][i];
return ans.p[K];
}
int main(){
ios::sync_with_stdio(false);
cin>>n>>K;
C[0][0] = 1;
for(i=0;i<=K;i++){
for(j=0;j<=K;j++){
if(i==0&&j==0) continue;
if(i) C[i][j]+=C[i-1][j];
if(i&&j) C[i][j]+=C[i-1][j-1];
C[i][j] %= mod;
}
}
for(i=1;i<=n;i++){
for(j=1;j<=n;j++){
cin>>maps[i][j];
if(i>=j) continue;
for(k=0,res=1;k<=K;k++,res=res*maps[i][j]%mod) a[i][i].p[k]+=res,a[j][j].p[k]+=res,a[i][j].p[k]-=res,a[j][i].p[k]-=res;
}
}
cout<<solve()<<endl;
return 0;
}
/*
Input:
4 5
1 2 12
1 3 9
2 4 6
3 4 8
1 4 4
Output:
15
*/