本文仅作为树状数组入门讲解。
Part 1: 简介 / Introduction \LARGE\text{{Part 1: 简介\ /\ Introduction}} Part 1: 简介 / Introduction
树状数组(Binary Indexed Tree),是一种高级(?数据结构,基本用法是解决单点修改 & 区间查询问题。
Part 2: 原理 / Theory \LARGE{\text{Part 2: 原理\ /\ Theory}} Part 2: 原理 / Theory
树状数组,顾名思义,就是一个数组。那么它如何存储呢?如图所示:
其中, a a a 为原数组, s s s 为树状数组。
规律:假设正整数 i i i 的二进制位的最后一个 1 1 1 所在的位是从低到高第 j j j 位,则有 s i = ∑ k = i − 2 j − 1 + 1 i a k s_i=\displaystyle\sum_{k=i-2^{j-1}+1}^{i}a_k si=k=i−2j−1+1∑iak。我们记这里的 2 j − 1 2^{j-1} 2j−1 为 lowbit(i) \operatorname*{lowbit(i)} lowbit(i)。
- 如何修改?
考虑对 a 3 a_3 a3 进行 + 2 +2 +2 操作。
首先,显然 s 3 s_3 s3 包含 a 3 a_3 a3,那么让 s 3 + 2 s_3+2 s3+2;接着往右上方走, s 4 s_4 s4 和 s 8 s_8 s8 也包含 a 3 a_3 a3,也需要 + 2 +2 +2。
规律:假设修改的位置是 a x a_x ax,则不断将 x x x 增加 lowbit(x) \operatorname{lowbit(x)} lowbit(x),并更新新的 s x s_x sx,直到到达“树”顶。
- 如何查询?
假设我们需要查询 [ l , r ] [l,r] [l,r] 的区间和。则可以转换为求 [ 1 , r ] − [ 1 , l − 1 ] [1,r] - [1,l-1] [1,r]−[1,l−1]。那么目标就转换到了求 ∑ i = 1 x a i \displaystyle\sum_{i=1}^{x}a_i i=1∑xai。
观察 lowbit(x) \operatorname*{lowbit(x)} lowbit(x) 的定义,可以发现你可以通过不断将一个 x x x 减少 lowbit(x) \operatorname*{lowbit(x)} lowbit(x),再对 s x s_x sx 求和,这样就可以覆盖所有 [ 1 , x ] [1,x] [1,x] 之间的数。
时间复杂度:因为每次都会从 x x x 的二进制中删掉一个 1 1 1,所以单次操作显然是 O ( log n ) O(\log n) O(logn) 的。
Part 3: 题目/ Problems \LARGE{\text{Part 3: \ 题目/\ Problems}} Part 3: 题目/ Problems
P3374 【模板】树状数组 1
对一个序列维护两个操作:
-
令 a x ← a x + y a_x \gets a_x+y ax←ax+y;
-
查询 ∑ i = l r a i \displaystyle\sum_{i=l}^{r}a_i i=l∑rai。
树状数组板子题。直接根据上面思路写即可。时间复杂度 O ( m log n ) O(m \log n) O(mlogn)。
#include <iostream>
using namespace std;
int tree[500005],n,m;
int lowbit(int x) { return x & -x; }
void update(int x,int d) {
while(x <= n) {
tree[x] += d; x += lowbit(x);
}
}
int query(int x) {
int t = 0;
while(x > 0) {
t += tree[x]; x -= lowbit(x);
} return t;
}
int main() {
cin >> n >> m;
for(int i = 1,x;i <= n;i++) {
cin >> x; update(i,x);
}
while(m--) {
int opt,x,y; cin >> opt >> x >> y;
if(opt == 1) update(x,y);
else cout << query(y) - query(x-1) << endl;
}
return 0;
}
P3368 【模板】树状数组 2
-
对于所有 l ≤ i ≤ r l \le i \le r l≤i≤r,令 a i ← a i + x a_i \gets a_i + x ai←ai+x;
-
求出 a x a_x ax。
考虑差分, s i = a i − a i − 1 s_i=a_i-a_{i-1} si=ai−ai−1。那么对于操作 1 1 1,直接让 s l ← s l + x s_l \gets s_l + x sl←sl+x, s r + 1 ← s r + 1 − x s_{r+1} \gets s_{r+1} - x sr+1←sr+1−x;而操作 2 2 2 就可以转换为 ∑ i = 1 x \displaystyle\sum_{i=1}^{x} i=1∑x。那么就可以用树状数组完成了。时间复杂度 O ( m log n ) O(m \log n) O(mlogn)。
#include <bits/stdc++.h>
#define lowbit(x) ((x) & (-x))
using namespace std;
int tree[500005],n,m;
void update(int x,int d) {
while(x <= n) {
tree[x] += d;
x += lowbit(x);
}
}
int sum(int x) {
int s = 0;
while(x > 0) {
s += tree[x];
x -= lowbit(x);
}
return s;
}
int main() {
cin >> n >> m;
int last = 0;
for(int i = 1;i <= n;i++) {
int X; cin >> X;
update(i,X-last); last = X;
}
while(m--) {
int opt,x,y; cin >> opt >> x;
if(opt == 1) { cin >> y;
int k; cin >> k;
update(x,k);
update(y+1,-k);
}
else {
cout << sum(x) << endl;
}
}
return 0;
}
P5057 [CQOI2006] 简单题
板子… …
考虑区间反转在模 2 2 2 意义下其实等价于给所有数 + 1 +1 +1。那么我们考虑直接转换成树状数组 2,对于查询操作,将差分数组求和之后模 2 2 2 即可。时间复杂度 O ( m log n ) O(m \log n) O(mlogn)。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
struct Binary_Indexed_Tree {
ll tree[500005], n;
ll lowbit(int x) { return x & (-x); }
void update(ll x,ll k) {
while(x <= n) tree[x] += k, x += lowbit(x);
}
ll query(ll x) {
ll ans = 0; while(x > 0) ans += tree[x], x -= lowbit(x); return ans;
}
} BIT;
int main() {
int m; cin >> BIT.n >> m;
while(m--) {
int opt; cin >> opt;
if(opt == 1) {
int l,r; cin >> l >> r; BIT.update(l,1); BIT.update(r+1,-1);
}
else {
int x; cin >> x; cout << BIT.query(x) % 2 << endl;
}
}
return 0;
}
Part 4: 更多应用/ Problems \LARGE{\text{Part 4: \ 更多应用/\ Problems}} Part 4: 更多应用/ Problems
树状数组应用:权值树状数组。
例题:
P1908 逆序对
首先,因为本题只要求数之间的相对顺序,所以我们可以使用离散化,将所有数映射到 1 ∼ n 1 \sim n 1∼n 的范围内。然后可以得到一个暴力做法:建一个数组 cnt \text{cnt} cnt,然后对于所有 a i a_i ai 令 cnt a i ← cnt a i + 1 \text{cnt}_{a_i} \gets \text{cnt}_{a_i}+1 cntai←cntai+1;然后再枚举一遍,对于每个 a i a_i ai,将所有 cnt 1 ∼ cnt a i − 1 \text{cnt}_1 \sim \text{cnt}_{a_i-1} cnt1∼cntai−1 加起来就是 a i a_i ai 的贡献。这样时间复杂度 O ( n 2 ) O(n^2) O(n2)。然后发现这是一个单点加,区间和的东西,用树状数组维护一下即可。时间复杂度 O ( n log n ) O(n \log n) O(nlogn)。
#include <bits/stdc++.h>
#define int long long
using namespace std;
int lowbit(int x) { return x & (-x); }
struct node {
int id,x;
bool operator < (const node& y) {
return (x != y.x ? x < y.x : id < y.id);
}
} a[500005];
int rk[500005];
int tree[500005],n;
void update(int x,int d) {
while(x <= n) {
tree[x] += d;
x += lowbit(x);
}
}
int query(int x) {
int sum = 0;
while(x >= 1) {
sum += tree[x];
x -= lowbit(x);
}
return sum;
}
signed main() {
cin >> n;
for(int i = 1;i <= n;i++) cin >> a[i].x;
for(int i = 1;i <= n;i++) a[i].id = (a[i].x == a[i-1].x ? a[i-1].id : i);
sort(a+1,a+n+1);
for(int i = 1;i <= n;i++) {
rk[a[i].id] = i;
}
int ans = 0;
for(int i = n;i >= 1;i--) {
ans += query(rk[i]); update(rk[i],1);
}
cout << ans << endl;
return 0;
}
P10589 楼兰图腾
假如数据范围很小,那么这就是一道 DP 题。我们不难写出暴力 DP 代码:
#include <bits/stdc++.h>
using namespace std;
int a[300005];
int f[300005][4];
int main() {
int n; cin >> n;
for(int i = 1;i <= n;i++) cin >> a[i], f[i][1] = 1;
for(int i = 1;i <= n;i++) {
for(int j = 1;j < i;j++) f[i][2] += (a[j] > a[i]);
}
for(int i = 1;i <= n;i++) {
for(int j = 1;j < i;j++) f[i][3] += (a[j] < a[i]) * f[j][2];
} long long ans = 0;
for(int i = 1;i <= n;i++) ans += f[i][3]; cout << ans << ' ';
for(int i = 1;i <= n;i++) f[i][1] = 1, f[i][2] = 0, f[i][3] = 0;
for(int i = 1;i <= n;i++) {
for(int j = 1;j < i;j++) f[i][2] += (a[j] < a[i]);
}
for(int i = 1;i <= n;i++) {
for(int j = 1;j < i;j++) f[i][3] += (a[j] > a[i]) * f[j][2];
} ans = 0;
for(int i = 1;i <= n;i++) ans += f[i][3]; cout << ans;
return 0;
}
于是不难发现瓶颈在于计算有多少个比当前数小的数。那么直接拿权值树状数组优化,时间复杂度 O ( n log n ) O(n \log n) O(nlogn)。
#include <bits/stdc++.h>
#define int long long
using namespace std;
int a[300005], f[300005][4];
struct BIT {
int tree[300005], n;
int lowbit(int x) { return x & (-x); }
void init() {
for(int i = 1;i <= n;i++) tree[i] = 0;
} void update(int x, int d) {
for(; x <= n; x += lowbit(x)) tree[x] += d;
} int query(int x) {
int sum = 0; for(; x; x -= lowbit(x)) sum += tree[x];
return sum;
}
} t;
signed main() {
cin >> t.n; t.init();
for(int i = 1;i <= t.n;i++) cin >> a[i], f[i][1] = 1;
for(int i = 1;i <= t.n;i++) {
f[i][2] = t.query(t.n - a[i] + 1); t.update(t.n - a[i] + 1, 1);
} t.init(); for(int i = 1;i <= t.n;i++) {
f[i][3] = t.query(a[i]); t.update(a[i], f[i][2]);
} long long ans = 0; t.init();
for(int i = 1;i <= t.n;i++) ans += f[i][3]; cout << ans << ' ';
for(int i = 1;i <= t.n;i++) f[i][1] = 1, f[i][2] = 0, f[i][3] = 0;
for(int i = 1;i <= t.n;i++) {
f[i][2] = t.query(a[i]); t.update(a[i], 1);
} t.init(); for(int i = 1;i <= t.n;i++) {
f[i][3] = t.query(t.n - a[i] + 1); t.update(t.n - a[i] + 1, f[i][2]);
} ans = 0;
for(int i = 1;i <= t.n;i++) ans += f[i][3]; cout << ans;
return 0;
}