珂朵莉树模板CF896C Willem, Chtholly and Seniorious题解

泥土笨笨

2021-06-23 15:45:25

题解

0x00 前言

事情是这样的,有一天,教练群里面讨论这道题,我说这道题排名前 10 的题解里面,有 7 个都是错的,我打算写一个对的,避免同学们被误导。这时候,群里的 lxl 默默的来了一句:

好的,既然是 lxl 大神出的题,我就更有必要写一篇正确的题解,来表达我对他的膜拜了。

0x01 为什么很多题解不对,照着写会RE

因为如果要用 split 操作,截取一段区间的时候,必须要先 split(r+1) ,再 split(l) ,否则会有 RE ,具体原因我后面会细说。请大家参考其他题解或者资料的时候,也注意这一点。

0x02 什么是珂朵莉树

珂朵莉树,还有个名字叫老司机树(Old Driver Tree, ODT),是一个暴力数据结构。甚至都不一定可以将其称之为数据结构了,我们不妨认为它是一类题目的暴力做法,对于随机数据比较有效。

0x03 珂朵莉树可以解决什么问题

有一类问题,对一个序列,进行一个区间推平操作。就是把一个范围内,比如 [l,r] 范围内的数字变成同一个。可能除了推平以外,还夹杂其他操作。如果数据是随机的,就可以用珂朵莉树啦。比如这道题中的操作 2 ,将 [l,r] 区间内的所有数都改成 x ,这就是一个区间推平操作。

0x04 珂朵莉树的基本原理

暴力的地方来喽,刚才不是提到有推平操作么?那么推平操作结束以后,被推平的区间内的每个数字都是相同的。其实经过若干次推平以后,我们可以看成,这个序列上的数字是一段一段的,每一小段里面数字相同,整个区间由若干个小段组成。类似这样:

这个时候,我们定义一个结构体,用一个结构体变量,来表示每个数字相同的段。

struct Node {
    ll l, r;//l和r表示这一段的起点和终点
    mutable ll v;//v表示这一段上所有元素相同的值是多少

    Node(ll l, ll r = 0, ll v = 0) : l(l), r(r), v(v) {}

    bool operator<(const Node &a) const {
        return l < a.l;//规定按照每段的左端点排序
    }
};

相关变量的含义,注释里面已经解释了。这里有个细节是, v 变量前面加个了 mutable 关键字。 mutable 的意思是,即使它是个常量,也允许修改v的值,具体我在下面区间修改的地方解释。

当每个数字相同的区间都用一个结构体变量表示以后,我们把这四段插入到一个 set 里面, set 会按照每段的左端点顺序进行排序,这样这个序列就维护好了,类似下图:

当然,对于本题,一开始的时候,每段都只有一个数,所以我们的set里面维护n个长度为1的段。

0x05 核心操作 split

天下大势,分久必合,合久必分,珂朵莉树也一样。随着推平操作的进行,有一些位置被合并到了一个 Node 里面,但是也有可能一个 Node 要被拆开,其中的一部分要被改变值。

split 操作就是干这个用的,参数是一个位置 pos ,以 pos 去做切割,找到一个包含 pos 的区间,把它分成 [l,pos-1] , [pos,r] 两半。当然,如果 pos 本身就是一个区间的开头,就不用切割了,直接返回这个区间。

先看代码

set<Node>::iterator split(int pos) {
    set<Node>::iterator it = s.lower_bound(Node(pos));
    if (it != s.end() && it->l == pos) {
        return it;
    }
    it--;
    if (it->r < pos) return s.end();
    ll l = it->l;
    ll r = it->r;
    ll v = it->v;
    s.erase(it);
    s.insert(Node(l, pos - 1, v));
    //insert函数返回pair,其中的first是新插入结点的迭代器
    return s.insert(Node(pos, r, v)).first;
}

首先,第一行里面的 s 是一个全局变量,是那个装 nodeset 。大家知道 set 里面有个函数叫 lower_bound ,它的作用是返回跟参数相等的,或者比参数更大的第一个 set 中元素的位置,返回的是一个迭代器。

那么我们按照 pos 创建一个 node ,然后去查询,就找到了 it 这个位置。这个时候有三种情况,一种是我们正好找到了一个区间,它是以 pos 开头的,所以就对应了代码中的第一个 if 判断,这时候直接返回这个区间的迭代器 it

