python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > Python限流器

使用Python手搓一个生产级限流器

作者:铭渊老黄

这篇文章主要为大家详细介绍了三种生产级限流算法的Python实现,即令牌桶,漏桶和滑动窗口,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下

“限流不是拒绝服务,而是保护服务——让系统在压力下优雅地活着,而不是突然崩塌。”

一、为什么你的服务需要限流

2023 年某知名 AI 平台在新品发布当晚,因未做好限流防护,短短 15 分钟内涌入的请求量超过正常峰值的 80 倍,数据库连接池耗尽,整个服务雪崩。事后复盘,技术负责人苦涩地说:“如果当时有一个限流器,最多损失部分用户体验,而不是全站瘫痪。”

限流(Rate Limiting),是系统稳定性设计中最重要的防线之一。它的本质是:在时间维度上控制资源的消费速率,保护下游系统不被压垮,同时对用户提供公平的服务保障。

三种经典限流算法——令牌桶(Token Bucket)漏桶(Leaky Bucket)滑动窗口(Sliding Window)——各有侧重,适用场景不同。本文将带你从原理到代码,逐一实现它们,并在文末给出选型建议。

二、固定窗口:最朴素的限流(以及它的致命缺陷)

在深入三大算法之前,先了解最简单的固定窗口计数器,理解它的不足,才能体会后续算法的价值。

import time
import threading
from collections import defaultdict

class FixedWindowRateLimiter:
    """
    固定窗口限流器
    缺陷:窗口边界处可能出现 2 倍流量突刺
    """
    
    def __init__(self, max_requests: int, window_seconds: int):
        self.max_requests = max_requests
        self.window_seconds = window_seconds
        self._counts = defaultdict(int)
        self._window_starts = defaultdict(float)
        self._lock = threading.Lock()
    
    def is_allowed(self, key: str = "default") -> bool:
        now = time.time()
        
        with self._lock:
            window_start = self._window_starts[key]
            
            # 当前时间超过窗口,重置计数
            if now - window_start >= self.window_seconds:
                self._window_starts[key] = now
                self._counts[key] = 0
            
            if self._counts[key] < self.max_requests:
                self._counts[key] += 1
                return True
            
            return False


# 演示固定窗口的边界问题
limiter = FixedWindowRateLimiter(max_requests=10, window_seconds=60)

# 假设窗口为 [0s, 60s),在第 59s 发送 10 个请求(全部通过)
# 然后在第 61s(新窗口开始)再发送 10 个请求(全部通过)
# 实际上在 [59s, 61s] 这 2 秒内通过了 20 个请求——是上限的 2 倍!
print("固定窗口在边界处存在 2 倍流量突刺风险")

这个"边界突刺"问题在高并发场景下可能让下游系统瞬间承受双倍压力。这正是滑动窗口要解决的问题。

三、令牌桶算法:允许突发,平滑限速

原理图解

令牌桶示意图:

                   ┌─────────────────┐
  以固定速率        │  🪙 🪙 🪙 🪙 🪙  │  桶容量上限
  补充令牌 →        │  🪙 🪙 🪙 🪙    │  (burst capacity)
                   └────────┬────────┘
                            │
                    请求到来时取令牌
                            ↓
                   有令牌 → 放行请求
                   无令牌 → 拒绝/等待

令牌桶的核心特性:允许一定程度的流量突发(桶里积累的令牌可以被一次性消费),但长期平均速率不超过令牌补充速率。这非常适合 API 调用:用户可能偶尔有突发需求,但不能持续高频。

Python 实现

import time
import threading
from typing import Optional

