LC 528. 按权重随机选择
题目描述
这是 LeetCode 上的 528. 按权重随机选择 ,难度为 中等。
给定一个正整数数组 $w$ ,其中 $w[i]$ 代表下标 $i$ 的权重(下标从 $0$ 开始),请写一个函数 pickIndex ,它可以随机地获取下标 $i$,选取下标 $i$ 的概率与 $w[i]$ 成正比。
例如,对于 $w = [1, 3]$,挑选下标 $0$ 的概率为 $1 / (1 + 3) = 0.25$ (即,$25$%),而选取下标 $1$ 的概率为 $3 / (1 + 3) = 0.75$(即,$75$%)。
也就是说,选取下标 $i$ 的概率为 $w[i] / sum(w)$ 。
示例 1:1
2
3
4
5
6
7
8
9
10输入:
["Solution","pickIndex"]
[[[1]],[]]
输出:
[null,0]
解释:
Solution solution = new Solution([1]);
solution.pickIndex(); // 返回 0,因为数组中只有一个元素,所以唯一的选择是返回下标 0。
示例 2:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23输入:
["Solution","pickIndex","pickIndex","pickIndex","pickIndex","pickIndex"]
[[[1,3]],[],[],[],[],[]]
输出:
[null,1,1,1,1,0]
解释:
Solution solution = new Solution([1, 3]);
solution.pickIndex(); // 返回 1,返回下标 1,返回该下标概率为 3/4 。
solution.pickIndex(); // 返回 1
solution.pickIndex(); // 返回 1
solution.pickIndex(); // 返回 1
solution.pickIndex(); // 返回 0,返回下标 0,返回该下标概率为 1/4 。
由于这是一个随机问题,允许多个答案,因此下列输出都可以被认为是正确的:
[null,1,1,1,1,0]
[null,1,1,1,1,1]
[null,1,1,1,0,0]
[null,1,1,1,0,1]
[null,1,0,1,0,0]
......
诸若此类。
提示:
- $1 <= w.length <= 10000$
- $1 <= w[i] <= 10^5$
- pickIndex将被调用不超过 $10000$ 次
前缀和 + 二分
根据题意,权重值 $w[i]$ 可以作为 pickIndex 调用总次数为 $\sum_{i = 0}^{w.length - 1} w[i]$ 时,下标 $i$ 的返回次数。
随机数的产生可以直接使用语言自带的 API,剩下的我们需要构造一个分布符合权重的序列。
由于 $1 <= w[i] <= 10^5$,且 $w$ 长度为 $10^4$,因此直接使用构造一个有 $w[i]$ 个的 $i$ 的数字会 MLE。
我们可以使用「前缀和」数组来作为权重分布序列,权重序列的基本单位为 $1$。
一个长度为 $n$ 的构造好的「前缀和」数组可以看是一个基本单位为 $1$ 的 $[1, sum[n - 1]]$ 数轴。
使用随机函数参数产生 $[1, sum[n - 1]]$ 范围内的随机数,通过「二分」前缀和数组即可找到分布位置对应的原始下标值。
Java 代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20class Solution {
    int[] sum;
    public Solution(int[] w) {
        int n = w.length;
        sum = new int[n + 1];
        for (int i = 1; i <= n; i++) sum[i] = sum[i - 1] + w[i - 1];
    }
    
    public int pickIndex() {
        int n = sum.length;
        int t = (int) (Math.random() * sum[n - 1]) + 1;
        int l = 1, r = n - 1;
        while (l < r) {
            int mid = l + r >> 1;
            if (sum[mid] >= t) r = mid;
            else l = mid + 1;
        }
        return r - 1;
    }
}
C++ 代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20class Solution {
public:
    vector<int> sumv;
    Solution(vector<int>& w) {
        int n = w.size();
        sumv.resize(n + 1);
        for (int i = 1; i <= n; i++) sumv[i] = sumv[i - 1] + w[i - 1];
    }
    int pickIndex() {
        int n = sumv.size();
        int t = (rand() % sumv.back()) + 1; // [1, sum.back()]
        int l = 1, r = n - 1;
        while (l < r) {
            int mid = (l + r) >> 1;
            if (sumv[mid] >= t) r = mid;
            else l = mid + 1;
        }
        return r - 1;
    }
};
Python 代码:1
2
3
4
5
6
7
8
9
10
11class Solution:
    def __init__(self, w: List[int]):
        n = len(w)
        self.sumv = [0] * (n + 1)
        for i in range(1, n + 1):
            self.sumv[i] = self.sumv[i - 1] + w[i - 1]
    def pickIndex(self) -> int:
        target = random.randint(1, self.sumv[-1])
        return bisect.bisect_left(self.sumv, target) - 1
