Bootstrap

线段树合并

前置芝士

  1. 权值线段树
  2. 动态开点线段树

线段树合并的全称为动态开点权值线段树合并

线段树合并

概念

顾名思义,就是建立一棵新的线段树保存原有的两颗线段树的信息。

merge

线段树合并所完成的操作就是:在相同区间内对应位置相加

合并方式主要如下:两个相同区间的节点 p {p} p q {q} q,合并区间 [ l , r ] {[l,r]} [l,r]

  • p {p} p q {q} q 有一者为 0 {0} 0,则另一者为根。
  • l = = r {l == r} l==r,则累加权值,返回 p {p} p
  • 合并 p {p} p 的左儿子和 q {q} q 的左儿子,合并 p {p} p 的右儿子和 q {q} q 的右儿子。
  • 最后将对应位置权值累加, 返回 p {p} p

代码如下:

int merge(int p, int q, int l, int r) {
	if(!p || !q) return p + q;
	if(l == r) { tr[p].sum += tr[q].sum; return p; }
	int mid = l + r >> 1;
	tr[p].l = merge(tr[p].l, tr[q].l, l, mid);
	tr[p].r = merge(tr[p].r, tr[q].r, mid + 1, r);
	pushup(p);
	return p;
}

例题:晋升者计数

题目描述

奶牛们又一次试图创建一家创业公司,还是没有从过去的经验中吸取教训–牛是可怕的管理者!

为了方便,把奶牛从 1 {1} 1 ~ n {n} n 编号,把公司组织成一棵树, 1 {1} 1 号奶牛作为总裁(这棵树的根节点)。除了总裁以外的每头奶牛都有一个单独的上司(它在树上的 “双亲结点”)。

所有的第 i {i} i 头牛都有一个不同的能力指数 ,描述了她对其工作的擅长程度。如果奶牛 i {i} i 是奶牛 j {j} j 的祖先节点,那么我们我们把奶牛 j {j} j 叫做 i {i} i 的下属。

不幸地是,奶牛们发现经常发生一个上司比她的一些下属能力低的情况,在这种情况下,上司应当考虑晋升她的一些下属。你的任务是帮助奶牛弄清楚这是什么时候发生的。简而言之,对于公司的中的每一头奶牛 i {i} i ,请计算其下属 j {j} j 的数量满足 p j > p i {p_j>p_i} pj>pi


输入格式

输入的第一行包括一个整数 n {n} n

接下来的 n {n} n 行包括奶牛们的能力指数 p 1 , p 2 , . . . , p n {p_1,p_2,...,p_n} p1,p2,...,pn。保证所有数互不相同。

接下来的 n − 1 {n - 1} n1 行描述了奶牛 2 {2} 2 ~ n {n} n 的上司的编号。再次提醒, 1 {1} 1 号奶牛作为总裁,没有上司。


输出格式

输出包括 n {n} n 行。输出的第 i {i} i 行应当给出有多少奶牛 i {i} i 的下属比奶牛 i {i} i 能力高。


输入输出样例

输入 #1

5
804289384
846930887
681692778
714636916
957747794
1
1
2
3

输出 #1

2
0
1
0
0


数据范围: 1 < = n < = 1 0 5 , 1 < = p i < = 1 0 9 {1<=n<=10^5 , 1<=p_i<=10^9} 1<=n<=105,1<=pi<=109


