前言
在刷Leetcode的过程种,遇到过不少类似的问题:给出一个链表,如何从中随机获取一个节点?
直观的解法是把链表转换为List,或者获取其长度,再用Random解决。那么假如不能使用额外空间以及不允许事先获取其长度呢?一边扫描一边随机采样,这就是Reservoir Sampling能做到的。
事实上,Reservoir Sampling可以用来解决n个元素里面随机抽取k个,乃至于支持不平均的随机权重,不过先让我们看看一个最简单的例子吧。
举例分析
[1->2->3]包含3个元素,够简单了吧。每个元素应该有1/3的几率被抽中。当然了事先我们不知道有3个元素。
首先指向1。假如后面没有了,那这自然没别的选择。而我们看到后面还有一个2.
假如只考虑1->2,那么也就是有1/2的几率往后走一步:
a) 留在1。那么问题在于1只能看到2,而2后面还可能有东西,所以需要两个指针,一个遍历链表,一个指向我们所取值的位置。这样,我们就能看到3,而到了3,每个元素概率变为1/3,也就是说假如继续留在1的概率是x,那么1/2 * x = 1/3,x=2/3,也就是说1/3的概率这时候进到2.
b)进到2,这样3来了以后,留在2的几率也应该是1/3。这有两种可能:一是第一步1进2,然后留在2;二是第一步留在1,然后再进到2。假如留在2的概率是x,那么有1/2*x+1/2*(1-2/3)=1/3,x=1/3.
而1和2的概率搞定了,那3的概率自然也搞定了。
推广I
现在我们尝试将其推广到一般情况。假如我们已经有办法从n个里面随机挑了,也就是说每个元素几率是1/n,现在来了一个新元素,要想办法让每个元素几率成为1/(n+1)。
根据上面的分析,假如现在还留在1,那么本身概率是1/n,然后要继续留下来的概率要为n/(n+1),这样最终留下了的概率才是1/n * n/(n+1) = 1/(n+1)。
对于2,受到1的影响,同样假如留下的概率是x,有x/n + 1/(n(n+1)) = 1/(n+1)得x=(n-1)/(n+1)。
假如对于i,其留下的概率是(n+1-i)/(n+1),没留下即往前一步的概率则是i/(n+1)。
可以验证:对于i+1,那么取值为它的概率P=(1/n)*P(留在i+1)+(1/n)*P(没留在i)=(1/n)*P(留在i+1)+i/(n(n+1))=1/(n+1)可得x=(n-i)/(n+1),符合之前的公式。
算法
题目
整理一下之前的内容,我们需要一个指针指向当前选择的节点i,一个指针来遍历并记录当前长度n,然后每当长度增加时进行一个判断,即(n+1-i)/(n+1)概率不动,否则i指向下一个节点。当遍历完成时,返回指向的节点即可。时间复杂度为O(n)。代码如下:
class Solution {
private final ListNode head;
private final Random random = new Random();
public Solution(ListNode head) {
this.head = head;
}
public int getRandom() {
int count = 1, i = 1;
ListNode res = head, cur = head;
while (cur.next != null) {
count++;
boolean stay = hit((double) (count - i) / count);
if (!stay) {
res = res.next;
i++;
}
cur = cur.next;
}
return res.val;
}
private boolean hit(double chance) {
return random.nextDouble() <= chance;
}
}
推广II
上面说的是从n个抽1个元素,现在尝试推广到k个元素(1<=k<=n)。
显然每个元素应该有k/n的几率被抽中。操作如下:
- 首先抽出1~k;
- 对于k+1,以k/(k+1)的概率选择它,然后再与前k个中随机的一个元素置换;
- 对于k+i,以k/(k+i)的概率选择它,然后再与前k个中随机的一个元素置换;
- 持续进行直至k+i=n。前k个元素即为所求。
k=1正是我们之前分析的情况,当然这里具体操作还是不一样,不过本质还是一致的。
假设我们已经处理好了k+i-1,现在来到k+i。预期结果是每个元素有k/(k+i)的几率被选择。
那么对于之前的元素x,其在这一轮后被选择的概率为
P=P(x之前就被选择)*P(x这一轮没有被换出去)
=[k/(k+i-1)] * [1-P(x这轮被换出去)]
=[k/(k+i-1)] * [1-P(选中k+i)*P(与x置换)]
=[k/(k+i-1)] *[1-k/(k+i)*(1/k)]=k/(k+i)
使用这种方法得到的k=1的代码:
public int getRandom() {
int count = 1;
ListNode res = head, cur = head;
while (cur.next != null) {
count++;
cur = cur.next;
boolean chosen = hit((double) 1 / count);
if (chosen) res = cur;
}
return res.val;
}