Bootstrap

算法提高第二章 线段树基础

我们以下面一道题为例题开始分析:

[TJOI2009] 开关

题目描述

现有 n n n 盏灯排成一排,从左到右依次编号为: 1 1 1 2 2 2,……, n n n。然后依次执行 m m m 项操作。

操作分为两种:

  1. 指定一个区间 [ a , b ] [a,b] [a,b],然后改变编号在这个区间内的灯的状态(把开着的灯关上,关着的灯打开);
  2. 指定一个区间 [ a , b ] [a,b] [a,b],要求你输出这个区间内有多少盏灯是打开的。

灯在初始时都是关着的。

输入格式

第一行有两个整数 n n n m m m,分别表示灯的数目和操作的数目。

接下来有 m m m 行,每行有三个整数,依次为: c c c a a a b b b。其中 c c c 表示操作的种类。

  • c c c 的值为 0 0 0 时,表示是第一种操作。
  • c c c 的值为 1 1 1 时,表示是第二种操作。

a a a b b b 则分别表示了操作区间的左右边界。

输出格式

每当遇到第二种操作时,输出一行,包含一个整数,表示此时在查询的区间中打开的灯的数目。

样例 #1

样例输入 #1

4 5
0 1 2
0 2 4
1 2 3
0 2 4
1 1 4

样例输出 #1

1
2

提示

数据规模与约定

对于全部的测试点,保证 2 ≤ n ≤ 1 0 5 2\le n\le 10^5 2n105 1 ≤ m ≤ 1 0 5 1\le m\le 10^5 1m105 1 ≤ a , b ≤ n 1\le a,b\le n 1a,bn c ∈ { 0 , 1 } c\in\{0,1\} c{0,1}

线段树的基本结构与建树模板:

struct Node
{
	int l, r;
	ll sum, add;
}tr[N * 4];
void build(int u, int l, int r)
{
	if(l == r) tr[u] = {l, r, 0, 0};
	else {
		tr[u] = {l, r};
		int mid = l + r >> 1;
		build(u << 1, l, mid);
		build(u << 1 | 1, mid + 1, r);
		pushup(u);
	}
}

Pushup操作:

void pushup(int u)
{
	tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;//区间数量等于左区间加上右区间
}

Pushdown操作

void pushdown(int u)
{
	if(tr[u].add == 1)
	{
		int mid = tr[u].l + tr[u].r >> 1;
		tr[u << 1].add ^= 1;
		tr[u << 1 | 1].add ^= 1;//懒标记
		tr[u << 1].sum = (mid - tr[u << 1].l + 1 - tr[u << 1].sum);
		tr[u << 1 | 1].sum = (tr[u].r - mid - tr[u << 1 | 1] .sum);
		tr[u].add = 0;
	}
}

区间查询:

ll query(int u, int l, int r)
{
	if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
	else 
	{
		pushdown(u);
		ll sum = 0;
		int mid = tr[u].l + tr[u].r >> 1;
		if(l <= mid) sum += query(u << 1, l, r); 
		if(r > mid) sum += query(u << 1 | 1, l, r);
		return sum;
	}
}

区间修改:

void modify(int u, int l, int r, int d)
{
	if(tr[u].l >= l && tr[u].r <= r){
		tr[u].sum = (tr[u].r - tr[u].l + 1 - tr[u].sum);
		tr[u].add ^= 1;
	}
	else {
		pushdown(u);
		int mid = tr[u].l + tr[u].r >> 1;
		if(l <= mid) modify(u << 1, l, r, d);
		if(r > mid) modify(u << 1 | 1, l, r, d);
		pushup(u);
	}
}

AC代码:

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int N = 1e6 + 10;

ll w[N];

struct Node
{
	int l, r;
	ll sum, add;
}tr[N * 4];

void pushup(int u)
{
	tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;//区间数量等于左区间加上右区间
}

void pushdown(int u)
{
	if(tr[u].add == 1)
	{
		int mid = tr[u].l + tr[u].r >> 1;
		tr[u << 1].add ^= 1;
		tr[u << 1 | 1].add ^= 1;//懒标记
		tr[u << 1].sum = (mid - tr[u << 1].l + 1 - tr[u << 1].sum);
		tr[u << 1 | 1].sum = (tr[u].r - mid - tr[u << 1 | 1] .sum);
		tr[u].add = 0;
	}
}

void build(int u, int l, int r)
{
	if(l == r) tr[u] = {l, r, 0, 0};
	else {
		tr[u] = {l, r};
		int mid = l + r >> 1;
		build(u << 1, l, mid);
		build(u << 1 | 1, mid + 1, r);
		pushup(u);
	}
}

void modify(int u, int l, int r, int d)
{
	if(tr[u].l >= l && tr[u].r <= r){
		tr[u].sum = (tr[u].r - tr[u].l + 1 - tr[u].sum);
		tr[u].add ^= 1;
	}
	else {
		pushdown(u);
		int mid = tr[u].l + tr[u].r >> 1;
		if(l <= mid) modify(u << 1, l, r, d);
		if(r > mid) modify(u << 1 | 1, l, r, d);
		pushup(u);
	}
}

ll query(int u, int l, int r)
{
	if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
	else 
	{
		pushdown(u);
		ll sum = 0;
		int mid = tr[u].l + tr[u].r >> 1;
		if(l <= mid) sum += query(u << 1, l, r); 
		if(r > mid) sum += query(u << 1 | 1, l, r);
		return sum;
	}
}
int main()
{
    ios::sync_with_stdio(false);
	cin.tie(0);
	cout.tie(0);
	int n,m ;
	cin >> n >> m;
//	for(int i = 1; i <= n; i ++) cin >> w[i];
	build(1, 1, n);
	while(m --){
		int op, l, r;
		cin >> op >> l >> r;
		if(op == 0)
		{
			modify(1, l, r, 1);
		}
		else {
			cout << query(1, l, r) << endl;
		}
	}
	return 0;
}
;