dongxiezhi9564 2013-11-30 08:34
浏览 433
已采纳

限速HTTP请求(通过http.HandlerFunc中间件)

I'm looking to write a small piece of rate-limiting middleware that:

  1. Allows me to set a sensible rate (say, 10 req/s) per remote IP
  2. Possibly (but it doesn't have to) allow for bursts
  3. Drops (closes?) connections that exceed the rate and returns a HTTP 429

I can then wrap this around authentication routes or other routes that might be vulnerable to brute-force attacks (i.e. password reset URLs using a token that expires, etc.). The chances of someone brute forcing a 16 or 24 byte token are really low, but it doesn't hurt to go that extra step.

I've had a look at https://code.google.com/p/go-wiki/wiki/RateLimiting but am not sure how to reconcile it with http.Request(s). Further, I'm not sure how we'd "track" requests from a given IP over any period of time.

Ideally I'd end up with something like this, noting that I'm behind a reverse proxy (nginx) so we're checking for the REMOTE_ADDR HTTP header rather than using r.RemoteAddr:

// Rate-limiting middleware
func rateLimit(h http.HandlerFunc) http.HandlerFunc {
    return func(w http.ResponseWriter, r *http.Request) {

        remoteIP := r.Header.Get("REMOTE_ADDR")
        for req := range (what here?) {
            // what here?
            // w.WriteHeader(429) and close the request if it exceeds the limit
            // else pass to the next handler in the chain
            h.ServeHTTP(w, r)
        }
}

// Example routes
r.HandleFunc("/login", use(loginForm, rateLimit, csrf)
r.HandleFunc("/form", use(editHandler, rateLimit, csrf)

// Middleware wrapper, for context
func use(h http.HandlerFunc, middleware ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc {
    for _, m := range middleware {
        h = m(h)
    }

    return h
}

I'd appreciate some guidance here.

  • 写回答

4条回答 默认 最新

  • dshgnt2008 2013-11-30 09:29
    关注

    The rate limiting example you've linked to is a general one. It uses range because it gets requests over a channel.

    It's a different story with HTTP requests, but there's nothing really complicated here. Note that you don't iterate over a channel of requests, or anything -- your HandlerFunc is called for every incoming request separately.

    func rateLimit(h http.HandlerFunc) http.HandlerFunc {
        return func(w http.ResponseWriter, r *http.Request) {
            remoteIP := r.Header.Get("REMOTE_ADDR")
            if exceededTheLimit(remoteIP) {
                w.WriteHeader(429)
                // it then returns, not passing the request down the chain
            } else {
                h.ServeHTTP(w, r);
            }
        }       
    }
    

    Now, choosing the place to store the rate limit counters is up to you. One solution would be to simply use a global map (don't forget safe concurrent access) that would map IPs to their request counters. However, you would have to be aware of how long ago the requests were made.

    Sergio suggested using Redis. Its key-value nature is a perfect fit for simple structures like this and you get expiration for free.

    本回答被题主选为最佳回答 , 对您是否有帮助呢?
    评论
查看更多回答(3条)

报告相同问题?

悬赏问题

  • ¥20 有关区间dp的问题求解
  • ¥15 多电路系统共用电源的串扰问题
  • ¥15 slam rangenet++配置
  • ¥15 有没有研究水声通信方面的帮我改俩matlab代码
  • ¥15 对于相关问题的求解与代码
  • ¥15 ubuntu子系统密码忘记
  • ¥15 信号傅里叶变换在matlab上遇到的小问题请求帮助
  • ¥15 保护模式-系统加载-段寄存器
  • ¥15 电脑桌面设定一个区域禁止鼠标操作
  • ¥15 求NPF226060磁芯的详细资料