Browse Source

update test

master
evanchen333 8 months ago
parent
commit
06eef265dc
  1. 5
      Makefile
  2. 15
      README.md
  3. 12
      example/main.go
  4. 28
      limiter.go
  5. 94
      limter_test.go

5
Makefile

@ -1,2 +1,5 @@
run:
go run example/main.go
go run example/main.go
test:
go test

15
README.md

@ -5,22 +5,15 @@ The counting of the rate is based on the hits within the past **interval**
## Start
Run `make run` and `curl http://localhost:8080/`
Run `go run example/main.go` and `curl http://localhost:8080/hit`.
The
## Limiter Usage
```go
var l *limiter.Limiter
func handler(w http.ResponseWriter, r *http.Request) {
rate, err := l.HitOrError(r)
if err != nil {
fmt.Fprintf(w, "%s", err)
} else {
fmt.Fprintf(w, "%v", rate)
}
}
func main() {
l = limiter.Default()
@ -29,7 +22,7 @@ func main() {
// l.Interval = time.Minute * 10
// l.Limit = 5
http.HandleFunc("/", handler)
http.HandleFunc("/hit", l.Handler)
http.ListenAndServe(":8080", nil)
}

12
example/main.go

@ -1,7 +1,6 @@
package main
import (
"fmt"
"net/http"
limiter "github.com/mutsuki333/rate-limiter"
@ -9,20 +8,11 @@ import (
var l *limiter.Limiter
func handler(w http.ResponseWriter, r *http.Request) {
rate, err := l.HitOrError(r)
if err != nil {
fmt.Fprintln(w, err)
} else {
fmt.Fprintln(w, rate)
}
}
func main() {
l = limiter.Default()
// l.Interval = time.Minute * 10
l.Limit = 5
http.HandleFunc("/hit", handler)
http.HandleFunc("/hit", l.Handler)
http.ListenAndServe(":8080", nil)
}

28
limiter.go

@ -3,6 +3,7 @@ package limiter
import (
"database/sql"
"errors"
"fmt"
"net"
"net/http"
"sync"
@ -14,8 +15,8 @@ import (
//Default setting of the limiter, which limits 60hits/minute and uses in-memory sqlite db
func Default() *Limiter {
var err error
// db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
db, err := sql.Open("sqlite3", "test.db")
db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
// db, err := sql.Open("sqlite3", "test.db")
if err != nil {
panic(err)
}
@ -77,7 +78,17 @@ func (l *Limiter) Hit(ip string) error {
}
//HitOrError hit and return the rate, and error if exceeds limit.
func (l *Limiter) HitOrError(r *http.Request) (rate int, err error) {
func (l *Limiter) HitOrError(ip string) (rate int, err error) {
l.Hit(ip)
rate, _ = l.Rate(ip)
if rate > l.Limit {
err = errors.New("Rate limit exceeded")
}
return
}
//Handler response the rate, and error if exceeds limit.
func (l *Limiter) Handler(w http.ResponseWriter, r *http.Request) {
ip := r.Header.Get("X-Real-Ip")
if ip == "" {
ip = r.Header.Get("X-Forwarded-For")
@ -85,10 +96,11 @@ func (l *Limiter) HitOrError(r *http.Request) (rate int, err error) {
if ip == "" {
ip, _, _ = net.SplitHostPort(r.RemoteAddr)
}
l.Hit(ip)
rate, _ = l.Rate(ip)
if rate > l.Limit {
err = errors.New("Error")
rate, err := l.HitOrError(ip)
if err != nil {
fmt.Fprintln(w, "Error")
} else {
fmt.Fprintln(w, rate)
}
return
}
@ -96,5 +108,5 @@ func (l *Limiter) HitOrError(r *http.Request) (rate int, err error) {
func (l *Limiter) clear() {
l.Mux.Lock()
defer l.Mux.Unlock()
l.Store.Exec("delete from hit where hit_time < ?", time.Now().Add(-l.Interval))
l.Store.Exec("delete from hit where hit_time < ?", time.Now().Add(-l.Interval*2))
}

94
limter_test.go

@ -0,0 +1,94 @@
package limiter
import (
"errors"
"fmt"
"strconv"
"sync"
"testing"
"time"
)
const (
ip = "192.168.0."
)
var testCount = 0
var l = Default()
func examine(limit int, testIP string) error {
for i := 1; i < limit+20; i++ {
rate, err := l.HitOrError(testIP)
if i <= limit {
if rate != i {
return fmt.Errorf("rate should be %v not %v", i, rate)
}
if err != nil {
return fmt.Errorf("hit %v should not be err, but get %s", i, err.Error())
}
} else {
if err == nil {
return errors.New("there should be err, but none get")
}
}
}
return nil
}
func test(limit int, interval time.Duration, testIP string) error {
var err error
err = examine(limit, testIP)
if err != nil {
return err
}
time.Sleep(interval + time.Second)
err = examine(limit, testIP)
if err != nil {
return err
}
return err
}
func Test60in60sec(t *testing.T) {
testCount++
testIP := ip + strconv.Itoa(testCount)
err := test(60, time.Minute, testIP)
if err != nil {
t.Fatal(err)
}
}
func Test5in10sec(t *testing.T) {
testCount++
testIP := ip + strconv.Itoa(testCount)
l.Limit = 5
l.Interval = 10 * time.Second
err := test(5, time.Second*10, testIP)
if err != nil {
t.Fatal(err)
}
}
func TestParallel(t *testing.T) {
l.Limit = 5
l.Interval = 10 * time.Second
var wg sync.WaitGroup
testP := func(testIP string, wg *sync.WaitGroup) {
defer wg.Done()
err := test(5, time.Second*10, testIP)
if err != nil {
t.Fatal(err)
}
}
wg.Add(3)
testCount++
testIP := ip + strconv.Itoa(testCount)
go testP(testIP, &wg)
testCount++
testIP = ip + strconv.Itoa(testCount)
go testP(testIP, &wg)
testCount++
testIP = ip + strconv.Itoa(testCount)
go testP(testIP, &wg)
wg.Wait()
}
Loading…
Cancel
Save