Bootstrap

【详解】树链剖分之重链剖分

终于搞懂了树链剖分的一些皮毛了……

树链剖分

“树链剖分”,顾名思义,就是把一棵树剖分成一条条的链……

重链剖分

重链剖分的基本概念

重链剖分是树链剖分的一种,它会把树剖分成一条条重链……

什么是重链呢?

重链就是连接每一个树的重儿子所形成的链。

重儿子就是其儿子重以儿子为根的子树大小最大的儿子。

画一个图来理解一下:

对于这样一棵树,它剖成重链应该是这样的(红色是重链,绿色包裹的是重链上的点):

重链剖分的过程

首先,如果我们需要知道重儿子,那得知道子树的大小在进行处理,所以要分两次 dfs 来实现。

第一次 dfs

这一次 dfs 主要是算出每一个点的深度、父亲、子树大小。

就是一个简单的 dfs,很好理解。

局部代码
void dfs1(int x,int father,int deep)
{
	dep[x]=deep;
	fa[x]=father;
	siz[x]=1;
	int mxson=-1;
	for(auto y:g[x])
	{
		if(y==father)continue;
		dfs1(y,x,deep+1);
		siz[x]+=siz[y];
		if(siz[y]>mxson)
		{
			mxson=siz[y];
			son[x]=y;
		}
	}
}
第二次 dfs

算出每个点的链头、重儿子,以及其 dfs 序的编号(这个编号后续会有用)。

这一个比较简单的 dfs,应该也比较好理解。

局部代码
void dfs2(int x,int tf)
{
	id[x]=++cnt;
	top[x]=tf;
	if(!son[x])return;
	dfs2(son[x],tf);
	for(auto y:g[x])
	{
		if(y==fa[x]||y==son[x])continue;
		dfs2(y,y);
	}
}

以上就是重链剖分的基本步骤,时间复杂度 O(n),也是十分的高级好吧。

重链剖分的应用

学会了重链剖分的基本步骤,那还是得学会怎么用对吧……

先来看一道题目

洛谷 P3384 【模板】重链剖分/树链剖分icon-default.png?t=O83Ahttps://www.luogu.com.cn/problem/P3384

题目大意

解题思路

当然,一看题目的名称,就知道是重链剖分……

回顾重链剖分(这次加上每个点的 dfs 序编号):

也许你会发现,对于在同一个重链里面的点,他们的编号都是连续的

由此,我们可以把树看成一个个序列,从树上问题转换为序列问题

那题目中的每个操作都可以看做是在序列上进行的……

那么,很自然地就可以想到线段树可以解决它。

对于第 1 个操作

和倍增法求 LCA 类似地,我们可以依次往上跳,直到 x,y 都在同一条链上,每一次跳,都可以看做是在链上这个区间内加上了 z,直接套上线段树即可。

但是,最后也要记得处理最终的 x,y 之间的这个区间。

局部代码:

void update(int x,int y,int z)
{
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]])
			swap(x,y);
		change(1,id[top[x]],id[x],z);
		x=fa[top[x]];
	}
	if(dep[x]>dep[y])swap(x,y);
	change(1,id[x],id[y],z);
}
对于第 2 个操作

和第一个操作类似地,也是往上跳,直到 x,y 在同一条链上,只不过每次跳时都是对链这个区间进行一次求和查询,累加即可。

局部代码:

int que(int x,int y)
{
	int ans=0;
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]])
			swap(x,y);
		ans+=query(1,id[top[x]],id[x]);
		x=fa[top[x]];
	}
	if(dep[x]>dep[y])swap(x,y);
	ans+=query(1,id[x],id[y]);
	return ans;
}
对于第 3 个操作

可以再次发现,同一颗子树内的点的编号是连续的,所以可以一次修改操作就行了。

修改的区间是 [id_x,id_x+size_x-1]id_x 是 x 的编号,size_x 是以 x 为子树的大小。

局部代码(好像没啥好展示的):

change(1,id[x],id[x]+siz[x]-1,z);
对于第 4 个操作

修改的区间一样,只是查询操作而已。

局部代码:

cout<<query(1,id[x],id[x]+siz[x]-1)<<"\n";

完整代码

记得取模!!!

