文章

蓄水池抽样算法

蓄水池抽样算法解析:问题背景、算法实现、数学证明、工程实践

蓄水池抽样算法

背景问题

对于一个长度未知(假设为n)的数据流,数据流无法全部加载到内存,如何在只遍历一次的情况下随机选取m个元素

算法实现

设 $i$ 为元素在数据流中的编号, $v$ 为具体元素, $m$ 为样本集长度, $n$ 为数据流长度

  1. 当 $i < m$ 时,直接将元素 $v$ 放入样本集的 $i$ 号位置
  2. 当 $i \geq m$ 时,在区间 $\left[ 0, i \right]$ 内取随机数 $d$ ,若 $d < m$ 则将样本集的 $d$ 号位置元素更新为 $v$

时间复杂度为 $O(n)$ ,空间复杂度为 $O(m)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
package algorithm

import "math/rand"

type Iterator[T any] interface {
    Next() T
    HasNext() bool
}

// ReservoirSampling 蓄水池抽样
func ReservoirSampling[T any](it Iterator[T], m int) []T {
    ans := make([]T, m)
    for i := 0; it.HasNext(); i++ {
        // i - 在数据流中的索引号
        // v - 具体数据
        v := it.Next()

        // 对于 i < m 的数据,直接放入样本集的i号位置
        if i < m {
            ans[i] = v
            continue
        }

        // 对于 i >= m 的数据,在[0, i]内取随机数d,若 d < m 则替换样本集的d号位置
        ri := rand.Intn(i + 1)
        if ri < m {
            ans[ri] = v
            continue
        }
    }
    return ans
}

数学证明

样本 $i$ 进入样本集的概率 $P_{enter}$

\[P_{enter} = \begin{cases} 1, & i < m \\ \frac{m}{i+1}, & i \geq m \end{cases}\]
  1. 当 $i < m$ 时,进入样本集的概率为 $1$
  2. 当 $i \geq m$ 时,随机数 $d$ 的随机范围为 $i+1$ ,样本集的长度为 $m$ ,进入样本集的概率为 $\frac{m}{i+1}$

样本 $i$ 不被换出样本集的概率 $P_{not\ replaced}$

\[P_{not\ replaced} = \begin{cases} \prod\limits_{k=m}^{n}\left(1 - \frac{1}{k+1} \right) = \frac{m}{n}, & i < m \\ \prod\limits_{k=i+1}^{n}\left(1 - \frac{1}{k+1} \right) = \frac{i+1}{n}, & i \geq m \end{cases}\]
  1. 当 $i < m$ 时,仅当 $i’ \geq m$ 时有机会将 $i$ 号元素换出,不被换出的概率为 $(1 - \frac{1}{m+1}) \times (1 - \frac{1}{m+2}) \times \cdots \times (1 - \frac{1}{n}) = \frac{m}{n}$
  2. 当 $i \geq m$ 时,仅当 $i’ \geq i+1$ 时有机会将 $i$ 号元素换出,不被换出的概率为 $(1 - \frac{1}{i+2}) \times (1 - \frac{1}{i+2}) \times \cdots \times (1 - \frac{1}{n}) = \frac{i+1}{n}$

样本 $i$ 最终保留在样本集的概率 $P_{stay}$

\[P_{stay} = P_{enter} \times P_{not\ replaced} = \frac{m}{n}\]

结论

数据流每个元素的选取概率相同,均为 $\frac{m}{n}$ ,证毕

工程实践

数据流抽样

  1. Redis内存淘汰样本集抽样
  2. 总人数未知的在线实时抽奖

红包分配

  1. 若红包有n分,则红包有n个线段,n+1个点
  2. 将红包划分为m份,则需要随机选取m-1个点,以将线段划分为m份
  3. 忽略最左端点和最右端点,保证最左线段和最右线段非空
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
package algorithm

import (
    "math/rand"
    "sort"
)

// SplitRedPacket 基于蓄水池抽样算法的线段切割法,随机选取不重复的点
func SplitRedPacket(n int, m int) []int {
    switch m {
        case 1:
            return []int{n}
        case 2:
            x := 1 + rand.Intn(n-1) // x = [1, n-1]
            return []int{x, n - x}
    }

    // 设红包由n个线段和n+1个点(点的编号为[0, n])组成

    // 根据要求随机选取m-1个点
    points := make([]int, m-1)

    // 忽略最左端点0和最右端点n, 从1遍历至n-1
    for point := 1; point < n; point++ {
        // i - 在数据流中的索引号
        // point - 具体数据
        i := point - 1

        if i < len(points) {
            points[i] = point
            continue
        }

        ri := rand.Intn(i + 1)
        if ri < len(points) {
            points[ri] = point
            continue
        }
    }

    // 随机选取了m-1个点,对随机选取的点按位置顺序排序
    sort.Ints(points)

    // 计算m个线段长度
    ans := make([]int, m)
    ans[0] = points[0]
    for i := 1; i < m-1; i++ {
        ans[i] = points[i] - points[i-1]
    }
    ans[m-1] = n - points[m-2]
    return ans
}

权重抽奖

在线实时抽奖进阶版,按用户权重随机选取m个不重复的用户(普通用户权重为1,关键用户权重大于1)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
package algorithm

import (
    "math/rand"
)

func WeightRaffle(uwm map[uint64]int, m int) []uint64 {
    ans := make([]uint64, m)
    i := 0
    for uid, weight := range uwm {
        // 未抽中当前用户时重试
        // 蓄水池抽样算法能够保证数据流的每个元素被选取的概率相同
        for ; weight > 0; i++ {
            // i - 数据流编号
            // uid - 具体数据

            if i < m {
                ans[i] = uid
                break
            }

            ri := rand.Intn(i + 1)
            if ri < m {
                ans[ri] = uid
                break
            }

            weight--
        }
    }
    return ans
}
本文由作者按照 CC BY 4.0 进行授权