还有两种情况是,我们找到的这个区间是正好比包含 pos 的区间大一点点的,或者pos太大了,超过了最后一个区间的右端点。不管怎样先把it往前挪一个格,然后这时候看看it的右端点,如果比pos小,说明是pos太大了,就直接返回send()迭代器。否则的话,现在it就是应该包含了pos的那个区间。这时候,我们要把它一分为二,把原来的那个区间删掉,然后插入两个新区间,分别是[l,pos-1][pos,r]

这里还有个小技巧,insert这个函数是有返回值的,它返回的是一个pairpair的第一个字段正好是新插入的那个node的位置的迭代器,所以return那个东西就行了。

0x06 推平操作assign

刚刚的split作用是分,现在还需要一个相反的操作,就是合并。当出现对区间的推平操作的时候,我们可以把整个set中所有要被合并掉的node都删掉,然后插入一个新区间表示推平以后的结果。

如图,按照上面的例子,set里面有4node,此时我们想进行一次推平操作,把[2,8]区间内的元素都改成666.首先我们发现,[8,10]是一个区间,那么需要先split(9),把[8,8][9,10]拆成两个区间。同理,原来的[1,2]区间,也需要拆成[1,1][2,2]

接下来,我们把要被合并的从28的所有node都从set里面删掉,再重新插入一个[2,8]区间就行了。删除某个范围内的元素可以用seterase函数,这个函数接受两个迭代器st,把[s,t)范围内的东西都删掉。

代码如下:

void assign(ll l, ll r, ll x) {
    set<Node>::iterator itr = split(r + 1), itl = split(l);
    s.erase(itl, itr);
    s.insert(Node(l, r, x));
}

0x07 推平操作里面RE的坑

现在说一下为啥大部分题解都不对,注意刚刚assign函数里面调用的那两次split,我是先split(r+1),计算出itr,然后再split(l),计算itl的。这个顺序不能反

为啥不能反?举个具体例子,比如现在有个区间是[1,4],我们想从里面截取[1,1]出来,那么我们需要调用两次split,分别是split(2)split(1)

假设先调用split(1),如图中间的结果:

现在的itl指向的还是原来的那个node,没有什么变化。但是当我们后续调用itr的时候,出事儿了。因为这时候,我们把原来的[1,4]区间删掉了,拆成了两份,itr指向的是后面那个,但是原来itl指向的那个已经被erase掉了。这时候用itlitr调用s.erase的时候就会出问题,直接RE。

有同学说我顺序反了没RE啊,也AC。恭喜你,你人品好。这东西理论上会RE,但是实际上概率不大,我对拍了一下,大概1%的概率RE吧。但是人品不好的同学,可能上来就RE一片,而且是随机RE,同一个数据,一会儿能过,一会儿过不了。所以,还是别给自己找麻烦了。

0x08 修改操作add

区间内每个数都加上x,这个实现方式和前面的推平差不多,我们还是找到这个区间的首尾,然后循环一遍区间内的每个node,把每个nodev都加上x就行

void add(ll l, ll r, ll x) {
    set<Node>::iterator itr = split(r + 1), itl = split(l);
    for (set<Node>::iterator it = itl; it != itr; ++it) {
        it->v += x;
    }
}

这里是用一个迭代器it遍历每个位置,把每个位置的v都加x。大家发现前面提到的mutable的作用了么?因为这里it是个常量迭代器,它不能修改它指向的那个node,而我们这里要改node里面的v,所以就把v声明为mutable,就可以改了。否则会得到类似这样的编译错误: error: cannot assign to return value because function 'operator->' returns a const value

0x09 其他操作

其他操作都是类似的暴力操作。比如要找区间第k小,那么就把区间内所有的node拿出来,按照v从小到大排序,把每个node里面的区间长度相加,看看啥时候加够为止。这里就不细致展开,有问题可以去看代码。

0x0A 复杂度

因为本题数据是随机的,所以每次assign操作的区间长度大概在vmax/3,所以经过很多次assign以后,区间个数不会太多,大概在log这个数量级上。这样每次暴力操作的复杂度差不多也是这个数量级。详细的分析,可以参考这篇博客:

https://www.luogu.com.cn/blog/blaze/solution-cf896c

0x0B 代码时间

#include <iostream>
#include <set>
#include <algorithm>
#include <vector>
#include <cstdio>

