1. 水塘抽样算法简介

水塘抽样(英语:Reservoir sampling)是一系列的随机算法,其目的在于从包含n个项目的集合S中选取k个样本,其中n为一很大或未知的数量,尤其适用于不能把所有n个项目都存放到内存的情况。最常见例子为Jeffrey Vitter在其论文[1]中所提及的算法R。(摘自维基百科)

因为他可以处理 n 很大甚至未知的情况,所以在大数据中应用广泛,比如 Flink 就用到了这个算法。


2. 水塘抽样思想

假设我们有一个 Steam, 简称为 S,我们需要从S中选取 k 个元素,并且要保证每一个元素取到的可能性相等。

n 已知时,为了保证每一个元素取到的概率均相同,我们可以使用随机数生成的方法来解决这个问题,但是当 n 未知时又如何保证呢?如果我们先去遍历一次获取数据的总长度,在面对大数据的场景下,未免过于费时间,若去估计整体数据的规模,则可能最后采样的结果分布并不是均匀的。

所以,这里,水塘抽样算法就派上用场了,这里用自然语言来描述一下整个过程:

采样过程: 集合中总元素个数为 n,随机选取 k 个元素。

  • step1: 首先将前k个元素全部选取;
  • step2: 对于第ii 个元素i>ki > k,以概率ki\frac{k}{i} 来决定是否保留该元素,如果保留该元素的话,则随机丢弃掉原有的kk个元素中的一个(即原来某个元素被丢掉的概率是1k\frac{1}{k});

结果: 每个元素被最终被选取的概率都是kn\frac{k}{n}

3. 具体例子

下面,我就以一个具体例子来介绍这个算法的思想,需要一点古典概率的知识。

假设我们有一个长度为nn 的数据流SS, 我们将其产生的第ii 个数据编号为ii ,需要从SS 中取 1 个元素。

首先,我们读到了第 1 个数据,此时不能直接返回,因为数据流并没有结束,此时 1 被选中的概率是 1。

接下来,我们读到了第 2 个数据,此时数据 2 被选中的概率是12\frac{1}{2},因为解集合的容量有限,所以数据 1 被留下的概率(即选中数据 1 的概率)也为1/21 / 2

接着, 我们读到了第 3 个数据,此时数据 3 被选中的概率是13\frac{1}{3},此时:

假设前两个数据中留下来的是数据 1 ,则此时我们手里有 2 个数据(1 、3),需要淘汰掉一个数据,又因为第 3 个元素被选中的概率是13\frac{1}{3},所以数据 1 被选择的概率是23\frac{2}{3},即:

数据 1 最终被保留的概率为1223 = 13\frac{1}{2} * \frac{2}{3} \ = \ \frac{1}{3} (两个数中选一个的概率是12\frac{1}{2},再乘以数据 1 被选择的概率23\frac{2}{3} );

同理,数据 2 被保留下来的概率为1223 = 13\frac{1}{2} * \frac{2}{3} \ = \ \frac{1}{3}

数据 3 被保留下来的概率为13\frac{1}{3}

依次类推。

4. 证明(太长不看系列)

前提:假设需要采样的数量为kk,数据流的长度为nn

  1. 对于第ii 个数(iki \le k),在第kk 步之前,被选中的概率均为 1。当读到到第k+1k + 1 个数时,第k+1k + 1 个数被选中的概率为kk+1\frac{k}{k + 1},所以解集合中的数被替换的概率为:第k+1k + 1 个元素被选中的概率乘以解集合中第ii 个元素被选中替换的概率,即kk+11k = 1k+1\frac{k}{k + 1} * \frac{1}{k} \ = \ \frac{1}{k+1}, 不被第k+1k+1 个元素替换的概率为11k+1 = kk+11 - \frac{1}{k + 1} \ = \ \frac{k}{k + 1},不被第k+2k+2 个元素替换的概率为1kk+21k = k+1k+21 - \frac{k}{k + 2}*\frac{1}{k} \ = \ \frac{k+1}{k + 2},读取到第nn 个数时,第ii 个数仍然保留的概率为:1kk+1k+1k+2...n1n = kn1 * \frac{k}{k + 1} *\frac{k + 1}{k + 2}* ... * \frac{n - 1}{n} \ = \ \frac{k}{n}

注:乘以 1/k 的意义是从长度为 k 的解集合中等可能取元素的概率是 1 / k。

  1. 对于第jj 个数(j>kj > k), 在第jj 步被选中的概率为kj\frac{k}{j}

    不被第j+1j + 1 个元素替换的概率为:1kj+11k = jj+11 - \frac{k}{j + 1} * \frac{1}{k} \ = \ \frac{j}{j + 1}

    则读到第nn 个数时,第jj 个数被保留的概率为被选中的概率乘以不被替换的概率,

    即:kjjj+1...n1n = kn\frac{k}{j} * \frac{j}{j + 1} * ... *\frac{n - 1}{n} \ = \ \frac{k}{n}

所以,对于数据流中的每一个元素,被保留的概率均为kn\frac{k}{n}

5. 节选算法题

LeetCode382:链表随机结点

给你一个单链表,随机选择链表的一个节点,并返回相应的节点值。每个节点 被选中的概率一样

这里,我们为了避免计算一次链表的长度,就可以采用水塘抽样算法。

下面的具体的代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
/**
* Definition for singly-linked list.
* struct ListNode {
* int val;
* ListNode *next;
* ListNode() : val(0), next(nullptr) {}
* ListNode(int x) : val(x), next(nullptr) {}
* ListNode(int x, ListNode *next) : val(x), next(next) {}
* };
*/
class Solution {
private:
ListNode *head;
public:
Solution(ListNode *head) {
this->head = head;
}

int getRandom() {
srand(unsigned(time(NULL)));
int i = 1, ans = 0;
for (auto node = head; node; node = node->next) {
if (rand() % i == 0) { // 1/i 的概率选中(替换为答案)
ans = node->val;
}
++i;
}
return ans;
}
};