受邀写题解,讲的墨迹点……
相当麻烦的换根$dp$。
思路
换根$dp$,尝试两次转移来求解答案——对于树上的所有点对$(u,v)$,$f(u,v)$的总和。
第一次转移
尝试得到以$x$为根时,节点$x$到$x$子树上所有点的步数和。在草纸上画一画就可以发现如果我们得到节点$v$所在子树的信息,在向其父节点$x$传递贡献时,$v$子树上距离$v$为$k$的倍数的点的贡献会整体$+1$(也就是有多少个点到$v$的距离模$k=0$,贡献就新增多少),所以肯定还要把距离和$k$的关系的点的数目记录下来。
我们可以发现$k\le 5$,那么设$dp[x]$表示$x$子树上所有点到$x$的贡献(一维就可以,$dp$不必记录关于$k$的信息)。
设$sum[x][y]$表示在$x$子树上到点$x$的距离模$k$后为$y$的点的数量,用来辅助转移。
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 
 | void dfs(int x,int fa){
 sum[x][0]=1;
 for(int &v:G[x])
 {
 if(v==fa)
 continue;
 dfs(v,x);
 dp[x]+=dp[v];
 for(int i=0;i<k;i++)
 sum[x][i]+=sum[v][(i-1+k)%k];
 dp[x]+=sum[v][0];
 }
 }
 
 | 
第二次转移
换根$dp$第二部尝试换根,设$dp2[x]$表示以$x$为根节点时,整棵树上所有点到达$x$所需步数和。
当我们尝试将根节点从$x$转移到$v$时,贡献$dp2[v]$相对于$dp2[x]$会发生什么变化?会发现$v$子树上所有到$v$的距离模$k$为$0$的点的贡献相对于$dp2[x]$会整体减$1$(这部分点经过$v$到$x$要多走一步,现在可以不走了);而在整棵树关于$v$子树的补树上(原来的树去掉$v$子树),所有到$x$的距离模$k=0$的点贡献会整体加$1$。
那么此时我们需要寻找方案来方便地求解补树,因为之前我们得到了$sum$数组,所以我们只要得到某个点作为根节点的信息,即可通过减去子树$sum$信息来方便地得到补树上特定距离模数的点的数量。
设$ret[x][y]$表示以$x$为根时树上到点$x$的距离模$k$后等于$y$的点的数量,用来辅助第二次转移。对于原来的根(默认为$1$),$ret[1]$与$sum[1]$的信息含义是一样的。
那么如何通过父节点$x$的$ret$信息来推导出子节点$v$的$ret$呢?使用类似容斥的方法,先假设树上所有点到$v$的距离相对于原来到$x$的距离全部增加$1$,我们发现$v$的子树上的距离会被错误地增加了,$v$子树上到$v$的距离相对于到$x$点不应增加反而应该减少,那么我们删掉这一部分的数量并且使用$sum$里的信息来修正。
| 12
 3
 
 | ret[v][i]=ret[x][(i-1+k)%k];ret[v][i]-=sum[v][(i-2+k)%k];
 ret[v][i]+=sum[v][i];
 
 | 
我们得到了$ret$数组,接下来是换根转移过程,和上面类似,跳回到该部分第二段。
观察根节点转移到相邻节点时答案贡献的变化,以新的节点$x$为根的子树上距离模$k=0$的点贡献会整体减$1$,而补树上到$fa$距离模$k=0$的点贡献会整体加$1$,如下。
| 12
 3
 
 | dp2[x]=dp2[fa];dp2[x]+=ret[fa][0]-sum[x][(k-1+k)%k];
 dp2[x]-=sum[x][0];
 
 | 
那么这题我们就做出来了,完结撒花ヾ(≧▽≦*)o。
代码
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 
 | #include<bits/stdc++.h>using namespace std;
 typedef long long ll;
 #define int ll
 #define debug(x) std:: cerr << #x << " = " << x << std::endl;
 const int maxn=2e5+10,inf=0x3f3f3f3f,mod=1000000007;
 void read(){}
 template<typename T,typename... T2>inline void read(T &x,T2 &... oth) {
 x=0; int ch=getchar(),f=0;
 while(ch<'0'||ch>'9'){if (ch=='-') f=1;ch=getchar();}
 while (ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
 if(f)x=-x;
 read(oth...);
 }
 int sum[maxn][6],dp[maxn],dp2[maxn],k,ans=0;
 vector<int>G[maxn];
 void dfs(int x,int fa)
 {
 sum[x][0]=1;
 for(int &v:G[x])
 {
 if(v==fa)
 continue;
 dfs(v,x);
 dp[x]+=dp[v];
 for(int i=0;i<k;i++)
 sum[x][i]+=sum[v][(i-1+k)%k];
 dp[x]+=sum[v][0];
 }
 }
 int ret[maxn][6];
 void dfs2(int x,int fa)
 {
 if(x!=1)
 {
 dp2[x]=dp2[fa];
 dp2[x]+=ret[fa][0]-sum[x][(k-1+k)%k];
 dp2[x]-=sum[x][0];
 }
 ans+=dp2[x];
 for(int &v:G[x])
 {
 if(v==fa)
 continue;
 for(int i=0;i<k;i++)
 {
 ret[v][i]=ret[x][(i-1+k)%k];
 ret[v][i]-=sum[v][(i-2+k)%k];
 ret[v][i]+=sum[v][i];
 }
 dfs2(v,x);
 }
 }
 signed main(signed argc, char const *argv[])
 {
 int n,u,v;
 cin>>n>>k;
 for(int i=1;i<n;i++)
 {
 cin>>u>>v;
 G[u].push_back(v);
 G[v].push_back(u);
 }
 dfs(1,0);
 for(int i=0;i<k;i++)
 ret[1][i]=sum[1][i];
 dp2[1]=dp[1];
 dfs2(1,1);
 cout<<ans/2<<endl;
 return 0;
 }
 
 
 
 
 
 
 
 
 |