- 时间复杂度:Solution类的构造方法整体复杂度为 $O(n)$;pickIndex的复杂度为 $O(\log{n})$
- 空间复杂度:$O(n)$
模拟(桶轮询)
利用 OJ 不太聪明(对权重分布做近似检查),我们可以构造一个最小轮询序列(权重精度保留到小数点一位),并使用 $(i, cnt)$ 的形式进行存储,代表下标 $i$ 在最小轮询序列中出现次数为 $cnt$。
然后使用两个编号 $bid$ 和 $iid$ 来对桶进行轮询返回(循环重置 & 跳到下一个桶)。
该解法的最大好处是不需要使用 random 函数,同时返回的连续序列满足每一段(长度不短于最小段)都符合近似权重分布。
Java 代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35class Solution {
    // 桶编号 / 桶内编号 / 总数
    int bid, iid, tot;
    List<int[]> list = new ArrayList<>();
    public Solution(int[] w) {
        int n = w.length;
        double sum = 0, min = 1e9;
        for (int i : w) {
            sum += i;
            min = Math.min(min, i);
        }
        double minv = min / sum;
        int k = 1;
        while (minv * k < 1) k *= 10;
        for (int i = 0; i < n; i++) {
            int cnt = (int)(w[i] / sum * k);
            list.add(new int[]{i, cnt});
            tot += cnt;
        }
    }
    
    public int pickIndex() {
        if (bid >= list.size()) {
            bid = 0; iid = 0;
        }
        int[] info = list.get(bid);
        int id = info[0], cnt = info[1];
        if (iid >= cnt) {
            bid++; iid = 0;
            return pickIndex();
        }
        iid++;
        return id;
    }
}
C++ 代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33class Solution {
public:
    int bid = 0, iid = 0, tot = 0;
    vector<pair<int, int>> list;
    Solution(vector<int>& w) {
        double sum = 0, minv = 1e9;
        for (int num : w) {
            sum += num;
            if (num < minv) minv = num;
        }
        double t = minv / sum;
        int k = 1;
        while (t * k < 1) k *= 10;
        for (int i = 0; i < w.size(); i++) {
            int cnt = static_cast<int>(w[i] / sum * k);
            list.emplace_back(i, cnt);
            tot += cnt;
        }
    }
    
    int pickIndex() {
        if (bid >= list.size()) {
            bid = 0; iid = 0;
        }
        auto [id, cnt] = list[bid];
        if (iid >= cnt) {
            bid++; iid = 0;
            return pickIndex();
        }
        iid++;
        return id;
    }
};
Python 代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27class Solution:
    def __init__(self, w: List[int]):
        self.bid = self.iid = self.tot = 0
        self.nums = []
        
        total, minv = sum(w), min(w)
        t = minv / total
        k = 1
        while t * k < 1:
            k *= 10
        for idx, weight in enumerate(w):
            cnt = int(weight / total * k)
            self.nums.append( (idx, cnt) )
            self.tot += cnt
    def pickIndex(self) -> int:
        if self.bid >= len(self.nums):
            self.bid = self.iid = 0
            
        idx, cnt = self.nums[self.bid]
        if self.iid >= cnt:
            self.bid += 1
            self.iid = 0
            return self.pickIndex()
            
        self.iid += 1
        return idx
- 时间复杂度:计算 $k$ 的操作只会发生一次,可以看作是一个均摊到每个下标的常数计算,Solution类的构造方法的整体复杂度可看作 $O(n)$;pickIndex的复杂度为 $O(1)$
- 空间复杂度:$O(n)$
最后
这是我们「刷穿 LeetCode」系列文章的第 No.528 篇,系列开始于 2021/01/01,截止于起始日 LeetCode 上共有 1916 道题目,部分是有锁题,我们将先把所有不带锁的题目刷完。
在这个系列文章里面,除了讲解解题思路以外,还会尽可能给出最为简洁的代码。如果涉及通解还会相应的代码模板。
为了方便各位同学能够电脑上进行调试和提交代码,我建立了相关的仓库:https://github.com/SharingSource/LogicStack-LeetCode 。
在仓库地址里,你可以看到系列文章的题解链接、系列文章的相应代码、LeetCode 原题链接和其他优选题解。
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!
