SpringBoot使用Redis对用户IP进行接口限流的项目实践
作者:莫轻言舞
一、思路
使用接口限流的主要目的在于提高系统的稳定性,防止接口被恶意打击(短时间内大量请求)。
比如要求某接口在1分钟内请求次数不超过1000次,那么应该如何设计代码呢?
下面讲两种思路,如果想看代码可直接翻到后面的代码部分。
1.1 固定时间段(旧思路)
1.1.1 思路描述
该方案的思路是:使用Redis记录固定时间段内某用户IP访问某接口的次数,其中:
- Redis的key:用户IP + 接口方法名
- Redis的value:当前接口访问次数。
当用户在近期内第一次访问该接口时,向Redis中设置一个包含了用户IP和接口方法名的key,value的值初始化为1(表示第一次访问当前接口)。同时,设置该key的过期时间(比如为60秒)。
之后,只要这个key还未过期,用户每次访问该接口都会导致value自增1次。
用户每次访问接口前,先从Redis中拿到当前接口访问次数,如果发现访问次数大于规定的次数(如超过1000次),则向用户返回接口访问失败的标识。
1.1.2 思路缺陷
该方案的缺点在于,限流时间段是固定的。
比如要求某接口在1分钟内请求次数不超过1000次,观察以下流程:
可以发现,00:59和01:01之间仅仅间隔了2秒,但接口却被访问了1000+999=1999次,是限流次数(1000次)的2倍!
所以在该方案中,限流次数的设置可能不起作用,仍然可能在短时间内造成大量访问。
1.2 滑动窗口(新思路)
1.2.1 思路描述
为了避免出现方案1中由于键过期导致的短期访问量增大的情况,我们可以改变一下思路,也就是把固定的时间段改成动态的:
假设某个接口在10秒内只允许访问5次。用户每次访问接口时,记录当前用户访问的时间点(时间戳),并计算前10秒内用户访问该接口的总次数。如果总次数大于限流次数,则不允许用户访问该接口。这样就能保证在任意时刻用户的访问次数不会超过1000次。
如下图,假设用户在0:19时间点访问接口,经检查其前10秒内访问次数为5次,则允许本次访问。
假设用户0:20时间点访问接口,经检查其前10秒内访问次数为6次(超出限流次数5次),则不允许本次访问。
1.2.2 Redis部分的实现
1)选用何种 Redis 数据结构
首先是需要确定使用哪个Redis数据结构。用户每次访问时,需要用一个key记录用户访问的时间点,而且还需要利用这些时间点进行范围检查。
2)为何选择 zSet 数据结构
为了能够实现范围检查,可以考虑使用Redis中的zSet有序集合。
添加一个zSet元素的命令如下:
ZADD [key] [score] [member]
它有一个关键的属性score,通过它可以记录当前member的优先级。
于是我们可以把score设置成用户访问接口的时间戳,以便于通过score进行范围检查。key则记录用户IP和接口方法名,至于member设置成什么没有影响,一个member记录了用户访问接口的时间点。因此member也可以设置成时间戳。
3)zSet 如何进行范围检查(检查前几秒的访问次数)
思路是,把特定时间间隔之前的member都删掉,留下的member就是时间间隔之内的总访问次数。然后统计当前key中的member有多少个即可。
① 把特定时间间隔之前的member都删掉。
zSet有如下命令,用于删除score范围在[min~max]
之间的member:
Zremrangebyscore [key] [min] [max]
假设限流时间设置为5秒,当前用户访问接口时,获取当前系统时间戳为currentTimeMill
,那么删除的score范围可以设置为:
min = 0 max = currentTimeMill - 5 * 1000
相当于把5秒之前的所有member都删除了,只留下前5秒内的key。
② 统计特定key中已存在的member有多少个。
zSet有如下命令,用于统计某个key的member总数:
ZCARD [key]
统计的key的member总数,就是当前接口已经访问的次数。如果该数目大于限流次数,则说明当前的访问应被限流。
二、代码实现
主要是使用注解 + AOP的形式实现。
2.1 固定时间段思路
使用了lua脚本。
2.1.1 限流注解
@Retention(RetentionPolicy.RUNTIME) @Target(ElementType.METHOD) public @interface RateLimiter { /** * 限流时间,单位秒 */ int time() default 5; /** * 限流次数 */ int count() default 10; }
2.1.2 定义lua脚本
在resources/lua
下新建limit.lua
:
-- 获取redis键 local key = KEYS[1] -- 获取第一个参数(次数) local count = tonumber(ARGV[1]) -- 获取第二个参数(时间) local time = tonumber(ARGV[2]) -- 获取当前流量 local current = redis.call('get', key); -- 如果current值存在,且值大于规定的次数,则拒绝放行(直接返回当前流量) if current and tonumber(current) > count then return tonumber(current) end -- 如果值小于规定次数,或值不存在,则允许放行,当前流量数+1 (值不存在情况下,可以自增变为1) current = redis.call('incr', key); -- 如果是第一次进来,那么开始设置键的过期时间。 if tonumber(current) == 1 then redis.call('expire', key, time); end -- 返回当前流量 return tonumber(current)
2.1.3 注入Lua执行脚本
关键代码是limitScript()
方法
@Configuration public class RedisConfig { @Bean public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory connectionFactory) { RedisTemplate<Object, Object> redisTemplate = new RedisTemplate<>(); redisTemplate.setConnectionFactory(connectionFactory); // 使用Jackson2JsonRedisSerialize 替换默认序列化(默认采用的是JDK序列化) Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class); ObjectMapper om = new ObjectMapper(); om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY); om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL); jackson2JsonRedisSerializer.setObjectMapper(om); redisTemplate.setKeySerializer(jackson2JsonRedisSerializer); redisTemplate.setValueSerializer(jackson2JsonRedisSerializer); redisTemplate.setHashKeySerializer(jackson2JsonRedisSerializer); redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer); return redisTemplate; } /** * 解析lua脚本的bean */ @Bean("limitScript") public DefaultRedisScript<Long> limitScript() { DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>(); redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua"))); redisScript.setResultType(Long.class); return redisScript; } }
2.1.3 定义Aop切面类
@Slf4j @Aspect @Component public class RateLimiterAspect { @Autowired private RedisTemplate redisTemplate; @Autowired private RedisScript<Long> limitScript; @Before("@annotation(rateLimiter)") public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable { int time = rateLimiter.time(); int count = rateLimiter.count(); String combineKey = getCombineKey(rateLimiter.type(), point); List<String> keys = Collections.singletonList(combineKey); try { Long number = (Long) redisTemplate.execute(limitScript, keys, count, time); // 当前流量number已超过限制,则抛出异常 if (number == null || number.intValue() > count) { throw new RuntimeException("访问过于频繁,请稍后再试"); } log.info("[limit] 限制请求数'{}',当前请求数'{}',缓存key'{}'", count, number.intValue(), combineKey); } catch (Exception ex) { ex.printStackTrace(); throw new RuntimeException("服务器限流异常,请稍候再试"); } } /** * 把用户IP和接口方法名拼接成 redis 的 key * @param point 切入点 * @return 组合key */ private String getCombineKey(JoinPoint point) { StringBuilder sb = new StringBuilder("rate_limit:"); ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); HttpServletRequest request = attributes.getRequest(); sb.append( Utils.getIpAddress(request) ); MethodSignature signature = (MethodSignature) point.getSignature(); Method method = signature.getMethod(); Class<?> targetClass = method.getDeclaringClass(); // keyPrefix + "-" + class + "-" + method return sb.append("-").append( targetClass.getName() ) .append("-").append(method.getName()).toString(); } }
2.2 滑动窗口思路
2.2.1 限流注解
@Retention(RetentionPolicy.RUNTIME) @Target(ElementType.METHOD) public @interface RateLimiter { /** * 限流时间,单位秒 */ int time() default 5; /** * 限流次数 */ int count() default 10; }
2.2.2 定义Aop切面类
@Slf4j @Aspect @Component public class RateLimiterAspect { @Autowired private RedisTemplate redisTemplate; /** * 实现限流(新思路) * @param point * @param rateLimiter * @throws Throwable */ @SuppressWarnings("unchecked") @Before("@annotation(rateLimiter)") public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable { // 在 {time} 秒内仅允许访问 {count} 次。 int time = rateLimiter.time(); int count = rateLimiter.count(); // 根据用户IP(可选)和接口方法,构造key String combineKey = getCombineKey(rateLimiter.type(), point); // 限流逻辑实现 ZSetOperations zSetOperations = redisTemplate.opsForZSet(); // 记录本次访问的时间结点 long currentMs = System.currentTimeMillis(); zSetOperations.add(combineKey, currentMs, currentMs); // 这一步是为了防止member一直存在于内存中 redisTemplate.expire(combineKey, time, TimeUnit.SECONDS); // 移除{time}秒之前的访问记录(滑动窗口思想) zSetOperations.removeRangeByScore(combineKey, 0, currentMs - time * 1000); // 获得当前窗口内的访问记录数 Long currCount = zSetOperations.zCard(combineKey); // 限流判断 if (currCount > count) { log.error("[limit] 限制请求数'{}',当前请求数'{}',缓存key'{}'", count, currCount, combineKey); throw new RuntimeException("访问过于频繁,请稍后再试!"); } } /** * 把用户IP和接口方法名拼接成 redis 的 key * @param point 切入点 * @return 组合key */ private String getCombineKey(JoinPoint point) { StringBuilder sb = new StringBuilder("rate_limit:"); ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); HttpServletRequest request = attributes.getRequest(); sb.append( Utils.getIpAddress(request) ); MethodSignature signature = (MethodSignature) point.getSignature(); Method method = signature.getMethod(); Class<?> targetClass = method.getDeclaringClass(); // keyPrefix + "-" + class + "-" + method return sb.append("-").append( targetClass.getName() ) .append("-").append(method.getName()).toString(); } }
到此这篇关于SpringBoot使用Redis对用户IP进行接口限流的文章就介绍到这了,更多相关SpringBoot Redis接口限流内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!