蓄水池抽样算法
蓄水池抽样算法解析:问题背景、算法实现、数学证明、工程实践
蓄水池抽样算法
背景问题
对于一个长度未知(假设为n)的数据流,数据流无法全部加载到内存,如何在只遍历一次的情况下随机选取m个元素?
算法实现
设 $i$ 为元素在数据流中的编号, $v$ 为具体元素, $m$ 为样本集长度, $n$ 为数据流长度
- 当 $i < m$ 时,直接将元素 $v$ 放入样本集的 $i$ 号位置
- 当 $i \geq m$ 时,在区间 $\left[ 0, i \right]$ 内取随机数 $d$ ,若 $d < m$ 则将样本集的 $d$ 号位置元素更新为 $v$
时间复杂度为 $O(n)$ ,空间复杂度为 $O(m)$
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
package algorithm
import "math/rand"
type Iterator[T any] interface {
Next() T
HasNext() bool
}
// ReservoirSampling 蓄水池抽样
func ReservoirSampling[T any](it Iterator[T], m int) []T {
ans := make([]T, m)
for i := 0; it.HasNext(); i++ {
// i - 在数据流中的索引号
// v - 具体数据
v := it.Next()
// 对于 i < m 的数据,直接放入样本集的i号位置
if i < m {
ans[i] = v
continue
}
// 对于 i >= m 的数据,在[0, i]内取随机数d,若 d < m 则替换样本集的d号位置
ri := rand.Intn(i + 1)
if ri < m {
ans[ri] = v
continue
}
}
return ans
}
数学证明
样本 $i$ 进入样本集的概率 $P_{enter}$
\[P_{enter} = \begin{cases} 1, & i < m \\ \frac{m}{i+1}, & i \geq m \end{cases}\]- 当 $i < m$ 时,进入样本集的概率为 $1$
- 当 $i \geq m$ 时,随机数 $d$ 的随机范围为 $i+1$ ,样本集的长度为 $m$ ,进入样本集的概率为 $\frac{m}{i+1}$
样本 $i$ 不被换出样本集的概率 $P_{not\ replaced}$
\[P_{not\ replaced} = \begin{cases} \prod\limits_{k=m}^{n}\left(1 - \frac{1}{k+1} \right) = \frac{m}{n}, & i < m \\ \prod\limits_{k=i+1}^{n}\left(1 - \frac{1}{k+1} \right) = \frac{i+1}{n}, & i \geq m \end{cases}\]- 当 $i < m$ 时,仅当 $i’ \geq m$ 时有机会将 $i$ 号元素换出,不被换出的概率为 $(1 - \frac{1}{m+1}) \times (1 - \frac{1}{m+2}) \times \cdots \times (1 - \frac{1}{n}) = \frac{m}{n}$
- 当 $i \geq m$ 时,仅当 $i’ \geq i+1$ 时有机会将 $i$ 号元素换出,不被换出的概率为 $(1 - \frac{1}{i+2}) \times (1 - \frac{1}{i+2}) \times \cdots \times (1 - \frac{1}{n}) = \frac{i+1}{n}$
样本 $i$ 最终保留在样本集的概率 $P_{stay}$
\[P_{stay} = P_{enter} \times P_{not\ replaced} = \frac{m}{n}\]结论
数据流每个元素的选取概率相同,均为 $\frac{m}{n}$ ,证毕
工程实践
数据流抽样
- Redis内存淘汰样本集抽样
- 总人数未知的在线实时抽奖
- …
红包分配
- 若红包有n分,则红包有n个线段,n+1个点
- 将红包划分为m份,则需要随机选取m-1个点,以将线段划分为m份
- 忽略最左端点和最右端点,保证最左线段和最右线段非空
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
package algorithm
import (
"math/rand"
"sort"
)
// SplitRedPacket 基于蓄水池抽样算法的线段切割法,随机选取不重复的点
func SplitRedPacket(n int, m int) []int {
switch m {
case 1:
return []int{n}
case 2:
x := 1 + rand.Intn(n-1) // x = [1, n-1]
return []int{x, n - x}
}
// 设红包由n个线段和n+1个点(点的编号为[0, n])组成
// 根据要求随机选取m-1个点
points := make([]int, m-1)
// 忽略最左端点0和最右端点n, 从1遍历至n-1
for point := 1; point < n; point++ {
// i - 在数据流中的索引号
// point - 具体数据
i := point - 1
if i < len(points) {
points[i] = point
continue
}
ri := rand.Intn(i + 1)
if ri < len(points) {
points[ri] = point
continue
}
}
// 随机选取了m-1个点,对随机选取的点按位置顺序排序
sort.Ints(points)
// 计算m个线段长度
ans := make([]int, m)
ans[0] = points[0]
for i := 1; i < m-1; i++ {
ans[i] = points[i] - points[i-1]
}
ans[m-1] = n - points[m-2]
return ans
}
权重抽奖
在线实时抽奖进阶版,按用户权重随机选取m个不重复的用户(普通用户权重为1,关键用户权重大于1)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
package algorithm
import (
"math/rand"
)
func WeightRaffle(uwm map[uint64]int, m int) []uint64 {
ans := make([]uint64, m)
i := 0
for uid, weight := range uwm {
// 未抽中当前用户时重试
// 蓄水池抽样算法能够保证数据流的每个元素被选取的概率相同
for ; weight > 0; i++ {
// i - 数据流编号
// uid - 具体数据
if i < m {
ans[i] = uid
break
}
ri := rand.Intn(i + 1)
if ri < m {
ans[ri] = uid
break
}
weight--
}
}
return ans
}
本文作者: panyc0217
本文链接: https://panyc0217.github.io/posts/%E8%93%84%E6%B0%B4%E6%B1%A0%E6%8A%BD%E6%A0%B7%E7%AE%97%E6%B3%95/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明出处!
本文由作者按照 CC BY 4.0 进行授权