#include<bits/stdc++.h>
using namespace std;
#define int long long
int n,q,rt,mod;
int a[100001];
vector<int> g[100001];
int dep[100001];
int fa[100001];
int siz[100001];
int son[100001];
int id[100001],top[100001],wt[100001],cnt;
void dfs1(int x,int father,int deep)
{
	dep[x]=deep;
	fa[x]=father;
	siz[x]=1;
	int mxson=-1;
	for(auto y:g[x])
	{
		if(y==father)continue;
		dfs1(y,x,deep+1);
		siz[x]+=siz[y];
		if(siz[y]>mxson)
		{
			mxson=siz[y];
			son[x]=y;
		}
	}
}
void dfs2(int x,int tf)
{
	id[x]=++cnt;
	wt[id[x]]=a[x];
	top[x]=tf;
	if(!son[x])return;
	dfs2(son[x],tf);
	for(auto y:g[x])
	{
		if(y==fa[x]||y==son[x])continue;
		dfs2(y,y);
	}
}
struct tree{
	int sum,l,r,add;
}tr[400001];
void build(int u,int l,int r)
{
	tr[u]={0,l,r,0};
	if(l==r)
	{
		tr[u].sum=wt[l];
		tr[u].sum%mod;
		return;
	}
	int mid=l+r>>1;
	build(u*2,l,mid);
	build(u*2+1,mid+1,r);
	tr[u].sum=tr[u*2].sum+tr[u*2+1].sum;
	tr[u].sum%=mod;
}
void push_down(int u)
{
	if(tr[u].add)
	{
		tr[u*2].sum+=(tr[u*2].r-tr[u*2].l+1)*tr[u].add;
		tr[u*2].add+=tr[u].add;
		tr[u*2].sum%=mod;
	//	tr[u*2].add%=mod;
		tr[u*2+1].sum+=(tr[u*2+1].r-tr[u*2+1].l+1)*tr[u].add;
		tr[u*2+1].add+=tr[u].add;
		tr[u*2+1].sum%=mod;
	//	tr[u*2+1].add%=mod;
		tr[u].add=0;
	}
}
void push_up(int u)
{
	tr[u].sum=tr[u*2].sum+tr[u*2+1].sum;
	tr[u].sum%=mod;
}
void change(int u,int l,int r,int d)
{
	if(l<=tr[u].l&&tr[u].r<=r)
	{
		tr[u].sum+=(tr[u].r-tr[u].l+1)*d;
		tr[u].sum%=mod;
		tr[u].add+=d;
		return;
	}
	push_down(u);
	int mid=tr[u].l+tr[u].r>>1;
	if(l<=mid)
		change(u*2,l,r,d);
	if(r>mid)
		change(u*2+1,l,r,d);
	push_up(u);
}
int query(int u,int l,int r)
{
	if(l<=tr[u].l&&tr[u].r<=r)
	{
		return tr[u].sum;
	}
	push_down(u);
	int mid=tr[u].l+tr[u].r>>1;
	int res=0;
	if(l<=mid)
		res+=query(u*2,l,r);
	res%=mod;
	if(r>mid)
		res+=query(u*2+1,l,r);
	res%=mod;
	return res;
}
void update(int x,int y,int z)
{
	z%=mod;
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]])
			swap(x,y);
		change(1,id[top[x]],id[x],z);
		x=fa[top[x]];
	}
	if(dep[x]>dep[y])swap(x,y);
	change(1,id[x],id[y],z);
}
int que(int x,int y)
{
	int ans=0;
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]])
			swap(x,y);
		ans+=query(1,id[top[x]],id[x]);
		ans%=mod;
		x=fa[top[x]];
	}
	if(dep[x]>dep[y])swap(x,y);
	ans+=query(1,id[x],id[y]);
	ans%=mod;
	return ans;
}
signed main()
{
	ios::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	
	cin>>n>>q>>rt>>mod;
	
	for(int i=1;i<=n;i++)
	{
		cin>>a[i];
	}
	
	int u,v;
	for(int i=1;i<n;i++)
	{
		cin>>u>>v;
		g[u].push_back(v);
		g[v].push_back(u);
	}
	
	dfs1(rt,0,1);
	dfs2(rt,rt);
	
	build(1,1,n);
	
	int op,x,y,z;
	while(q--)
	{
		cin>>op;
		if(op==1)
		{
			cin>>x>>y>>z;
			update(x,y,z);
		}
		if(op==2)
		{
			cin>>x>>y;
			cout<<que(x,y)<<"\n";
		}
		if(op==3)
		{
			cin>>x>>z;
			change(1,id[x],id[x]+siz[x]-1,z);
		}
		if(op==4)
		{
			cin>>x;
			cout<<query(1,id[x],id[x]+siz[x]-1)%mod<<"\n";
		}
	}
}

最后求赞勿喷。

;