问题描述:
“给出一个数据流,这个数据流的长度很大或者未知。并且对该数据流中数据只能访问一次。请写出一个随机选择算法,使得数据流中所有数据被选中的概率相等。”
通过数学归纳法进行分析,找出规律:
数据流只有一个数据。接收一个数据,发现数据流结束了,直接返回该数据,其概率为1。
数据流中有两个数据。接收第一个数据,此时不能立即返回数据,因为流还没有结束。继续读取第二个数据,发现数据流结束了。我们生成一个随机整数(各个整数概率相等),取值范围在[0,1],如果=0就返回第一个数据,如果=1就返回第二个元素。
数据流中有三个数据,假定为1、2、3。和上边一样,我们会陆续接收到1、2,此时我们只能保留一个数据,我们以二分之一的概率进行取舍。假如我们淘汰了2。继续读取数据流得到3,发现数据流结束了。此时返回3的概率应该为1/3时,才能保证选择的正确性。也就是说,此时我们手中有两个数据1、3,通过一次随机选择,以1/3的概率留下3,以2/3的概率留下数据1。那么数据1最终被留下的概率是:
- 数据1被留下:(1/2)*(2/3)= 1/3
- 数据2被留下:(1/2)*(2/3)= 1/3
- 数据3被留下:1/3
这个方法满足题目要求,所有数据被留下返回的概率一样。
因此,我们做一下推论:假设当前正在读取第n个数据,则我们以1/n的概率留下该数据,否则留下前n-1个数据中的一个。以这种方法选择,所有数据流中的数据被选择的概率一样。简短证明:假设n-1时候成立,即前n-1个数据被放回的概率都是1/n-1,当前正在读取的第n个数据,以1/n的概率返回它。那么前n-1个数据中数据被返回的概率为:(1/n-1)*(n-1)/n=1/n,假设成立。
以上最终选择的数据个数为1,这个可以改为k,其中k <= n。
- Java代码实现:
import javax.validation.constraints.NotNull;
import java.util.Random;
import java.util.stream.IntStream;
public class ReservoirSampling {
// default k = 1;
int k;
// pick result
Object result[];
// random
Random r;
public ReservoirSampling() {
k = 1;
result = new Object[k];
r = new Random();
}
public ReservoirSampling(int k) {
this.k = k;
result = new Object[k];
r = new Random();
}
public void pick(@NotNull Object[] data) {
if (k > data.length) {
result = data;
}
for (int i = 0; i < k; i++) {
result[i] = data[i];
}
for (int i = k; i < data.length; i++) {
int t = r.nextInt(i + 1);
// picked
if (t <= k - 1) {
int j = r.nextInt(k);
result[j] = data[i];
}
}
}
public String show() {
StringBuilder sb = new StringBuilder("");
for(Object o : result) {
sb.append(o + " ");
}
String ts = sb.toString().trim();
//System.out.println(ts);
return ts;
}
public static void main(String[] args) {
ReservoirSampling rs = new ReservoirSampling(3);
int MAX = Integer.MAX_VALUE;
Object[] data = new Object[]{1,2,3,4,5,6,7,8,9,10};
int[] ratio = new int[data.length];
IntStream.range(0, MAX).forEach(i -> {
rs.pick(data);
String[] indexs = rs.show().split(" ");
for (int k = 0; k < indexs.length; k++) {
int index = Integer.valueOf(indexs[k]);
ratio[index - 1]++;
}
});
for (int i = 0; i < ratio.length; i++) {
System.out.println("picked " + data[i] + ", ratio=" + new Double(1.0 * ratio[i] / MAX));
}
}
}
输出结果:
picked 1, ratio=0.29998814188874706
picked 2, ratio=0.3000000958796591
picked 3, ratio=0.30000969828106916
picked 4, ratio=0.3000079115387089
picked 5, ratio=0.30000100159086335
picked 6, ratio=0.3000035445671545
picked 7, ratio=0.29999905140139116
picked 8, ratio=0.29999870355240943
picked 9, ratio=0.3000037606339919
picked 10, ratio=0.2999880906660054