受邀写题解,讲的墨迹点……
相当麻烦的换根$dp$。
思路
换根$dp$,尝试两次转移来求解答案——对于树上的所有点对$(u,v)$,$f(u,v)$的总和。
第一次转移
尝试得到以$x$为根时,节点$x$到$x$子树上所有点的步数和。在草纸上画一画就可以发现如果我们得到节点$v$所在子树的信息,在向其父节点$x$传递贡献时,$v$子树上距离$v$为$k$的倍数的点的贡献会整体$+1$(也就是有多少个点到$v$的距离模$k=0$,贡献就新增多少),所以肯定还要把距离和$k$的关系的点的数目记录下来。
我们可以发现$k\le 5$,那么设$dp[x]$表示$x$子树上所有点到$x$的贡献(一维就可以,$dp$不必记录关于$k$的信息)。
设$sum[x][y]$表示在$x$子树上到点$x$的距离模$k$后为$y$的点的数量,用来辅助转移。
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| void dfs(int x,int fa) { sum[x][0]=1; for(int &v:G[x]) { if(v==fa) continue; dfs(v,x); dp[x]+=dp[v]; for(int i=0;i<k;i++) sum[x][i]+=sum[v][(i-1+k)%k]; dp[x]+=sum[v][0]; } }
|
第二次转移
换根$dp$第二部尝试换根,设$dp2[x]$表示以$x$为根节点时,整棵树上所有点到达$x$所需步数和。
当我们尝试将根节点从$x$转移到$v$时,贡献$dp2[v]$相对于$dp2[x]$会发生什么变化?会发现$v$子树上所有到$v$的距离模$k$为$0$的点的贡献相对于$dp2[x]$会整体减$1$(这部分点经过$v$到$x$要多走一步,现在可以不走了);而在整棵树关于$v$子树的补树上(原来的树去掉$v$子树),所有到$x$的距离模$k=0$的点贡献会整体加$1$。
那么此时我们需要寻找方案来方便地求解补树,因为之前我们得到了$sum$数组,所以我们只要得到某个点作为根节点的信息,即可通过减去子树$sum$信息来方便地得到补树上特定距离模数的点的数量。
设$ret[x][y]$表示以$x$为根时树上到点$x$的距离模$k$后等于$y$的点的数量,用来辅助第二次转移。对于原来的根(默认为$1$),$ret[1]$与$sum[1]$的信息含义是一样的。
那么如何通过父节点$x$的$ret$信息来推导出子节点$v$的$ret$呢?使用类似容斥的方法,先假设树上所有点到$v$的距离相对于原来到$x$的距离全部增加$1$,我们发现$v$的子树上的距离会被错误地增加了,$v$子树上到$v$的距离相对于到$x$点不应增加反而应该减少,那么我们删掉这一部分的数量并且使用$sum$里的信息来修正。
1 2 3
| ret[v][i]=ret[x][(i-1+k)%k]; ret[v][i]-=sum[v][(i-2+k)%k]; ret[v][i]+=sum[v][i];
|
我们得到了$ret$数组,接下来是换根转移过程,和上面类似,跳回到该部分第二段。
观察根节点转移到相邻节点时答案贡献的变化,以新的节点$x$为根的子树上距离模$k=0$的点贡献会整体减$1$,而补树上到$fa$距离模$k=0$的点贡献会整体加$1$,如下。
1 2 3
| dp2[x]=dp2[fa]; dp2[x]+=ret[fa][0]-sum[x][(k-1+k)%k]; dp2[x]-=sum[x][0];
|
那么这题我们就做出来了,完结撒花ヾ(≧▽≦*)o。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
| #include<bits/stdc++.h> using namespace std; typedef long long ll; #define int ll #define debug(x) std:: cerr << #x << " = " << x << std::endl; const int maxn=2e5+10,inf=0x3f3f3f3f,mod=1000000007; void read(){} template<typename T,typename... T2>inline void read(T &x,T2 &... oth) { x=0; int ch=getchar(),f=0; while(ch<'0'||ch>'9'){if (ch=='-') f=1;ch=getchar();} while (ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();} if(f)x=-x; read(oth...); } int sum[maxn][6],dp[maxn],dp2[maxn],k,ans=0; vector<int>G[maxn]; void dfs(int x,int fa) { sum[x][0]=1; for(int &v:G[x]) { if(v==fa) continue; dfs(v,x); dp[x]+=dp[v]; for(int i=0;i<k;i++) sum[x][i]+=sum[v][(i-1+k)%k]; dp[x]+=sum[v][0]; } } int ret[maxn][6]; void dfs2(int x,int fa) { if(x!=1) { dp2[x]=dp2[fa]; dp2[x]+=ret[fa][0]-sum[x][(k-1+k)%k]; dp2[x]-=sum[x][0]; } ans+=dp2[x]; for(int &v:G[x]) { if(v==fa) continue; for(int i=0;i<k;i++) { ret[v][i]=ret[x][(i-1+k)%k]; ret[v][i]-=sum[v][(i-2+k)%k]; ret[v][i]+=sum[v][i]; } dfs2(v,x); } } signed main(signed argc, char const *argv[]) { int n,u,v; cin>>n>>k; for(int i=1;i<n;i++) { cin>>u>>v; G[u].push_back(v); G[v].push_back(u); } dfs(1,0); for(int i=0;i<k;i++) ret[1][i]=sum[1][i]; dp2[1]=dp[1]; dfs2(1,1); cout<<ans/2<<endl; return 0; }
|