-
Notifications
You must be signed in to change notification settings - Fork 16
/
ratelimit_var.go
170 lines (145 loc) · 4.54 KB
/
ratelimit_var.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
package goflags
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
"unicode"
stringsutil "github.com/projectdiscovery/utils/strings"
timeutil "github.com/projectdiscovery/utils/time"
)
var (
MaxRateLimitTime = time.Minute // anything above time.Minute is not practical (for our use case)
rateLimitOptionMap map[*RateLimitMap]Options
)
func init() {
rateLimitOptionMap = make(map[*RateLimitMap]Options)
}
type RateLimit struct {
MaxCount uint
Duration time.Duration
}
type RateLimitMap struct {
kv map[string]RateLimit
}
// Set inserts a value to the map. Format: key=value
func (rateLimitMap *RateLimitMap) Set(value string) error {
if rateLimitMap.kv == nil {
rateLimitMap.kv = make(map[string]RateLimit)
}
option, ok := rateLimitOptionMap[rateLimitMap]
if !ok {
option = StringSliceOptions
}
rateLimits, err := ToStringSlice(value, option)
if err != nil {
return err
}
for _, rateLimit := range rateLimits {
var k, v string
if idxSep := strings.Index(rateLimit, kvSep); idxSep > 0 {
k = rateLimit[:idxSep]
v = rateLimit[idxSep+1:]
}
// note:
// - inserting multiple times the same key will override the previous v
// - empty string is legitimate rateLimit
if k != "" {
rateLimit, err := parseRateLimit(v)
if err != nil {
return err
}
rateLimitMap.kv[k] = rateLimit
}
}
return nil
}
// Del removes the specified key
func (rateLimitMap *RateLimitMap) Del(key string) error {
if rateLimitMap.kv == nil {
return errors.New("empty runtime map")
}
delete(rateLimitMap.kv, key)
return nil
}
// IsEmpty specifies if the underlying map is empty
func (rateLimitMap *RateLimitMap) IsEmpty() bool {
return rateLimitMap.kv == nil || len(rateLimitMap.kv) == 0
}
// AsMap returns the internal map as reference - changes are allowed
func (rateLimitMap *RateLimitMap) AsMap() map[string]RateLimit {
return rateLimitMap.kv
}
func (rateLimitMap RateLimitMap) String() string {
defaultBuilder := &strings.Builder{}
defaultBuilder.WriteString("{")
var items string
for k, v := range rateLimitMap.kv {
items += fmt.Sprintf("\"%s\":\"%d/%s\",", k, v.MaxCount, v.Duration.String())
}
defaultBuilder.WriteString(stringsutil.TrimSuffixAny(items, ",", ":"))
defaultBuilder.WriteString("}")
return defaultBuilder.String()
}
// RateLimitMapVar adds a ratelimit flag with a longname
func (flagSet *FlagSet) RateLimitMapVar(field *RateLimitMap, long string, defaultValue []string, usage string, options Options) *FlagData {
return flagSet.RateLimitMapVarP(field, long, "", defaultValue, usage, options)
}
// RateLimitMapVarP adds a ratelimit flag with a short name and long name.
// It is equivalent to RateLimitMapVar, and also allows specifying ratelimits in days (e.g., "hackertarget=2/d" 2 requests per day, which is equivalent to 24h).
func (flagSet *FlagSet) RateLimitMapVarP(field *RateLimitMap, long, short string, defaultValue StringSlice, usage string, options Options) *FlagData {
if field == nil {
panic(fmt.Errorf("field cannot be nil for flag -%v", long))
}
rateLimitOptionMap[field] = options
for _, defaultItem := range defaultValue {
values, _ := ToStringSlice(defaultItem, options)
for _, value := range values {
if err := field.Set(value); err != nil {
panic(fmt.Errorf("failed to set default value for flag -%v: %v", long, err))
}
}
}
flagData := &FlagData{
usage: usage,
long: long,
defaultValue: defaultValue,
skipMarshal: true,
}
if short != "" {
flagData.short = short
flagSet.CommandLine.Var(field, short, usage)
flagSet.flagKeys.Set(short, flagData)
}
flagSet.CommandLine.Var(field, long, usage)
flagSet.flagKeys.Set(long, flagData)
return flagData
}
func parseRateLimit(s string) (RateLimit, error) {
sArr := strings.Split(s, "/")
if len(sArr) < 2 {
return RateLimit{}, errors.New("parse error: expected format k=v/d (e.g., scanme.sh=10/s got " + s)
}
maxCount, err := strconv.ParseUint(sArr[0], 10, 64)
if err != nil {
return RateLimit{}, errors.New("parse error: " + err.Error())
}
timeValue := sArr[1]
if len(timeValue) > 0 {
// check if time is given ex: 1s
// if given value is just s (add prefix 1)
firstChar := timeValue[0]
if !unicode.IsDigit(rune(firstChar)) {
timeValue = "1" + timeValue
}
}
duration, err := timeutil.ParseDuration(timeValue)
if err != nil {
return RateLimit{}, errors.New("parse error: " + err.Error())
}
if MaxRateLimitTime < duration {
return RateLimit{}, fmt.Errorf("duration cannot be more than %v but got %v", MaxRateLimitTime, duration)
}
return RateLimit{MaxCount: uint(maxCount), Duration: duration}, nil
}