Browse Source

update

mysql
Evan 6 months ago
parent
commit
e9da2154ee
  1. 2
      .gitignore
  2. 34
      README.md
  3. 13
      example/main.go
  4. 81
      limiter.go

2
.gitignore

@ -14,4 +14,4 @@
# Dependency directories (remove the comment below to include it)
# vendor/
*.db

34
README.md

@ -1,2 +1,36 @@
# rate-limiter
A go lib for limiting hit rate of a given ip within a given interval.
The counting of the rate is based on the hits within the past **interval**
## Start
Run `make run` and `curl http://localhost:8080/`
## Limiter Usage
```go
var l *limiter.Limiter
func handler(w http.ResponseWriter, r *http.Request) {
rate, err := l.HitOrError(r)
if err != nil {
fmt.Fprintf(w, "%s", err)
} else {
fmt.Fprintf(w, "%v", rate)
}
}
func main() {
l = limiter.Default()
// Limiter settings can be changed
// for example:
// l.Interval = time.Minute * 10
// l.Limit = 5
http.HandleFunc("/", handler)
http.ListenAndServe(":8080", nil)
}
```

13
example/main.go

@ -7,12 +7,21 @@ import (
limiter "github.com/mutsuki333/rate-limiter"
)
var l *limiter.Limiter
func handler(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hi there, I love %s!", r.URL.Path[1:])
rate, err := l.HitOrError(r)
if err != nil {
fmt.Fprintln(w, err)
} else {
fmt.Fprintln(w, rate)
}
}
func main() {
limiter.Default()
l = limiter.Default()
// l.Interval = time.Minute * 10
l.Limit = 5
http.HandleFunc("/hit", handler)
http.ListenAndServe(":8080", nil)

81
limiter.go

@ -2,27 +2,31 @@ package limiter
import (
"database/sql"
"errors"
"net"
"net/http"
"sync"
"time"
_ "github.com/mattn/go-sqlite3"
)
var ddb *sql.DB
//Default setting of the limiter, which limits 60hits/minute and uses in-memory sqlite db
func Default() *Limiter {
var err error
// ddb, err = sql.Open("sqlite3", "file::memory:?cache=shared")
ddb, err = sql.Open("sqlite3", "test.db")
// db, err := sql.Open("sqlite3", "file::memory:?cache=shared")
db, err := sql.Open("sqlite3", "test.db")
if err != nil {
panic(err)
}
ddb.Exec(`create table hit (id integer, dt datetime);`)
return &Limiter{
l := &Limiter{
Interval: time.Minute,
Limit: 60,
GetRate: getRate,
Hit: hit,
Mux: &sync.Mutex{},
Store: db,
}
l.Init()
return l
}
//Limiter the limiter instance
@ -34,26 +38,63 @@ type Limiter struct {
//Limit of counts within given interval
Limit int
GetRate func(time.Duration, string) (int, error)
Hit func(string) error
Mux *sync.Mutex
Store *sql.DB
}
//Init the db table, and start cleanup goroutine
func (l *Limiter) Init() {
l.Store.Exec(`create table hit (ip string, hit_time datetime);`)
go func() {
for {
time.Sleep(l.Interval)
l.clear()
}
}()
}
//Rate within interval
func (l *Limiter) Rate() int {
var rate int
return rate
func (l *Limiter) Rate(ip string) (rate int, err error) {
row := l.Store.QueryRow(
"select COUNT(*) from hit where ip = ? and hit_time > ?",
ip,
time.Now().Add(-l.Interval),
)
err = row.Scan(&rate)
return
}
//Exceeded the rate limit
func (l *Limiter) Exceeded() bool {
var exceed bool
return exceed
//Hit record the hit from an ip
func (l *Limiter) Hit(ip string) error {
l.Mux.Lock()
defer l.Mux.Unlock()
_, err := l.Store.Exec(
"insert into hit (ip, hit_time) values (?, ?)",
ip,
time.Now(),
)
return err
}
func getRate(i time.Duration, ip string) (rate int, err error) {
//HitOrError hit and return the rate, and error if exceeds limit.
func (l *Limiter) HitOrError(r *http.Request) (rate int, err error) {
ip := r.Header.Get("X-Real-Ip")
if ip == "" {
ip = r.Header.Get("X-Forwarded-For")
}
if ip == "" {
ip, _, _ = net.SplitHostPort(r.RemoteAddr)
}
l.Hit(ip)
rate, _ = l.Rate(ip)
if rate > l.Limit {
err = errors.New("Error")
}
return
}
func hit(ip string) error {
return nil
func (l *Limiter) clear() {
l.Mux.Lock()
defer l.Mux.Unlock()
l.Store.Exec("delete from hit where hit_time < ?", time.Now().Add(-l.Interval))
}
Loading…
Cancel
Save