「计算客国庆赛」简单数据结构题

Description

Link

有 $n$ 个魔法物品(编号为 $1, 2, \cdots, n$)和 $m$ 对互斥关系,每对关系形如 $x, y$,表示编号为 $x$ 和 $y$ 的魔法物品会相互排斥。

回答 $q$ 次询问,每次询问给出 $k$ 个区间 $[l_1,r_1], [l_2,r_2],\cdots,[l_k,r_k]$,你需要知道这 $k$ 个区间的并是否包含互斥的魔法物品。

$$
1 \le n, q, \sum k \le 10 ^5
$$

保证询问区间之间互不相交,且左端点递增。部分测试点强制在线。

Solution

实际上部分分给出了一定的提示。考虑两类暴力:

  1. 遍历 $k$ 个区间全部打上标记,再枚举 $m$ 对关系进行 check。单次询问时间复杂度 $O(n + m)$。

  2. 将限制关系放至二维平面,假设限制关系为 $(a, b)$,则向二维平面插入一个坐标为 $(a, b)$ 的点。

    每次枚举两个区间 $[l _1, r _1], [l _2, r _2] (r _1 < l _2)$,若以 $(l1, l2)$ 为左下角点,$(r1, r2)$ 为右上角的矩形内部有点,则表明包含互斥物品。时间复杂度 $O(k ^2 \log n)$(主席树二维数点)

则考虑根号分治。设阈值 $U$,若 $k > U$ 使用第一种暴力,否则使用第二种。两种操作操作次数均不超过 $\frac q U$,故总复杂度为 $O(q (\frac {n + m} U + U \log n))$,易知 $U$ 取 $\sqrt {\frac n {\log n}}$ 得到最优复杂度 $O(n \sqrt {n \log n})$(几个变量复杂度同阶,故省略)。

常数巨小不开 O2 也能过。

还有一个理论复杂度是 $O(n \sqrt n)$ 的做法。题解里面讲的分块看不懂,这里的 $\sqrt n$ 是 K-D Tree 操作矩形的复杂度。

还是将限制关系放在二维平面上考虑,一开始所有点的点权均为 0。

遍历 $k$ 个区间,设当前区间为 $[l, r]$,首先将横坐标在 $[l, r]$ 内的所有点整体加上 1,然后查询纵坐标在 $[l, r]$ 内的点是否有点权大于 1 的点。有则表明包含了互斥物品。

矩形加,矩形查(而且是对矩形内部的某些关键点操作),K-D Tree 是很好的选择。然而实测完全不像个 $\sqrt n$ 的东西,50000 的数据节点访问次数就到了 $10 ^7$ 级别,更别说一大串常数了…

另外这个做法可以预先加好再查询,易知这不会造成影响。如果有用这个思路过了的,求能教一教…

Code

gen

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from random import randint
import os
N = int(100000)
cnt = 0
while True :
T = 5
outs = "%d 0\n" % T
for o in range(T) :
n, m, q = N, N, N
sumk = N
outs += "%d %d %d\n" % (n, m, q)
for i in range(m) :
x, y = randint(1, n), randint(1, n)
while x == y :
y = randint(1, n)
outs += "%d %d\n" % (x, y)
for i in range(q) :
lis = set()
last = 0
while len(lis) < sumk and last != n:
l, r = randint(last + 1, n), randint(last + 1, n)
if l > r :
l, r = r, l
lis.add((l, r))
last = r
a = sorted(list(lis))
outs += "%d " % len(a)
sumk -= len(a)
for j in a :
outs += "%d %d " % j
outs += '\n'
print(outs, file = open("c.in", "w"))
exit()

os.system("./std && ./c")
if os.system("diff c.out c.ans") :
print("WA")
exit()
else :
cnt += 1
print("AC %d times!" % cnt)

std

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
/**********************************************************
* Author : EndSaH
* Email : [email protected]
* Created Time : 2019-10-12 14:19
* FileName : wib.cpp
* Website : https://endsah.cf
* *******************************************************/

#include <cstdio>
#include <cctype>
#include <cmath>
#include <cstring>
#include <algorithm>

typedef std::pair<int, int> pii;

#define fir first
#define sec second
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define Debug(s) debug("The message in line %d, Function %s: %s\n", __LINE__, __FUNCTION__, s)
#define getchar() (ipos == iend and (iend = (ipos = _ibuf) + fread(_ibuf, 1, __bufsize, stdin), ipos == iend) ? EOF : *ipos++)
#define putchar(ch) (opos == oend ? fwrite(_obuf, 1, __bufsize, stdout), opos = _obuf : 0, *opos++ = (ch))
#define __bufsize (1 << 21 | 1)

