题解 CF1458F 【Range Diameter Sum】

· · 题解

题意 : 给出一棵 n 个点的树,边权均为 1

len(S) 为点集 S 的直径长度。

\sum\limits_{i=1}^n\sum\limits_{j=i}^nlen([i,j])

------------ 很有意思的一个题。 树上邻域(圆)理论。 定义 $f(u,r)=\{v|dis(u,v)\leq r\}=\{\text{距离u不超过r的节点}\}$ ,即邻域。其中 $r$ 称为半径。 - **性质①** 对于一个点集 $S$ ,其所有直径的中点是重合的。(可能在一条边的中间) 若中点不重合则显然可以构造新的直径。 设点集 $S$ 的直径中点为 $mid(S)$ ,直径长度为 $len(S)$。 根据这条性质,定义点集 $S$ 的覆盖邻域 $c(S)=f(mid(S),len(S)/2)$。 显然,这是能覆盖 $S$ 的半径最小的邻域,类似于最小覆盖圆。 - **性质②** $f(u,r)\supseteq S\Rightarrow f(u,r)\supseteq c(S)$。 圆的形式 : $⊙C\supseteq S\Rightarrow ⊙C\supseteq S$ 的最小覆盖圆。 **证明** : 显然有 $dis(mid(S),u)\leq r-len(S)/2$ ,否则 $f(u,r)$ 不可能包含 $S$ 的直径。 也就是说,从 $mid(S)$ 开始还能走 $len(S)/2$ 步,自然就覆盖了 $c(S)$。 - **性质③** 设 $td(S,v)$ 为 $v$ 到 $S$ 内最远点的距离。 有 $td(S,v)=dis(mid(S),v)+len(S)/2$。 **证明** : 显然最远点一定是直径的端点之一。 无论从哪个方向来到中点,一定能继续走 $len(S)/2$ 来到某个端点。显然这是上界。 这个结论应用比较广泛,但是在本题中并不是重点,只用来证明。 计算 $c(S∪T)$ 并不困难,因为点集直径具有封闭性。 但是,合并直径的讨论太复杂,不便分析,考虑仅使用 $c(S)$ 来推导。 - 情况A1 : 若 $c(S)\subseteq c(T)$ 则 $c(S∪T)=c(T)$。 - 情况A2 : 若 $c(T)\subseteq c(S)$ 则 $c(S∪T)=c(S)$。 上述两条是显然的。 判定 $c(T)\subseteq c(S)$ 的方法 : 计算 $d=dis\big(mid(S),mid(T)\big)$ ,查看是否 $len(T)/2+d\leq len(S)/2$。 圆的形式 :即两圆包含的情况。 ![](https://cdn.luogu.com.cn/upload/image_hosting/y9h7mztj.png) **证明** : 若 $len(T)/2+d\leq len(S)/2$ ,则说明从 $mid(S)$ 走到 $mid(T)$ 之后还能走 $len(T)/2$ 步,显然覆盖了 $c(T)$。 否则,无论从哪个方向到达 $mid(T)$ ,总有某个不是来向的方向存在长为 $len(T)/2$ 的半条 $T$ 的直径,则不能完整覆盖。 - 情况B : 若不满足情况A1,A2,有 $c(S∪T)=c(f(u_1,r_1)∪f(u_2,r_2)) =f\Big(t,\big(dis(u_1,u_2)+r_1+r_2\big)/2\Big)

其中 tu_1,u_2 路径上的一个点,距离 u_1\big(dis(u_1,u_2)-r_1+r_2\big)/2 ,距离 u_2\big(dis(u_1,u_2)+r_1-r_2\big)/2

圆的形式 :(两圆相离也类似)

证明 :已经排除了前两种情况,所以,新的直径一定跨越 S,T

考虑 td(S,u_2)=dis(u_1,u_2)+r_1。取这个最远点 p_S ,进一步得到 td(T,p_S)=dis(p_S,u_2)+r_2=dis(u_1,u_2)+r_1+r_2

我们已经证明了直径长度,现在来找中点。

不难发现,前面构造的直径形如 p_S\xleftrightarrow{r_1}u_1\xleftrightarrow{dis(u_1,u_2)}u_2\xleftrightarrow{r_2}p_T

现在中点位置是显然的。完全等价于数轴上画圆的情况。

好了,现在我们已经有一套 c(S) 合并的理论了,来分析一下本题吧。

考虑分治,每次计算跨越区间 [L,R] 中点 t 的区间的贡献。

枚举每个 [l,t] ,批量计算并上各个 (t,r] 的贡献。

\sum\limits_{r=t+1}^Rlen([l,r])

=\sum\limits_{r=t+1}^Rlen\Big(c\big([l,t]∪(t,r]\big)\Big)

现在要分类讨论了。

不难发现,随着 r 的增大,点集 (t,r] 不断扩大。

一开始有 c\big([l,t]\big)\supseteq c\big((t,r]\big) ,然后是情况B,最终是 c\big([l,t]\big)\subseteq c\big((t,r]\big)

我们对这三个区间分别计算贡献。

( 注意,可能有 c\big([l,t]\big)=c\big((t,r]\big)

首先预处理出各个 len,mid

第一部分和第三部分的贡献对 len 求和即可计算。

第二部分的贡献中,\Big(len\big(c([l,t])\big)+len\big(c((t,r])\big)\Big)\Big/2 也容易计算。

剩下的 dis\big(mid([1,t]),mid((t,r])\big) 等价于形如 \sum\limits_{v\in[l,r]}dis(u,v) 的问题。

这是个经典问题,做法见 :P4211 [LNOI2014]LCA。若使用树剖BIT,复杂度为 O(\log^2) ,也可以使用全局平衡二叉树或者点分树做到 O(\log)

接下来考虑如何求出这三部分的分界线。

l 逐渐减小时,点集 [l,t] 逐渐扩大,所以这三条分界线会单调右移,容易维护。

代码实现采用了全局平衡二叉树,复杂度 O(n\log^2n)

#include<algorithm>
#include<cstdio>
#include<vector>
#define ll long long
#define pb push_back
#define MaxN 200500
using namespace std;
vector<int> g[MaxN];
struct totode
{int tf,f,son,c;}b[MaxN];
int dep[MaxN];
void pfs1(int u)
{
  b[u].c=1;
  dep[u]=dep[b[u].f]+1;
  for (int i=0,v;i<g[u].size();i++)
    if ((v=g[u][i])!=b[u].f){
      b[v].f=u;pfs1(v);
      b[u].c+=b[v].c;
      if (b[v].c>b[b[u].son].c)
        b[u].son=v;
    }
}
struct Node{
  int f,l,r,tag,c;ll s;
  inline void ladd(int t)
  {tag+=t;s+=1ll*t*c;}
}a[MaxN<<1];int tn;
int st[MaxN],tc[MaxN],tot;
int build(int l,int r)
{
  if (l==r)return st[l];
  int c=0,mid;
  for (int i=l;i<=r;i++)c+=tc[i];
  for (int i=l,c2=0;i<=r;i++){
    c2+=tc[i];
    if (c2+c2>c){mid=i-1;break;}
  }if (mid<l)mid=l;
  a[c=++tn].c=r-l+1;
  return
  a[a[c].l=build(l,mid)].f=
  a[a[c].r=build(mid+1,r)].f=c;
}
int dfn[MaxN],tp[MaxN],tim;
void pfs2(int u,int top)
{
  tp[dfn[u]=++tim]=u;
  b[u].tf=top;
  if (!b[u].son){
    for (tot=0;b[u].tf==top;u=b[u].f)tc[++tot]=u;
    for (int i=1;i<=tot;i++)st[tot-i+1]=tc[i];
    for (int i=1;i<=tot;i++)
      tc[i]=b[st[i]].c-b[b[st[i]].son].c;
    build(1,tot);
    return ;
  }pfs2(b[u].son,top);
  for (int i=0,v;i<g[u].size();i++)
    if ((v=g[u][i])!=b[u].f&&v!=b[u].son)
      pfs2(v,v);
}
inline void up(int u)
{a[u].s=a[a[u].l].s+a[a[u].r].s;}
inline void ladd(int u){
  if (a[u].tag){
    a[a[u].l].ladd(a[u].tag);
    a[a[u].r].ladd(a[u].tag);
    a[u].tag=0;
  }
}
void grt(int u){
  for (tot=0;u;u=a[u].f)st[++tot]=u;
  for (int i=tot;i>1;i--)ladd(st[i]);
}
void add(int u,int w){
  grt(u);
  a[st[1]].ladd(w);
  for (int i=2,v;i<=tot;i++){
    if ((v=a[st[i]].l)!=st[i-1])
      a[v].ladd(w);
    up(st[i]);
  }
}
ll sum;
void qry(int u){
  grt(u);ll las=sum;
  sum+=a[st[1]].s;
  for (int i=2,v;i<=tot;i++)
    if ((v=a[st[i]].l)!=st[i-1])
      sum+=a[v].s;
}
void padd(int x,int w)
{while(x){add(x,w);x=b[b[x].tf].f;}}
ll pqry(int x){
  sum=0;
  while(x){qry(x);x=b[b[x].tf].f;}
  return sum;
}
int lca(int u,int v)
{
  while(b[u].tf!=b[v].tf){
    if (dep[b[v].tf]>dep[b[u].tf])swap(u,v);
    u=b[b[u].tf].f;
  }return dep[u]<dep[v] ? u:v;
}
int dis(int u,int v)
{return dep[u]+dep[v]-2*dep[lca(u,v)];}
int up(int u,int d){
  while(dep[b[u].tf]>d)u=b[b[u].tf].f;
  return tp[dfn[b[u].tf]+d-dep[b[u].tf]];
}
int lin(int u,int v,int l)
{
  int t=lca(u,v);
  if (dep[u]-dep[t]>=l)return up(u,dep[u]-l);
  return up(v,l-dep[u]+2*dep[t]);
}
struct Data{int u,r;}s[MaxN];
Data merge(const Data &A,int u)
{
  int d=dis(A.u,u);
  if (d<=A.r)return A;
  return (Data){lin(A.u,u,(d-A.r)/2),(d+A.r)/2};
}
bool in(const Data &A,const Data &B)
{return A.r+dis(A.u,B.u)<=B.r;}
ll ans,o[MaxN],o2[MaxN];
Data q[MaxN];int tq;
bool cmpQ(const Data &A,const Data &B)
{return A.r<B.r;}
void solve(int L,int R)
{
  if (L==R)return ;
  int mid=(L+R)>>1;
  solve(L,mid);solve(mid+1,R);
  s[mid]=(Data){mid,0};
  s[mid+1]=(Data){mid+1,0};
  for (int l=mid-1;l>=L;l--)
    s[l]=merge(s[l+1],l);
  for (int r=mid+2;r<=R;r++)
    s[r]=merge(s[r-1],r);
  o[mid]=0;
  for (int r=mid+1;r<=R;r++){
    o[r]=o[r-1]+s[r].r;
    o2[r]=o2[r-1]+dep[s[r].u];
  }
  int p1=mid,p2=mid;
  tq=0;
  for (int l=mid;l>=L;l--){
    while(p1<R&&in(s[p1+1],s[l]))p1++;
    while(p2<R&&!in(s[l],s[p2+1]))p2++;
    p2=max(p2,p1);
    ans+=2ll*(p1-mid)*s[l].r+2*(o[R]-o[p2])
        +1ll*(p2-p1)*s[l].r+(o[p2]-o[p1])
        +1ll*(p2-p1)*dep[s[l].u]+(o2[p2]-o2[p1]);
    if (p1<p2){
      if (p1>mid)q[++tq]=(Data){-s[l].u,p1};
      q[++tq]=(Data){s[l].u,p2};
    }
  }sort(q+1,q+tq+1,cmpQ);
  for (int i=mid+1,p=1;i<=R;i++){
    padd(s[i].u,1);
    while(p<=tq&&q[p].r==i){
      if (q[p].u<0)ans+=2*pqry(-q[p].u);
      else ans-=2*pqry(q[p].u);
      p++;
    }
  }for (int i=mid+1;i<=R;i++)padd(s[i].u,-1);
}
int n;
int main()
{
  scanf("%d",&n);tn=n*2-1;
  for (int i=1;i<=tn;i++)a[i].c=1;
  for (int i=1,u,v;i<n;i++){
    scanf("%d%d",&u,&v);
    g[u].pb(n+i);g[n+i].pb(u);
    g[v].pb(n+i);g[n+i].pb(v);
  }pfs1(1);pfs2(1,1);
  solve(1,n);
  printf("%lld\n",ans>>1);
  return 0;
}