LC 710. 黑名单中的随机数

题目描述

这是 LeetCode 上的 710. 黑名单中的随机数 ,难度为 困难

给定一个整数 n 和一个无重复黑名单整数数组 blacklist

设计一种算法,从 $[0, n - 1]$ 范围内的任意整数中选取一个未加入黑名单 blacklist 的整数。任何在上述范围内且不在黑名单 blacklist 中的整数都应该有同等的可能性被返回。

优化你的算法,使它最小化调用语言内置随机函数的次数。

实现 Solution 类:

  • Solution(int n, int[] blacklist) 初始化整数 n 和被加入黑名单 blacklist 的整数
  • int pick() 返回一个范围为 $[0, n - 1]$ 且不在黑名单 blacklist 中的随机整数

示例 1:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
输入
["Solution", "pick", "pick", "pick", "pick", "pick", "pick", "pick"]
[[7, [2, 3, 5]], [], [], [], [], [], [], []]

输出
[null, 0, 4, 1, 6, 1, 0, 4]

解释
Solution solution = new Solution(7, [2, 3, 5]);
solution.pick(); // 返回0,任何[0,1,4,6]的整数都可以。注意,对于每一个pick的调用,
// 0146的返回概率必须相等(即概率为1/4)。
solution.pick(); // 返回 4
solution.pick(); // 返回 1
solution.pick(); // 返回 6
solution.pick(); // 返回 1
solution.pick(); // 返回 0
solution.pick(); // 返回 4

提示:

  • $1 <= n <= 10^9$
  • $0 <= blacklist.length <= \min(1065, n - 1)$
  • $0 <= blacklist[i] < n$
  • blacklist 中所有值都 不同
  • pick 最多被调用 $2 \times 10^4$ 次

前缀和 + 二分

为了方便,我们记 blacklistbs,将其长度记为 m

问题本质是让我们从 $[0, n - 1]$ 范围内随机一个数,这数不能在 bs 里面。

由于 $n$ 的范围是 $1e9$,我们不能在对范围在 $[0, n - 1]$ 且不在 bs 中的点集进行离散化,因为离散化后的点集大小仍很大。

同时 $m$ 的范围是 $1e5$,我们也不能使用普通的拒绝采样做法,这样单次 pick 被拒绝的次数可能很大。

一个简单且绝对正确的做法是:我们不对「点」做离散化,而利用 bs 数据范围为 $1e5$,来对「线段」做离散化

具体的,我们先对 bs 进行排序,然后从前往后处理所有的 $bs[i]$,将相邻 $bs[i]$ 之间的能被选择的「线段」以二元组 $(a, b)$ 的形式进行记录(即一般情况下的 $a = bs[i - 1] + 1$,$b = bs[i] - 1$),存入数组 list 中(注意特殊处理一下两端的线段)。

当处理完所有的 $bs[i]$ 后,我们得到了所有可被选择线段,同时对于每个线段可直接算得其所包含的整数点数。

我们可以对 list 数组做一遍「线段所包含点数」的「前缀和」操作,得到 sum 数组,同时得到所有线段所包含的总点数(前缀和数组的最后一位)。

对于 pick 操作而言,我们先在 $[1, tot]$ 范围进行随机(其中 $tot$ 代表总点数),假设取得的随机值为 $val$,然后在前缀和数组中进行二分,找到第一个满足「值大于等于 $val$」的位置(含义为找到这个点所在的线段),然后再利用该线段的左右端点的值,取出对应的点。

Java 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class Solution {
List<int[]> list = new ArrayList<>();
int[] sum = new int[100010];
int sz;
Random random = new Random();
public Solution(int n, int[] bs) {
Arrays.sort(bs);
int m = bs.length;
if (m == 0) {
list.add(new int[]{0, n - 1});
} else {
if (bs[0] != 0) list.add(new int[]{0, bs[0] - 1});
for (int i = 1; i < m; i++) {
if (bs[i - 1] == bs[i] - 1) continue;
list.add(new int[]{bs[i - 1] + 1, bs[i] - 1});
}
if (bs[m - 1] != n - 1) list.add(new int[]{bs[m - 1] + 1, n - 1});
}
sz = list.size();
for (int i = 1; i <= sz; i++) {
int[] info = list.get(i - 1);
sum[i] = sum[i - 1] + info[1] - info[0] + 1;
}
}
public int pick() {
int val = random.nextInt(sum[sz]) + 1;
int l = 1, r = sz;
while (l < r) {
int mid = l + r >> 1;
if (sum[mid] >= val) r = mid;
else l = mid + 1;
}
int[] info = list.get(r - 1);
int a = info[0], b = info[1], end = sum[r];
return b - (end - val);
}
}

C++ 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Solution {
public:
vector<pair<int, int>> list;
vector<int> sumv;
int sz;
Solution(int n, vector<int>& bs) {
sort(bs.begin(), bs.end());
int m = bs.size();
if (m == 0) {
list.emplace_back(0, n - 1);
} else {
if (bs[0] != 0) list.emplace_back(0, bs[0] - 1);
for (int i = 1; i < m; i++) {
if (bs[i - 1] != bs[i] - 1) list.emplace_back(bs[i - 1] + 1, bs[i] - 1);
}
if (bs[m - 1] != n - 1) list.emplace_back(bs[m - 1] + 1, n - 1);
}
sz = list.size();
sumv.resize(sz + 1, 0);
for (int i = 1; i <= sz; i++) {
sumv[i] = sumv[i - 1] + list[i - 1].second - list[i - 1].first + 1;
}
}
int pick() {
int val = (rand() % sumv[sz]) + 1;
int l = 1, r = sz;
while (l < r) {
int mid = l + r >> 1;
if (sumv[mid] >= val) r = mid;
else l = mid + 1;
}
auto& info = list[r - 1];
int a = info.first, b = info.second, end = sumv[r];
return b - (end - val);
}
};