char _ibuf[__bufsize], _obuf[__bufsize], _stk[50];
char *ipos = _ibuf, *iend = _ibuf, *opos = _obuf, *oend = _obuf + __bufsize, *stkpos = _stk;

struct END
{ ~END() { fwrite(_obuf, 1, opos - _obuf, stdout); } }
__________;

inline int read()
{
register int x = 0;
register char ch;
while (!isdigit(ch = getchar()));
while (x = x * 10 + (ch & 15), isdigit(ch = getchar()));
return x;
}

template <typename _INT>
inline void write(_INT x)
{
while (*++stkpos = x % 10 ^ 48, x /= 10, x);
while (stkpos != _stk)
putchar(*stkpos--);
}

template<typename _Tp>
inline bool Chkmax(_Tp& x, const _Tp& y)
{ return x < y ? x = y, true : false; }

template<typename _Tp>
inline bool Chkmin(_Tp& x, const _Tp& y)
{ return x > y ? x = y, true : false; }

const int maxN = 1e5 + 5;

int n, m, q, U, anscnt;
pii pt[maxN], inter[maxN];

namespace SEG
{
const int maxM = maxN * 40;

int ncnt, pre, pos, ql, qr;
int root[maxN];
int size[maxM], ls[maxM], rs[maxM];

void Insert(int l, int r, int& cur)
{
size[cur = ++ncnt] = size[pre] + 1;
if (l == r)
return;
int mid = (l + r) >> 1;
if (pos <= mid)
{
rs[cur] = rs[pre];
pre = ls[pre];
Insert(l, mid, ls[cur]);
}
else
{
ls[cur] = ls[pre];
pre = rs[pre];
Insert(mid + 1, r, rs[cur]);
}
}

void Insert(int t, int _pos)
{
pre = root[t - 1], pos = _pos;
Insert(1, n, root[t]);
}

int _query(int l, int r, int cur) // attention the name
{
if (!cur or ql > r or qr < l)
return 0;
if (ql <= l and r <= qr)
return size[cur];
int mid = (l + r) >> 1;
return _query(l, mid, ls[cur]) + _query(mid + 1, r, rs[cur]);
}

int Query(int t, int l, int r)
{
ql = l, qr = r;
return _query(1, n, root[t]);
}

void Print()
{
printf("%d\n", ncnt);
for (int i = 1; i <= ncnt; ++i)
printf("%d: %d %d %d\n", i, ls[i], rs[i], size[i]);
}

void Init()
{
memset(ls + 1, 0, ncnt * sizeof(int));
memset(rs + 1, 0, ncnt * sizeof(int));
ncnt = 0;
}
}

void Init()
{
n = read(), m = read(), q = read();
U = sqrt(n / log2(n));
for (int i = 1; i <= m; ++i)
{
pt[i].fir = read(), pt[i].sec = read();
if (pt[i].fir > pt[i].sec)
std::swap(pt[i].fir, pt[i].sec);
}
std::sort(pt + 1, pt + m + 1);
SEG::Init();
for (int i = 1; i <= m; ++i)
SEG::Insert(i, pt[i].sec);
}

inline int Id(int x)
{ return std::lower_bound(pt + 1, pt + m + 1, pii(x, 0)) - pt; }

int Solve1(int k)
{
for (int i = 1; i <= k; ++i)
{
int l = Id(inter[i].fir) - 1, r = Id(inter[i].sec + 1) - 1;
for (int j = i; j <= k; ++j)
if (SEG::Query(r, inter[j].fir, inter[j].sec) -
SEG::Query(l, inter[j].fir, inter[j].sec) > 0)
return 1;
}
return 0;
}

int Solve2(int k)
{
static bool vis[maxN];
memset(vis + 1, 0, n * sizeof(bool));
for (int i = 1; i <= k; ++i)
memset(vis + inter[i].fir, 1,
(inter[i].sec - inter[i].fir + 1) * sizeof(bool));
for (int i = 1; i <= m; ++i)
if (vis[pt[i].fir] and vis[pt[i].sec])
return 1;
return 0;
}

int main()
{
#ifndef ONLINE_JUDGE
freopen("c.in", "r", stdin);
freopen("c.out", "w", stdout);
#endif
int T = read(), type = read();
while (T--)
{
Init();
while (q--)
{
int k = read();
for (int i = 1; i <= k; ++i)
{
inter[i].fir = read();
inter[i].sec = read();
if (type)
{
inter[i].fir ^= anscnt;
inter[i].sec ^= anscnt;
}
}
int fg;
if (k <= U)
fg = Solve1(k);
else
fg = Solve2(k);
anscnt += fg;
putchar(char(fg + '0'));
}
putchar('\n');
}
return 0;
}