自定义限流注解实现接口限流

Lou.Chen2021年5月15日大约 7 分钟

分析以及思路

为了对系统的接口访问频率做限制,从而降低服务器的压力。例如10s内最多请求3次,下次允许的访问必须等到10s后才能访问该接口。

实现思路:

  • 定义限流注解,注解中可自定义接口存在redis中的key的前缀,最大请求次数,时间窗(单位时间内最大访问次数),限制类型(基于IP或者默认的)
  • 对所有加了限流注解的方法进行拦截,使用前置通知即可,拦截后解析方法上的限流注解,先通过反射获取该接口的key(前缀:IP地址-全类名-方法名),然后使用redis执行lua脚本,将key,最大请求次数周期时间传进lua脚本执行
  • 定义lua脚本:拿到key,最大请求次数,周期时间后,先使用key拿到redis中存在值,若该key存在并且值超过最大访问次数,则直接访问。否则代表key不存在,那么直接对key进行自增1,因为这里存在多线程的并发操作,所以还需要再次判断自增1后是否有其它线程对其自增1,所以再次要判断一个是否等于1,若等于1则说明没有其他线程访问,则直接设置key对过期时间,否则直接返回
  • 切面执行完lua脚本后拿到值,然后判断其是否大于最大访问次数,若大于最大范围次数则抛出异常,由全局异常捕获返回异常信息给调用者,否则允许访问。

lua脚本优势:

  • Lua文件中的多个脚本保证在redis服务端执行的原子性
  • 避免多个命令在程序服务端执行时网络对其的性能影响

案例

pom依赖

 <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-aop</artifactId>
        </dependency>

限流类型

public enum LimitType {
    /**
     * 默认的限流策略,针对某一个接口进行限流
     */
    DEFAULT,
    /**
     * 针对某一个IP进行限流
     */
    IP;
}

限流注解

/**
 * 限流注解
 */
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE,ElementType.METHOD})
public @interface RateLimiter {
    /**
     * 限流key的在redis的前缀
     * @return
     */
    String key() default "rate_limit:";

    /**
     * 限流时间窗(时间周期内)
     * @return
     */
    int time() default 60;

    /**
     * 在时间窗内的限流次数
     * @return
     */
    int count() default 100;

    /**
     * 限流类型
     * @return
     */
    LimitType limitType() default LimitType.DEFAULT;

}

IP解析工具

package org.lc.rate_limiter.util;

import org.springframework.util.StringUtils;

import javax.servlet.http.HttpServletRequest;
import java.net.InetAddress;
import java.net.UnknownHostException;

/**
 * 获取IP方法
 *
 * @author tienchin
 */
