RateLimitInterceptor.java

/*
 * BSD 2-Clause License
 * 
 * Copyright (c) 2022, [Aleksandra Serba, Marcin Czerniak, Bartosz Wawrzyniak, Adrian Antkowiak]
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 * 
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 * 
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

package dev.vernite.vernite;

import java.util.ArrayDeque;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.HandlerInterceptor;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.AllArgsConstructor;
import lombok.Getter;

/**
 * Rate limit interceptor. Limits GET requests to 1000 per minute and
 * POST/PUT/DELETE to 100 per minute.
 */
@Component
public class RateLimitInterceptor implements HandlerInterceptor {

    private static final long TIME_LIMIT = TimeUnit.MINUTES.toMillis(1);
    private static final int READ_LIMIT = 1000;
    private static final int WRITE_LIMIT = 100;

    private static final Map<Long, ArrayDeque<Long>> userReadLimit = new ConcurrentHashMap<>();
    private static final Map<Long, ArrayDeque<Long>> userWriteLimit = new ConcurrentHashMap<>();
    private static final Map<String, ArrayDeque<Long>> ipReadLimit = new ConcurrentHashMap<>();
    private static final Map<String, ArrayDeque<Long>> ipWriteLimit = new ConcurrentHashMap<>();

    private static boolean isWriteMethod(String method) {
        switch (method) {
            case "GET":
            case "HEAD":
            case "OPTIONS":
                return false;
            case "PUT":
            case "PATCH":
            case "DELETE":
            case "POST":
                return true;
            default:
                return true;
        }
    }

    private static <T> int increment(T key, Map<T, ArrayDeque<Long>> map, long time, int maxSize)
            throws TooManyRequestsException {
        ArrayDeque<Long> deque = map.get(key);
        if (deque == null) {
            // do not add empty ArrayDeque to map - thread safety
            deque = new ArrayDeque<>();
            deque.addLast(time);
            map.put(key, deque);
            return maxSize - 1;
        } else {
            synchronized (deque) {
                while (!deque.isEmpty() && time - deque.getFirst() > TIME_LIMIT) {
                    deque.removeFirst();
                }
                if (deque.size() >= maxSize) {
                    throw new TooManyRequestsException(TIME_LIMIT + deque.getFirst() - time);
                }
                deque.addLast(time);
                map.putIfAbsent(key, deque);
                return maxSize - deque.size();
            }
        }
    }

    private static String getIP(HttpServletRequest request) {
        String ip = request.getHeader("X-Forwarded-For");
        if (ip == null) {
            ip = request.getRemoteAddr();
        }
        return ip;
    }

    // this interceptor is called before and after UserResolver
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
            throws Exception {
        boolean writeMethod = isWriteMethod(request.getMethod());
        long now = System.currentTimeMillis();
        try {
            int remaining;

            if (request.getAttribute("ratelimit") != null) {
                // after user resolver:
                Long userID = (Long) request.getAttribute("userID");
                if (userID == null) {
                    return true;
                }
                remaining = writeMethod ? increment(userID, userWriteLimit, now, WRITE_LIMIT)
                        : increment(userID, userReadLimit, now, READ_LIMIT);
                remaining = Math.min((int) request.getAttribute("ratelimit"), remaining);
            } else {
                // before user resolver:
                String ip = getIP(request);
                remaining = writeMethod ? increment(ip, ipWriteLimit, now, WRITE_LIMIT)
                        : increment(ip, ipReadLimit, now, READ_LIMIT);
                request.setAttribute("ratelimit", remaining);
            }
            response.setHeader("X-Rate-Limit-Remaining", Integer.toString(remaining));
            return true;
        } catch (TooManyRequestsException e) {
            // round up
            long seconds = (e.retryAfter + 999L) / 1000L;
            response.setHeader("X-Rate-Limit-Retry-After-Seconds", Long.toString(seconds));
            response.sendError(429, "You have exhausted your API Request Quota");
            return false;
        }
    }

    private static <T> void clearMap(Map<T, ArrayDeque<Long>> map, long now) {
        Iterator<ArrayDeque<Long>> it = map.values().iterator();
        while (it.hasNext()) {
            ArrayDeque<Long> deque = it.next();
            while (!deque.isEmpty() && now - deque.getFirst() > TIME_LIMIT) {
                deque.removeFirst();
            }
            if (deque.isEmpty()) {
                it.remove();
            }
        }
    }

    @Scheduled(cron = "0 * * * * *")
    public void cleanDeques() {
        long now = System.currentTimeMillis();
        clearMap(ipWriteLimit, now);
        clearMap(ipReadLimit, now);
        clearMap(userWriteLimit, now);
        clearMap(userReadLimit, now);
    }

    @AllArgsConstructor
    private static class TooManyRequestsException extends Exception {
        @Getter
        private long retryAfter;
    }
}