class TokenBucketRateLimiter:
    """
    令牌桶限流器
    
    特性:
    - 允许短时突发(桶满时可消费全部令牌)
    - 长期速率受 rate 控制
    - 线程安全
    """
    
    def __init__(self, rate: float, capacity: int):
        """
        :param rate: 令牌补充速率(个/秒)
        :param capacity: 桶的最大容量(突发上限)
        """
        self.rate = rate
        self.capacity = capacity
        self._tokens = float(capacity)   # 初始满桶
        self._last_refill = time.monotonic()
        self._lock = threading.Lock()
    
    def _refill(self):
        """根据时间流逝补充令牌(惰性计算,不需要后台线程)"""
        now = time.monotonic()
        elapsed = now - self._last_refill
        new_tokens = elapsed * self.rate
        self._tokens = min(self.capacity, self._tokens + new_tokens)
        self._last_refill = now
    
    def acquire(self, tokens: int = 1) -> bool:
        """
        尝试获取令牌
        :param tokens: 需要的令牌数(支持批量消费)
        :return: True=放行, False=拒绝
        """
        with self._lock:
            self._refill()
            
            if self._tokens >= tokens:
                self._tokens -= tokens
                return True
            return False
    
    def acquire_with_wait(self, tokens: int = 1, 
                          timeout: float = None) -> bool:
        """
        等待直到获取到令牌(阻塞版本)
        :param timeout: 最长等待时间(秒),None 表示永久等待
        """
        start = time.monotonic()
        
        while True:
            with self._lock:
                self._refill()
                if self._tokens >= tokens:
                    self._tokens -= tokens
                    return True
            
            # 计算需要等待多久才能获得足够令牌
            with self._lock:
                deficit = tokens - self._tokens
                wait_time = deficit / self.rate
            
            if timeout is not None:
                elapsed = time.monotonic() - start
                if elapsed + wait_time > timeout:
                    return False
            
            time.sleep(min(wait_time, 0.01))  # 最多睡 10ms 避免过度等待
    
    @property
    def current_tokens(self) -> float:
        """查看当前令牌数(调试用)"""
        with self._lock:
            self._refill()
            return self._tokens


# ── 使用示例 ──────────────────────────────────────────────
def demo_token_bucket():
    # 每秒补充 5 个令牌,桶容量 10
    limiter = TokenBucketRateLimiter(rate=5, capacity=10)
    
    print("=== 令牌桶演示 ===")
    print(f"初始令牌数:{limiter.current_tokens:.1f}")
    
    # 模拟突发请求:一次性消费 8 个令牌
    results = []
    for i in range(12):
        allowed = limiter.acquire()
        results.append("✅" if allowed else "❌")
        print(f"请求 {i+1:2d}:{'放行' if allowed else '拒绝'} "
              f"(剩余令牌:{limiter.current_tokens:.2f})")
    
    # 等待 2 秒后令牌恢复
    print("\n等待 2 秒,令牌恢复中...")
    time.sleep(2)
    print(f"2 秒后令牌数:{limiter.current_tokens:.1f}(补充了约 10 个)")
    print(f"再次请求:{'放行' if limiter.acquire() else '拒绝'}")

demo_token_bucket()

四、漏桶算法:绝对平滑,削峰填谷

原理图解

漏桶示意图:

  请求涌入(任意速率)
   ↓↓↓↓↓↓↓↓
  ┌──────────┐
  │          │  ← 桶满则溢出(拒绝请求)
  │  队列    │
  │  缓冲    │
  │          │
  └────┬─────┘
       │ 以固定速率漏出(处理请求)
       ↓
  恒定速率输出

漏桶与令牌桶的核心区别:漏桶的输出速率是严格恒定的,不允许突发。无论入流量多大,处理速率始终如一。这对下游服务的保护最为彻底,但用户体验上感知更明显——即使桶里有空间,也要排队等待。

Python 实现

import time
import threading
from collections import deque
from dataclasses import dataclass
from typing import Any, Optional

@dataclass
class Request:
    """封装请求及其元数据"""
    data: Any
    arrive_time: float
    key: str = "default"