public class IpUtils {
    /**
     * 获取客户端IP
     *
     * @param request 请求对象
     * @return IP地址
     */
    public static String getIpAddr(HttpServletRequest request) {
        if (request == null) {
            return "unknown";
        }
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("X-Forwarded-For");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("X-Real-IP");
        }

        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }

        return "0:0:0:0:0:0:0:1".equals(ip) ? "127.0.0.1" : getMultistageReverseProxyIp(ip);
    }

    /**
     * 检查是否为内部IP地址
     *
     * @param ip IP地址
     * @return 结果
     */
    public static boolean internalIp(String ip) {
        byte[] addr = textToNumericFormatV4(ip);
        return internalIp(addr) || "127.0.0.1".equals(ip);
    }

    /**
     * 检查是否为内部IP地址
     *
     * @param addr byte地址
     * @return 结果
     */
    private static boolean internalIp(byte[] addr) {
        if (addr == null || addr.length < 2) {
            return true;
        }
        final byte b0 = addr[0];
        final byte b1 = addr[1];
        // 10.x.x.x/8
        final byte SECTION_1 = 0x0A;
        // 172.16.x.x/12
        final byte SECTION_2 = (byte) 0xAC;
        final byte SECTION_3 = (byte) 0x10;
        final byte SECTION_4 = (byte) 0x1F;
        // 192.168.x.x/16
        final byte SECTION_5 = (byte) 0xC0;
        final byte SECTION_6 = (byte) 0xA8;
        switch (b0) {
            case SECTION_1:
                return true;
            case SECTION_2:
                if (b1 >= SECTION_3 && b1 <= SECTION_4) {
                    return true;
                }
            case SECTION_5:
                switch (b1) {
                    case SECTION_6:
                        return true;
                }
            default:
                return false;
        }
    }

    /**
     * 将IPv4地址转换成字节
     *
     * @param text IPv4地址
     * @return byte 字节
     */
    public static byte[] textToNumericFormatV4(String text) {
        if (text.length() == 0) {
            return null;
        }

        byte[] bytes = new byte[4];
        String[] elements = text.split("\\.", -1);
        try {
            long l;
            int i;
            switch (elements.length) {
                case 1:
                    l = Long.parseLong(elements[0]);
                    if ((l < 0L) || (l > 4294967295L)) {
                        return null;
                    }
                    bytes[0] = (byte) (int) (l >> 24 & 0xFF);
                    bytes[1] = (byte) (int) ((l & 0xFFFFFF) >> 16 & 0xFF);
                    bytes[2] = (byte) (int) ((l & 0xFFFF) >> 8 & 0xFF);
                    bytes[3] = (byte) (int) (l & 0xFF);
                    break;
                case 2:
                    l = Integer.parseInt(elements[0]);
                    if ((l < 0L) || (l > 255L)) {
                        return null;
                    }
                    bytes[0] = (byte) (int) (l & 0xFF);
                    l = Integer.parseInt(elements[1]);
                    if ((l < 0L) || (l > 16777215L)) {
                        return null;
                    }
                    bytes[1] = (byte) (int) (l >> 16 & 0xFF);
                    bytes[2] = (byte) (int) ((l & 0xFFFF) >> 8 & 0xFF);
                    bytes[3] = (byte) (int) (l & 0xFF);
                    break;
                case 3:
                    for (i = 0; i < 2; ++i) {
                        l = Integer.parseInt(elements[i]);
                        if ((l < 0L) || (l > 255L)) {
                            return null;
                        }
                        bytes[i] = (byte) (int) (l & 0xFF);
                    }
                    l = Integer.parseInt(elements[2]);
                    if ((l < 0L) || (l > 65535L)) {
                        return null;
                    }
                    bytes[2] = (byte) (int) (l >> 8 & 0xFF);
                    bytes[3] = (byte) (int) (l & 0xFF);
                    break;
                case 4:
                    for (i = 0; i < 4; ++i) {
                        l = Integer.parseInt(elements[i]);
                        if ((l < 0L) || (l > 255L)) {
                            return null;
                        }
                        bytes[i] = (byte) (int) (l & 0xFF);
                    }
                    break;
                default:
                    return null;
            }
        } catch (NumberFormatException e) {
            return null;
        }
        return bytes;
    }

    /**
     * 获取IP地址
     *
     * @return 本地IP地址
     */
    public static String getHostIp() {
        try {
            return InetAddress.getLocalHost().getHostAddress();
        } catch (UnknownHostException e) {
        }
        return "127.0.0.1";
    }

    /**
     * 获取主机名
     *
     * @return 本地主机名
     */
    public static String getHostName() {
        try {
            return InetAddress.getLocalHost().getHostName();
        } catch (UnknownHostException e) {
        }
        return "未知";
    }

    /**
     * 从多级反向代理中获得第一个非unknown IP地址
     *
     * @param ip 获得的IP地址
     * @return 第一个非unknown IP地址
     */
    public static String getMultistageReverseProxyIp(String ip) {
        // 多级反向代理检测
        if (ip != null && ip.indexOf(",") > 0) {
            final String[] ips = ip.trim().split(",");
            for (String subIp : ips) {
                if (false == isUnknown(subIp)) {
                    ip = subIp;
                    break;
                }
            }
        }
        return ip;
    }

    /**
     * 检测给定字符串是否为未知,多用于检测HTTP请求相关
     *
     * @param checkString 被检测的字符串
     * @return 是否未知
     */
    public static boolean isUnknown(String checkString) {
        return StringUtils.isEmpty(checkString) || "unknown".equalsIgnoreCase(checkString);
    }
}

lua脚本

在Resource目录新建lua文件夹并新建limit.lua脚本

-- 保证服务端执行多个命令的原子性
-- 减少执行多个命令带来的网络性能影响

-- 获取第一个key参数为限流的key
local key = KEY[1]
-- 获取第一个参数为周期时间
local time = tonumber(ARG[1])
-- 获取第二个参数为最大限流次数
local count = tonumber(ARG[2])
-- 拿到当前redis执行的get操作的值
local current = redis.call('get', key)
-- 如果获取到访问key存在,并且值大于指定的限流次数的值
if current and tonumber(current) > count then
    -- 直接返回
    return tonumber(current)
