Bootstrap

codeforces gym-101736 Farmer Faul 平衡树+并查集

题目

题目链接

题意

给出 n n 个整数,其中1n106
给出三种操作:

  • GROW x y,表示给 x x 位置的数增加y
  • MAGIC x,表示给所有的数增加 x x
  • CUT x,表示把所有大于x的数切割到x,并输出本次切割割了多少。

题解

乍一看,似乎没有很符合这个题目情形的数据结构,我们需要把多种数据结构结合起来。

首先我们把所有的数值相同的元素都归并到一起(采用并查集的方法),并在这一个集合中找出一个关键的点(并查集的根节点)扔到平衡树里面去,平衡树的第一关键字是该并查集所具有的值,第二关键字是该并查集的根节点。

在这个基础下,GROW x y操作就相当于把x处的元素从它所在并查集中拆出来,拆成一个独立的点,然后给这个点的值加y,再把修改后的元素放入起所在的并查集中去。
时间复杂度:O(log(n))

MAGIC x操作就直接记录一个累加 h h 就好了,在查询的时候用到。
时间复杂度:O(1)

CUT x操作就相当于在平衡树中,找到所有的值大于等于x-h的并查集,计算好贡献以后,把所有的这些找到的并查集合并成为一个并查集,并且该并查集的值为x-h。

时间复杂度: O(CUT) O ( 两 次 C U T 之 间 的 操 作 数 )

复杂度计算如果不对,请评论告知我。

注意

  1. 在这道题中并查集还应该记录一个属性,就是并查集的大小。
  2. 该并查集支持元素从并查集中剥离,因此需要为每个元素设置一个盒子,即 ida i d a 数组,当一个元素被剥离的时候,给元素一个新的盒子,原来的并查集结构保持不变,但是要求原来并查集的大小-1。

代码

#include <iostream>
#include <cstdio>
#include <set>
#define pr(x) cout<<#x<<":"<<x<<endl
#define int long long
using namespace std;
typedef pair<int,int> pii;
const int maxn = 2e6;
int pa[maxn],sz[maxn];
void init(){ for(int i = 1;i < maxn;++i) pa[i] = i,sz[i] = 1;}
int find(int x){ return x == pa[x]?x:pa[x] = find(pa[x]);}
void join(int x,int y){
    int px = find(x),py = find(y);
    if(px != py) { pa[px] = py; sz[py] += sz[px];}
}
set<pii> st;
int id,ida[maxn],n,q,h,tmp,a[maxn];
char op[6];
void ins(int pid){
    auto it = st.lower_bound(make_pair(a[pid],0));
    if(it == st.end() || it->first != a[pid]) st.insert(make_pair(a[pid],pid));
    else join(pid,it->second);
}
int split(int pos){
    int pid = find(ida[pos]);  
    if(sz[pid] == 1) st.erase(st.find(make_pair(a[pid],pid)));
    else {sz[pid] --;ida[pos] = ++id;a[id] = a[pid];}
    return pid;
}
main()
{
    init();
    id = 0;
    scanf("%lld%lld",&n,&q);
    for(int i = 1;i <= n;++i) {
        scanf("%lld",&tmp);
        ida[i] = ++id;
        a[id] = tmp;
        ins(id);
    }
    while(q--){
        scanf("%s",op);
        if(*op == 'G'){
            int pos,x;
            scanf("%lld%lld",&pos,&x);
            int nid = split(pos);
            a[nid] += x;
            ins(nid);
        }
        else if(*op == 'M'){
            int x;scanf("%lld",&x);
            h += x;
        }
        else if(*op == 'C'){
            int x;scanf("%lld",&x);
            int ans = 0;
            auto it = st.lower_bound(make_pair(x-h,0));
            if(it == st.end()) {
                printf("0\n");
                continue;
            }
            pii p = make_pair(x-h,it->second);
            ans += (it->first+h-x)*sz[it->second];
            a[it->second] = x-h;
            int pid = it->second;
            it = st.erase(it);
            while(it != st.end()){
                ans += (it->first-x+h)*sz[it->second];
                join(it->second,pid);
                it = st.erase(it);
            }
            st.insert(p);
            printf("%lld\n",ans);
        }
    }
    return 0;
}
;