代码

  • 线段树合并

    • 先建好每一颗树,在遍历树时只需进行合并。(开至少 18 N {18N} 18N 空间)
    • 遍历时,边查询,边建树。(开至少 10 N {10N} 10N 倍空间)
    • 上述开空间是针对本题,具体空间大小 (我不会计算
#include <bits/stdc++.h>
using namespace std;

const int N = 100010;

struct node {
	int l, r;
	int sum;
} tr[N * 18];

struct Node {
	int val, id;
	bool operator<(const Node &A) const {
		if(val == A.val) return id < A.id;
		return val < A.val;
	}
};

Node a[N]; 
int b[N], root[N], idx, cnt, n;
// idx 动态开点
// cnt 线段树维护的权值,即离散化完后数的rank最大值,即b数组最大值
int ans[N];

vector<int> e[N];

void pushup(int p) {
	tr[p].sum = tr[tr[p].l].sum + tr[tr[p].r].sum;
}

// 建树+单点修改
void build(int &p, int l, int r, int x, int v) {
	if(!p) p = ++idx;
	if(l == r) { tr[p].sum = v; return ;}
	int mid = l + r >> 1;
	if(x <= mid) build(tr[p].l, l, mid, x, v);
	else build(tr[p].r, mid + 1, r, x, v);
	pushup(p);
}

// 线段树合并
int merge(int p, int q, int l, int r) {
	if(!p || !q) return p + q;
	if(l == r) { tr[p].sum += tr[q].sum; return p; }
	int mid = l + r >> 1;
	tr[p].l = merge(tr[p].l, tr[q].l, l, mid);
	tr[p].r = merge(tr[p].r, tr[q].r, mid + 1, r);
	pushup(p);
	return p;
}

// 区间查询
int query(int p, int l, int r, int ql, int qr) {
	if(!p) return 0;
	if(l >= ql && r <= qr) return tr[p].sum;
	int mid = l + r >> 1;
	int ans = 0;
	if(ql <= mid) ans = query(tr[p].l, l, mid, ql, qr);
	if(qr > mid) ans += query(tr[p].r, mid + 1, r, ql, qr);
	return ans;
}

void dfs(int u) {
	for (auto &j: e[u]) {
		dfs(j);
		ans[u] += query(root[j], 1, cnt, b[u] + 1, cnt);
		root[u] = merge(root[u], root[j], 1, cnt);
	}
}

int main() {
	cin >> n;
	
	for (int i = 1; i <= n; i++) cin >> a[i].val, a[i].id = i;
	sort(a + 1, a + 1 + n);
	
	// 离散化
	for (int i = 1; i <= n; i++) {
		if(i == 1 || a[i].val ^ a[i - 1].val) b[a[i].id] = ++cnt; 
		else b[a[i].id] = b[a[i - 1].id];
	}
	
	for (int i = 2; i <= n; i++) {
		int x; cin >> x;
		e[x].push_back(i);
	}
	
	// 每一点都为根,建一颗权值线段树
	for (int i = 1; i <= n; i++) {
		root[i] = ++idx;
		build(root[i], 1, cnt, b[i], 1);
	}

	dfs(1);
	
	for (int i = 1; i <= n; i++) cout << ans[i] << endl;
	return 0;
}
  • 删去一开始每点先建树操作,以及 d f s {dfs} dfs 变动如下。
struct node {
	int l, r;
	int sum;
} tr[N * 10];

void dfs(int u) {
	for (auto &j: e[u]) {
		dfs(j);
		root[u] = merge(root[u], root[j], 1, cnt);
	}
	ans[u] = query(root[u], 1, cnt, b[u] + 1, cnt);
	build(root[u], 1, cnt, b[u], 1);
}
  • 又一次写:前后debug两遍,全是一个地方的问题,(第一次还debug错了,我真牛

query中:[ql, qr],这里写反了。

#include <bits/stdc++.h>
using namespace std;

const int N = 100010, M = N << 1;

struct node {
	int l, r;
	int v;
} tr[N * 10];

int n, m;
int a[N], b[N], ans[N];
int root[N], idx;

int h[N], e[M], ne[M], cnt;

void add(int a, int b) {
	e[cnt] = b, ne[cnt] = h[a], h[a] = cnt++;
}

void pushup(int p) {
	tr[p].v = tr[tr[p].l].v + tr[tr[p].r].v;
}

int merge(int p, int q, int l, int r) {
	if(!p || !q) return p + q;
	if(l == r) { tr[p].v += tr[q].v; return p; }
	int mid = l + r >> 1;
	tr[p].l = merge(tr[p].l, tr[q].l, l, mid);
	tr[p].r = merge(tr[p].r, tr[q].r, mid + 1, r);
	pushup(p);
	return p;
}

void modify(int &p, int l, int r, int x, int v) {
	if(!p) p = ++idx;
	if(l == r) { tr[p].v = v; return ; }
	int mid = l + r >> 1;
	if(x <= mid) modify(tr[p].l, l, mid, x, v);
	else modify(tr[p].r, mid + 1, r, x, v);
	pushup(p);
}

int query(int p, int l, int r, int ql, int qr) {
	if(!p) return 0;
	if(l >= ql && r <= qr) { return tr[p].v; }
	int mid = l + r >> 1;
	int res = 0;
	if(ql <= mid) res += query(tr[p].l, l, mid, ql, qr);
	if(qr > mid) res += query(tr[p].r, mid + 1, r, ql, qr);
	return res;
}

void dfs(int u) {
	for (int i = h[u]; ~i; i = ne[i]) {
		int j = e[i];
		dfs(j);
		root[u] = merge(root[u], root[j], 1, m);
	}
	ans[u] = query(root[u], 1, m, b[u] + 1, m);
	modify(root[u], 1, m, b[u], 1);
}

int main() {
	ios_base::sync_with_stdio(false), cin.tie(0), cout.tie(0);
	cin >> n;
	for (int i = 1; i <= n; i++) cin >> a[i], b[i] = a[i];
	sort(a + 1, a + 1 + n);
	m = unique(a + 1, a + 1 + n) - a - 1;
	for (int i = 1; i <= n; i++) 
		b[i] = lower_bound(a + 1, a + 1 + m, b[i]) - a;

	memset(h, -1, sizeof h);
	for (int i = 2; i <= n; i++) {
		int x; cin >> x;
		add(x, i);
	}
	
	dfs(1);
	for (int i = 1; i <= n; i++) cout << ans[i] << endl; 
	return 0;
}

  • 加个树状数组题解,不过分吧。

    • 对于点 u {u} u,答案 = 树状数组中加了 u {u} u 下属后比 u {u} u 强的 − {-} 原来就比 u {u} u 强的
#include <bits/stdc++.h>
using namespace std;

const int N = 100010;
int a[N], b[N], n;
int tr[N], ans[N]; 
int h[N], e[N], ne[N], idx;

void Add(int x, int c) {
	for (; x <= n; x += x & -x) tr[x] += c;
}

int Sum(int x) {
	int res = 0;
	for (; x; x -= x & -x) res += tr[x];
	return res;
}

void add(int a, int b) {
	e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

void dfs(int u) {
	ans[u] = -(Sum(n) - Sum(b[u]));
	for (int i = h[u]; ~i; i = ne[i]) dfs(e[i]);
	ans[u] += (Sum(n) - Sum(b[u]));
	Add(b[u], 1);
}

int main() {
	cin >> n;
	memset(h, -1, sizeof h);
	for (int i = 1; i <= n; i++) cin >> a[i], b[i] = a[i];
	sort(a + 1, a + 1 + n);
	for (int i = 1; i <= n; i++)
		b[i] = lower_bound(a + 1, a + 1 + n, b[i]) - a;
	
	for (int i = 2; i <= n; i++) {
		int x; cin >> x;
		add(x, i);
	}
	dfs(1);
	for (int i = 1; i <= n; i++) cout << ans[i] << endl;
	return 0;
}
;