class LeakyBucketRateLimiter:
    """
    漏桶限流器
    
    特性:
    - 严格恒定的输出速率
    - 超出桶容量的请求直接丢弃
    - 天然实现请求整形(Traffic Shaping)
    """
    
    def __init__(self, rate: float, capacity: int):
        """
        :param rate: 漏出速率(请求数/秒)
        :param capacity: 桶容量(最大排队数)
        """
        self.rate = rate
        self.capacity = capacity
        self._interval = 1.0 / rate  # 两次漏出之间的间隔
        
        self._queue: deque = deque()
        self._lock = threading.Lock()
        self._last_leak_time = time.monotonic()
        
        # 启动后台漏出线程
        self._running = True
        self._worker = threading.Thread(target=self._leak_worker, daemon=True)
        self._worker.start()
    
    def _leak_worker(self):
        """后台工作线程:以固定速率处理请求"""
        while self._running:
            time.sleep(self._interval)
            
            with self._lock:
                if self._queue:
                    request = self._queue.popleft()
                    wait_time = time.monotonic() - request.arrive_time
                    print(f"  🚰 漏出处理:key={request.key}, "
                          f"等待={wait_time*1000:.1f}ms")
    
    def submit(self, data: Any, key: str = "default") -> bool:
        """
        提交请求到漏桶
        :return: True=入队成功, False=桶已满(丢弃)
        """
        with self._lock:
            if len(self._queue) >= self.capacity:
                print(f"  💧 桶已满,丢弃请求:key={key}")
                return False
            
            request = Request(data=data, arrive_time=time.monotonic(), key=key)
            self._queue.append(request)
            print(f"  📥 请求入队:key={key}, 队列深度={len(self._queue)}")
            return True
    
    def stop(self):
        self._running = False


# ── 使用示例 ──────────────────────────────────────────────
def demo_leaky_bucket():
    print("\n=== 漏桶演示(2 req/s,桶容量 5)===")
    limiter = LeakyBucketRateLimiter(rate=2, capacity=5)
    
    # 模拟突发:瞬间提交 8 个请求
    for i in range(8):
        limiter.submit(f"request-{i}", key=f"user_{i % 3}")
    
    # 等待漏桶处理完毕
    time.sleep(4)
    limiter.stop()
    print("漏桶处理完成,输出速率严格为 2 req/s")

demo_leaky_bucket()

五、滑动窗口算法:精准统计,消除边界突刺

原理图解

固定窗口(存在边界问题):
  |← 窗口1(0-60s)→|← 窗口2(60-120s)→|
  |  9 requests     |  9 requests        |
                    ↑
               边界时刻:前后 2 秒内实际有 18 个请求!

滑动窗口(精确统计):
  当前时间 t=75s,窗口大小 60s
  统计 [15s, 75s] 内的请求数 → 精确到每一秒

滑动窗口分为两种实现:滑动日志(精确但内存大)滑动计数器(近似但高效)

Python 实现

import time
import threading
from collections import deque
from typing import Dict

class SlidingWindowLogLimiter:
    """
    滑动窗口日志限流器(精确版)
    记录每个请求的时间戳,精确统计窗口内的请求数
    
    适合:低流量、高精度要求的场景
    缺陷:内存占用随请求量线性增长
    """
    
    def __init__(self, max_requests: int, window_seconds: float):
        self.max_requests = max_requests
        self.window_seconds = window_seconds
        # key → deque of timestamps
        self._logs: Dict[str, deque] = {}
        self._lock = threading.Lock()
    
    def is_allowed(self, key: str = "default") -> bool:
        now = time.monotonic()
        window_start = now - self.window_seconds
        
        with self._lock:
            if key not in self._logs:
                self._logs[key] = deque()
            
            log = self._logs[key]
            
            # 清除窗口外的旧记录
            while log and log[0] <= window_start:
                log.popleft()
            
            # 判断窗口内请求数
            if len(log) < self.max_requests:
                log.append(now)
                return True
            
            return False
    
    def current_count(self, key: str = "default") -> int:
        """当前窗口内的请求数"""
        now = time.monotonic()
        window_start = now - self.window_seconds
        
        with self._lock:
            if key not in self._logs:
                return 0
            return sum(1 for t in self._logs[key] if t > window_start)