end
-- 说明第一次访问,则自增1
current = redis.call('incr',key)
-- 如果还是1(代表没有其他线程对其进行访问加1操作
if tonumber(current) ==1 then
    -- 设置该key时间窗
    redis.call('expire',key,time)
end
-- 最后返回访问的次数
return tonumber(current)

Redis配置

@Configuration
public class RedisConfig {

    /**
     * 重写RedisTemplate 对默认的存储值进行序列化
     * @param redisConnectionFactory
     * @return
     */
    @Bean
    public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
        RedisTemplate<Object,Object> redisTemplate=new RedisTemplate<>();
        redisTemplate.setConnectionFactory(redisConnectionFactory);
        Jackson2JsonRedisSerializer<Object> serializer = new Jackson2JsonRedisSerializer<>(Object.class);
        redisTemplate.setKeySerializer(serializer);
        redisTemplate.setValueSerializer(serializer);
        redisTemplate.setHashKeySerializer(serializer);
        redisTemplate.setHashValueSerializer(serializer);
        return redisTemplate;
    }

    /**
     * 加载lua脚本
     * @return
     */
    @Bean
    public DefaultRedisScript<Long> limitScript() {
        DefaultRedisScript<Long> script=new DefaultRedisScript<>();
        script.setResultType(Long.class);
        script.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
        //或者直接设置lua字符串
        // script.setScriptText("");
        return script;
    }
}

全局异常以及自定义异常

/**
 * 自定义限流异常
 */
public class RateLimitException extends Exception{
    public RateLimitException(String message) {
        super(message);
    }
}

@RestControllerAdvice
public class GlobalException {
    @ExceptionHandler(RateLimitException.class)
    public Map<String, Object> handlerRateLimitException(RateLimitException rateLimitException) {
        Map<String,Object> map=new HashMap<>();
        map.put("status", 500);
        map.put("message", rateLimitException.getMessage());
        return map;
    }
}

application配置

# 应用名称
spring.application.name=rate_limiter
# 应用服务 WEB 访问端口
server.port=8080
spring.redis.host=xxx
spring.redis.password=xxxx

切面配置

@Aspect
@Component
public class RateLimitAspect {

    private static final Logger logger = LoggerFactory.getLogger(RateLimitAspect.class);

    @Autowired
    private RedisTemplate<Object, Object> redisTemplate;

    @Autowired
    private RedisScript<Long> redisScript;

    /**
     * 定义前置通知即可
     * @param joinPoint
     * @param rateLimiter 方法上的RateLimiter对象
     */
    @Before("@annotation(rateLimiter)")
    public void before(JoinPoint joinPoint, RateLimiter rateLimiter) throws RateLimitException {
        int time = rateLimiter.time();
        int count = rateLimiter.count();
        //获取需要缓存的key
        String key = getCombineKey(joinPoint, rateLimiter);
        //执行lua脚本
        try {
            Long executeResult = redisTemplate.execute(redisScript, Collections.singletonList(key), time, count);
            //如果结果为空或者
            if (executeResult == null || executeResult.intValue() > count) {
                logger.info("当前接口:{} 超过最大请求限制:{}", key, count);
                throw new RateLimitException("接口请求次数超过限制,请稍后再试");
            }
            logger.info("当前接口:{} 当前窗口时间最大请求限制:{} 当前请求次数:{}", key, count, executeResult);
        } catch (Exception e) {
            throw e;
        }
    }

    /**
     * 获取要限流的key
     * 若限流类型为ip则形式为 rate_limit:127.0.0.1-org.lc.rate_limiter.HelloController-hello
     * 若为默认类型则形式为 rate_limit:127.0.0.1-org.lc.rate_limiter.HelloController-hello
     * @param joinPoint
     * @param rateLimiter
     * @return
     */
    private String getCombineKey(JoinPoint joinPoint, RateLimiter rateLimiter) {
        StringBuilder keyBuffer = new StringBuilder(rateLimiter.key());
        // 如果限流类型为ip
        if (rateLimiter.limitType() == LimitType.IP) {
            keyBuffer.append(IpUtils.getIpAddr(((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest()))
                    .append("-");
        }
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        keyBuffer.append(method.getDeclaringClass().getName())
                .append("-")
                .append(method.getName());
        return keyBuffer.toString();
    }
}

接口控制器

@RestController
public class HelloController {
    @RateLimiter(time = 20, count = 3,limitType = LimitType.IP)
    @GetMapping("/hello")
    public String hello() {
        return "hello";
    }
}

测试

http://localhost:8080/helloopen in new window

redis中的key: "rate_limit:127.0.0.1-org.lc.rate_limiter.controller.HelloController-hello" 值为1

20s内请求3次以上时,则会报一下错误:

{
    "message": "接口请求次数超过限制,请稍后再试",
    "status": 500
}