蓄水池采样算法-Lua版本

由于业务需要,所以搜索了一些相关的随机算法
代码是参考维基百科进行编写的:https://en.wikipedia.org/wiki/Reservoir_sampling

注意点:
Chao算法会有缺陷,因为它一开始就把所需要的数据全部扔池子里了。 如果权重存在0的的数据,且数据量较少,可能会出现在最终结果里。(实验中,Chao算法结果不正确
Res算法,我使用了遍历求最小值,所以在处理大量数据时,可能会存在性能瓶颈。

代码:
-- 蓄水池采样算法
function reserviorSampling(tbSequence, dNeed, dSequenceSize, funcRandom)
    dSequenceSize = dSequenceSize or #tbSequence
    funcRandom = funcRandom or math.random
    
    if dNeed > dSequenceSize then
        return tbSequence
    end
    
    local tbSample = {}
    
    for i=1, dNeed do
        table.insert(tbSample, tbSequence[i])
    end
    
    for i=dNeed+1, dSequenceSize do
        local j = funcRandom(1,i)
        if dNeed >= j then
            tbSample[j] = tbSequence[i]
        end
    end
    
    return tbSample
end

-- 加权蓄水池采样算法: Algorithm A-Chao
function weightedReserviorSampling_Chao(tbSequence, dNeed, dSequenceSize, funcRandom)
    dSequenceSize = dSequenceSize or #tbSequence
    funcRandom = funcRandom or math.random
    
    if dNeed > dSequenceSize then
        return tbSequence
    end

    local dWeightSum = 0
    local tbSample = {}
    
    for i=1, dNeed do
        table.insert(tbSample, tbSequence[i])
        dWeightSum = dWeightSum + tbSequence[i].weight
    end
    
    for i=dNeed+1, dSequenceSize do
        dWeightSum = dWeightSum + tbSequence[i].weight
        local p = tbSequence[i].weight / dWeightSum
        local j = funcRandom()
        if j <= p then
            tbSample[funcRandom(1, dNeed)] = tbSequence[i]
        end
    end
    
    return tbSample
end

-- 加权蓄水池采样算法: Algorithm A-Res
function weightedReserviorSampling_Res(tbSequence, dNeed, dSequenceSize, funcRandom)
    dSequenceSize = dSequenceSize or #tbSequence
    funcRandom = funcRandom or math.random
    
    if dNeed > dSequenceSize then
        return tbSequence
    end

    local dWeightSum = 0
    local tbSample = {}
    
    local dMinIndex = nil
    
    for i=1, dSequenceSize do
        local r = funcRandom()^(1/tbSequence[i].weight)
        tbSequence[i].__reservior_r = r
        if i <= dNeed then
            table.insert(tbSample, tbSequence[i])
        else
            if not dMinIndex then
                local dMin = 9999
                for i,v in ipairs(tbSample) do
                    if v.__reservior_r < dMin then
                        dMin = v.__reservior_r
                        dMinIndex = i
                    end
                end
            end
            assert(dMinIndex)
            if r > tbSample[dMinIndex].__reservior_r then
                table.remove(tbSample,dMinIndex)
                dMinIndex = nil
                table.insert(tbSample, tbSequence[i])
            end
        end
    end
    
    return tbSample
end
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 最近有个需求,需要从不固定大小的数据集中取固定数量的数据作为样本,有个同学提到了蓄水池算法,于是了解了一下。 蓄水...
    hatlonely阅读 1,607评论 0 0
  • --- layout: post title: "如果有人问你关系型数据库的原理,叫他看这篇文章(转)" date...
    蓝坠星阅读 864评论 0 3
  • 0. 导语 推荐系统里面有两个经典问题:EE 问题和冷启动问题。前者涉及到平衡准确和多样,后者涉及到产品算法运营等...
    Liam_ml阅读 1,790评论 0 4
  • 一、基础篇 1.1 JVM 1.1.1. Java内存模型,Java内存管理,Java堆和栈,垃圾回收 http:...
    勿以浮沙筑高台阅读 952评论 0 9
  • 从博厄斯(Franz Boas)对北美西北海岸印第安人的夸富宴(potlach)的介绍开始,人类学家不断地从不同视...
    萍儿100081阅读 9,786评论 1 8