class SlidingWindowCounterLimiter:
    """
    滑动窗口计数器限流器(高效版)
    将窗口划分为多个小格子,用计数器近似统计
    
    适合:高流量、对精度要求不苛刻的场景
    Redis 分布式限流的主流方案
    """
    
    def __init__(self, max_requests: int, window_seconds: float, 
                 precision: int = 10):
        """
        :param precision: 将窗口划分为多少个子格子(越多越精确,越耗内存)
        """
        self.max_requests = max_requests
        self.window_seconds = window_seconds
        self.precision = precision
        self.slot_duration = window_seconds / precision
        
        # 每个 key 的滑动计数器:{slot_index: count}
        self._counters: Dict[str, Dict[int, int]] = {}
        self._lock = threading.Lock()
    
    def _current_slot(self) -> int:
        """当前时间对应的格子索引"""
        return int(time.monotonic() / self.slot_duration)
    
    def is_allowed(self, key: str = "default") -> bool:
        current_slot = self._current_slot()
        # 有效的格子范围
        valid_slots = set(range(current_slot - self.precision + 1, 
                                current_slot + 1))
        
        with self._lock:
            if key not in self._counters:
                self._counters[key] = {}
            
            counter = self._counters[key]
            
            # 清除过期格子
            expired = [s for s in counter if s not in valid_slots]
            for s in expired:
                del counter[s]
            
            # 统计窗口内总请求数
            total = sum(counter.values())
            
            if total < self.max_requests:
                counter[current_slot] = counter.get(current_slot, 0) + 1
                return True
            
            return False


# ── 对比演示 ──────────────────────────────────────────────
def demo_sliding_window():
    print("\n=== 滑动窗口演示(10 req/10s)===")
    
    log_limiter = SlidingWindowLogLimiter(max_requests=10, window_seconds=10)
    counter_limiter = SlidingWindowCounterLimiter(max_requests=10, 
                                                   window_seconds=10, 
                                                   precision=10)
    
    # 模拟在 5 秒内发送 15 个请求
    results_log = []
    results_counter = []
    
    for i in range(15):
        results_log.append(log_limiter.is_allowed("user_1"))
        results_counter.append(counter_limiter.is_allowed("user_1"))
        time.sleep(0.2)
    
    print(f"滑动日志结果:{['✅' if r else '❌' for r in results_log]}")
    print(f"计数器结果:  {['✅' if r else '❌' for r in results_counter]}")
    print(f"日志版窗口内请求数:{log_limiter.current_count('user_1')}")

demo_sliding_window()

六、整合实战:带限流的 FastAPI 中间件

将三种限流器整合为一个统一接口,集成到 FastAPI 中:

from abc import ABC, abstractmethod
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
import time

class RateLimiter(ABC):
    """限流器抽象基类"""
    
    @abstractmethod
    def is_allowed(self, key: str) -> bool:
        pass
    
    @abstractmethod
    def get_info(self, key: str) -> dict:
        """返回限流状态信息(用于响应头)"""
        pass


class RateLimitMiddleware:
    """
    FastAPI 限流中间件
    支持按 IP、用户 ID、API Key 等维度限流
    """
    
    def __init__(self, app, limiter: RateLimiter, 
                 key_func=None, exclude_paths: list = None):
        self.app = app
        self.limiter = limiter
        self.key_func = key_func or self._default_key_func
        self.exclude_paths = exclude_paths or ["/health", "/docs"]
    
    @staticmethod
    def _default_key_func(request: Request) -> str:
        """默认按客户端 IP 限流"""
        forwarded_for = request.headers.get("X-Forwarded-For")
        if forwarded_for:
            return forwarded_for.split(",")[0].strip()
        return request.client.host
    
    async def __call__(self, scope, receive, send):
        if scope["type"] == "http":
            request = Request(scope, receive)
            
            # 排除不限流的路径
            if request.url.path not in self.exclude_paths:
                key = self.key_func(request)
                
                if not self.limiter.is_allowed(key):
                    response = JSONResponse(
                        status_code=429,
                        content={
                            "error": "Too Many Requests",
                            "message": "请求过于频繁,请稍后再试",
                            "retry_after": 1
                        },
                        headers={
                            "Retry-After": "1",
                            "X-RateLimit-Limit": "100",
                        }
                    )
                    await response(scope, receive, send)
                    return
        
        await self.app(scope, receive, send)


# 构建应用
app = FastAPI(title="限流演示 API")

# 选择限流算法(可热切换)
# 方案1:令牌桶(允许突发)
rate_limiter = TokenBucketRateLimiter(rate=10, capacity=20)
# 方案2:滑动窗口(精确统计)
# rate_limiter = SlidingWindowLogLimiter(max_requests=100, window_seconds=60)

app.add_middleware(RateLimitMiddleware, limiter=rate_limiter)


@app.get("/api/data")
async def get_data(request: Request):
    client_ip = request.client.host
    return {
        "message": "请求成功",
        "client": client_ip,
        "timestamp": time.time()
    }

