泥土笨笨
2021-06-23 15:45:25
事情是这样的,有一天,教练群里面讨论这道题,我说这道题排名前
好的,既然是 lxl 大神出的题,我就更有必要写一篇正确的题解,来表达我对他的膜拜了。
因为如果要用 split
操作,截取一段区间的时候,必须要先 split(r+1)
,再 split(l)
,否则会有 RE ,具体原因我后面会细说。请大家参考其他题解或者资料的时候,也注意这一点。
珂朵莉树,还有个名字叫老司机树(Old Driver Tree, ODT),是一个暴力数据结构。甚至都不一定可以将其称之为数据结构了,我们不妨认为它是一类题目的暴力做法,对于随机数据比较有效。
有一类问题,对一个序列,进行一个区间推平操作。就是把一个范围内,比如
暴力的地方来喽,刚才不是提到有推平操作么?那么推平操作结束以后,被推平的区间内的每个数字都是相同的。其实经过若干次推平以后,我们可以看成,这个序列上的数字是一段一段的,每一小段里面数字相同,整个区间由若干个小段组成。类似这样:
这个时候,我们定义一个结构体,用一个结构体变量,来表示每个数字相同的段。
struct Node {
ll l, r;//l和r表示这一段的起点和终点
mutable ll v;//v表示这一段上所有元素相同的值是多少
Node(ll l, ll r = 0, ll v = 0) : l(l), r(r), v(v) {}
bool operator<(const Node &a) const {
return l < a.l;//规定按照每段的左端点排序
}
};
相关变量的含义,注释里面已经解释了。这里有个细节是, v
变量前面加个了 mutable
关键字。 mutable
的意思是,即使它是个常量,也允许修改v的值,具体我在下面区间修改的地方解释。
当每个数字相同的区间都用一个结构体变量表示以后,我们把这四段插入到一个 set
里面, set
会按照每段的左端点顺序进行排序,这样这个序列就维护好了,类似下图:
当然,对于本题,一开始的时候,每段都只有一个数,所以我们的set里面维护n个长度为1的段。
split
天下大势,分久必合,合久必分,珂朵莉树也一样。随着推平操作的进行,有一些位置被合并到了一个 Node
里面,但是也有可能一个 Node
要被拆开,其中的一部分要被改变值。
split
操作就是干这个用的,参数是一个位置 pos
,以 pos
去做切割,找到一个包含 pos
的区间,把它分成 [l,pos-1]
, [pos,r]
两半。当然,如果 pos
本身就是一个区间的开头,就不用切割了,直接返回这个区间。
先看代码
set<Node>::iterator split(int pos) {
set<Node>::iterator it = s.lower_bound(Node(pos));
if (it != s.end() && it->l == pos) {
return it;
}
it--;
if (it->r < pos) return s.end();
ll l = it->l;
ll r = it->r;
ll v = it->v;
s.erase(it);
s.insert(Node(l, pos - 1, v));
//insert函数返回pair,其中的first是新插入结点的迭代器
return s.insert(Node(pos, r, v)).first;
}
首先,第一行里面的 s
是一个全局变量,是那个装 node
的 set
。大家知道 set
里面有个函数叫 lower_bound
,它的作用是返回跟参数相等的,或者比参数更大的第一个 set
中元素的位置,返回的是一个迭代器。
那么我们按照 pos
创建一个 node
,然后去查询,就找到了 it
这个位置。这个时候有三种情况,一种是我们正好找到了一个区间,它是以 pos
开头的,所以就对应了代码中的第一个 if
判断,这时候直接返回这个区间的迭代器 it
。
还有两种情况是,我们找到的这个区间是正好比包含 pos
的区间大一点点的,或者pos
太大了,超过了最后一个区间的右端点。不管怎样先把it
往前挪一个格,然后这时候看看it
的右端点,如果比pos
小,说明是pos
太大了,就直接返回s
的end()
迭代器。否则的话,现在it
就是应该包含了pos
的那个区间。这时候,我们要把它一分为二,把原来的那个区间删掉,然后插入两个新区间,分别是[l,pos-1]
和[pos,r]
。
这里还有个小技巧,insert
这个函数是有返回值的,它返回的是一个pair
,pair
的第一个字段正好是新插入的那个node
的位置的迭代器,所以return
那个东西就行了。
assign
刚刚的split
作用是分,现在还需要一个相反的操作,就是合并。当出现对区间的推平操作的时候,我们可以把整个set
中所有要被合并掉的node
都删掉,然后插入一个新区间表示推平以后的结果。
如图,按照上面的例子,set
里面有node
,此时我们想进行一次推平操作,把[2,8]
区间内的元素都改成[8,10]
是一个区间,那么需要先split(9)
,把[8,8]
和[9,10]
拆成两个区间。同理,原来的[1,2]
区间,也需要拆成[1,1]
和[2,2]
。
接下来,我们把要被合并的从node
都从set
里面删掉,再重新插入一个[2,8]
区间就行了。删除某个范围内的元素可以用set
的erase
函数,这个函数接受两个迭代器s
和t
,把[s,t)
范围内的东西都删掉。
代码如下:
void assign(ll l, ll r, ll x) {
set<Node>::iterator itr = split(r + 1), itl = split(l);
s.erase(itl, itr);
s.insert(Node(l, r, x));
}
现在说一下为啥大部分题解都不对,注意刚刚assign
函数里面调用的那两次split
,我是先split(r+1)
,计算出itr
,然后再split(l)
,计算itl
的。这个顺序不能反。
为啥不能反?举个具体例子,比如现在有个区间是[1,4]
,我们想从里面截取[1,1]
出来,那么我们需要调用两次split
,分别是split(2)
和split(1)
。
假设先调用split(1)
,如图中间的结果:
现在的itl
指向的还是原来的那个node
,没有什么变化。但是当我们后续调用itr
的时候,出事儿了。因为这时候,我们把原来的[1,4]
区间删掉了,拆成了两份,itr
指向的是后面那个,但是原来itl
指向的那个已经被erase
掉了。这时候用itl
和itr
调用s.erase
的时候就会出问题,直接RE。
有同学说我顺序反了没RE啊,也AC。恭喜你,你人品好。这东西理论上会RE,但是实际上概率不大,我对拍了一下,大概1%的概率RE吧。但是人品不好的同学,可能上来就RE一片,而且是随机RE,同一个数据,一会儿能过,一会儿过不了。所以,还是别给自己找麻烦了。
add
区间内每个数都加上x
,这个实现方式和前面的推平差不多,我们还是找到这个区间的首尾,然后循环一遍区间内的每个node
,把每个node
的v
都加上x
就行
void add(ll l, ll r, ll x) {
set<Node>::iterator itr = split(r + 1), itl = split(l);
for (set<Node>::iterator it = itl; it != itr; ++it) {
it->v += x;
}
}
这里是用一个迭代器it
遍历每个位置,把每个位置的v
都加x
。大家发现前面提到的mutable
的作用了么?因为这里it
是个常量迭代器,它不能修改它指向的那个node
,而我们这里要改node
里面的v
,所以就把v
声明为mutable
,就可以改了。否则会得到类似这样的编译错误:
error: cannot assign to return value because function 'operator->' returns a const value
其他操作都是类似的暴力操作。比如要找区间第node
拿出来,按照v
从小到大排序,把每个node
里面的区间长度相加,看看啥时候加够为止。这里就不细致展开,有问题可以去看代码。
因为本题数据是随机的,所以每次assign
操作的区间长度大概在assign
以后,区间个数不会太多,大概在log
这个数量级上。这样每次暴力操作的复杂度差不多也是这个数量级。详细的分析,可以参考这篇博客:
https://www.luogu.com.cn/blog/blaze/solution-cf896c
#include <iostream>
#include <set>
#include <algorithm>
#include <vector>
#include <cstdio>
using namespace std;
typedef long long ll;
const ll MOD = 1000000007;
const ll MAXN = 100005;
struct Node {
ll l, r;//l和r表示这一段的起点和终点
mutable ll v;//v表示这一段上所有元素相同的值是多少
Node(ll l, ll r = 0, ll v = 0) : l(l), r(r), v(v) {}
bool operator<(const Node &a) const {
return l < a.l;//规定按照每段的左端点排序
}
};
ll n, m, seed, vmax, a[MAXN];
set<Node> s;
//以pos去做切割,找到一个包含pos的区间,把它分成[l,pos-1],[pos,r]两半
set<Node>::iterator split(int pos) {
set<Node>::iterator it = s.lower_bound(Node(pos));
if (it != s.end() && it->l == pos) {
return it;
}
it--;
if (it->r < pos) return s.end();
ll l = it->l;
ll r = it->r;
ll v = it->v;
s.erase(it);
s.insert(Node(l, pos - 1, v));
//insert函数返回pair,其中的first是新插入结点的迭代器
return s.insert(Node(pos, r, v)).first;
}
/*
* 这里注意必须先计算itr。
* 比如现在区间是[1,4],如果要add的是[1,2],如果先split(1)
* 那么返回的itl是[1,4],但是下一步计算itr的时候会把这个区间删掉拆成[1,2]和[3,4]
* 那么itl这个指针就被释放了
* */
void add(ll l, ll r, ll x) {
set<Node>::iterator itr = split(r + 1), itl = split(l);
for (set<Node>::iterator it = itl; it != itr; ++it) {
it->v += x;
}
}
void assign(ll l, ll r, ll x) {
set<Node>::iterator itr = split(r + 1), itl = split(l);
s.erase(itl, itr);
s.insert(Node(l, r, x));
}
struct Rank {
ll num, cnt;
bool operator<(const Rank &a) const {
return num < a.num;
}
Rank(ll num, ll cnt) : num(num), cnt(cnt) {}
};
ll rnk(ll l, ll r, ll x) {
set<Node>::iterator itr = split(r + 1), itl = split(l);
vector<Rank> v;
for (set<Node>::iterator i = itl; i != itr; ++i) {
v.push_back(Rank(i->v, i->r - i->l + 1));
}
sort(v.begin(), v.end());
int i;
for (i = 0; i < v.size(); ++i) {
if (v[i].cnt < x) {
x -= v[i].cnt;
} else {
break;
}
}
return v[i].num;
}
ll ksm(ll x, ll y, ll p) {
ll r = 1;
ll base = x % p;
while (y) {
if (y & 1) {
r = r * base % p;
}
base = base * base % p;
y >>= 1;
}
return r;
}
ll calP(ll l, ll r, ll x, ll y) {
set<Node>::iterator itr = split(r + 1), itl = split(l);
ll ans = 0;
for (set<Node>::iterator i = itl; i != itr; ++i) {
ans = (ans + ksm(i->v, x, y) * (i->r - i->l + 1) % y) % y;
}
return ans;
}
ll rnd() {
ll ret = seed;
seed = (seed * 7 + 13) % MOD;
return ret;
}
int main() {
cin >> n >> m >> seed >> vmax;
for (int i = 1; i <= n; ++i) {
a[i] = (rnd() % vmax) + 1;
s.insert(Node(i, i, a[i]));
}
for (int i = 1; i <= m; ++i) {
ll op, l, r, x, y;
op = (rnd() % 4) + 1;
l = (rnd() % n) + 1;
r = (rnd() % n) + 1;
if (l > r) swap(l, r);
if (op == 3) {
x = (rnd() % (r - l + 1)) + 1;
} else {
x = (rnd() % vmax) + 1;
}
if (op == 4) {
y = (rnd() % vmax) + 1;
}
if (op == 1) {
add(l, r, x);
} else if (op == 2) {
assign(l, r, x);
} else if (op == 3) {
cout << rnk(l, r, x) << endl;
} else {
cout << calP(l, r, x, y) << endl;
}
}
return 0;
}