前置芝士
线段树合并的全称为动态开点权值线段树合并
线段树合并
概念
顾名思义,就是建立一棵新的线段树保存原有的两颗线段树的信息。
merge
线段树合并所完成的操作就是:在相同区间内,对应位置相加。
合并方式主要如下:两个相同区间的节点 p {p} p 和 q {q} q,合并区间 [ l , r ] {[l,r]} [l,r]。
- 若 p {p} p 和 q {q} q 有一者为 0 {0} 0,则另一者为根。
- 若 l = = r {l == r} l==r,则累加权值,返回 p {p} p。
- 合并 p {p} p 的左儿子和 q {q} q 的左儿子,合并 p {p} p 的右儿子和 q {q} q 的右儿子。
- 最后将对应位置权值累加, 返回 p {p} p 。
代码如下:
int merge(int p, int q, int l, int r) {
if(!p || !q) return p + q;
if(l == r) { tr[p].sum += tr[q].sum; return p; }
int mid = l + r >> 1;
tr[p].l = merge(tr[p].l, tr[q].l, l, mid);
tr[p].r = merge(tr[p].r, tr[q].r, mid + 1, r);
pushup(p);
return p;
}
例题:晋升者计数
题目描述
奶牛们又一次试图创建一家创业公司,还是没有从过去的经验中吸取教训–牛是可怕的管理者!
为了方便,把奶牛从 1 {1} 1 ~ n {n} n 编号,把公司组织成一棵树, 1 {1} 1 号奶牛作为总裁(这棵树的根节点)。除了总裁以外的每头奶牛都有一个单独的上司(它在树上的 “双亲结点”)。
所有的第 i {i} i 头牛都有一个不同的能力指数 ,描述了她对其工作的擅长程度。如果奶牛 i {i} i 是奶牛 j {j} j 的祖先节点,那么我们我们把奶牛 j {j} j 叫做 i {i} i 的下属。
不幸地是,奶牛们发现经常发生一个上司比她的一些下属能力低的情况,在这种情况下,上司应当考虑晋升她的一些下属。你的任务是帮助奶牛弄清楚这是什么时候发生的。简而言之,对于公司的中的每一头奶牛 i {i} i ,请计算其下属 j {j} j 的数量满足 p j > p i {p_j>p_i} pj>pi 。
输入格式
输入的第一行包括一个整数 n {n} n。
接下来的 n {n} n 行包括奶牛们的能力指数 p 1 , p 2 , . . . , p n {p_1,p_2,...,p_n} p1,p2,...,pn。保证所有数互不相同。
接下来的 n − 1 {n - 1} n−1 行描述了奶牛 2 {2} 2 ~ n {n} n 的上司的编号。再次提醒, 1 {1} 1 号奶牛作为总裁,没有上司。
输出格式
输出包括 n {n} n 行。输出的第 i {i} i 行应当给出有多少奶牛 i {i} i 的下属比奶牛 i {i} i 能力高。
输入输出样例
输入 #1
5
804289384
846930887
681692778
714636916
957747794
1
1
2
3
输出 #1
2
0
1
0
0
数据范围: 1 < = n < = 1 0 5 , 1 < = p i < = 1 0 9 {1<=n<=10^5 , 1<=p_i<=10^9} 1<=n<=105,1<=pi<=109
代码
-
线段树合并
- 先建好每一颗树,在遍历树时只需进行合并。(开至少 18 N {18N} 18N 空间)
- 遍历时,边查询,边建树。(开至少 10 N {10N} 10N 倍空间)
- 上述开空间是针对本题,具体空间大小
(我不会计算
#include <bits/stdc++.h>
using namespace std;
const int N = 100010;
struct node {
int l, r;
int sum;
} tr[N * 18];
struct Node {
int val, id;
bool operator<(const Node &A) const {
if(val == A.val) return id < A.id;
return val < A.val;
}
};
Node a[N];
int b[N], root[N], idx, cnt, n;
// idx 动态开点
// cnt 线段树维护的权值,即离散化完后数的rank最大值,即b数组最大值
int ans[N];
vector<int> e[N];
void pushup(int p) {
tr[p].sum = tr[tr[p].l].sum + tr[tr[p].r].sum;
}
// 建树+单点修改
void build(int &p, int l, int r, int x, int v) {
if(!p) p = ++idx;
if(l == r) { tr[p].sum = v; return ;}
int mid = l + r >> 1;
if(x <= mid) build(tr[p].l, l, mid, x, v);
else build(tr[p].r, mid + 1, r, x, v);
pushup(p);
}
// 线段树合并
int merge(int p, int q, int l, int r) {
if(!p || !q) return p + q;
if(l == r) { tr[p].sum += tr[q].sum; return p; }
int mid = l + r >> 1;
tr[p].l = merge(tr[p].l, tr[q].l, l, mid);
tr[p].r = merge(tr[p].r, tr[q].r, mid + 1, r);
pushup(p);
return p;
}
// 区间查询
int query(int p, int l, int r, int ql, int qr) {
if(!p) return 0;
if(l >= ql && r <= qr) return tr[p].sum;
int mid = l + r >> 1;
int ans = 0;
if(ql <= mid) ans = query(tr[p].l, l, mid, ql, qr);
if(qr > mid) ans += query(tr[p].r, mid + 1, r, ql, qr);
return ans;
}
void dfs(int u) {
for (auto &j: e[u]) {
dfs(j);
ans[u] += query(root[j], 1, cnt, b[u] + 1, cnt);
root[u] = merge(root[u], root[j], 1, cnt);
}
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++) cin >> a[i].val, a[i].id = i;
sort(a + 1, a + 1 + n);
// 离散化
for (int i = 1; i <= n; i++) {
if(i == 1 || a[i].val ^ a[i - 1].val) b[a[i].id] = ++cnt;
else b[a[i].id] = b[a[i - 1].id];
}
for (int i = 2; i <= n; i++) {
int x; cin >> x;
e[x].push_back(i);
}
// 每一点都为根,建一颗权值线段树
for (int i = 1; i <= n; i++) {
root[i] = ++idx;
build(root[i], 1, cnt, b[i], 1);
}
dfs(1);
for (int i = 1; i <= n; i++) cout << ans[i] << endl;
return 0;
}
- 删去一开始每点先建树操作,以及 d f s {dfs} dfs 变动如下。
struct node {
int l, r;
int sum;
} tr[N * 10];
void dfs(int u) {
for (auto &j: e[u]) {
dfs(j);
root[u] = merge(root[u], root[j], 1, cnt);
}
ans[u] = query(root[u], 1, cnt, b[u] + 1, cnt);
build(root[u], 1, cnt, b[u], 1);
}
- 又一次写:前后debug两遍,全是一个地方的问题,(第一次还debug错了,
我真牛
query
中:[ql, qr],这里写反了。
#include <bits/stdc++.h>
using namespace std;
const int N = 100010, M = N << 1;
struct node {
int l, r;
int v;
} tr[N * 10];
int n, m;
int a[N], b[N], ans[N];
int root[N], idx;
int h[N], e[M], ne[M], cnt;
void add(int a, int b) {
e[cnt] = b, ne[cnt] = h[a], h[a] = cnt++;
}
void pushup(int p) {
tr[p].v = tr[tr[p].l].v + tr[tr[p].r].v;
}
int merge(int p, int q, int l, int r) {
if(!p || !q) return p + q;
if(l == r) { tr[p].v += tr[q].v; return p; }
int mid = l + r >> 1;
tr[p].l = merge(tr[p].l, tr[q].l, l, mid);
tr[p].r = merge(tr[p].r, tr[q].r, mid + 1, r);
pushup(p);
return p;
}
void modify(int &p, int l, int r, int x, int v) {
if(!p) p = ++idx;
if(l == r) { tr[p].v = v; return ; }
int mid = l + r >> 1;
if(x <= mid) modify(tr[p].l, l, mid, x, v);
else modify(tr[p].r, mid + 1, r, x, v);
pushup(p);
}
int query(int p, int l, int r, int ql, int qr) {
if(!p) return 0;
if(l >= ql && r <= qr) { return tr[p].v; }
int mid = l + r >> 1;
int res = 0;
if(ql <= mid) res += query(tr[p].l, l, mid, ql, qr);
if(qr > mid) res += query(tr[p].r, mid + 1, r, ql, qr);
return res;
}
void dfs(int u) {
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
dfs(j);
root[u] = merge(root[u], root[j], 1, m);
}
ans[u] = query(root[u], 1, m, b[u] + 1, m);
modify(root[u], 1, m, b[u], 1);
}
int main() {
ios_base::sync_with_stdio(false), cin.tie(0), cout.tie(0);
cin >> n;
for (int i = 1; i <= n; i++) cin >> a[i], b[i] = a[i];
sort(a + 1, a + 1 + n);
m = unique(a + 1, a + 1 + n) - a - 1;
for (int i = 1; i <= n; i++)
b[i] = lower_bound(a + 1, a + 1 + m, b[i]) - a;
memset(h, -1, sizeof h);
for (int i = 2; i <= n; i++) {
int x; cin >> x;
add(x, i);
}
dfs(1);
for (int i = 1; i <= n; i++) cout << ans[i] << endl;
return 0;
}
-
加个树状数组题解,不过分吧。
- 对于点 u {u} u,答案 = 树状数组中加了 u {u} u 下属后比 u {u} u 强的 − {-} − 原来就比 u {u} u 强的
#include <bits/stdc++.h>
using namespace std;
const int N = 100010;
int a[N], b[N], n;
int tr[N], ans[N];
int h[N], e[N], ne[N], idx;
void Add(int x, int c) {
for (; x <= n; x += x & -x) tr[x] += c;
}
int Sum(int x) {
int res = 0;
for (; x; x -= x & -x) res += tr[x];
return res;
}
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void dfs(int u) {
ans[u] = -(Sum(n) - Sum(b[u]));
for (int i = h[u]; ~i; i = ne[i]) dfs(e[i]);
ans[u] += (Sum(n) - Sum(b[u]));
Add(b[u], 1);
}
int main() {
cin >> n;
memset(h, -1, sizeof h);
for (int i = 1; i <= n; i++) cin >> a[i], b[i] = a[i];
sort(a + 1, a + 1 + n);
for (int i = 1; i <= n; i++)
b[i] = lower_bound(a + 1, a + 1 + n, b[i]) - a;
for (int i = 2; i <= n; i++) {
int x; cin >> x;
add(x, i);
}
dfs(1);
for (int i = 1; i <= n; i++) cout << ans[i] << endl;
return 0;
}