Bootstrap

树状数组学习笔记

本文仅作为树状数组入门讲解。

Part 1: 简介 / Introduction \LARGE\text{{Part 1: 简介\ /\ Introduction}} Part 1: 简介 / Introduction

树状数组(Binary Indexed Tree),是一种高级(?数据结构,基本用法是解决单点修改 & 区间查询问题。

Part 2: 原理 / Theory \LARGE{\text{Part 2: 原理\ /\ Theory}} Part 2: 原理 / Theory

树状数组,顾名思义,就是一个数组。那么它如何存储呢?如图所示:

其中, a a a 为原数组, s s s 为树状数组。

规律:假设正整数 i i i 的二进制位的最后一个 1 1 1 所在的位是从低到高第 j j j 位,则有 s i = ∑ k = i − 2 j − 1 + 1 i a k s_i=\displaystyle\sum_{k=i-2^{j-1}+1}^{i}a_k si=k=i2j1+1iak。我们记这里的 2 j − 1 2^{j-1} 2j1 lowbit(i) ⁡ \operatorname*{lowbit(i)} lowbit(i)

  • 如何修改?

考虑对 a 3 a_3 a3 进行 + 2 +2 +2 操作。

首先,显然 s 3 s_3 s3 包含 a 3 a_3 a3,那么让 s 3 + 2 s_3+2 s3+2;接着往右上方走, s 4 s_4 s4 s 8 s_8 s8 也包含 a 3 a_3 a3,也需要 + 2 +2 +2

规律:假设修改的位置是 a x a_x ax,则不断将 x x x 增加 lowbit(x) ⁡ \operatorname{lowbit(x)} lowbit(x),并更新新的 s x s_x sx,直到到达“树”顶。

  • 如何查询?

假设我们需要查询 [ l , r ] [l,r] [l,r] 的区间和。则可以转换为求 [ 1 , r ] − [ 1 , l − 1 ] [1,r] - [1,l-1] [1,r][1,l1]。那么目标就转换到了求 ∑ i = 1 x a i \displaystyle\sum_{i=1}^{x}a_i i=1xai

观察 lowbit(x) ⁡ \operatorname*{lowbit(x)} lowbit(x) 的定义,可以发现你可以通过不断将一个 x x x 减少 lowbit(x) ⁡ \operatorname*{lowbit(x)} lowbit(x),再对 s x s_x sx 求和,这样就可以覆盖所有 [ 1 , x ] [1,x] [1,x] 之间的数。

时间复杂度:因为每次都会从 x x x 的二进制中删掉一个 1 1 1,所以单次操作显然是 O ( log ⁡ n ) O(\log n) O(logn) 的。

Part 3:  题目/ Problems \LARGE{\text{Part 3: \ 题目/\ Problems}} Part 3:  题目/ Problems

P3374 【模板】树状数组 1

对一个序列维护两个操作:

  • a x ← a x + y a_x \gets a_x+y axax+y

  • 查询 ∑ i = l r a i \displaystyle\sum_{i=l}^{r}a_i i=lrai


树状数组板子题。直接根据上面思路写即可。时间复杂度 O ( m log ⁡ n ) O(m \log n) O(mlogn)

#include <iostream>
using namespace std;
int tree[500005],n,m;
int lowbit(int x) { return x & -x; }
void update(int x,int d) {
	while(x <= n) {
		tree[x] += d; x += lowbit(x);
	}
}
int query(int x) {
	int t = 0;
	while(x > 0) {
		t += tree[x]; x -= lowbit(x);
	} return t;
}
int main() {
	cin >> n >> m;
	for(int i = 1,x;i <= n;i++) {
		cin >> x; update(i,x);
	}
	while(m--) {
		int opt,x,y; cin >> opt >> x >> y;
		if(opt == 1) update(x,y);
		else cout << query(y) - query(x-1) << endl;
	}
	return 0;
}

P3368 【模板】树状数组 2

  • 对于所有 l ≤ i ≤ r l \le i \le r lir,令 a i ← a i + x a_i \gets a_i + x aiai+x

  • 求出 a x a_x ax


考虑差分, s i = a i − a i − 1 s_i=a_i-a_{i-1} si=aiai1。那么对于操作 1 1 1,直接让 s l ← s l + x s_l \gets s_l + x slsl+x s r + 1 ← s r + 1 − x s_{r+1} \gets s_{r+1} - x sr+1sr+1x;而操作 2 2 2 就可以转换为 ∑ i = 1 x \displaystyle\sum_{i=1}^{x} i=1x。那么就可以用树状数组完成了。时间复杂度 O ( m log ⁡ n ) O(m \log n) O(mlogn)

#include <bits/stdc++.h>
#define lowbit(x) ((x) & (-x))
using namespace std;
int tree[500005],n,m;
void update(int x,int d) {
	while(x <= n) {
		tree[x] += d;
		x += lowbit(x);
	}
}
int sum(int x) {
	int s = 0;
	while(x > 0) {
		s += tree[x];
		x -= lowbit(x);
	}
	return s;
}
int main() {
	cin >> n >> m;
	int last = 0;
	for(int i = 1;i <= n;i++) {
		int X; cin >> X;
		update(i,X-last); last = X;
	}
	while(m--) {
		int opt,x,y; cin >> opt >> x;
		if(opt == 1) { cin >> y;
			int k; cin >> k;
			update(x,k);
			update(y+1,-k);
		}
		else {
			cout << sum(x) << endl;
		}
	}
	return 0;
}

P5057 [CQOI2006] 简单题

板子… …

考虑区间反转在模 2 2 2 意义下其实等价于给所有数 + 1 +1 +1。那么我们考虑直接转换成树状数组 2,对于查询操作,将差分数组求和之后模 2 2 2 即可。时间复杂度 O ( m log ⁡ n ) O(m \log n) O(mlogn)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
struct Binary_Indexed_Tree {
	ll tree[500005], n;
	ll lowbit(int x) { return x & (-x); }
	void update(ll x,ll k) {
		while(x <= n) tree[x] += k, x += lowbit(x);
	}
	ll query(ll x) {
		ll ans = 0; while(x > 0) ans += tree[x], x -= lowbit(x); return ans;
	}
} BIT;
int main() {
	int m; cin >> BIT.n >> m;
	while(m--) {
		int opt; cin >> opt;
		if(opt == 1) {
			int l,r; cin >> l >> r; BIT.update(l,1); BIT.update(r+1,-1);
		}
		else {
			int x; cin >> x; cout << BIT.query(x) % 2 << endl;
		}
	}
	return 0;
}

Part 4:  更多应用/ Problems \LARGE{\text{Part 4: \ 更多应用/\ Problems}} Part 4:  更多应用/ Problems

树状数组应用:权值树状数组。

例题:

P1908 逆序对

首先,因为本题只要求数之间的相对顺序,所以我们可以使用离散化,将所有数映射到 1 ∼ n 1 \sim n 1n 的范围内。然后可以得到一个暴力做法:建一个数组 cnt \text{cnt} cnt,然后对于所有 a i a_i ai cnt a i ← cnt a i + 1 \text{cnt}_{a_i} \gets \text{cnt}_{a_i}+1 cntaicntai+1;然后再枚举一遍,对于每个 a i a_i ai,将所有 cnt 1 ∼ cnt a i − 1 \text{cnt}_1 \sim \text{cnt}_{a_i-1} cnt1cntai1 加起来就是 a i a_i ai 的贡献。这样时间复杂度 O ( n 2 ) O(n^2) O(n2)。然后发现这是一个单点加,区间和的东西,用树状数组维护一下即可。时间复杂度 O ( n log ⁡ n ) O(n \log n) O(nlogn)

#include <bits/stdc++.h>
#define int long long
using namespace std;
int lowbit(int x) { return x & (-x); }
struct node {
	int id,x;
	bool operator < (const node& y) {
		return (x != y.x ? x < y.x : id < y.id);
	}
} a[500005];
int rk[500005];
int tree[500005],n;
void update(int x,int d) {
	while(x <= n) {
		tree[x] += d;
		x += lowbit(x);
	}
}
int query(int x) {
	int sum = 0;
	while(x >= 1) {
		sum += tree[x];
		x -= lowbit(x);
	}
	return sum;
}
signed main() {
	cin >> n;
	for(int i = 1;i <= n;i++) cin >> a[i].x;
	for(int i = 1;i <= n;i++) a[i].id = (a[i].x == a[i-1].x ? a[i-1].id : i);
	sort(a+1,a+n+1);
	for(int i = 1;i <= n;i++) {
		rk[a[i].id] = i;
	}
	int ans = 0;
	for(int i = n;i >= 1;i--) {
		ans += query(rk[i]); update(rk[i],1);
	}
	cout << ans << endl;
	return 0;
}

P10589 楼兰图腾

假如数据范围很小,那么这就是一道 DP 题。我们不难写出暴力 DP 代码:

#include <bits/stdc++.h>
using namespace std;
int a[300005];
int f[300005][4];
int main() {
	int n; cin >> n;
	for(int i = 1;i <= n;i++) cin >> a[i], f[i][1] = 1;
	for(int i = 1;i <= n;i++) {
		for(int j = 1;j < i;j++) f[i][2] += (a[j] > a[i]);
	}
	for(int i = 1;i <= n;i++) {
		for(int j = 1;j < i;j++) f[i][3] += (a[j] < a[i]) * f[j][2];
	} long long ans = 0;
	for(int i = 1;i <= n;i++) ans += f[i][3]; cout << ans << ' ';
	for(int i = 1;i <= n;i++) f[i][1] = 1, f[i][2] = 0, f[i][3] = 0;
	for(int i = 1;i <= n;i++) {
		for(int j = 1;j < i;j++) f[i][2] += (a[j] < a[i]);
	}
	for(int i = 1;i <= n;i++) {
		for(int j = 1;j < i;j++) f[i][3] += (a[j] > a[i]) * f[j][2];
	} ans = 0;
	for(int i = 1;i <= n;i++) ans += f[i][3]; cout << ans;
	return 0;
}

于是不难发现瓶颈在于计算有多少个比当前数小的数。那么直接拿权值树状数组优化,时间复杂度 O ( n log ⁡ n ) O(n \log n) O(nlogn)

#include <bits/stdc++.h>
#define int long long
using namespace std;
int a[300005], f[300005][4];
struct BIT {
	int tree[300005], n;
	int lowbit(int x) { return x & (-x); }
	void init() {
		for(int i = 1;i <= n;i++) tree[i] = 0;
	} void update(int x, int d) {
		for(; x <= n; x += lowbit(x)) tree[x] += d;
	} int query(int x) {
		int sum = 0; for(; x; x -= lowbit(x)) sum += tree[x]; 
		return sum;
	}
} t;
signed main() {
	cin >> t.n; t.init();
	for(int i = 1;i <= t.n;i++) cin >> a[i], f[i][1] = 1;
	for(int i = 1;i <= t.n;i++) { 
		f[i][2] = t.query(t.n - a[i] + 1); t.update(t.n - a[i] + 1, 1);
	} t.init(); for(int i = 1;i <= t.n;i++) {
		f[i][3] = t.query(a[i]); t.update(a[i], f[i][2]);
	} long long ans = 0; t.init();
	for(int i = 1;i <= t.n;i++) ans += f[i][3]; cout << ans << ' ';
	for(int i = 1;i <= t.n;i++) f[i][1] = 1, f[i][2] = 0, f[i][3] = 0;
	for(int i = 1;i <= t.n;i++) {
		f[i][2] = t.query(a[i]); t.update(a[i], 1);
	} t.init(); for(int i = 1;i <= t.n;i++) {
		f[i][3] = t.query(t.n - a[i] + 1); t.update(t.n - a[i] + 1, f[i][2]);
	} ans = 0;
	for(int i = 1;i <= t.n;i++) ans += f[i][3]; cout << ans;
	return 0;
}
;