基于redis实现令牌桶

package ratelimit

import (
    "fmt"

    "code.byted.org/gopkg/logs"
    "code.byted.org/ttarch/hermes_go_api/dal/redis"
)

const (
    lua = `redis.replicate_commands()

local tokens_key = KEYS[1]    //令牌桶key
local timestamp_key = KEYS[2]  //最后一次更新令牌桶的时间戳
local rate = tonumber(ARGV[1])  // 令牌产生的速率
local capacity = tonumber(ARGV[2])  //令牌桶容量
local requested = tonumber(ARGV[3]) //所需的令牌数量
local now_time = redis.call('TIME')
local now = now_time[1]*1000000+now_time[2]  //获取当前微秒时间戳

local fill_time = capacity/rate   //生成所有令牌(令牌桶填满)所需要的时间
local ttl = math.floor(fill_time*2)  //缓存时间

local last_tokens = tonumber(redis.call("get", tokensxkey))  //获取令牌桶剩余的令牌数量
if last_tokens == nil then
  last_tokens = capacity  //剩余令牌数初始化为容量大小
end

local last_refreshed = tonumber(redis.call("get", timestamp_key))  // 上次一令牌桶更新的时间戳
if last_refreshed == nil then
  last_refreshed = 0  // 第一次为0
end

local delta = math.floor(math.max(0, now-last_refreshed)*rate/1000000)  // 计算当前时间内新增的令牌数
local filled_tokens = math.min(capacity, last_tokens+delta)  //计算令牌桶里的令牌总数
local allowed = filled_tokens >= requested  //是否有足够的令牌处理请求
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
  new_tokens = filled_tokens - requested  //剩余的令牌数量 = 当前令牌数 - 所需令牌数
  allowed_num = 1
end

if ttl > 0 then
  redis.call("setex", tokens_key, ttl, new_tokens)
  redis.call("setex", timestamp_key, ttl, now-now%math.floor((1000000/rate)))
end

return allowed_num
`
)

func Acquire(uniqueKey string, qps float64) bool {
    if qps == 0 {
        return false
    }
        //1. 令牌桶数量 key 2. 最后一次更新时间key
    keys := []string{fmt.Sprintf("{%s}:tokens_key", uniqueKey), fmt.Sprintf("{%s}:timestamp_key", uniqueKey)}
    rate := qps
    requested := 1.0  // 初始化期望得到的令牌数量,默认为1
    var cap float64
    if qps >= 1 {
                // 令牌桶的容量为请求速率
        cap = rate
    } else {
                // 期望得到的令牌数量 + 请求速率乘以3
        cap = requested + rate*3
    }
    val, err := redis.DefaultRateLimitRedisCli.Eval(lua, keys, rate, cap, requested).Result()
    if err != nil {
        logs.Error("%+v", err)
        return false
    }
    valI, ok := val.(int64)
    if ok {
        return valI == 1
    }
    logs.Error("fail to parse redis result")
    return false
}

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容