@app.get("/health")
async def health():
    return {"status": "ok"}

七、Redis 分布式限流:生产环境的标准方案

单机限流器在多实例部署时会失效——每个实例各自维护状态,总请求数可能超出预期。生产环境必须用 Redis 实现分布式限流

import redis
import time

class RedisTokenBucketLimiter:
    """
    基于 Redis 的分布式令牌桶限流器
    使用 Lua 脚本保证原子性
    """
    
    # Lua 脚本:原子化地补充令牌并判断是否放行
    LUA_SCRIPT = """
    local key = KEYS[1]
    local rate = tonumber(ARGV[1])
    local capacity = tonumber(ARGV[2])
    local now = tonumber(ARGV[3])
    local requested = tonumber(ARGV[4])
    
    local last_tokens = tonumber(redis.call('hget', key, 'tokens') or capacity)
    local last_refill = tonumber(redis.call('hget', key, 'last_refill') or now)
    
    -- 补充令牌
    local elapsed = now - last_refill
    local new_tokens = math.min(capacity, last_tokens + elapsed * rate)
    
    local allowed = 0
    if new_tokens >= requested then
        new_tokens = new_tokens - requested
        allowed = 1
    end
    
    redis.call('hset', key, 'tokens', new_tokens)
    redis.call('hset', key, 'last_refill', now)
    redis.call('expire', key, math.ceil(capacity / rate) * 2)
    
    return {allowed, math.floor(new_tokens)}
    """
    
    def __init__(self, redis_url: str, rate: float, capacity: int):
        self.redis = redis.from_url(redis_url)
        self.rate = rate
        self.capacity = capacity
        self._script = self.redis.register_script(self.LUA_SCRIPT)
    
    def is_allowed(self, key: str, tokens: int = 1) -> tuple[bool, int]:
        """
        :return: (是否允许, 剩余令牌数)
        """
        now = time.time()
        result = self._script(
            keys=[f"ratelimit:{key}"],
            args=[self.rate, self.capacity, now, tokens]
        )
        allowed = bool(result[0])
        remaining = int(result[1])
        return allowed, remaining


# 使用示例(需要 Redis 服务)
# redis_limiter = RedisTokenBucketLimiter(
#     redis_url="redis://localhost:6379",
#     rate=10,
#     capacity=50
# )
# allowed, remaining = redis_limiter.is_allowed("user:12345")
# print(f"放行:{allowed},剩余令牌:{remaining}")

八、三大算法横向对比

维度令牌桶漏桶滑动窗口
突发流量✅ 允许(桶容量内)❌ 严格排队⚡ 有限允许
输出平滑度极高
实现复杂度
内存占用极低中(队列)低(计数)
精确度近似(计数器版)
适用场景API 调用限额流量整形精细化统计
典型应用GitHub API网络带宽限速Nginx limit_req

选型建议:

如果你的场景是对外 API 服务,用户可能有合理的突发需求(如批量导出),选令牌桶,既能控制长期速率,又不会让正常突发请求体验太差。

如果你的场景是保护下游服务(如数据库、第三方 API),需要严格控制请求速率避免过载,选漏桶,它能将任意突发流量整形为恒定输出。

如果你需要精确统计时间窗口内的请求次数(如"每分钟最多 100 次"的 SLA 保障),选滑动窗口,配合 Redis 可轻松实现分布式精确限流。

九、总结

限流是系统稳定性设计的基石,三种算法各有侧重:令牌桶平衡突发与限速,漏桶追求绝对平滑,滑动窗口精确感知流量。在 Python 实战中,本地场景用线程安全的纯 Python 实现即可;多实例部署时,务必引入 Redis + Lua 脚本的分布式方案,保证全局一致性。

好的限流器不只是"拒绝请求",它还应该做到:返回清晰的 429 Too Many Requests、在响应头中告知 Retry-After 时间、对不同用户等级设置差异化配额。这些细节,才是用户友好型 API 的体现。

限流的本质是保护——保护你的服务,也保护你的用户。把这道门守好,系统才能在风雨中岿然不动。

以上就是使用Python手搓一个生产级限流器的详细内容,更多关于Python限流器的资料请关注脚本之家其它相关文章!

您可能感兴趣的文章:
阅读全文