using namespace std;

typedef long long ll;
const ll MOD = 1000000007;
const ll MAXN = 100005;

struct Node {
    ll l, r;//l和r表示这一段的起点和终点
    mutable ll v;//v表示这一段上所有元素相同的值是多少

    Node(ll l, ll r = 0, ll v = 0) : l(l), r(r), v(v) {}

    bool operator<(const Node &a) const {
        return l < a.l;//规定按照每段的左端点排序
    }
};

ll n, m, seed, vmax, a[MAXN];
set<Node> s;

//以pos去做切割,找到一个包含pos的区间,把它分成[l,pos-1],[pos,r]两半
set<Node>::iterator split(int pos) {
    set<Node>::iterator it = s.lower_bound(Node(pos));
    if (it != s.end() && it->l == pos) {
        return it;
    }
    it--;
    if (it->r < pos) return s.end();
    ll l = it->l;
    ll r = it->r;
    ll v = it->v;
    s.erase(it);
    s.insert(Node(l, pos - 1, v));
    //insert函数返回pair,其中的first是新插入结点的迭代器
    return s.insert(Node(pos, r, v)).first;
}

/*
 * 这里注意必须先计算itr。
 * 比如现在区间是[1,4],如果要add的是[1,2],如果先split(1)
 * 那么返回的itl是[1,4],但是下一步计算itr的时候会把这个区间删掉拆成[1,2]和[3,4]
 * 那么itl这个指针就被释放了
 * */
void add(ll l, ll r, ll x) {
    set<Node>::iterator itr = split(r + 1), itl = split(l);
    for (set<Node>::iterator it = itl; it != itr; ++it) {
        it->v += x;
    }
}

void assign(ll l, ll r, ll x) {
    set<Node>::iterator itr = split(r + 1), itl = split(l);
    s.erase(itl, itr);
    s.insert(Node(l, r, x));
}

struct Rank {
    ll num, cnt;

    bool operator<(const Rank &a) const {
        return num < a.num;
    }

    Rank(ll num, ll cnt) : num(num), cnt(cnt) {}
};

ll rnk(ll l, ll r, ll x) {
    set<Node>::iterator itr = split(r + 1), itl = split(l);
    vector<Rank> v;
    for (set<Node>::iterator i = itl; i != itr; ++i) {
        v.push_back(Rank(i->v, i->r - i->l + 1));
    }
    sort(v.begin(), v.end());
    int i;
    for (i = 0; i < v.size(); ++i) {
        if (v[i].cnt < x) {
            x -= v[i].cnt;
        } else {
            break;
        }
    }
    return v[i].num;
}

ll ksm(ll x, ll y, ll p) {
    ll r = 1;
    ll base = x % p;
    while (y) {
        if (y & 1) {
            r = r * base % p;
        }
        base = base * base % p;
        y >>= 1;
    }
    return r;
}

ll calP(ll l, ll r, ll x, ll y) {
    set<Node>::iterator itr = split(r + 1), itl = split(l);
    ll ans = 0;
    for (set<Node>::iterator i = itl; i != itr; ++i) {
        ans = (ans + ksm(i->v, x, y) * (i->r - i->l + 1) % y) % y;
    }
    return ans;
}

ll rnd() {
    ll ret = seed;
    seed = (seed * 7 + 13) % MOD;
    return ret;
}

int main() {
    cin >> n >> m >> seed >> vmax;
    for (int i = 1; i <= n; ++i) {
        a[i] = (rnd() % vmax) + 1;
        s.insert(Node(i, i, a[i]));
    }
    for (int i = 1; i <= m; ++i) {
        ll op, l, r, x, y;
        op = (rnd() % 4) + 1;
        l = (rnd() % n) + 1;
        r = (rnd() % n) + 1;
        if (l > r) swap(l, r);
        if (op == 3) {
            x = (rnd() % (r - l + 1)) + 1;
        } else {
            x = (rnd() % vmax) + 1;
        }
        if (op == 4) {
            y = (rnd() % vmax) + 1;
        }
        if (op == 1) {
            add(l, r, x);
        } else if (op == 2) {
            assign(l, r, x);
        } else if (op == 3) {
            cout << rnk(l, r, x) << endl;
        } else {
            cout << calP(l, r, x, y) << endl;
        }
    }
    return 0;
}