Python 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Solution:
def __init__(self, n: int, bs: List[int]):
self.arr = []
self.sumv = [0] * (len(bs) + 10)
self.sz = 0
bs.sort()
m = len(bs)
if m == 0:
self.arr.append([0, n - 1])
else:
if bs[0] != 0:
self.arr.append([0, bs[0] - 1])
for i in range(1, m):
if bs[i - 1] != bs[i] - 1:
self.arr.append([bs[i - 1] + 1, bs[i] - 1])
if bs[m - 1] != n - 1:
self.arr.append([bs[m - 1] + 1, n - 1])
self.sz = len(self.arr)
for i in range(1, self.sz + 1):
self.sumv[i] = self.sumv[i - 1] + self.arr[i - 1][1] - self.arr[i - 1][0] + 1

def pick(self) -> int:
val = random.randint(1, self.sumv[self.sz])
l, r = 1, self.sz
while l < r:
mid = l + r >> 1
if self.sumv[mid] >= val:
r = mid
else:
l = mid + 1
a, b = self.arr[r - 1]
end = self.sumv[r]
return b - (end - val)

  • 时间复杂度:在初始化操作中:对 bs 进行排序复杂度为 $O(m\log{m})$;统计所有线段复杂度为 $O(m)$;对所有线段求前缀和复杂度为 $O(m)$。在 pick 操作中:随机后会对所有线段做二分,复杂度为 $O(\log{m})$
  • 空间复杂度:$O(m)$

哈希表

总共有 $n$ 个数,其中 $m$ 个数不可被选,即真实可选个数为 $n - m$ 个。

为了能够在 $[0, n - m)$ 连续段内进行随机,我们可以将 $[0, n - m)$ 内不可被选的数映射到 $[n - m, n - 1]$ 内可选的数上。

具体的,我们可以使用两个 Set 结构 s1s2 分别记录在 $[0, n - m)$ 和 $[n - m, n - 1]$ 范围内的黑名单数字。

pick 操作时,我们在 $[0, n - m)$ 范围内进行随机,根据随机值 $val$ 是否为黑名单内数字(是否在 s1 中)进行分情况讨论:

  • 随机值 val 不在 s1 中,说明 val 是真实可选的数值,直接返回;
  • 随机值 vals1 中,说明 val 是黑名单内的数值,我们先查询是否已存在 val 的映射记录,若已存在直接返回其映射值;否则需要在 $[n - m, n - 1]$ 内找到一个可替代它的数值,我们可以使用一个变量 idx 在 $[n- m, n - 1]$ 范围内从前往后扫描,找到第一个未被使用的,同时不在 s2 中(不在黑名单内)的数字,假设找到的值为 x,将 x 进行返回(同时将 valx 的映射关系,用哈希表进行记录)。

Java 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution {
int n, m, idx;
Random random = new Random();
Set<Integer> s1 = new HashSet<>(), s2 = new HashSet<>();
Map<Integer, Integer> map = new HashMap<>();
public Solution(int _n, int[] bs) {
n = _n; m = bs.length;
int max = n - m;
for (int x : bs) {
if (x < max) s1.add(x);
else s2.add(x);
}
idx = n - m;
}
public int pick() {
int val = random.nextInt(n - m);
if (!s1.contains(val)) return val;
if (!map.containsKey(val)) {
while (s2.contains(idx)) idx++;
map.put(val, idx++);
}
return map.get(val);
}
}

C++ 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution {
public:
int n, m, idx;
unordered_set<int> s1, s2;
unordered_map<int, int> map;
Solution(int _n, vector<int>& bs) {
n = _n; m = bs.size();
int max = n - m;
for (int x : bs) {
if (x < max) s1.insert(x);
else s2.insert(x);
}
idx = n - m;
}
int pick() {
int val = rand() % (n - m);
if (s1.find(val) == s1.end()) return val;
if (map.find(val) == map.end()) {
while (s2.find(idx) != s2.end()) idx++;
map[val] = idx++;
}
return map[val];
}
};

Python 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution:
def __init__(self, n: int, bs: List[int]):
self.n, self.m = n, len(bs)
self.s1, self.s2 = set(), set()
self.mapping = {}
maxv = self.n - self.m
for x in bs:
if x < maxv:
self.s1.add(x)
else:
self.s2.add(x)
self.idx = self.n - self.m

def pick(self) -> int:
global idx
val = random.randint(0, self.n - self.m - 1)
if val not in self.s1:
return val
if val not in self.mapping:
while self.idx in self.s2:
self.idx += 1
self.mapping[val] = self.idx
self.idx += 1
return self.mapping[val]

  • 时间复杂度:初始化操作复杂度为 $O(m)$,pick 操作复杂度为 $O(1)$
  • 空间复杂度:$O(m)$

最后

这是我们「刷穿 LeetCode」系列文章的第 No.710 篇,系列开始于 2021/01/01,截止于起始日 LeetCode 上共有 1916 道题目,部分是有锁题,我们将先把所有不带锁的题目刷完。

在这个系列文章里面,除了讲解解题思路以外,还会尽可能给出最为简洁的代码。如果涉及通解还会相应的代码模板。

为了方便各位同学能够电脑上进行调试和提交代码,我建立了相关的仓库:https://github.com/SharingSource/LogicStack-LeetCode

在仓库地址里,你可以看到系列文章的题解链接、系列文章的相应代码、LeetCode 原题链接和其他优选题解。