From 94d9e009d87978eb4c58b60d6a176867bf087553 Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Wed, 8 Jan 2020 19:53:02 +0100 Subject: [PATCH 001/446] Tidy up unused imports from go.mod (#1468) --- go.mod | 1 - go.sum | 14 +------------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/go.mod b/go.mod index c5db2ae1a..eacaf4bee 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.12 require ( github.com/dgrijalva/jwt-go v3.2.0+incompatible - github.com/labstack/echo v3.3.10+incompatible // indirect github.com/labstack/gommon v0.3.0 github.com/mattn/go-colorable v0.1.4 // indirect github.com/mattn/go-isatty v0.0.11 // indirect diff --git a/go.sum b/go.sum index 57c79877e..e329def22 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,6 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= -github.com/labstack/echo v3.3.10+incompatible h1:pGRcYk231ExFAyoAjAfD85kQzRJCRI8bbnE7CX5OEgg= -github.com/labstack/echo v3.3.10+incompatible/go.mod h1:0INS7j/VjnFxD4E2wkz67b8cVwCLbBmJyDaka6Cmk1s= github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0= github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= @@ -13,8 +11,6 @@ github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVc github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.9 h1:d5US/mDsogSGW37IV293h//ZFaeajb69h+EHFsv2xGg= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= -github.com/mattn/go-isatty v0.0.10 h1:qxFzApOv4WsAL965uUPIsXzAKCZxN2p9UqdhFS4ZW10= -github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM= github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -29,16 +25,10 @@ github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPU github.com/valyala/fasttemplate v1.1.0 h1:RZqt0yGBsps8NGvLSGW804QQqCUYYLsaOjTVHy1Ocw4= github.com/valyala/fasttemplate v1.1.0/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 h1:HuIa8hRrWRSrqYzx1qI49NNxhdi2PrY7gxVSq1JjLDc= -golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 h1:sKJQZMuxjOAR/Uo2LBfU90onWEf1dF4C+0hPJCc9Mpc= golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20191021144547-ec77196f6094 h1:5O4U9trLjNpuhpynaDsqwCk+Tw6seqJz1EbqbnzHrc8= -golang.org/x/net v0.0.0-20191021144547-ec77196f6094/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 h1:efeOvDhwQ29Dj3SdAV/MJf8oukgn+8D8WgaCaRMchF8= golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -46,9 +36,6 @@ golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a h1:aYOabOQFp6Vj6W1F80affTUvO9UxmJRx8K0gsfABByQ= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191024172528-b4ff53e7a1cb h1:ZxSglHghKPYD8WDeRUzRJrUJtDF0PxsTUSxyqr9/5BI= -golang.org/x/sys v0.0.0-20191024172528-b4ff53e7a1cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8 h1:JA8d3MPx/IToSyXZG/RhwYEtfrKO1Fxrqe8KrkiLXKM= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -57,6 +44,7 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= From 399da56370c28c1788b9c61e0e55ab82be59fac1 Mon Sep 17 00:00:00 2001 From: Eugene Date: Wed, 8 Jan 2020 23:40:52 +0200 Subject: [PATCH 002/446] Improve bind performance (#1469) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Improve bind performance By some slight optimisations and lesser reflect usage now binding has significantly better performance: name old time/op new time/op delta BindbindData-8 21.2µs ± 2% 13.5µs ± 2% -36.66% (p=0.000 n=16+18) BindbindDataWithTags-8 22.1µs ± 1% 16.4µs ± 2% -26.03% (p=0.000 n=20+20) name old alloc/op new alloc/op delta BindbindData-8 2.40kB ± 0% 1.33kB ± 0% -44.64% (p=0.000 n=20+20) BindbindDataWithTags-8 2.31kB ± 0% 1.54kB ± 0% -33.19% (p=0.000 n=20+20) name old allocs/op new allocs/op delta BindbindData-8 297 ± 0% 122 ± 0% -58.92% (p=0.000 n=20+20) BindbindDataWithTags-8 267 ± 0% 125 ± 0% -53.18% (p=0.000 n=20+20) * Remove creation of new value in unmarshalFieldNonPtr --- bind.go | 42 ++++++--------------------------- bind_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 35 deletions(-) diff --git a/bind.go b/bind.go index c8c88bb20..f89147435 100644 --- a/bind.go +++ b/bind.go @@ -115,7 +115,7 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag if inputFieldName == "" { inputFieldName = typeField.Name // If tag is nil, we inspect if the field is a struct. - if _, ok := bindUnmarshaler(structField); !ok && structFieldKind == reflect.Struct { + if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil { return err } @@ -129,9 +129,8 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag // url params are bound case sensitive which is inconsistent. To // fix this we must check all of the map values in a // case-insensitive search. - inputFieldName = strings.ToLower(inputFieldName) for k, v := range data { - if strings.ToLower(k) == inputFieldName { + if strings.EqualFold(k, inputFieldName) { inputValue = v exists = true break @@ -221,40 +220,13 @@ func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bo } } -// bindUnmarshaler attempts to unmarshal a reflect.Value into a BindUnmarshaler -func bindUnmarshaler(field reflect.Value) (BindUnmarshaler, bool) { - ptr := reflect.New(field.Type()) - if ptr.CanInterface() { - iface := ptr.Interface() - if unmarshaler, ok := iface.(BindUnmarshaler); ok { - return unmarshaler, ok - } - } - return nil, false -} - -// textUnmarshaler attempts to unmarshal a reflect.Value into a TextUnmarshaler -func textUnmarshaler(field reflect.Value) (encoding.TextUnmarshaler, bool) { - ptr := reflect.New(field.Type()) - if ptr.CanInterface() { - iface := ptr.Interface() - if unmarshaler, ok := iface.(encoding.TextUnmarshaler); ok { - return unmarshaler, ok - } - } - return nil, false -} - func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) { - if unmarshaler, ok := bindUnmarshaler(field); ok { - err := unmarshaler.UnmarshalParam(value) - field.Set(reflect.ValueOf(unmarshaler).Elem()) - return true, err + fieldIValue := field.Addr().Interface() + if unmarshaler, ok := fieldIValue.(BindUnmarshaler); ok { + return true, unmarshaler.UnmarshalParam(value) } - if unmarshaler, ok := textUnmarshaler(field); ok { - err := unmarshaler.UnmarshalText([]byte(value)) - field.Set(reflect.ValueOf(unmarshaler).Elem()) - return true, err + if unmarshaler, ok := fieldIValue.(encoding.TextUnmarshaler); ok { + return true, unmarshaler.UnmarshalText([]byte(value)) } return false, nil diff --git a/bind_test.go b/bind_test.go index 84ac8710e..b9fb9de3c 100644 --- a/bind_test.go +++ b/bind_test.go @@ -56,6 +56,43 @@ type ( Tptr *Timestamp SA StringArray } + bindTestStructWithTags struct { + I int `json:"I" form:"I"` + PtrI *int `json:"PtrI" form:"PtrI"` + I8 int8 `json:"I8" form:"I8"` + PtrI8 *int8 `json:"PtrI8" form:"PtrI8"` + I16 int16 `json:"I16" form:"I16"` + PtrI16 *int16 `json:"PtrI16" form:"PtrI16"` + I32 int32 `json:"I32" form:"I32"` + PtrI32 *int32 `json:"PtrI32" form:"PtrI32"` + I64 int64 `json:"I64" form:"I64"` + PtrI64 *int64 `json:"PtrI64" form:"PtrI64"` + UI uint `json:"UI" form:"UI"` + PtrUI *uint `json:"PtrUI" form:"PtrUI"` + UI8 uint8 `json:"UI8" form:"UI8"` + PtrUI8 *uint8 `json:"PtrUI8" form:"PtrUI8"` + UI16 uint16 `json:"UI16" form:"UI16"` + PtrUI16 *uint16 `json:"PtrUI16" form:"PtrUI16"` + UI32 uint32 `json:"UI32" form:"UI32"` + PtrUI32 *uint32 `json:"PtrUI32" form:"PtrUI32"` + UI64 uint64 `json:"UI64" form:"UI64"` + PtrUI64 *uint64 `json:"PtrUI64" form:"PtrUI64"` + B bool `json:"B" form:"B"` + PtrB *bool `json:"PtrB" form:"PtrB"` + F32 float32 `json:"F32" form:"F32"` + PtrF32 *float32 `json:"PtrF32" form:"PtrF32"` + F64 float64 `json:"F64" form:"F64"` + PtrF64 *float64 `json:"PtrF64" form:"PtrF64"` + S string `json:"S" form:"S"` + PtrS *string `json:"PtrS" form:"PtrS"` + cantSet string + DoesntExist string `json:"DoesntExist" form:"DoesntExist"` + GoT time.Time `json:"GoT" form:"GoT"` + GoTptr *time.Time `json:"GoTptr" form:"GoTptr"` + T Timestamp `json:"T" form:"T"` + Tptr *Timestamp `json:"Tptr" form:"Tptr"` + SA StringArray `json:"SA" form:"SA"` + } Timestamp time.Time TA []Timestamp StringArray []string @@ -433,6 +470,34 @@ func TestBindSetFields(t *testing.T) { } } +func BenchmarkBindbindData(b *testing.B) { + b.ReportAllocs() + assert := assert.New(b) + ts := new(bindTestStruct) + binder := new(DefaultBinder) + var err error + b.ResetTimer() + for i := 0; i < b.N; i++ { + err = binder.bindData(ts, values, "form") + } + assert.NoError(err) + assertBindTestStruct(assert, ts) +} + +func BenchmarkBindbindDataWithTags(b *testing.B) { + b.ReportAllocs() + assert := assert.New(b) + ts := new(bindTestStructWithTags) + binder := new(DefaultBinder) + var err error + b.ResetTimer() + for i := 0; i < b.N; i++ { + err = binder.bindData(ts, values, "form") + } + assert.NoError(err) + assertBindTestStruct(assert, (*bindTestStruct)(ts)) +} + func assertBindTestStruct(a *assert.Assertions, ts *bindTestStruct) { a.Equal(0, ts.I) a.Equal(int8(8), ts.I8) From 5bf6888444a4491eb1a46506044bbae0fb0abf39 Mon Sep 17 00:00:00 2001 From: Ajitem Sahasrabuddhe Date: Fri, 24 Jan 2020 05:43:18 +0530 Subject: [PATCH 003/446] Parameterized routes sometimes return 404 (#1480) * url param bug * add comment * add tests * Bump echo version --- echo.go | 2 +- router.go | 4 ++++ router_test.go | 30 ++++++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/echo.go b/echo.go index a6ac0fa80..3869dc6c5 100644 --- a/echo.go +++ b/echo.go @@ -227,7 +227,7 @@ const ( const ( // Version of Echo - Version = "4.1.13" + Version = "4.1.14" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` diff --git a/router.go b/router.go index 08145973a..9db9ea928 100644 --- a/router.go +++ b/router.go @@ -414,6 +414,10 @@ func (r *Router) Find(method, path string, c Context) { if cn = nn.findChildByKind(pkind); cn != nil && strings.IndexByte(ns, '/') == -1 { pvalues[len(cn.pnames)-1] = search break + } else if cn != nil && strings.IndexByte(ns, '/') != 1 { + // If slash is present, it means that this is a parameterized route. + cn = cn.parent + goto Param } for { np = nn.parent diff --git a/router_test.go b/router_test.go index 449af910f..f47225d28 100644 --- a/router_test.go +++ b/router_test.go @@ -1141,6 +1141,16 @@ func TestRouterParam1466(t *testing.T) { r.Add(http.MethodGet, "/skills/:name/users", func(c Context) error { return nil }) + // Additional routes for Issue 1479 + r.Add(http.MethodGet, "/users/:username/likes/projects/ids", func(c Context) error { + return nil + }) + r.Add(http.MethodGet, "/users/:username/profile", func(c Context) error { + return nil + }) + r.Add(http.MethodGet, "/users/:username/uploads/:type", func(c Context) error { + return nil + }) c := e.NewContext(nil, nil).(*context) @@ -1152,6 +1162,26 @@ func TestRouterParam1466(t *testing.T) { r.Find(http.MethodGet, "/users/signup", c) assert.Equal(t, "", c.Param("username")) + // Additional assertions for #1479 + r.Find(http.MethodGet, "/users/sharewithme/likes/projects/ids", c) + assert.Equal(t, "sharewithme", c.Param("username")) + + r.Find(http.MethodGet, "/users/ajitem/likes/projects/ids", c) + assert.Equal(t, "ajitem", c.Param("username")) + + r.Find(http.MethodGet, "/users/sharewithme/profile", c) + assert.Equal(t, "sharewithme", c.Param("username")) + + r.Find(http.MethodGet, "/users/ajitem/profile", c) + assert.Equal(t, "ajitem", c.Param("username")) + + r.Find(http.MethodGet, "/users/sharewithme/uploads/self", c) + assert.Equal(t, "sharewithme", c.Param("username")) + assert.Equal(t, "self", c.Param("type")) + + r.Find(http.MethodGet, "/users/ajitem/uploads/self", c) + assert.Equal(t, "ajitem", c.Param("username")) + assert.Equal(t, "self", c.Param("type")) } func benchmarkRouterRoutes(b *testing.B, routes []*Route) { From 8d7f05e5336fa9a05c6ca5a610a0c5140c01bfc3 Mon Sep 17 00:00:00 2001 From: "J. David Lowe" Date: Thu, 23 Jan 2020 18:32:17 -0800 Subject: [PATCH 004/446] round-trip paramValues without exploding (#1463) --- bind_test.go | 2 ++ context.go | 5 ++++- context_test.go | 22 ++++++++++++++++++++++ middleware/jwt_test.go | 2 ++ 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/bind_test.go b/bind_test.go index b9fb9de3c..943cfd559 100644 --- a/bind_test.go +++ b/bind_test.go @@ -332,6 +332,7 @@ func TestBindbindData(t *testing.T) { func TestBindParam(t *testing.T) { e := New() + *e.maxParam = 2 req := httptest.NewRequest(GET, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) @@ -362,6 +363,7 @@ func TestBindParam(t *testing.T) { // Bind something with param and post data payload body := bytes.NewBufferString(`{ "name": "Jon Snow" }`) e2 := New() + *e2.maxParam = 2 req2 := httptest.NewRequest(POST, "/", body) req2.Header.Set(HeaderContentType, MIMEApplicationJSON) diff --git a/context.go b/context.go index 27da5ffe3..0046e5d4f 100644 --- a/context.go +++ b/context.go @@ -312,7 +312,10 @@ func (c *context) ParamValues() []string { } func (c *context) SetParamValues(values ...string) { - c.pvalues = values + // NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times + for i, val := range values { + c.pvalues[i] = val + } } func (c *context) QueryParam(name string) string { diff --git a/context_test.go b/context_test.go index 47be19cce..bb4a9ed4e 100644 --- a/context_test.go +++ b/context_test.go @@ -93,6 +93,7 @@ func (responseWriterErr) WriteHeader(statusCode int) { func TestContext(t *testing.T) { e := New() + *e.maxParam = 1 req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) @@ -471,6 +472,7 @@ func TestContextPath(t *testing.T) { func TestContextPathParam(t *testing.T) { e := New() + *e.maxParam = 2 req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, nil) @@ -487,6 +489,26 @@ func TestContextPathParam(t *testing.T) { testify.Equal(t, "", c.Param("undefined")) } +func TestContextGetAndSetParam(t *testing.T) { + e := New() + *e.maxParam = 2 + req := httptest.NewRequest(http.MethodGet, "/:foo", nil) + c := e.NewContext(req, nil) + c.SetParamNames("foo") + + // round-trip param values with modification + paramVals := c.ParamValues() + testify.EqualValues(t, []string{""}, c.ParamValues()) + paramVals[0] = "bar" + c.SetParamValues(paramVals...) + testify.EqualValues(t, []string{"bar"}, c.ParamValues()) + + // shouldn't explode during Reset() afterwards! + testify.NotPanics(t, func() { + c.Reset(nil, nil) + }) +} + func TestContextFormValue(t *testing.T) { f := make(url.Values) f.Set("name", "Jon Snow") diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 7f15bd467..f7f089fb4 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -60,6 +60,8 @@ func TestJWTRace(t *testing.T) { func TestJWT(t *testing.T) { e := echo.New() + r := e.Router() + r.Add("GET", "/:jwt", func(echo.Context) error { return nil }) handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } From 712b5e46c539096fd204535e61c568e455f4c604 Mon Sep 17 00:00:00 2001 From: Dmitry Kutakov Date: Sat, 25 Jan 2020 18:48:53 +0100 Subject: [PATCH 005/446] format code (gofmt + trim trailing space) (#1452) --- context_test.go | 2 +- middleware/cors_test.go | 2 +- middleware/jwt.go | 8 ++++---- middleware/jwt_test.go | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/context_test.go b/context_test.go index bb4a9ed4e..866d06431 100644 --- a/context_test.go +++ b/context_test.go @@ -7,7 +7,6 @@ import ( "encoding/xml" "errors" "fmt" - "github.com/labstack/gommon/log" "io" "math" "mime/multipart" @@ -19,6 +18,7 @@ import ( "text/template" "time" + "github.com/labstack/gommon/log" testify "github.com/stretchr/testify/assert" ) diff --git a/middleware/cors_test.go b/middleware/cors_test.go index acfdf47bc..456ec7b3d 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -73,7 +73,7 @@ func TestCORS(t *testing.T) { c = e.NewContext(req, rec) req.Header.Set(echo.HeaderOrigin, "http://aaa.example.com") cors = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"http://*.example.com"}, + AllowOrigins: []string{"http://*.example.com"}, }) h = cors(echo.NotFoundHandler) h(c) diff --git a/middleware/jwt.go b/middleware/jwt.go index 55a986327..3c7c48681 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -25,7 +25,7 @@ type ( // ErrorHandler defines a function which is executed for an invalid token. // It may be used to define a custom JWT error. ErrorHandler JWTErrorHandler - + // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. ErrorHandlerWithContext JWTErrorHandlerWithContext @@ -74,7 +74,7 @@ type ( // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. JWTErrorHandlerWithContext func(error, echo.Context) error - + jwtExtractor func(echo.Context) (string, error) ) @@ -183,7 +183,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.ErrorHandler != nil { return config.ErrorHandler(err) } - + if config.ErrorHandlerWithContext != nil { return config.ErrorHandlerWithContext(err, c) } @@ -210,7 +210,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { return config.ErrorHandler(err) } if config.ErrorHandlerWithContext != nil { - return config.ErrorHandlerWithContext(err, c) + return config.ErrorHandlerWithContext(err, c) } return &echo.HTTPError{ Code: http.StatusUnauthorized, diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index f7f089fb4..1731d90fa 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -205,7 +205,7 @@ func TestJWT(t *testing.T) { req.Header.Set(echo.HeaderCookie, tc.hdrCookie) c := e.NewContext(req, res) - if tc.reqURL == "/" + token { + if tc.reqURL == "/"+token { c.SetParamNames("jwt") c.SetParamValues(token) } From c2f2e8d25849098056676e574c4fca316778c963 Mon Sep 17 00:00:00 2001 From: ochan Date: Wed, 29 Jan 2020 07:46:00 +0900 Subject: [PATCH 006/446] Support HTTP/2 h2c mode (cleartext) (#1489) --- echo.go | 30 ++++++++++++++++++++++++++++++ echo_test.go | 12 ++++++++++++ go.mod | 2 +- go.sum | 4 ++-- 4 files changed, 45 insertions(+), 3 deletions(-) diff --git a/echo.go b/echo.go index 3869dc6c5..ed658884f 100644 --- a/echo.go +++ b/echo.go @@ -59,6 +59,8 @@ import ( "github.com/labstack/gommon/log" "golang.org/x/crypto/acme" "golang.org/x/crypto/acme/autocert" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) type ( @@ -723,6 +725,34 @@ func (e *Echo) StartServer(s *http.Server) (err error) { return s.Serve(e.TLSListener) } +// StartH2CServer starts a custom http/2 server with h2c (HTTP/2 Cleartext). +func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { + // Setup + s := e.Server + s.Addr = address + e.colorer.SetOutput(e.Logger.Output()) + s.ErrorLog = e.StdLogger + s.Handler = h2c.NewHandler(e, h2s) + if e.Debug { + e.Logger.SetLevel(log.DEBUG) + } + + if !e.HideBanner { + e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website)) + } + + if e.Listener == nil { + e.Listener, err = newListener(s.Addr) + if err != nil { + return err + } + } + if !e.HidePort { + e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) + } + return s.Serve(e.Listener) +} + // Close immediately stops the server. // It internally calls `http.Server#Close()`. func (e *Echo) Close() error { diff --git a/echo_test.go b/echo_test.go index 3f2e48e51..68c556f41 100644 --- a/echo_test.go +++ b/echo_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/net/http2" ) type ( @@ -512,6 +513,17 @@ func TestEchoStartAutoTLS(t *testing.T) { } } +func TestEchoStartH2CServer(t *testing.T) { + e := New() + e.Debug = true + h2s := &http2.Server{} + + go func() { + assert.NoError(t, e.StartH2CServer(":0", h2s)) + }() + time.Sleep(200 * time.Millisecond) +} + func testMethod(t *testing.T, method, path string, e *Echo) { p := reflect.ValueOf(path) h := reflect.ValueOf(func(c Context) error { diff --git a/go.mod b/go.mod index eacaf4bee..17e57fea0 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/stretchr/testify v1.4.0 github.com/valyala/fasttemplate v1.1.0 golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 - golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 // indirect + golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8 // indirect golang.org/x/text v0.3.2 // indirect ) diff --git a/go.sum b/go.sum index e329def22..08ef80ee2 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,8 @@ golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 h1:sKJQZMuxjOAR/Uo2LBfU90 golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 h1:efeOvDhwQ29Dj3SdAV/MJf8oukgn+8D8WgaCaRMchF8= -golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= From 7c58856fb433748d9f850cc8d2f6bfe996a7070c Mon Sep 17 00:00:00 2001 From: sai umesh Date: Wed, 29 Jan 2020 08:53:29 +0530 Subject: [PATCH 007/446] added installation command in guide (#1443) * added installation command in guide * fixed lints --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 0da031225..c57d478fb 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,13 @@ Lower is better! ## [Guide](https://echo.labstack.com/guide) +### Installation + +```go +// go get github.com/labstack/echo/{version} +go get github.com/labstack/echo/v4 +``` + ### Example ```go From 75620e676752baf3b159768535312cf36db22bdd Mon Sep 17 00:00:00 2001 From: Ajitem Sahasrabuddhe Date: Wed, 29 Jan 2020 08:54:22 +0530 Subject: [PATCH 008/446] Migrate to GitHub Actions (#1473) * add workflow yml * fix syntax error * update test command --- .github/workflows/echo.yml | 53 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 .github/workflows/echo.yml diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml new file mode 100644 index 000000000..cfa44e683 --- /dev/null +++ b/.github/workflows/echo.yml @@ -0,0 +1,53 @@ +name: Run Tests + +on: + push: + branches: + - master + pull_request: + branches: + - master + +env: + GO111MODULE: on + GOPROXY: https://proxy.golang.org + +jobs: + test: + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + go: [1.11, 1.12, 1.13] + name: ${{ matrix.os }} @ Go ${{ matrix.go }} + runs-on: ${{ matrix.os }} + steps: + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v1 + with: + go-version: ${{ matrix.go }} + + - name: Set GOPATH and PATH + run: | + echo "::set-env name=GOPATH::$(dirname $GITHUB_WORKSPACE)" + echo "::add-path::$(dirname $GITHUB_WORKSPACE)/bin" + shell: bash + + - name: Checkout Code + uses: actions/checkout@v1 + with: + ref: ${{ github.ref }} + + - name: Install Dependencies + run: go get -v golang.org/x/lint/golint + + - name: Run Tests + run: | + golint -set_exit_status ./... + go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... + + - name: Upload coverage to Codecov + if: success() && matrix.go == 1.13 && matrix.os == 'ubuntu-latest' + uses: codecov/codecov-action@v1 + with: + token: + fail_ci_if_error: false From 504f39abaf3230cd8bbc1b090effb31a9c70b130 Mon Sep 17 00:00:00 2001 From: Jur van den Berg Date: Mon, 3 Feb 2020 19:09:27 +0100 Subject: [PATCH 009/446] Fix crash on OpenBSD due to unsupported TCP KeepAlivePeriod (#1456) --- echo.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/echo.go b/echo.go index ed658884f..a759e34b6 100644 --- a/echo.go +++ b/echo.go @@ -856,9 +856,10 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { return } else if err = c.(*net.TCPConn).SetKeepAlive(true); err != nil { return - } else if err = c.(*net.TCPConn).SetKeepAlivePeriod(3 * time.Minute); err != nil { - return } + // Ignore error from setting the KeepAlivePeriod as some systems, such as + // OpenBSD, do not support setting TCP_USER_TIMEOUT on IPPROTO_TCP + _ = c.(*net.TCPConn).SetKeepAlivePeriod(3 * time.Minute) return } From f4b5a90ad3d3d1817d4e87d349d699d5ad272110 Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Wed, 19 Feb 2020 16:10:57 +0100 Subject: [PATCH 010/446] Fix #1493 router loop for param routes (#1501) * Add test to reproduce router loop for #1493 * Simplify and correct router param tests * Fix #1493 to avoid router loop for param nodes --- router.go | 77 ++++++++++++++++++++++++++++---------------------- router_test.go | 17 +++++++++++ 2 files changed, 61 insertions(+), 33 deletions(-) diff --git a/router.go b/router.go index 9db9ea928..cb2cc16c8 100644 --- a/router.go +++ b/router.go @@ -376,8 +376,8 @@ func (r *Router) Find(method, path string, c Context) { continue } - // Param node Param: + // Param node if child = cn.findChildByKind(pkind); child != nil { // Issue #378 if len(pvalues) == n { @@ -401,47 +401,58 @@ func (r *Router) Find(method, path string, c Context) { continue } - // Any node Any: - if cn = cn.findChildByKind(akind); cn == nil { - if nn != nil { - // No next node to go down in routing (issue #954) - // Find nearest "any" route going up the routing tree - search = ns - np := nn.parent - // Consider param route one level up only - // if no slash is remaining in search string - if cn = nn.findChildByKind(pkind); cn != nil && strings.IndexByte(ns, '/') == -1 { + // Any node + if cn = cn.findChildByKind(akind); cn != nil { + // If any node is found, use remaining path for pvalues + pvalues[len(cn.pnames)-1] = search + break + } + + // No node found, continue at stored next node + // or find nearest "any" route + if nn != nil { + // No next node to go down in routing (issue #954) + // Find nearest "any" route going up the routing tree + search = ns + np := nn.parent + // Consider param route one level up only + if cn = nn.findChildByKind(pkind); cn != nil { + pos := strings.IndexByte(ns, '/') + if pos == -1 { + // If no slash is remaining in search string set param value pvalues[len(cn.pnames)-1] = search break - } else if cn != nil && strings.IndexByte(ns, '/') != 1 { - // If slash is present, it means that this is a parameterized route. - cn = cn.parent + } else if pos > 0 { + // Otherwise continue route processing with restored next node + cn = nn + nn = nil + ns = "" goto Param } - for { - np = nn.parent - if cn = nn.findChildByKind(akind); cn != nil { - break - } - if np == nil { - break // no further parent nodes in tree, abort - } - var str strings.Builder - str.WriteString(nn.prefix) - str.WriteString(search) - search = str.String() - nn = np - } - if cn != nil { // use the found "any" route and update path - pvalues[len(cn.pnames)-1] = search + } + // No param route found, try to resolve nearest any route + for { + np = nn.parent + if cn = nn.findChildByKind(akind); cn != nil { break } + if np == nil { + break // no further parent nodes in tree, abort + } + var str strings.Builder + str.WriteString(nn.prefix) + str.WriteString(search) + search = str.String() + nn = np + } + if cn != nil { // use the found "any" route and update path + pvalues[len(cn.pnames)-1] = search + break } - return // Not found } - pvalues[len(cn.pnames)-1] = search - break + return // Not found + } ctx.handler = cn.findHandler(method) diff --git a/router_test.go b/router_test.go index f47225d28..92db9d9d2 100644 --- a/router_test.go +++ b/router_test.go @@ -1031,12 +1031,15 @@ func TestRouterParamBacktraceNotFound(t *testing.T) { r.Find(http.MethodGet, "/a", c) assert.Equal(t, "a", c.Param("param1")) + c = e.NewContext(nil, nil).(*context) r.Find(http.MethodGet, "/a/foo", c) assert.Equal(t, "a", c.Param("param1")) + c = e.NewContext(nil, nil).(*context) r.Find(http.MethodGet, "/a/bar", c) assert.Equal(t, "a", c.Param("param1")) + c = e.NewContext(nil, nil).(*context) r.Find(http.MethodGet, "/a/bar/b", c) assert.Equal(t, "a", c.Param("param1")) assert.Equal(t, "b", c.Param("param2")) @@ -1157,31 +1160,45 @@ func TestRouterParam1466(t *testing.T) { r.Find(http.MethodGet, "/users/ajitem", c) assert.Equal(t, "ajitem", c.Param("username")) + c = e.NewContext(nil, nil).(*context) r.Find(http.MethodGet, "/users/sharewithme", c) assert.Equal(t, "sharewithme", c.Param("username")) + c = e.NewContext(nil, nil).(*context) r.Find(http.MethodGet, "/users/signup", c) assert.Equal(t, "", c.Param("username")) // Additional assertions for #1479 + c = e.NewContext(nil, nil).(*context) r.Find(http.MethodGet, "/users/sharewithme/likes/projects/ids", c) assert.Equal(t, "sharewithme", c.Param("username")) + c = e.NewContext(nil, nil).(*context) r.Find(http.MethodGet, "/users/ajitem/likes/projects/ids", c) assert.Equal(t, "ajitem", c.Param("username")) + c = e.NewContext(nil, nil).(*context) r.Find(http.MethodGet, "/users/sharewithme/profile", c) assert.Equal(t, "sharewithme", c.Param("username")) + c = e.NewContext(nil, nil).(*context) r.Find(http.MethodGet, "/users/ajitem/profile", c) assert.Equal(t, "ajitem", c.Param("username")) + c = e.NewContext(nil, nil).(*context) r.Find(http.MethodGet, "/users/sharewithme/uploads/self", c) assert.Equal(t, "sharewithme", c.Param("username")) assert.Equal(t, "self", c.Param("type")) + c = e.NewContext(nil, nil).(*context) r.Find(http.MethodGet, "/users/ajitem/uploads/self", c) assert.Equal(t, "ajitem", c.Param("username")) assert.Equal(t, "self", c.Param("type")) + + // Issue #1493 - check for routing loop + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/tree/free", c) + assert.Equal(t, "", c.Param("id")) + assert.Equal(t, 0, c.response.Status) } func benchmarkRouterRoutes(b *testing.B, routes []*Route) { From 5ddc3a68ba1678147e3779bc80d3b744125d22fc Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Mon, 24 Feb 2020 17:26:49 +0100 Subject: [PATCH 011/446] Fix routing conflict for dynamic routes and static route with common prefix (#1509) (#1512) * Add test for issue #1509 for dynamic routes and multiple static routes with common prefix * Fix #1509: routing conflict for dynamic routes and static route with common prefix * Improve routing performance for static only route trees --- router.go | 13 ++++++++----- router_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/router.go b/router.go index cb2cc16c8..15a3398f2 100644 --- a/router.go +++ b/router.go @@ -347,7 +347,14 @@ func (r *Router) Find(method, path string, c Context) { if l == pl { // Continue search search = search[l:] - } else { + // Finish routing if no remaining search and we are on an leaf node + if search == "" && (nn == nil || cn.parent == nil || cn.ppath != "") { + break + } + } + + // Attempt to go back up the tree on no matching prefix or no remaining search + if l != pl || search == "" { if nn == nil { // Issue #1348 return // Not found } @@ -360,10 +367,6 @@ func (r *Router) Find(method, path string, c Context) { } } - if search == "" { - break - } - // Static node if child = cn.findChild(search[0], skind); child != nil { // Save next diff --git a/router_test.go b/router_test.go index 92db9d9d2..8c27b9f72 100644 --- a/router_test.go +++ b/router_test.go @@ -567,6 +567,32 @@ func TestRouterParamWithSlash(t *testing.T) { }) } +// Issue #1509 +func TestRouterParamStaticConflict(t *testing.T) { + e := New() + r := e.router + handler := func(c Context) error { + c.Set("path", c.Path()) + return nil + } + + g := e.Group("/g") + g.GET("/skills", handler) + g.GET("/status", handler) + g.GET("/:name", handler) + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/g/s", c) + c.handler(c) + assert.Equal(t, "s", c.Param("name")) + assert.Equal(t, "/g/:name", c.Get("path")) + + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/g/status", c) + c.handler(c) + assert.Equal(t, "/g/status", c.Get("path")) +} + func TestRouterMatchAny(t *testing.T) { e := New() r := e.router From 7c5af01350b73a75fb7c129f2dacdeea7d6189ba Mon Sep 17 00:00:00 2001 From: Shinichi TAMURA Date: Tue, 25 Feb 2020 01:29:34 +0900 Subject: [PATCH 012/446] Safer/trustable extraction of real ip from request (#1478) * Safer/trustable extraction of real ip from request * Fix x-real-ip handling on proxy * fix docs * fix default check --- context.go | 5 + echo.go | 1 + ip.go | 137 +++++++++++++++++++++++ ip_test.go | 235 +++++++++++++++++++++++++++++++++++++++ middleware/proxy.go | 4 +- middleware/proxy_test.go | 44 ++++++++ 6 files changed, 425 insertions(+), 1 deletion(-) create mode 100644 ip.go create mode 100644 ip_test.go diff --git a/context.go b/context.go index 0046e5d4f..86b50646b 100644 --- a/context.go +++ b/context.go @@ -43,6 +43,7 @@ type ( // RealIP returns the client's network address based on `X-Forwarded-For` // or `X-Real-IP` request header. + // The behavior can be configured using `Echo#IPExtractor`. RealIP() string // Path returns the registered path for the handler. @@ -270,6 +271,10 @@ func (c *context) Scheme() string { } func (c *context) RealIP() string { + if c.echo != nil && c.echo.IPExtractor != nil { + return c.echo.IPExtractor(c.request) + } + // Fall back to legacy behavior if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" { return strings.Split(ip, ", ")[0] } diff --git a/echo.go b/echo.go index a759e34b6..1f8abe7a2 100644 --- a/echo.go +++ b/echo.go @@ -90,6 +90,7 @@ type ( Validator Validator Renderer Renderer Logger Logger + IPExtractor IPExtractor } // Route contains a handler and information for matching against requests. diff --git a/ip.go b/ip.go new file mode 100644 index 000000000..39cb421fd --- /dev/null +++ b/ip.go @@ -0,0 +1,137 @@ +package echo + +import ( + "net" + "net/http" + "strings" +) + +type ipChecker struct { + trustLoopback bool + trustLinkLocal bool + trustPrivateNet bool + trustExtraRanges []*net.IPNet +} + +// TrustOption is config for which IP address to trust +type TrustOption func(*ipChecker) + +// TrustLoopback configures if you trust loopback address (default: true). +func TrustLoopback(v bool) TrustOption { + return func(c *ipChecker) { + c.trustLoopback = v + } +} + +// TrustLinkLocal configures if you trust link-local address (default: true). +func TrustLinkLocal(v bool) TrustOption { + return func(c *ipChecker) { + c.trustLinkLocal = v + } +} + +// TrustPrivateNet configures if you trust private network address (default: true). +func TrustPrivateNet(v bool) TrustOption { + return func(c *ipChecker) { + c.trustPrivateNet = v + } +} + +// TrustIPRange add trustable IP ranges using CIDR notation. +func TrustIPRange(ipRange *net.IPNet) TrustOption { + return func(c *ipChecker) { + c.trustExtraRanges = append(c.trustExtraRanges, ipRange) + } +} + +func newIPChecker(configs []TrustOption) *ipChecker { + checker := &ipChecker{trustLoopback: true, trustLinkLocal: true, trustPrivateNet: true} + for _, configure := range configs { + configure(checker) + } + return checker +} + +func isPrivateIPRange(ip net.IP) bool { + if ip4 := ip.To4(); ip4 != nil { + return ip4[0] == 10 || + ip4[0] == 172 && ip4[1]&0xf0 == 16 || + ip4[0] == 192 && ip4[1] == 168 + } + return len(ip) == net.IPv6len && ip[0]&0xfe == 0xfc +} + +func (c *ipChecker) trust(ip net.IP) bool { + if c.trustLoopback && ip.IsLoopback() { + return true + } + if c.trustLinkLocal && ip.IsLinkLocalUnicast() { + return true + } + if c.trustPrivateNet && isPrivateIPRange(ip) { + return true + } + for _, trustedRange := range c.trustExtraRanges { + if trustedRange.Contains(ip) { + return true + } + } + return false +} + +// IPExtractor is a function to extract IP addr from http.Request. +// Set appropriate one to Echo#IPExtractor. +// See https://echo.labstack.com/guide/ip-address for more details. +type IPExtractor func(*http.Request) string + +// ExtractIPDirect extracts IP address using actual IP address. +// Use this if your server faces to internet directory (i.e.: uses no proxy). +func ExtractIPDirect() IPExtractor { + return func(req *http.Request) string { + ra, _, _ := net.SplitHostPort(req.RemoteAddr) + return ra + } +} + +// ExtractIPFromRealIPHeader extracts IP address using x-real-ip header. +// Use this if you put proxy which uses this header. +func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { + checker := newIPChecker(options) + return func(req *http.Request) string { + directIP := ExtractIPDirect()(req) + realIP := req.Header.Get(HeaderXRealIP) + if realIP != "" { + if ip := net.ParseIP(directIP); ip != nil && checker.trust(ip) { + return realIP + } + } + return directIP + } +} + +// ExtractIPFromXFFHeader extracts IP address using x-forwarded-for header. +// Use this if you put proxy which uses this header. +// This returns nearest untrustable IP. If all IPs are trustable, returns furthest one (i.e.: XFF[0]). +func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor { + checker := newIPChecker(options) + return func(req *http.Request) string { + directIP := ExtractIPDirect()(req) + xffs := req.Header[HeaderXForwardedFor] + if len(xffs) == 0 { + return directIP + } + ips := append(strings.Split(strings.Join(xffs, ","), ","), directIP) + for i := len(ips) - 1; i >= 0; i-- { + ip := net.ParseIP(strings.TrimSpace(ips[i])) + if ip == nil { + // Unable to parse IP; cannot trust entire records + return directIP + } + if !checker.trust(ip) { + return ip.String() + } + } + // All of the IPs are trusted; return first element because it is furthest from server (best effort strategy). + return strings.TrimSpace(ips[0]) + } +} diff --git a/ip_test.go b/ip_test.go new file mode 100644 index 000000000..5acc11798 --- /dev/null +++ b/ip_test.go @@ -0,0 +1,235 @@ +package echo + +import ( + "net" + "net/http" + "strings" + "testing" + + testify "github.com/stretchr/testify/assert" +) + +const ( + // For RemoteAddr + ipForRemoteAddrLoopback = "127.0.0.1" // From 127.0.0.0/8 + sampleRemoteAddrLoopback = ipForRemoteAddrLoopback + ":8080" + ipForRemoteAddrExternal = "203.0.113.1" + sampleRemoteAddrExternal = ipForRemoteAddrExternal + ":8080" + // For x-real-ip + ipForRealIP = "203.0.113.10" + // For XFF + ipForXFF1LinkLocal = "169.254.0.101" // From 169.254.0.0/16 + ipForXFF2Private = "192.168.0.102" // From 192.168.0.0/16 + ipForXFF3External = "2001:db8::103" + ipForXFF4Private = "fc00::104" // From fc00::/7 + ipForXFF5External = "198.51.100.105" + ipForXFF6External = "192.0.2.106" + ipForXFFBroken = "this.is.broken.lol" + // keys for test cases + ipTestReqKeyNoHeader = "no header" + ipTestReqKeyRealIPExternal = "x-real-ip; remote addr external" + ipTestReqKeyRealIPInternal = "x-real-ip; remote addr internal" + ipTestReqKeyRealIPAndXFFExternal = "x-real-ip and xff; remote addr external" + ipTestReqKeyRealIPAndXFFInternal = "x-real-ip and xff; remote addr internal" + ipTestReqKeyXFFExternal = "xff; remote addr external" + ipTestReqKeyXFFInternal = "xff; remote addr internal" + ipTestReqKeyBrokenXFF = "broken xff" +) + +var ( + sampleXFF = strings.Join([]string{ + ipForXFF6External, ipForXFF5External, ipForXFF4Private, ipForXFF3External, ipForXFF2Private, ipForXFF1LinkLocal, + }, ", ") + + requests = map[string]*http.Request{ + ipTestReqKeyNoHeader: &http.Request{ + RemoteAddr: sampleRemoteAddrExternal, + }, + ipTestReqKeyRealIPExternal: &http.Request{ + Header: http.Header{ + "X-Real-Ip": []string{ipForRealIP}, + }, + RemoteAddr: sampleRemoteAddrExternal, + }, + ipTestReqKeyRealIPInternal: &http.Request{ + Header: http.Header{ + "X-Real-Ip": []string{ipForRealIP}, + }, + RemoteAddr: sampleRemoteAddrLoopback, + }, + ipTestReqKeyRealIPAndXFFExternal: &http.Request{ + Header: http.Header{ + "X-Real-Ip": []string{ipForRealIP}, + HeaderXForwardedFor: []string{sampleXFF}, + }, + RemoteAddr: sampleRemoteAddrExternal, + }, + ipTestReqKeyRealIPAndXFFInternal: &http.Request{ + Header: http.Header{ + "X-Real-Ip": []string{ipForRealIP}, + HeaderXForwardedFor: []string{sampleXFF}, + }, + RemoteAddr: sampleRemoteAddrLoopback, + }, + ipTestReqKeyXFFExternal: &http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{sampleXFF}, + }, + RemoteAddr: sampleRemoteAddrExternal, + }, + ipTestReqKeyXFFInternal: &http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{sampleXFF}, + }, + RemoteAddr: sampleRemoteAddrLoopback, + }, + ipTestReqKeyBrokenXFF: &http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{ipForXFFBroken + ", " + ipForXFF1LinkLocal}, + }, + RemoteAddr: sampleRemoteAddrLoopback, + }, + } +) + +func TestExtractIP(t *testing.T) { + _, ipv4AllRange, _ := net.ParseCIDR("0.0.0.0/0") + _, ipv6AllRange, _ := net.ParseCIDR("::/0") + _, ipForXFF3ExternalRange, _ := net.ParseCIDR(ipForXFF3External + "/48") + _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR(ipForRemoteAddrExternal + "/24") + + tests := map[string]*struct { + extractor IPExtractor + expectedIPs map[string]string + }{ + "ExtractIPDirect": { + ExtractIPDirect(), + map[string]string{ + ipTestReqKeyNoHeader: ipForRemoteAddrExternal, + ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, + ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, + ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, + ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback, + ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, + ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, + ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, + }, + }, + "ExtractIPFromRealIPHeader(default)": { + ExtractIPFromRealIPHeader(), + map[string]string{ + ipTestReqKeyNoHeader: ipForRemoteAddrExternal, + ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, + ipTestReqKeyRealIPInternal: ipForRealIP, + ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, + ipTestReqKeyRealIPAndXFFInternal: ipForRealIP, + ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, + ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, + ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, + }, + }, + "ExtractIPFromRealIPHeader(trust only direct-facing proxy)": { + ExtractIPFromRealIPHeader(TrustLoopback(false), TrustLinkLocal(false), TrustPrivateNet(false), TrustIPRange(ipForRemoteAddrExternalRange)), + map[string]string{ + ipTestReqKeyNoHeader: ipForRemoteAddrExternal, + ipTestReqKeyRealIPExternal: ipForRealIP, + ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, + ipTestReqKeyRealIPAndXFFExternal: ipForRealIP, + ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback, + ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, + ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, + ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, + }, + }, + "ExtractIPFromRealIPHeader(trust direct-facing proxy)": { + ExtractIPFromRealIPHeader(TrustIPRange(ipForRemoteAddrExternalRange)), + map[string]string{ + ipTestReqKeyNoHeader: ipForRemoteAddrExternal, + ipTestReqKeyRealIPExternal: ipForRealIP, + ipTestReqKeyRealIPInternal: ipForRealIP, + ipTestReqKeyRealIPAndXFFExternal: ipForRealIP, + ipTestReqKeyRealIPAndXFFInternal: ipForRealIP, + ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, + ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, + ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, + }, + }, + "ExtractIPFromXFFHeader(default)": { + ExtractIPFromXFFHeader(), + map[string]string{ + ipTestReqKeyNoHeader: ipForRemoteAddrExternal, + ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, + ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, + ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, + ipTestReqKeyRealIPAndXFFInternal: ipForXFF3External, + ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, + ipTestReqKeyXFFInternal: ipForXFF3External, + ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, + }, + }, + "ExtractIPFromXFFHeader(trust only direct-facing proxy)": { + ExtractIPFromXFFHeader(TrustLoopback(false), TrustLinkLocal(false), TrustPrivateNet(false), TrustIPRange(ipForRemoteAddrExternalRange)), + map[string]string{ + ipTestReqKeyNoHeader: ipForRemoteAddrExternal, + ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, + ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, + ipTestReqKeyRealIPAndXFFExternal: ipForXFF1LinkLocal, + ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback, + ipTestReqKeyXFFExternal: ipForXFF1LinkLocal, + ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, + ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, + }, + }, + "ExtractIPFromXFFHeader(trust direct-facing proxy)": { + ExtractIPFromXFFHeader(TrustIPRange(ipForRemoteAddrExternalRange)), + map[string]string{ + ipTestReqKeyNoHeader: ipForRemoteAddrExternal, + ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, + ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, + ipTestReqKeyRealIPAndXFFExternal: ipForXFF3External, + ipTestReqKeyRealIPAndXFFInternal: ipForXFF3External, + ipTestReqKeyXFFExternal: ipForXFF3External, + ipTestReqKeyXFFInternal: ipForXFF3External, + ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, + }, + }, + "ExtractIPFromXFFHeader(trust everything)": { + // This is similar to legacy behavior, but ignores x-real-ip header. + ExtractIPFromXFFHeader(TrustIPRange(ipv4AllRange), TrustIPRange(ipv6AllRange)), + map[string]string{ + ipTestReqKeyNoHeader: ipForRemoteAddrExternal, + ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, + ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, + ipTestReqKeyRealIPAndXFFExternal: ipForXFF6External, + ipTestReqKeyRealIPAndXFFInternal: ipForXFF6External, + ipTestReqKeyXFFExternal: ipForXFF6External, + ipTestReqKeyXFFInternal: ipForXFF6External, + ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, + }, + }, + "ExtractIPFromXFFHeader(trust ipForXFF3External)": { + // This trusts private network also after "additional" trust ranges unlike `TrustNProxies(1)` doesn't + ExtractIPFromXFFHeader(TrustIPRange(ipForXFF3ExternalRange)), + map[string]string{ + ipTestReqKeyNoHeader: ipForRemoteAddrExternal, + ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, + ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, + ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, + ipTestReqKeyRealIPAndXFFInternal: ipForXFF5External, + ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, + ipTestReqKeyXFFInternal: ipForXFF5External, + ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, + }, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + assert := testify.New(t) + for key, req := range requests { + actual := test.extractor(req) + expected := test.expectedIPs[key] + assert.Equal(expected, actual, "Request: %s", key) + } + }) + } +} diff --git a/middleware/proxy.go b/middleware/proxy.go index ef5602bd6..1da370dbf 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -231,7 +231,9 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { } // Fix header - if req.Header.Get(echo.HeaderXRealIP) == "" { + // Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream. + // However, for backward compatibility, legacy behavior is preserved unless you configure Echo#IPExtractor. + if req.Header.Get(echo.HeaderXRealIP) == "" || c.Echo().IPExtractor != nil { req.Header.Set(echo.HeaderXRealIP, c.RealIP()) } if req.Header.Get(echo.HeaderXForwardedProto) == "" { diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 1a375db86..40d150cff 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -2,6 +2,7 @@ package middleware import ( "fmt" + "net" "net/http" "net/http/httptest" "net/url" @@ -119,3 +120,46 @@ func TestProxy(t *testing.T) { rec = httptest.NewRecorder() e.ServeHTTP(rec, req) } + +func TestProxyRealIPHeader(t *testing.T) { + // Setup + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + defer upstream.Close() + url, _ := url.Parse(upstream.URL) + rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) + e := echo.New() + e.Use(Proxy(rrb)) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + remoteAddrIP, _, _ := net.SplitHostPort(req.RemoteAddr) + realIPHeaderIP := "203.0.113.1" + extractedRealIP := "203.0.113.10" + tests := []*struct { + hasRealIPheader bool + hasIPExtractor bool + extectedXRealIP string + }{ + {false, false, remoteAddrIP}, + {false, true, extractedRealIP}, + {true, false, realIPHeaderIP}, + {true, true, extractedRealIP}, + } + + for _, tt := range tests { + if tt.hasRealIPheader { + req.Header.Set(echo.HeaderXRealIP, realIPHeaderIP) + } else { + req.Header.Del(echo.HeaderXRealIP) + } + if tt.hasIPExtractor { + e.IPExtractor = func(*http.Request) string { + return extractedRealIP + } + } else { + e.IPExtractor = nil + } + e.ServeHTTP(rec, req) + assert.Equal(t, tt.extectedXRealIP, req.Header.Get(echo.HeaderXRealIP), "hasRealIPheader: %t / hasIPExtractor: %t", tt.hasRealIPheader, tt.hasIPExtractor) + } +} From 91b853a6f2a3748f385f3672de1bbf35f547ca2f Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Sat, 29 Feb 2020 07:25:32 -0800 Subject: [PATCH 013/446] Updated go.mod Signed-off-by: Vishal Rana --- go.mod | 14 +++++--------- go.sum | 30 +++++++++++------------------- 2 files changed, 16 insertions(+), 28 deletions(-) diff --git a/go.mod b/go.mod index 17e57fea0..0d0e79bda 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,12 @@ module github.com/labstack/echo/v4 -go 1.12 +go 1.14 require ( - github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/labstack/gommon v0.3.0 - github.com/mattn/go-colorable v0.1.4 // indirect - github.com/mattn/go-isatty v0.0.11 // indirect - github.com/stretchr/testify v1.4.0 - github.com/valyala/fasttemplate v1.1.0 - golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 - golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa - golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8 // indirect + github.com/mattn/go-colorable v0.1.6 // indirect + github.com/valyala/fasttemplate v1.1.0 // indirect + golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d + golang.org/x/net v0.0.0-20200226121028-0de0cce0169b golang.org/x/text v0.3.2 // indirect ) diff --git a/go.sum b/go.sum index 08ef80ee2..d0b1144ef 100644 --- a/go.sum +++ b/go.sum @@ -1,22 +1,17 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= -github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0= github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= -github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE= +github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.9 h1:d5US/mDsogSGW37IV293h//ZFaeajb69h+EHFsv2xGg= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= -github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM= -github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= @@ -25,26 +20,23 @@ github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPU github.com/valyala/fasttemplate v1.1.0 h1:RZqt0yGBsps8NGvLSGW804QQqCUYYLsaOjTVHy1Ocw4= github.com/valyala/fasttemplate v1.1.0/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 h1:sKJQZMuxjOAR/Uo2LBfU90onWEf1dF4C+0hPJCc9Mpc= -golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= +golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d h1:1ZiEyfaQIg3Qh0EoqpwAakHVhecoE5wlSg5GjnafJGw= +golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA= -golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b h1:0mm1VjtFUOIlE1SbDlwjYaDxZVDP2S5ou6y0gSgXHu8= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a h1:aYOabOQFp6Vj6W1F80affTUvO9UxmJRx8K0gsfABByQ= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8 h1:JA8d3MPx/IToSyXZG/RhwYEtfrKO1Fxrqe8KrkiLXKM= -golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae h1:/WDfKMnPU+m5M4xB+6x4kaepxRw6jWvR5iDRdvjHgy8= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= From 84b8aaf24f9aa6ff6aa3d8a0a59418460762dea7 Mon Sep 17 00:00:00 2001 From: lukesolo Date: Sat, 29 Feb 2020 17:46:25 +0200 Subject: [PATCH 014/446] Fix panic in FormFile if file not found (#1515) --- context.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/context.go b/context.go index 86b50646b..dfcbe16cd 100644 --- a/context.go +++ b/context.go @@ -360,8 +360,11 @@ func (c *context) FormParams() (url.Values, error) { func (c *context) FormFile(name string) (*multipart.FileHeader, error) { f, fh, err := c.request.FormFile(name) + if err != nil { + return nil, err + } defer f.Close() - return fh, err + return fh, nil } func (c *context) MultipartForm() (*multipart.Form, error) { From 3e8a797db0fed880ba5399b6e0eb2e49185d40ee Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Sat, 29 Feb 2020 07:49:22 -0800 Subject: [PATCH 015/446] Updated version Signed-off-by: Vishal Rana --- echo.go | 2 +- go.mod | 1 + go.sum | 4 ++++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/echo.go b/echo.go index 1f8abe7a2..fa1c93ec7 100644 --- a/echo.go +++ b/echo.go @@ -230,7 +230,7 @@ const ( const ( // Version of Echo - Version = "4.1.14" + Version = "4.1.15" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` diff --git a/go.mod b/go.mod index 0d0e79bda..f981ba481 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.14 require ( github.com/labstack/gommon v0.3.0 github.com/mattn/go-colorable v0.1.6 // indirect + github.com/stretchr/testify v1.4.0 github.com/valyala/fasttemplate v1.1.0 // indirect golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d golang.org/x/net v0.0.0-20200226121028-0de0cce0169b diff --git a/go.sum b/go.sum index d0b1144ef..2f6d74d07 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,4 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0= github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= @@ -10,8 +11,10 @@ github.com/mattn/go-isatty v0.0.9 h1:d5US/mDsogSGW37IV293h//ZFaeajb69h+EHFsv2xGg github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= @@ -39,4 +42,5 @@ golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= From 8b2c77b1079c17fc9d7b1b420b2c3102c4069d6f Mon Sep 17 00:00:00 2001 From: Vadim Sabirov Date: Wed, 4 Mar 2020 18:14:23 +0300 Subject: [PATCH 016/446] Fix #1523 by adding SameSite mode for CSRF settings --- middleware/csrf.go | 23 +++++++++++++++++------ middleware/csrf_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/middleware/csrf.go b/middleware/csrf.go index 09a66bb64..ec348ce1b 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -57,6 +57,10 @@ type ( // Indicates if CSRF cookie is HTTP only. // Optional. Default value false. CookieHTTPOnly bool `yaml:"cookie_http_only"` + + // Indicates SameSite mode of the CSRF cookie. + // Optional. Default value SameSiteDefaultMode. + CookieSameSite http.SameSite `yaml:"cookie_same_site"` } // csrfTokenExtractor defines a function that takes `echo.Context` and returns @@ -67,12 +71,13 @@ type ( var ( // DefaultCSRFConfig is the default CSRF middleware config. DefaultCSRFConfig = CSRFConfig{ - Skipper: DefaultSkipper, - TokenLength: 32, - TokenLookup: "header:" + echo.HeaderXCSRFToken, - ContextKey: "csrf", - CookieName: "_csrf", - CookieMaxAge: 86400, + Skipper: DefaultSkipper, + TokenLength: 32, + TokenLookup: "header:" + echo.HeaderXCSRFToken, + ContextKey: "csrf", + CookieName: "_csrf", + CookieMaxAge: 86400, + CookieSameSite: http.SameSiteDefaultMode, } ) @@ -105,6 +110,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.CookieMaxAge == 0 { config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge } + if config.CookieSameSite == 0 { + config.CookieSameSite = http.SameSiteDefaultMode + } // Initialize parts := strings.Split(config.TokenLookup, ":") @@ -157,6 +165,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.CookieDomain != "" { cookie.Domain = config.CookieDomain } + if config.CookieSameSite != http.SameSiteDefaultMode { + cookie.SameSite = config.CookieSameSite + } cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second) cookie.Secure = config.CookieSecure cookie.HttpOnly = config.CookieHTTPOnly diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index efb4dd1d2..5a3b49b7e 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -81,3 +81,39 @@ func TestCSRFTokenFromQuery(t *testing.T) { assert.Error(t, err) csrfTokenFromQuery("csrf") } + +func TestCSRFSetSameSiteMode(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + csrf := CSRFWithConfig(CSRFConfig{ + CookieSameSite: http.SameSiteStrictMode, + }) + + h := csrf(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + r := h(c) + assert.NoError(t, r) + assert.Regexp(t, "SameSite=Strict", rec.Header()["Set-Cookie"]) +} + +func TestCSRFWithoutSameSiteMode(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + csrf := CSRFWithConfig(CSRFConfig{}) + + h := csrf(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + r := h(c) + assert.NoError(t, r) + assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"]) +} From fc4b1c0a830f194bf6334d2b3e61d92cccebb2be Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Thu, 5 Mar 2020 06:36:43 -0800 Subject: [PATCH 017/446] Omit `internal=` in error strings (#1525) --- echo.go | 3 +++ echo_test.go | 17 ++++++++++++++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/echo.go b/echo.go index fa1c93ec7..86e7b2aae 100644 --- a/echo.go +++ b/echo.go @@ -783,6 +783,9 @@ func NewHTTPError(code int, message ...interface{}) *HTTPError { // Error makes it compatible with `error` interface. func (he *HTTPError) Error() string { + if he.Internal == nil { + return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message) + } return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal) } diff --git a/echo_test.go b/echo_test.go index 68c556f41..ddbc56f27 100644 --- a/echo_test.go +++ b/echo_test.go @@ -543,10 +543,21 @@ func request(method, path string, e *Echo) (int, string) { } func TestHTTPError(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, + t.Run("non-internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + + assert.Equal(t, "code=400, message=map[code:12]", err.Error()) + + }) + t.Run("internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + err.SetInternal(errors.New("internal error")) + assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error()) }) - assert.Equal(t, "code=400, message=map[code:12], internal=", err.Error()) } func TestEchoClose(t *testing.T) { From 542835808e41723e5ecf5864b189a0ad36b8f3f6 Mon Sep 17 00:00:00 2001 From: Leaf Date: Sun, 29 Mar 2020 00:12:39 +0000 Subject: [PATCH 018/446] Quote regex meta characters in Rewrite (#1541) Currently there is a half and half situation where the user can't use regex (fully) because * will be replaced with (.*), yet they also can't just enter any old string, because meta chars like . would need escaping. e.g. currently *.html wouldn't work as intended, and instead *\.html should be used. Work around this by using regexp's QuoteMeta function to sanitise the input before handling it. --- middleware/rewrite.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/middleware/rewrite.go b/middleware/rewrite.go index a64e10bb3..d1387af0f 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -57,7 +57,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { // Initialize for k, v := range config.Rules { - k = strings.Replace(k, "*", "(.*)", -1) + k = regexp.QuoteMeta(k) + k = strings.Replace(k, `\*`, "(.*)", -1) k = k + "$" config.rulesRegex[regexp.MustCompile(k)] = v } From 269dfcc9dd8c339383fd2a2d4b5101520e4225f6 Mon Sep 17 00:00:00 2001 From: 178inaba <178inaba.git@gmail.com> Date: Tue, 31 Mar 2020 04:28:07 +0900 Subject: [PATCH 019/446] Set maxParam with SetParamNames (#1535) * Set maxParam with SetParamNames Fixes #1492 * Revert go.mod --- bind_test.go | 2 -- context.go | 6 ++---- context_test.go | 5 ++--- middleware/jwt_test.go | 2 -- 4 files changed, 4 insertions(+), 11 deletions(-) diff --git a/bind_test.go b/bind_test.go index 943cfd559..b9fb9de3c 100644 --- a/bind_test.go +++ b/bind_test.go @@ -332,7 +332,6 @@ func TestBindbindData(t *testing.T) { func TestBindParam(t *testing.T) { e := New() - *e.maxParam = 2 req := httptest.NewRequest(GET, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) @@ -363,7 +362,6 @@ func TestBindParam(t *testing.T) { // Bind something with param and post data payload body := bytes.NewBufferString(`{ "name": "Jon Snow" }`) e2 := New() - *e2.maxParam = 2 req2 := httptest.NewRequest(POST, "/", body) req2.Header.Set(HeaderContentType, MIMEApplicationJSON) diff --git a/context.go b/context.go index dfcbe16cd..99ef03bcb 100644 --- a/context.go +++ b/context.go @@ -310,6 +310,7 @@ func (c *context) ParamNames() []string { func (c *context) SetParamNames(names ...string) { c.pnames = names + *c.echo.maxParam = len(names) } func (c *context) ParamValues() []string { @@ -317,10 +318,7 @@ func (c *context) ParamValues() []string { } func (c *context) SetParamValues(values ...string) { - // NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times - for i, val := range values { - c.pvalues[i] = val - } + c.pvalues = values } func (c *context) QueryParam(name string) string { diff --git a/context_test.go b/context_test.go index 866d06431..73e5dcb62 100644 --- a/context_test.go +++ b/context_test.go @@ -93,7 +93,6 @@ func (responseWriterErr) WriteHeader(statusCode int) { func TestContext(t *testing.T) { e := New() - *e.maxParam = 1 req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) @@ -472,7 +471,6 @@ func TestContextPath(t *testing.T) { func TestContextPathParam(t *testing.T) { e := New() - *e.maxParam = 2 req := httptest.NewRequest(http.MethodGet, "/", nil) c := e.NewContext(req, nil) @@ -491,7 +489,8 @@ func TestContextPathParam(t *testing.T) { func TestContextGetAndSetParam(t *testing.T) { e := New() - *e.maxParam = 2 + r := e.Router() + r.Add(http.MethodGet, "/:foo", func(Context) error { return nil }) req := httptest.NewRequest(http.MethodGet, "/:foo", nil) c := e.NewContext(req, nil) c.SetParamNames("foo") diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 1731d90fa..ce44f9c9c 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -60,8 +60,6 @@ func TestJWTRace(t *testing.T) { func TestJWT(t *testing.T) { e := echo.New() - r := e.Router() - r.Add("GET", "/:jwt", func(echo.Context) error { return nil }) handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } From 6e7c7cea03c4ac9d551c98d92fc4168888b8e2d5 Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Mon, 30 Mar 2020 12:32:58 -0700 Subject: [PATCH 020/446] Bumped version Signed-off-by: Vishal Rana --- echo.go | 2 +- go.mod | 3 ++- go.sum | 3 +++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/echo.go b/echo.go index 86e7b2aae..511eb43fb 100644 --- a/echo.go +++ b/echo.go @@ -230,7 +230,7 @@ const ( const ( // Version of Echo - Version = "4.1.15" + Version = "4.1.16" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` diff --git a/go.mod b/go.mod index f981ba481..b3ac0800e 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,11 @@ module github.com/labstack/echo/v4 go 1.14 require ( + github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/labstack/gommon v0.3.0 github.com/mattn/go-colorable v0.1.6 // indirect github.com/stretchr/testify v1.4.0 - github.com/valyala/fasttemplate v1.1.0 // indirect + github.com/valyala/fasttemplate v1.1.0 golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d golang.org/x/net v0.0.0-20200226121028-0de0cce0169b golang.org/x/text v0.3.2 // indirect diff --git a/go.sum b/go.sum index 2f6d74d07..8e7e54ce7 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,8 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgrijalva/jwt-go v1.0.2 h1:KPldsxuKGsS2FPWsNeg9ZO18aCrGKujPoWXn2yo+KQM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0= github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= From 2207c37bf875b40f2f7b14c587c1c3de7581340e Mon Sep 17 00:00:00 2001 From: Arun Gopalpuri Date: Wed, 8 Apr 2020 08:19:22 -0700 Subject: [PATCH 021/446] use echo.GetPath for rewrite in proxy (#1548) Co-authored-by: Arun Gopalpuri --- echo.go | 7 ++++--- middleware/proxy.go | 2 +- middleware/proxy_test.go | 4 ++++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/echo.go b/echo.go index 511eb43fb..af93354a7 100644 --- a/echo.go +++ b/echo.go @@ -606,12 +606,12 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { h := NotFoundHandler if e.premiddleware == nil { - e.findRouter(r.Host).Find(r.Method, getPath(r), c) + e.findRouter(r.Host).Find(r.Method, GetPath(r), c) h = c.Handler() h = applyMiddleware(h, e.middleware...) } else { h = func(c Context) error { - e.findRouter(r.Host).Find(r.Method, getPath(r), c) + e.findRouter(r.Host).Find(r.Method, GetPath(r), c) h := c.Handler() h = applyMiddleware(h, e.middleware...) return h(c) @@ -817,7 +817,8 @@ func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { } } -func getPath(r *http.Request) string { +// GetPath returns RawPath, if it's empty returns Path from URL +func GetPath(r *http.Request) string { path := r.URL.RawPath if path == "" { path = r.URL.Path diff --git a/middleware/proxy.go b/middleware/proxy.go index 1da370dbf..1956e91ee 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -224,7 +224,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { // Rewrite for k, v := range config.rewriteRegex { - replacer := captureTokens(k, req.URL.Path) + replacer := captureTokens(k, echo.GetPath(req)) if replacer != nil { req.URL.Path = replacer.Replace(v) } diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 40d150cff..5ef11bc89 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -104,6 +104,10 @@ func TestProxy(t *testing.T) { e.ServeHTTP(rec, req) assert.Equal(t, "/user/jack/order/1", req.URL.Path) assert.Equal(t, http.StatusOK, rec.Code) + req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F" + e.ServeHTTP(rec, req) + assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.Path) + assert.Equal(t, http.StatusOK, rec.Code) // ProxyTarget is set in context contextObserver := func(next echo.HandlerFunc) echo.HandlerFunc { From c29904d81c75a47f41520c205907337c7136d7b7 Mon Sep 17 00:00:00 2001 From: Ori Shoshan Date: Sat, 25 Apr 2020 20:58:16 +0300 Subject: [PATCH 022/446] Fixed double padding in Group.File, Group.Add (#1534) Group.File was padding with g.prefix even though it would later call Group.Add which padded with prefix again - for a total of two times --- group.go | 2 +- group_test.go | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/group.go b/group.go index 5d9582535..426bef9eb 100644 --- a/group.go +++ b/group.go @@ -109,7 +109,7 @@ func (g *Group) Static(prefix, root string) { // File implements `Echo#File()` for sub-routes within the Group. func (g *Group) File(path, file string) { - g.file(g.prefix+path, file, g.GET) + g.file(path, file, g.GET) } // Add implements `Echo#Add()` for sub-routes within the Group. diff --git a/group_test.go b/group_test.go index 342cd29e2..c51fd91eb 100644 --- a/group_test.go +++ b/group_test.go @@ -1,7 +1,9 @@ package echo import ( + "io/ioutil" "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -26,6 +28,19 @@ func TestGroup(t *testing.T) { g.File("/walle", "_fixture/images//walle.png") } +func TestGroupFile(t *testing.T) { + e := New() + g := e.Group("/group") + g.File("/walle", "_fixture/images/walle.png") + expectedData, err := ioutil.ReadFile("_fixture/images/walle.png") + assert.Nil(t, err) + req := httptest.NewRequest(http.MethodGet, "/group/walle", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, expectedData, rec.Body.Bytes()) +} + func TestGroupRouteMiddleware(t *testing.T) { // Ensure middleware slices are not re-used e := New() From a8b5de4286ca93f04e13a3ed00bbc2639fce5863 Mon Sep 17 00:00:00 2001 From: Takashi Iwamoto Date: Sun, 26 Apr 2020 03:01:03 +0900 Subject: [PATCH 023/446] Add test case for Response (#1557) --- response_test.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/response_test.go b/response_test.go index bc570a502..07776cae4 100644 --- a/response_test.go +++ b/response_test.go @@ -41,3 +41,13 @@ func TestResponse_Write_UsesSetResponseCode(t *testing.T) { res.Write([]byte("test")) assert.Equal(t, http.StatusBadRequest, rec.Code) } + +func TestResponse_Flush(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + res := &Response{echo: e, Writer: rec} + + res.Write([]byte("test")) + res.Flush() + assert.True(t, rec.Flushed) +} From c08f30359b790f06aecdcf9ec68de39e5b3815f6 Mon Sep 17 00:00:00 2001 From: Lars Lehtonen Date: Sat, 25 Apr 2020 11:01:54 -0700 Subject: [PATCH 024/446] test matrix add go1.14 (#1551) test matrix deprecate go1.11 --- .github/workflows/echo.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index cfa44e683..1d7508b97 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -17,7 +17,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - go: [1.11, 1.12, 1.13] + go: [1.12, 1.13, 1.14] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: From 4aebe68f37d8ba62cac9cc99c6a198eaff1cc6c8 Mon Sep 17 00:00:00 2001 From: Takashi Iwamoto Date: Sun, 26 Apr 2020 22:01:24 +0900 Subject: [PATCH 025/446] Add test for 'func (*Response) After' --- response_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/response_test.go b/response_test.go index 07776cae4..7a9c51c66 100644 --- a/response_test.go +++ b/response_test.go @@ -19,8 +19,13 @@ func TestResponse(t *testing.T) { res.Before(func() { c.Response().Header().Set(HeaderServer, "echo") }) + // After + res.After(func() { + c.Response().Header().Set(HeaderXFrameOptions, "DENY") + }) res.Write([]byte("test")) assert.Equal(t, "echo", rec.Header().Get(HeaderServer)) + assert.Equal(t, "DENY", rec.Header().Get(HeaderXFrameOptions)) } func TestResponse_Write_FallsBackToDefaultStatus(t *testing.T) { From 803c4f673b8fc4ad96efb78b80842f2d185cffe0 Mon Sep 17 00:00:00 2001 From: Jonathan Hall Date: Wed, 29 Apr 2020 16:13:30 +0200 Subject: [PATCH 026/446] Extend HTTPError to satisfy the Go 1.13 error wrapper interface --- echo.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/echo.go b/echo.go index af93354a7..c7e4d6dbc 100644 --- a/echo.go +++ b/echo.go @@ -795,6 +795,11 @@ func (he *HTTPError) SetInternal(err error) *HTTPError { return he } +// Unwrap satisfies the Go 1.13 error wrapper interface. +func (he *HTTPError) Unwrap() error { + return he.Internal +} + // WrapHandler wraps `http.Handler` into `echo.HandlerFunc`. func WrapHandler(h http.Handler) HandlerFunc { return func(c Context) error { From ea34bf944168568af6cdfb5decb62dbd727246d7 Mon Sep 17 00:00:00 2001 From: Jonathan Hall Date: Tue, 5 May 2020 16:18:16 +0200 Subject: [PATCH 027/446] Add tests for HTTPError.Unwrap --- echo_go1.13_test.go | 28 ++++++++++++++++++++++++++++ echo_test.go | 1 - 2 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 echo_go1.13_test.go diff --git a/echo_go1.13_test.go b/echo_go1.13_test.go new file mode 100644 index 000000000..3c488bc63 --- /dev/null +++ b/echo_go1.13_test.go @@ -0,0 +1,28 @@ +// +build go1.13 + +package echo + +import ( + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHTTPError_Unwrap(t *testing.T) { + t.Run("non-internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + + assert.Nil(t, errors.Unwrap(err)) + }) + t.Run("internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + err.SetInternal(errors.New("internal error")) + assert.Equal(t, "internal error", errors.Unwrap(err).Error()) + }) +} diff --git a/echo_test.go b/echo_test.go index ddbc56f27..b9c177844 100644 --- a/echo_test.go +++ b/echo_test.go @@ -549,7 +549,6 @@ func TestHTTPError(t *testing.T) { }) assert.Equal(t, "code=400, message=map[code:12]", err.Error()) - }) t.Run("internal", func(t *testing.T) { err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ From 43e32ba83d638c73a415609f26d513eda30033ee Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Wed, 6 May 2020 23:01:28 +0200 Subject: [PATCH 028/446] Fix #1526 trailing slash to any route (#1563) * refs #1526: Add tests for trailing slash requests with nested any routes * refs #1526: Handle specual router case with trailing slash for non-root any route * refs #1526: Fix accidential lookup for any route without trailing slash in request --- router.go | 4 +++ router_test.go | 73 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/router.go b/router.go index 15a3398f2..ed728d6a2 100644 --- a/router.go +++ b/router.go @@ -355,6 +355,10 @@ func (r *Router) Find(method, path string, c Context) { // Attempt to go back up the tree on no matching prefix or no remaining search if l != pl || search == "" { + // Handle special case of trailing slash route with existing any route (see #1526) + if path[len(path)-1] == '/' && cn.findChildByKind(akind) != nil { + goto Any + } if nn == nil { // Issue #1348 return // Not found } diff --git a/router_test.go b/router_test.go index 8c27b9f72..0e883233b 100644 --- a/router_test.go +++ b/router_test.go @@ -608,7 +608,6 @@ func TestRouterMatchAny(t *testing.T) { return nil }) c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/", c) assert.Equal(t, "", c.Param("*")) @@ -619,6 +618,78 @@ func TestRouterMatchAny(t *testing.T) { assert.Equal(t, "joe", c.Param("*")) } +// TestRouterMatchAnySlash shall verify finding the best route +// for any routes with trailing slash requests +func TestRouterMatchAnySlash(t *testing.T) { + e := New() + r := e.router + + handler := func(c Context) error { + c.Set("path", c.Path()) + return nil + } + + // Routes + r.Add(http.MethodGet, "/users", handler) + r.Add(http.MethodGet, "/users/*", handler) + r.Add(http.MethodGet, "/img/*", handler) + r.Add(http.MethodGet, "/img/load", handler) + r.Add(http.MethodGet, "/img/load/*", handler) + r.Add(http.MethodGet, "/assets/*", handler) + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/", c) + assert.Equal(t, "", c.Param("*")) + + // Test trailing slash request for simple any route (see #1526) + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/", c) + c.handler(c) + assert.Equal(t, "/users/*", c.Get("path")) + assert.Equal(t, "", c.Param("*")) + + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/joe", c) + c.handler(c) + assert.Equal(t, "/users/*", c.Get("path")) + assert.Equal(t, "joe", c.Param("*")) + + // Test trailing slash request for nested any route (see #1526) + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/img/load", c) + c.handler(c) + assert.Equal(t, "/img/load", c.Get("path")) + assert.Equal(t, "", c.Param("*")) + + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/img/load/", c) + c.handler(c) + assert.Equal(t, "/img/load/*", c.Get("path")) + assert.Equal(t, "", c.Param("*")) + + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/img/load/ben", c) + c.handler(c) + assert.Equal(t, "/img/load/*", c.Get("path")) + assert.Equal(t, "ben", c.Param("*")) + + // Test /assets/* any route + // ... without trailing slash must not match + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/assets", c) + c.handler(c) + assert.Equal(t, nil, c.Get("path")) + assert.Equal(t, "", c.Param("*")) + + // ... with trailing slash must match + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/assets/", c) + c.handler(c) + assert.Equal(t, "/assets/*", c.Get("path")) + assert.Equal(t, "", c.Param("*")) + +} + func TestRouterMatchAnyMultiLevel(t *testing.T) { e := New() r := e.router From de3a2d4df33def85548cc1ff2931898e4efc9902 Mon Sep 17 00:00:00 2001 From: roz3x <52892437+roz3x@users.noreply.github.com> Date: Wed, 1 Jul 2020 09:38:30 +0530 Subject: [PATCH 029/446] changed guide highlighting to shell (#1593) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c57d478fb..769c9bbbf 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Lower is better! ### Installation -```go +```sh // go get github.com/labstack/echo/{version} go get github.com/labstack/echo/v4 ``` From e125b2cf8473c30a42d17757a39f62474545a933 Mon Sep 17 00:00:00 2001 From: Masahiro Furudate <178inaba.git@gmail.com> Date: Mon, 6 Jul 2020 23:59:42 +0900 Subject: [PATCH 030/446] Fix recover print stack trace log level (#1604) * Fix recover print stack trace log level * Add recover log level test * Add default LogLevel to DefaultRecoverConfig --- middleware/recover.go | 22 ++++++++++++++- middleware/recover_test.go | 57 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/middleware/recover.go b/middleware/recover.go index e87aaf321..0dbe740da 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -5,6 +5,7 @@ import ( "runtime" "github.com/labstack/echo/v4" + "github.com/labstack/gommon/log" ) type ( @@ -25,6 +26,10 @@ type ( // DisablePrintStack disables printing stack trace. // Optional. Default value as false. DisablePrintStack bool `yaml:"disable_print_stack"` + + // LogLevel is log level to printing stack trace. + // Optional. Default value 0 (Print). + LogLevel log.Lvl } ) @@ -35,6 +40,7 @@ var ( StackSize: 4 << 10, // 4 KB DisableStackAll: false, DisablePrintStack: false, + LogLevel: 0, } ) @@ -70,7 +76,21 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { stack := make([]byte, config.StackSize) length := runtime.Stack(stack, !config.DisableStackAll) if !config.DisablePrintStack { - c.Logger().Printf("[PANIC RECOVER] %v %s\n", err, stack[:length]) + msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length]) + switch config.LogLevel { + case log.DEBUG: + c.Logger().Debug(msg) + case log.INFO: + c.Logger().Info(msg) + case log.WARN: + c.Logger().Warn(msg) + case log.ERROR: + c.Logger().Error(msg) + case log.OFF: + // None. + default: + c.Logger().Print(msg) + } } c.Error(err) } diff --git a/middleware/recover_test.go b/middleware/recover_test.go index 37707c5c1..644332972 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -2,11 +2,13 @@ package middleware import ( "bytes" + "fmt" "net/http" "net/http/httptest" "testing" "github.com/labstack/echo/v4" + "github.com/labstack/gommon/log" "github.com/stretchr/testify/assert" ) @@ -24,3 +26,58 @@ func TestRecover(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, rec.Code) assert.Contains(t, buf.String(), "PANIC RECOVER") } + +func TestRecoverWithConfig_LogLevel(t *testing.T) { + tests := []struct { + logLevel log.Lvl + levelName string + }{{ + logLevel: log.DEBUG, + levelName: "DEBUG", + }, { + logLevel: log.INFO, + levelName: "INFO", + }, { + logLevel: log.WARN, + levelName: "WARN", + }, { + logLevel: log.ERROR, + levelName: "ERROR", + }, { + logLevel: log.OFF, + levelName: "OFF", + }} + + for _, tt := range tests { + tt := tt + t.Run(tt.levelName, func(t *testing.T) { + e := echo.New() + e.Logger.SetLevel(log.DEBUG) + + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + config := DefaultRecoverConfig + config.LogLevel = tt.logLevel + h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { + panic("test") + })) + + h(c) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + + output := buf.String() + if tt.logLevel == log.OFF { + assert.Empty(t, output) + } else { + assert.Contains(t, output, "PANIC RECOVER") + assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName)) + } + }) + } +} From 546639c8d083f028bad93959ab6f58177604eea1 Mon Sep 17 00:00:00 2001 From: buglan <1831353087@qq.com> Date: Wed, 8 Jul 2020 20:17:34 +0800 Subject: [PATCH 031/446] Fix duplicate code --- echo.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/echo.go b/echo.go index af93354a7..e90e45342 100644 --- a/echo.go +++ b/echo.go @@ -504,11 +504,7 @@ func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware .. name := handlerName(handler) router := e.findRouter(host) router.Add(method, path, func(c Context) error { - h := handler - // Chain middleware - for i := len(middleware) - 1; i >= 0; i-- { - h = middleware[i](h) - } + h := applyMiddleware(handler, middleware...) return h(c) }) r := &Route{ From c4118c049ec59d6179dede346f3e69953263bc09 Mon Sep 17 00:00:00 2001 From: cathy zhang Date: Mon, 20 Jul 2020 14:01:21 +0800 Subject: [PATCH 032/446] dependency package golang.org/x/text v0.3.2 has high security vulnerabiliy, upgrade it to v0.3.3 --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index b3ac0800e..ed220b707 100644 --- a/go.mod +++ b/go.mod @@ -10,5 +10,5 @@ require ( github.com/valyala/fasttemplate v1.1.0 golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d golang.org/x/net v0.0.0-20200226121028-0de0cce0169b - golang.org/x/text v0.3.2 // indirect + golang.org/x/text v0.3.3 // indirect ) diff --git a/go.sum b/go.sum index 8e7e54ce7..11ef341a8 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= From 84da507a2ebd52e616bb9f76cbfbcaeb369ecf68 Mon Sep 17 00:00:00 2001 From: chotow Date: Wed, 17 Jun 2020 12:44:24 +0800 Subject: [PATCH 033/446] Fixes the uses of caret(^) in rewrite regex --- middleware/rewrite.go | 4 ++++ middleware/rewrite_test.go | 16 ++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/middleware/rewrite.go b/middleware/rewrite.go index d1387af0f..241bde9b8 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -59,7 +59,11 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { for k, v := range config.Rules { k = regexp.QuoteMeta(k) k = strings.Replace(k, `\*`, "(.*)", -1) + k = strings.Replace(k, `\^`, "^", -1) k = k + "$" + if strings.HasPrefix(k, "/") { + k = "^" + k + } config.rulesRegex[regexp.MustCompile(k)] = v } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index eb5a46d89..6d89c13cf 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -18,6 +18,10 @@ func TestRewrite(t *testing.T) { "/api/*": "/$1", "/js/*": "/public/javascripts/$1", "/users/*/orders/*": "/user/$1/order/$2", + "/foo/*": "/v1/foo/$1", + "/v1/foo/*": "/v1/foo/$1", + "/v2/foo/*": "/v2/foo/$1", + "^/bar/*": "/foobar/$1", }, })) req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -37,6 +41,18 @@ func TestRewrite(t *testing.T) { req.URL.Path = "/api/new users" e.ServeHTTP(rec, req) assert.Equal(t, "/new users", req.URL.Path) + req.URL.Path = "/foo/bar" + e.ServeHTTP(rec, req) + assert.Equal(t, "/v1/foo/bar", req.URL.Path) + req.URL.Path = "/v1/foo/bar" + e.ServeHTTP(rec, req) + assert.Equal(t, "/v1/foo/bar", req.URL.Path) + req.URL.Path = "/v2/foo/bar" + e.ServeHTTP(rec, req) + assert.Equal(t, "/v2/foo/bar", req.URL.Path) + req.URL.Path = "/bar/baz" + e.ServeHTTP(rec, req) + assert.Equal(t, "/foobar/baz", req.URL.Path) } // Issue #1086 From 68e8bce6450e66aa0da1f70e0fb7119a6d34030b Mon Sep 17 00:00:00 2001 From: chotow Date: Fri, 24 Jul 2020 21:55:27 +0800 Subject: [PATCH 034/446] Revert "Fixes the uses of caret(^) in rewrite regex" This reverts commit 1f51469436e3612e8e121413df905dc9f4ffed0b. --- middleware/rewrite.go | 4 ---- middleware/rewrite_test.go | 16 ---------------- 2 files changed, 20 deletions(-) diff --git a/middleware/rewrite.go b/middleware/rewrite.go index 241bde9b8..d1387af0f 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -59,11 +59,7 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { for k, v := range config.Rules { k = regexp.QuoteMeta(k) k = strings.Replace(k, `\*`, "(.*)", -1) - k = strings.Replace(k, `\^`, "^", -1) k = k + "$" - if strings.HasPrefix(k, "/") { - k = "^" + k - } config.rulesRegex[regexp.MustCompile(k)] = v } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index 6d89c13cf..eb5a46d89 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -18,10 +18,6 @@ func TestRewrite(t *testing.T) { "/api/*": "/$1", "/js/*": "/public/javascripts/$1", "/users/*/orders/*": "/user/$1/order/$2", - "/foo/*": "/v1/foo/$1", - "/v1/foo/*": "/v1/foo/$1", - "/v2/foo/*": "/v2/foo/$1", - "^/bar/*": "/foobar/$1", }, })) req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -41,18 +37,6 @@ func TestRewrite(t *testing.T) { req.URL.Path = "/api/new users" e.ServeHTTP(rec, req) assert.Equal(t, "/new users", req.URL.Path) - req.URL.Path = "/foo/bar" - e.ServeHTTP(rec, req) - assert.Equal(t, "/v1/foo/bar", req.URL.Path) - req.URL.Path = "/v1/foo/bar" - e.ServeHTTP(rec, req) - assert.Equal(t, "/v1/foo/bar", req.URL.Path) - req.URL.Path = "/v2/foo/bar" - e.ServeHTTP(rec, req) - assert.Equal(t, "/v2/foo/bar", req.URL.Path) - req.URL.Path = "/bar/baz" - e.ServeHTTP(rec, req) - assert.Equal(t, "/foobar/baz", req.URL.Path) } // Issue #1086 From 3dbd5dcf6e134d8b12875700dfc1c8d3f19dccb2 Mon Sep 17 00:00:00 2001 From: chotow Date: Fri, 24 Jul 2020 22:01:19 +0800 Subject: [PATCH 035/446] Fixes the uses of caret(^) at the beginning of the rewrite regex --- middleware/rewrite.go | 3 +++ middleware/rewrite_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/middleware/rewrite.go b/middleware/rewrite.go index d1387af0f..476023f6d 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -59,6 +59,9 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { for k, v := range config.Rules { k = regexp.QuoteMeta(k) k = strings.Replace(k, `\*`, "(.*)", -1) + if strings.HasPrefix(k, `\^`) { + k = strings.Replace(k, `\^`, "^", -1) + } k = k + "$" config.rulesRegex[regexp.MustCompile(k)] = v } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index eb5a46d89..848f0029f 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -94,3 +94,30 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { assert.Equal(t, "hosts", string(bodyBytes)) } } + +// Issue #1573 +func TestEchoRewriteWithCaret(t *testing.T) { + e := echo.New() + + e.Pre(RewriteWithConfig(RewriteConfig{ + Rules: map[string]string{ + "^/abc/*": "/v1/abc/$1", + }, + })) + + rec := httptest.NewRecorder() + + var req *http.Request + + req = httptest.NewRequest(http.MethodGet, "/abc/test", nil) + e.ServeHTTP(rec, req) + assert.Equal(t, "/v1/abc/test", req.URL.Path) + + req = httptest.NewRequest(http.MethodGet, "/v1/abc/test", nil) + e.ServeHTTP(rec, req) + assert.Equal(t, "/v1/abc/test", req.URL.Path) + + req = httptest.NewRequest(http.MethodGet, "/v2/abc/test", nil) + e.ServeHTTP(rec, req) + assert.Equal(t, "/v2/abc/test", req.URL.Path) +} From 8dd25c39ced4fafbe63e361c281f2275d7d9fa41 Mon Sep 17 00:00:00 2001 From: Shinnosuke Sawada <6warashi9@gmail.com> Date: Tue, 4 Aug 2020 09:58:08 +0900 Subject: [PATCH 036/446] make gzipResponseWriter implement http.Pusher (#1615) --- middleware/compress.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/middleware/compress.go b/middleware/compress.go index 89da16efe..dd97d983d 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -119,3 +119,10 @@ func (w *gzipResponseWriter) Flush() { func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return w.ResponseWriter.(http.Hijacker).Hijack() } + +func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { + if p, ok := w.ResponseWriter.(http.Pusher); ok { + return p.Push(target, opts) + } + return http.ErrNotSupported +} From 9a28fb8608b341f03a5e68c0595f40de11478e61 Mon Sep 17 00:00:00 2001 From: Andrew Klotz Date: Tue, 18 Aug 2020 01:39:54 +0000 Subject: [PATCH 037/446] cors allow regex pattern enable cors to use regex pattern for allowed origins implementation is similar to another popular cors middleware: https://github.com/astaxie/beego/blob/master/plugins/cors/cors.go#L196-L201 --- middleware/cors.go | 30 +++++++++ middleware/cors_test.go | 138 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 168 insertions(+) diff --git a/middleware/cors.go b/middleware/cors.go index 5dfe31f95..c263f7319 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -2,6 +2,7 @@ package middleware import ( "net/http" + "regexp" "strconv" "strings" @@ -76,6 +77,15 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { config.AllowMethods = DefaultCORSConfig.AllowMethods } + allowOriginPatterns := []string{} + for _, origin := range config.AllowOrigins { + pattern := regexp.QuoteMeta(origin) + pattern = strings.Replace(pattern, "\\*", ".*", -1) + pattern = strings.Replace(pattern, "\\?", ".", -1) + pattern = "^" + pattern + "$" + allowOriginPatterns = append(allowOriginPatterns, pattern) + } + allowMethods := strings.Join(config.AllowMethods, ",") allowHeaders := strings.Join(config.AllowHeaders, ",") exposeHeaders := strings.Join(config.ExposeHeaders, ",") @@ -108,6 +118,26 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } } + // Check allowed origin patterns + for _, re := range allowOriginPatterns { + if allowOrigin == "" { + didx := strings.Index(origin, "://") + if didx == -1 { + continue + } + domAuth := origin[didx+3:] + // to avoid regex cost by invalid long domain + if len(domAuth) > 253 { + break + } + + if match, _ := regexp.MatchString(re, origin); match { + allowOrigin = origin + break + } + } + } + // Simple request if req.Method != http.MethodOptions { res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) diff --git a/middleware/cors_test.go b/middleware/cors_test.go index 456ec7b3d..ca922321c 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -83,3 +83,141 @@ func TestCORS(t *testing.T) { h(c) assert.Equal(t, "http://bbb.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } + +func Test_allowOriginScheme(t *testing.T) { + tests := []struct { + domain, pattern string + expected bool + }{ + { + domain: "http://example.com", + pattern: "http://example.com", + expected: true, + }, + { + domain: "https://example.com", + pattern: "https://example.com", + expected: true, + }, + { + domain: "http://example.com", + pattern: "https://example.com", + expected: false, + }, + { + domain: "https://example.com", + pattern: "http://example.com", + expected: false, + }, + } + + e := echo.New() + for _, tt := range tests { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + req.Header.Set(echo.HeaderOrigin, tt.domain) + cors := CORSWithConfig(CORSConfig{ + AllowOrigins: []string{tt.pattern}, + }) + h := cors(echo.NotFoundHandler) + h(c) + + if tt.expected { + assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } else { + assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } + } +} + +func Test_allowOriginSubdomain(t *testing.T) { + tests := []struct { + domain, pattern string + expected bool + }{ + { + domain: "http://aaa.example.com", + pattern: "http://*.example.com", + expected: true, + }, + { + domain: "http://bbb.aaa.example.com", + pattern: "http://*.example.com", + expected: true, + }, + { + domain: "http://bbb.aaa.example.com", + pattern: "http://*.aaa.example.com", + expected: true, + }, + { + domain: "http://aaa.example.com:8080", + pattern: "http://*.example.com:8080", + expected: true, + }, + + { + domain: "http://fuga.hoge.com", + pattern: "http://*.example.com", + expected: false, + }, + { + domain: "http://ccc.bbb.example.com", + pattern: "http://*.aaa.example.com", + expected: false, + }, + { + domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ + .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ + .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ + .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`, + pattern: "http://*.example.com", + expected: false, + }, + { + domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`, + pattern: "http://*.example.com", + expected: false, + }, + { + domain: "http://ccc.bbb.example.com", + pattern: "http://example.com", + expected: false, + }, + { + domain: "https://prod-preview--aaa.bbb.com", + pattern: "https://*--aaa.bbb.com", + expected: true, + }, + { + domain: "http://ccc.bbb.example.com", + pattern: "http://*.example.com", + expected: true, + }, + { + domain: "http://ccc.bbb.example.com", + pattern: "http://foo.[a-z]*.example.com", + expected: false, + }, + } + + e := echo.New() + for _, tt := range tests { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + req.Header.Set(echo.HeaderOrigin, tt.domain) + cors := CORSWithConfig(CORSConfig{ + AllowOrigins: []string{tt.pattern}, + }) + h := cors(echo.NotFoundHandler) + h(c) + + if tt.expected { + assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } else { + assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } + } +} From 09f36b585d5b1c0d4ff9d38f32c49d8764729658 Mon Sep 17 00:00:00 2001 From: Juan Belieni Date: Thu, 27 Aug 2020 19:35:45 -0300 Subject: [PATCH 038/446] Create ErrJWTInvalid variable --- middleware/jwt.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/middleware/jwt.go b/middleware/jwt.go index 3c7c48681..bab00c9f8 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -86,6 +86,7 @@ const ( // Errors var ( ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt") + ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") ) var ( @@ -213,8 +214,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { return config.ErrorHandlerWithContext(err, c) } return &echo.HTTPError{ - Code: http.StatusUnauthorized, - Message: "invalid or expired jwt", + Code: ErrJWTInvalid.Code, + Message: ErrJWTInvalid.Message, Internal: err, } } From 6463bcb190302063a2375c3268c78de5861e8d8c Mon Sep 17 00:00:00 2001 From: Peter C <63091190+petoc@users.noreply.github.com> Date: Fri, 28 Aug 2020 02:51:27 +0200 Subject: [PATCH 039/446] added ModifyResponse option to ProxyConfig (#1622) Co-authored-by: Peter C --- middleware/proxy.go | 3 +++ middleware/proxy_1_11.go | 1 + middleware/proxy_test.go | 19 ++++++++++++++++++- 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/middleware/proxy.go b/middleware/proxy.go index 1956e91ee..a9b91f6ce 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -45,6 +45,9 @@ type ( // Examples: If custom TLS certificates are required. Transport http.RoundTripper + // ModifyResponse defines function to modify response from ProxyTarget. + ModifyResponse func(*http.Response) error + rewriteRegex map[*regexp.Regexp]string } diff --git a/middleware/proxy_1_11.go b/middleware/proxy_1_11.go index 12b7568bf..a43927817 100644 --- a/middleware/proxy_1_11.go +++ b/middleware/proxy_1_11.go @@ -20,5 +20,6 @@ func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handle c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err))) } proxy.Transport = config.Transport + proxy.ModifyResponse = config.ModifyResponse return proxy } diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 5ef11bc89..b19bf4f2e 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -1,7 +1,9 @@ package middleware import ( + "bytes" "fmt" + "io/ioutil" "net" "net/http" "net/http/httptest" @@ -104,11 +106,26 @@ func TestProxy(t *testing.T) { e.ServeHTTP(rec, req) assert.Equal(t, "/user/jack/order/1", req.URL.Path) assert.Equal(t, http.StatusOK, rec.Code) - req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F" + req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F" e.ServeHTTP(rec, req) assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.Path) assert.Equal(t, http.StatusOK, rec.Code) + // ModifyResponse + e = echo.New() + e.Use(ProxyWithConfig(ProxyConfig{ + Balancer: rrb, + ModifyResponse: func(res *http.Response) error { + res.Body = ioutil.NopCloser(bytes.NewBuffer([]byte("modified"))) + res.Header.Set("X-Modified", "1") + return nil + }, + })) + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "modified", rec.Body.String()) + assert.Equal(t, "1", rec.Header().Get("X-Modified")) + // ProxyTarget is set in context contextObserver := func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { From 28020c2a47a8d56c3357239af284c477684ddeb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A3=8E=E5=90=B9=E8=BF=87?= Date: Fri, 28 Aug 2020 08:53:48 +0800 Subject: [PATCH 040/446] The directory path does not end with '/', it needs to be redirected (#1572) * The directory path does not end with '/', it needs to be redirected * changed guide highlighting to shell (#1593) * Fix recover print stack trace log level (#1604) * Fix recover print stack trace log level * Add recover log level test * Add default LogLevel to DefaultRecoverConfig Co-authored-by: solym Co-authored-by: roz3x <52892437+roz3x@users.noreply.github.com> Co-authored-by: Masahiro Furudate <178inaba.git@gmail.com> --- echo.go | 14 ++++++++++++++ echo_test.go | 13 +++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/echo.go b/echo.go index c6881c9cf..888d4c0cd 100644 --- a/echo.go +++ b/echo.go @@ -48,6 +48,7 @@ import ( "net" "net/http" "net/url" + "os" "path" "path/filepath" "reflect" @@ -479,7 +480,20 @@ func (common) static(prefix, root string, get func(string, HandlerFunc, ...Middl if err != nil { return err } + name := filepath.Join(root, path.Clean("/"+p)) // "/"+ for security + fi, err := os.Stat(name) + if err != nil { + // The access path does not exist + return NotFoundHandler(c) + } + + // If the request is for a directory and does not end with "/" + p = c.Request().URL.Path // path must not be empty. + if fi.IsDir() && p[len(p)-1] != '/' { + // Redirect to ends with "/" + return c.Redirect(http.StatusMovedPermanently, p+"/") + } return c.File(name) } if prefix == "/" { diff --git a/echo_test.go b/echo_test.go index b9c177844..e1706eff7 100644 --- a/echo_test.go +++ b/echo_test.go @@ -76,9 +76,17 @@ func TestEchoStatic(t *testing.T) { // Directory e.Static("/images", "_fixture/images") - c, _ = request(http.MethodGet, "/images", e) + c, _ = request(http.MethodGet, "/images/", e) assert.Equal(http.StatusNotFound, c) + // Directory Redirect + e.Static("/", "_fixture") + req := httptest.NewRequest(http.MethodGet, "/folder", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(http.StatusMovedPermanently, rec.Code) + assert.Equal("/folder/", rec.HeaderMap["Location"][0]) + // Directory with index.html e.Static("/", "_fixture") c, r := request(http.MethodGet, "/", e) @@ -86,9 +94,10 @@ func TestEchoStatic(t *testing.T) { assert.Equal(true, strings.HasPrefix(r, "")) // Sub-directory with index.html - c, r = request(http.MethodGet, "/folder", e) + c, r = request(http.MethodGet, "/folder/", e) assert.Equal(http.StatusOK, c) assert.Equal(true, strings.HasPrefix(r, "")) + } func TestEchoFile(t *testing.T) { From cb84205219f6b0962ab5bb39ab6c0911f5f1652d Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Thu, 27 Aug 2020 20:04:53 -0700 Subject: [PATCH 041/446] Bumped version Signed-off-by: Vishal Rana --- .travis.yml | 4 ++-- README.md | 2 +- echo.go | 2 +- go.mod | 11 ++++++----- go.sum | 24 +++++++++++++----------- 5 files changed, 23 insertions(+), 20 deletions(-) diff --git a/.travis.yml b/.travis.yml index a1fc87684..ef826e952 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ language: go go: - - 1.12.x - - 1.13.x + - 1.14.x + - 1.15.x - tip env: - GO111MODULE=on diff --git a/README.md b/README.md index 769c9bbbf..d9d96139d 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ Therefore a Go version capable of understanding /vN suffixed imports is required - 1.9.7+ - 1.10.3+ -- 1.11+ +- 1.14+ Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended way of using Echo going forward. diff --git a/echo.go b/echo.go index 888d4c0cd..18c110166 100644 --- a/echo.go +++ b/echo.go @@ -231,7 +231,7 @@ const ( const ( // Version of Echo - Version = "4.1.16" + Version = "4.1.17" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` diff --git a/go.mod b/go.mod index ed220b707..74c6a9abe 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1,15 @@ module github.com/labstack/echo/v4 -go 1.14 +go 1.15 require ( github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/labstack/gommon v0.3.0 - github.com/mattn/go-colorable v0.1.6 // indirect + github.com/mattn/go-colorable v0.1.7 // indirect github.com/stretchr/testify v1.4.0 - github.com/valyala/fasttemplate v1.1.0 - golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d - golang.org/x/net v0.0.0-20200226121028-0de0cce0169b + github.com/valyala/fasttemplate v1.2.1 + golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a + golang.org/x/net v0.0.0-20200822124328-c89045814202 + golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6 // indirect golang.org/x/text v0.3.3 // indirect ) diff --git a/go.sum b/go.sum index 11ef341a8..58c80c831 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,13 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgrijalva/jwt-go v1.0.2 h1:KPldsxuKGsS2FPWsNeg9ZO18aCrGKujPoWXn2yo+KQM= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0= github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-colorable v0.1.6 h1:6Su7aK7lXmJ/U79bYtBjLNaha4Fs1Rg9plHpcH+vvnE= -github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.7 h1:bQGKb3vps/j0E9GfJQ03JyhRuxsvdAanXlT9BTw3mdw= +github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.9 h1:d5US/mDsogSGW37IV293h//ZFaeajb69h+EHFsv2xGg= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= @@ -23,14 +22,15 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.0.1 h1:tY9CJiPnMXf1ERmG2EyK7gNUd+c6RKGD0IfU8WdUSz8= github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= -github.com/valyala/fasttemplate v1.1.0 h1:RZqt0yGBsps8NGvLSGW804QQqCUYYLsaOjTVHy1Ocw4= -github.com/valyala/fasttemplate v1.1.0/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= +github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= +github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d h1:1ZiEyfaQIg3Qh0EoqpwAakHVhecoE5wlSg5GjnafJGw= -golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM= +golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20200226121028-0de0cce0169b h1:0mm1VjtFUOIlE1SbDlwjYaDxZVDP2S5ou6y0gSgXHu8= -golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -39,13 +39,15 @@ golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae h1:/WDfKMnPU+m5M4xB+6x4kaepxRw6jWvR5iDRdvjHgy8= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6 h1:DvY3Zkh7KabQE/kfzMvYvKirSiguP9Q/veMtkYyf0o8= +golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= From 1a6ec73e57633b3b64755c07b13080d3507f9cf5 Mon Sep 17 00:00:00 2001 From: Arun Gopalpuri Date: Fri, 28 Aug 2020 12:47:02 -0700 Subject: [PATCH 042/446] using url.EscapedPath instead of custom GetPath, rewritePath func added to middleware - used by proxy and rewrite --- echo.go | 14 ++------------ middleware/middleware.go | 13 +++++++++++++ middleware/proxy.go | 7 +++++-- middleware/proxy_test.go | 23 ++++++++++++++++++----- middleware/rewrite.go | 7 +++++-- middleware/rewrite_test.go | 27 ++++++++++++++++++++------- 6 files changed, 63 insertions(+), 28 deletions(-) diff --git a/echo.go b/echo.go index 18c110166..128f84fd2 100644 --- a/echo.go +++ b/echo.go @@ -612,16 +612,15 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Acquire context c := e.pool.Get().(*context) c.Reset(r, w) - h := NotFoundHandler if e.premiddleware == nil { - e.findRouter(r.Host).Find(r.Method, GetPath(r), c) + e.findRouter(r.Host).Find(r.Method, r.URL.EscapedPath(), c) h = c.Handler() h = applyMiddleware(h, e.middleware...) } else { h = func(c Context) error { - e.findRouter(r.Host).Find(r.Method, GetPath(r), c) + e.findRouter(r.Host).Find(r.Method, r.URL.EscapedPath(), c) h := c.Handler() h = applyMiddleware(h, e.middleware...) return h(c) @@ -832,15 +831,6 @@ func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { } } -// GetPath returns RawPath, if it's empty returns Path from URL -func GetPath(r *http.Request) string { - path := r.URL.RawPath - if path == "" { - path = r.URL.Path - } - return path -} - func (e *Echo) findRouter(host string) *Router { if len(e.routers) > 0 { if r, ok := e.routers[host]; ok { diff --git a/middleware/middleware.go b/middleware/middleware.go index d0b7153cb..12260ddb2 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -1,6 +1,8 @@ package middleware import ( + "net/http" + "net/url" "regexp" "strconv" "strings" @@ -32,6 +34,17 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { return strings.NewReplacer(replace...) } +//rewritePath sets request url path and raw path +func rewritePath(replacer *strings.Replacer, target string, req *http.Request) error { + replacerRawPath := replacer.Replace(target) + replacerPath, err := url.PathUnescape(replacerRawPath) + if err != nil { + return err + } + req.URL.Path, req.URL.RawPath = replacerPath, replacerRawPath + return nil +} + // DefaultSkipper returns false which processes the middleware. func DefaultSkipper(echo.Context) bool { return false diff --git a/middleware/proxy.go b/middleware/proxy.go index a9b91f6ce..cd50b76a1 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -227,9 +227,12 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { // Rewrite for k, v := range config.rewriteRegex { - replacer := captureTokens(k, echo.GetPath(req)) + //use req.URL.Path here or else we will have double escaping + replacer := captureTokens(k, req.URL.Path) if replacer != nil { - req.URL.Path = replacer.Replace(v) + if err := rewritePath(replacer, v, req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "invalid url") + } } } diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index b19bf4f2e..4bb74648c 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" ) +//Assert expected with url.EscapedPath method to obtain the path. func TestProxy(t *testing.T) { // Setup t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -94,22 +95,34 @@ func TestProxy(t *testing.T) { }, })) req.URL.Path = "/api/users" + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/users", req.URL.Path) + assert.Equal(t, "/users", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) req.URL.Path = "/js/main.js" + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/public/javascripts/main.js", req.URL.Path) + assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) req.URL.Path = "/old" + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/new", req.URL.Path) + assert.Equal(t, "/new", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) req.URL.Path = "/users/jack/orders/1" + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/user/jack/order/1", req.URL.Path) + assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F" + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.Path) + assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) + req.URL.Path = "/users/jill/orders/%%%%" + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusBadRequest, rec.Code) // ModifyResponse e = echo.New() diff --git a/middleware/rewrite.go b/middleware/rewrite.go index d1387af0f..a4d851cd0 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -1,6 +1,7 @@ package middleware import ( + "net/http" "regexp" "strings" @@ -70,12 +71,14 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { } req := c.Request() - // Rewrite for k, v := range config.rulesRegex { + //use req.URL.Path here or else we will have double escaping replacer := captureTokens(k, req.URL.Path) if replacer != nil { - req.URL.Path = replacer.Replace(v) + if err := rewritePath(replacer, v, req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "invalid url") + } break } } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index eb5a46d89..dbbcb20d8 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" ) +//Assert expected with url.EscapedPath method to obtain the path. func TestRewrite(t *testing.T) { e := echo.New() e.Use(RewriteWithConfig(RewriteConfig{ @@ -24,19 +25,31 @@ func TestRewrite(t *testing.T) { rec := httptest.NewRecorder() req.URL.Path = "/api/users" e.ServeHTTP(rec, req) - assert.Equal(t, "/users", req.URL.Path) + assert.Equal(t, "/users", req.URL.EscapedPath()) req.URL.Path = "/js/main.js" + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/public/javascripts/main.js", req.URL.Path) + assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) req.URL.Path = "/old" + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/new", req.URL.Path) + assert.Equal(t, "/new", req.URL.EscapedPath()) req.URL.Path = "/users/jack/orders/1" + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/user/jack/order/1", req.URL.Path) + assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) req.URL.Path = "/api/new users" + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/new users", req.URL.Path) + assert.Equal(t, "/new%20users", req.URL.EscapedPath()) + req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F" + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) + req.URL.Path = "/users/jill/orders/%%%%" + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusBadRequest, rec.Code) } // Issue #1086 @@ -59,7 +72,7 @@ func TestEchoRewritePreMiddleware(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/old", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/new", req.URL.Path) + assert.Equal(t, "/new", req.URL.EscapedPath()) assert.Equal(t, 200, rec.Code) } @@ -86,7 +99,7 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/api/v1/mgmt/proj/test/agt", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/api/v1/hosts/test", req.URL.Path) + assert.Equal(t, "/api/v1/hosts/test", req.URL.EscapedPath()) assert.Equal(t, 200, rec.Code) defer rec.Result().Body.Close() From bfbab25044df148b0bbf27b34be03995e814bfd8 Mon Sep 17 00:00:00 2001 From: Florian Polster Date: Sat, 12 Sep 2020 10:47:03 +0200 Subject: [PATCH 043/446] Update godoc link in README to /v4 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d9d96139d..03ad4dca3 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge) -[![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](http://godoc.org/github.com/labstack/echo) +[![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4) [![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo) [![Build Status](http://img.shields.io/travis/labstack/echo.svg?style=flat-square)](https://travis-ci.org/labstack/echo) [![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo) From a7b6d444a49ff1bab342765387710d31f08c13af Mon Sep 17 00:00:00 2001 From: Florian Polster Date: Sat, 12 Sep 2020 10:57:58 +0200 Subject: [PATCH 044/446] Run Test Workflow only if Go code was changed --- .github/workflows/echo.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 1d7508b97..128940cac 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -4,9 +4,13 @@ on: push: branches: - master + paths: + - '**.go' pull_request: branches: - master + paths: + - '**.go' env: GO111MODULE: on From 5c5c83d290fc37a218f653a45af1e4a9436a96a9 Mon Sep 17 00:00:00 2001 From: yonbiaoxiao Date: Sun, 23 Aug 2020 11:18:31 +0800 Subject: [PATCH 045/446] change the hardcode for http constant --- middleware/rewrite_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index 34c5e29c4..a9b3437ce 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -42,11 +42,11 @@ func TestRewrite(t *testing.T) { rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/new%20users", req.URL.EscapedPath()) - req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F" + req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F" rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) - req.URL.Path = "/users/jill/orders/%%%%" + req.URL.Path = "/users/jill/orders/%%%%" rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, http.StatusBadRequest, rec.Code) @@ -66,14 +66,14 @@ func TestEchoRewritePreMiddleware(t *testing.T) { // Route r.Add(http.MethodGet, "/new", func(c echo.Context) error { - return c.NoContent(200) + return c.NoContent(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/old", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/new", req.URL.EscapedPath()) - assert.Equal(t, 200, rec.Code) + assert.Equal(t, http.StatusOK, rec.Code) } // Issue #1143 @@ -89,10 +89,10 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { })) r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { - return c.String(200, "hosts") + return c.String(http.StatusOK, "hosts") }) r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { - return c.String(200, "eng") + return c.String(http.StatusOK, "eng") }) for i := 0; i < 100; i++ { @@ -100,7 +100,7 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { rec := httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/api/v1/hosts/test", req.URL.EscapedPath()) - assert.Equal(t, 200, rec.Code) + assert.Equal(t, http.StatusOK, rec.Code) defer rec.Result().Body.Close() bodyBytes, _ := ioutil.ReadAll(rec.Result().Body) From 622f5e33d4ec32aee7c74d0ecd5339ebe1445f7b Mon Sep 17 00:00:00 2001 From: yonbiaoxiao Date: Tue, 15 Sep 2020 16:58:05 +0800 Subject: [PATCH 046/446] Use IndexByte instead of Split to reduce memory allocation and improve performance --- context.go | 6 +++++- context_test.go | 9 +++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/context.go b/context.go index 99ef03bcb..0507f1390 100644 --- a/context.go +++ b/context.go @@ -276,7 +276,11 @@ func (c *context) RealIP() string { } // Fall back to legacy behavior if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" { - return strings.Split(ip, ", ")[0] + i := strings.IndexAny(ip, ", ") + if i > 0 { + return ip[:i] + } + return ip } if ip := c.request.Header.Get(HeaderXRealIP); ip != "" { return ip diff --git a/context_test.go b/context_test.go index 73e5dcb62..56ac4bebf 100644 --- a/context_test.go +++ b/context_test.go @@ -871,3 +871,12 @@ func TestContext_RealIP(t *testing.T) { testify.Equal(t, tt.s, tt.c.RealIP()) } } + +func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) { + c := context{request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, + }} + for i := 0; i < b.N; i++ { + c.RealIP() + } +} From 64c49509963da87bb45f32ec239995933798e5a6 Mon Sep 17 00:00:00 2001 From: yonbiaoxiao Date: Wed, 16 Sep 2020 10:36:43 +0800 Subject: [PATCH 047/446] improve the test coverage for context.go --- context_test.go | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/context_test.go b/context_test.go index 56ac4bebf..0044bf870 100644 --- a/context_test.go +++ b/context_test.go @@ -72,6 +72,15 @@ func BenchmarkAllocXML(b *testing.B) { } } +func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) { + c := context{request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, + }} + for i := 0; i < b.N; i++ { + c.RealIP() + } +} + func (t *Template) Render(w io.Writer, name string, data interface{}, c Context) error { return t.templates.ExecuteTemplate(w, name, data) } @@ -847,6 +856,14 @@ func TestContext_RealIP(t *testing.T) { }, "127.0.0.1", }, + { + &context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1"}}, + }, + }, + "127.0.0.1", + }, { &context{ request: &http.Request{ @@ -871,12 +888,3 @@ func TestContext_RealIP(t *testing.T) { testify.Equal(t, tt.s, tt.c.RealIP()) } } - -func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) { - c := context{request: &http.Request{ - Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, - }} - for i := 0; i < b.N; i++ { - c.RealIP() - } -} From f6dfcbe774b43b9ad12e9b477bb34b4412c23452 Mon Sep 17 00:00:00 2001 From: Arun Gopalpuri Date: Thu, 3 Sep 2020 00:39:57 -0700 Subject: [PATCH 048/446] bugfix proxy and rewrite, updated test with actual call settings --- middleware/middleware.go | 31 ++++++++++++++++++++++--------- middleware/proxy.go | 22 +++++----------------- middleware/proxy_test.go | 15 +++++++-------- middleware/rewrite.go | 30 ++++-------------------------- middleware/rewrite_test.go | 25 ++++++++++--------------- 5 files changed, 48 insertions(+), 75 deletions(-) diff --git a/middleware/middleware.go b/middleware/middleware.go index 12260ddb2..60834b505 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -2,7 +2,6 @@ package middleware import ( "net/http" - "net/url" "regexp" "strconv" "strings" @@ -34,15 +33,29 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { return strings.NewReplacer(replace...) } -//rewritePath sets request url path and raw path -func rewritePath(replacer *strings.Replacer, target string, req *http.Request) error { - replacerRawPath := replacer.Replace(target) - replacerPath, err := url.PathUnescape(replacerRawPath) - if err != nil { - return err +func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string { + // Initialize + rulesRegex := map[*regexp.Regexp]string{} + for k, v := range rewrite { + k = regexp.QuoteMeta(k) + k = strings.Replace(k, `\*`, "(.*)", -1) + if strings.HasPrefix(k, `\^`) { + k = strings.Replace(k, `\^`, "^", -1) + } + k = k + "$" + rulesRegex[regexp.MustCompile(k)] = v + } + return rulesRegex +} + +func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) { + for k, v := range rewriteRegex { + replacerRawPath := captureTokens(k, req.URL.EscapedPath()) + if replacerRawPath != nil { + replacerPath := captureTokens(k, req.URL.Path) + req.URL.RawPath, req.URL.Path = replacerRawPath.Replace(v), replacerPath.Replace(v) + } } - req.URL.Path, req.URL.RawPath = replacerPath, replacerRawPath - return nil } // DefaultSkipper returns false which processes the middleware. diff --git a/middleware/proxy.go b/middleware/proxy.go index cd50b76a1..1b972eb16 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -8,7 +8,6 @@ import ( "net/http" "net/url" "regexp" - "strings" "sync" "sync/atomic" "time" @@ -206,13 +205,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { if config.Balancer == nil { panic("echo: proxy middleware requires balancer") } - config.rewriteRegex = map[*regexp.Regexp]string{} - // Initialize - for k, v := range config.Rewrite { - k = strings.Replace(k, "*", "(\\S*)", -1) - config.rewriteRegex[regexp.MustCompile(k)] = v - } + config.rewriteRegex = rewriteRulesRegex(config.Rewrite) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { @@ -225,16 +219,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { tgt := config.Balancer.Next(c) c.Set(config.ContextKey, tgt) - // Rewrite - for k, v := range config.rewriteRegex { - //use req.URL.Path here or else we will have double escaping - replacer := captureTokens(k, req.URL.Path) - if replacer != nil { - if err := rewritePath(replacer, v, req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "invalid url") - } - } - } + // Set rewrite path and raw path + rewritePath(config.rewriteRegex, req) // Fix header // Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream. @@ -265,3 +251,5 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { } } } + + diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 4bb74648c..534e45f44 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -94,36 +94,35 @@ func TestProxy(t *testing.T) { "/users/*/orders/*": "/user/$1/order/$2", }, })) - req.URL.Path = "/api/users" + req.URL, _ = url.Parse("/api/users") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/users", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) - req.URL.Path = "/js/main.js" + req.URL, _ = url.Parse( "/js/main.js") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) - req.URL.Path = "/old" + req.URL, _ = url.Parse("/old") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/new", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) - req.URL.Path = "/users/jack/orders/1" + req.URL, _ = url.Parse( "/users/jack/orders/1") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) - req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F" + req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) assert.Equal(t, http.StatusOK, rec.Code) - req.URL.Path = "/users/jill/orders/%%%%" + req.URL, _ = url.Parse("/api/new users") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusBadRequest, rec.Code) - + assert.Equal(t, "/new%20users", req.URL.EscapedPath()) // ModifyResponse e = echo.New() e.Use(ProxyWithConfig(ProxyConfig{ diff --git a/middleware/rewrite.go b/middleware/rewrite.go index 855c8633a..0965e313f 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -1,11 +1,8 @@ package middleware import ( - "net/http" - "regexp" - "strings" - "github.com/labstack/echo/v4" + "regexp" ) type ( @@ -54,18 +51,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { if config.Skipper == nil { config.Skipper = DefaultBodyDumpConfig.Skipper } - config.rulesRegex = map[*regexp.Regexp]string{} - // Initialize - for k, v := range config.Rules { - k = regexp.QuoteMeta(k) - k = strings.Replace(k, `\*`, "(.*)", -1) - if strings.HasPrefix(k, `\^`) { - k = strings.Replace(k, `\^`, "^", -1) - } - k = k + "$" - config.rulesRegex[regexp.MustCompile(k)] = v - } + config.rulesRegex = rewriteRulesRegex(config.Rules) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { @@ -74,17 +61,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { } req := c.Request() - // Rewrite - for k, v := range config.rulesRegex { - //use req.URL.Path here or else we will have double escaping - replacer := captureTokens(k, req.URL.Path) - if replacer != nil { - if err := rewritePath(replacer, v, req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "invalid url") - } - break - } - } + // Set rewrite path and raw path + rewritePath(config.rulesRegex, req) return next(c) } } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index a9b3437ce..abf11b2f7 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -4,6 +4,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/labstack/echo/v4" @@ -23,33 +24,28 @@ func TestRewrite(t *testing.T) { })) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - req.URL.Path = "/api/users" + req.URL, _ = url.Parse("/api/users") e.ServeHTTP(rec, req) assert.Equal(t, "/users", req.URL.EscapedPath()) - req.URL.Path = "/js/main.js" + req.URL, _ = url.Parse("/js/main.js") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) - req.URL.Path = "/old" + req.URL, _ = url.Parse("/old") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/new", req.URL.EscapedPath()) - req.URL.Path = "/users/jack/orders/1" + req.URL, _ = url.Parse("/users/jack/orders/1") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) - req.URL.Path = "/api/new users" - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/new%20users", req.URL.EscapedPath()) - req.URL.Path = "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F" + req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F") rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) - req.URL.Path = "/users/jill/orders/%%%%" - rec = httptest.NewRecorder() + req.URL, _ = url.Parse("/api/new users") e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Equal(t, "/new%20users", req.URL.EscapedPath()) } // Issue #1086 @@ -58,11 +54,10 @@ func TestEchoRewritePreMiddleware(t *testing.T) { r := e.Router() // Rewrite old url to new one - e.Pre(RewriteWithConfig(RewriteConfig{ - Rules: map[string]string{ + e.Pre(Rewrite(map[string]string{ "/old": "/new", }, - })) + )) // Route r.Add(http.MethodGet, "/new", func(c echo.Context) error { From 42271822e43f390db71e7df1574e858f5cd1a921 Mon Sep 17 00:00:00 2001 From: yonbiaoxiao Date: Tue, 20 Oct 2020 11:54:40 +0800 Subject: [PATCH 049/446] remove unless defer --- context.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/context.go b/context.go index 99ef03bcb..3bc0bbd9c 100644 --- a/context.go +++ b/context.go @@ -361,7 +361,7 @@ func (c *context) FormFile(name string) (*multipart.FileHeader, error) { if err != nil { return nil, err } - defer f.Close() + f.Close() return fh, nil } From 85e521f3848b272a5b4fefd496c4846ef32e8476 Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Sun, 25 Oct 2020 04:03:07 +0000 Subject: [PATCH 050/446] Fixed panic when Router#Find fails on Param paths Fixed panic when Router#Find fails to find a route that could match a Param route that only have children routes and no root route. e.g /create /:id/edit /:id/active Finding /creates results in panic because the router tree node that belongs to the param route :id don't have pnames on it. The childrens of :id (:id/edit and :id/active) have the pnames properly set, but those are not processed because /creates don't match on those paths. --- go.sum | 2 ++ router.go | 4 +++- router_test.go | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/go.sum b/go.sum index 58c80c831..187f309f9 100644 --- a/go.sum +++ b/go.sum @@ -15,6 +15,7 @@ github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHX github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= @@ -46,6 +47,7 @@ golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e h1:FDhOuMEY4JVRztM/gsbk+IKUQ8kj74bxZrgw87eMMVc= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/router.go b/router.go index ed728d6a2..4c3898c46 100644 --- a/router.go +++ b/router.go @@ -428,7 +428,9 @@ func (r *Router) Find(method, path string, c Context) { pos := strings.IndexByte(ns, '/') if pos == -1 { // If no slash is remaining in search string set param value - pvalues[len(cn.pnames)-1] = search + if len(cn.pnames) > 0 { + pvalues[len(cn.pnames)-1] = search + } break } else if pos > 0 { // Otherwise continue route processing with restored next node diff --git a/router_test.go b/router_test.go index 0e883233b..fca3a79bb 100644 --- a/router_test.go +++ b/router_test.go @@ -1298,6 +1298,40 @@ func TestRouterParam1466(t *testing.T) { assert.Equal(t, 0, c.response.Status) } +// Issue #1653 +func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/users/create", handlerHelper("create", 1)) + r.Add(http.MethodGet, "/users/:id/edit", func(c Context) error { + return nil + }) + r.Add(http.MethodGet, "/users/:id/active", func(c Context) error { + return nil + }) + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/alice/edit", c) + assert.Equal(t, "alice", c.Param("id")) + + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/bob/active", c) + assert.Equal(t, "bob", c.Param("id")) + + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/create", c) + c.Handler()(c) + assert.Equal(t, 1, c.Get("create")) + assert.Equal(t, "/users/create", c.Get("path")) + + //This panic before the fix for Issue #1653 + c = e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/users/createNotFound", c) + he := c.Handler()(c).(*HTTPError) + assert.Equal(t, http.StatusNotFound, he.Code) +} + func benchmarkRouterRoutes(b *testing.B, routes []*Route) { e := New() r := e.router From 23c21871b7f9f3a3b0cb913161f55806746a5d6a Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Wed, 28 Oct 2020 04:30:41 +0000 Subject: [PATCH 051/446] Fixed Router#Find panic an infinite loop Before this fix, Router#Find panics or enters in an infinite loop when the context params values were set to a number less than the max number of params supported by the Router. --- context.go | 24 ++++++++++++++++++++++-- context_test.go | 34 ++++++++++++++++++++++++++++++++++ router_test.go | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 2 deletions(-) diff --git a/context.go b/context.go index 99ef03bcb..fb686c930 100644 --- a/context.go +++ b/context.go @@ -310,7 +310,19 @@ func (c *context) ParamNames() []string { func (c *context) SetParamNames(names ...string) { c.pnames = names - *c.echo.maxParam = len(names) + + l := len(names) + if *c.echo.maxParam < l { + *c.echo.maxParam = l + } + + if len(c.pvalues) < l { + // Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them, + // probably those values will be overriden in a Context#SetParamValues + newPvalues := make([]string, l) + copy(newPvalues, c.pvalues) + c.pvalues = newPvalues + } } func (c *context) ParamValues() []string { @@ -318,7 +330,15 @@ func (c *context) ParamValues() []string { } func (c *context) SetParamValues(values ...string) { - c.pvalues = values + // NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam at all times + // It will brake the Router#Find code + limit := len(values) + if limit > *c.echo.maxParam { + limit = *c.echo.maxParam + } + for i := 0; i < limit; i++ { + c.pvalues[i] = values[i] + } } func (c *context) QueryParam(name string) string { diff --git a/context_test.go b/context_test.go index 73e5dcb62..f17574583 100644 --- a/context_test.go +++ b/context_test.go @@ -508,6 +508,40 @@ func TestContextGetAndSetParam(t *testing.T) { }) } +// Issue #1655 +func TestContextSetParamNamesShouldUpdateEchoMaxParam(t *testing.T) { + assert := testify.New(t) + + e := New() + assert.Equal(0, *e.maxParam) + + expectedOneParam := []string{"one"} + expectedTwoParams := []string{"one", "two"} + expectedThreeParams := []string{"one", "two", ""} + expectedABCParams := []string{"A", "B", "C"} + + c := e.NewContext(nil, nil) + c.SetParamNames("1", "2") + c.SetParamValues(expectedTwoParams...) + assert.Equal(2, *e.maxParam) + assert.EqualValues(expectedTwoParams, c.ParamValues()) + + c.SetParamNames("1") + assert.Equal(2, *e.maxParam) + // Here for backward compatibility the ParamValues remains as they are + assert.EqualValues(expectedOneParam, c.ParamValues()) + + c.SetParamNames("1", "2", "3") + assert.Equal(3, *e.maxParam) + // Here for backward compatibility the ParamValues remains as they are, but the len is extended to e.maxParam + assert.EqualValues(expectedThreeParams, c.ParamValues()) + + c.SetParamValues("A", "B", "C", "D") + assert.Equal(3, *e.maxParam) + // Here D shouldn't be returned + assert.EqualValues(expectedABCParams, c.ParamValues()) +} + func TestContextFormValue(t *testing.T) { f := make(url.Values) f.Set("name", "Jon Snow") diff --git a/router_test.go b/router_test.go index 0e883233b..d0972720b 100644 --- a/router_test.go +++ b/router_test.go @@ -1298,6 +1298,43 @@ func TestRouterParam1466(t *testing.T) { assert.Equal(t, 0, c.response.Status) } +// Issue #1655 +func TestRouterFindNotPanicOrLoopsWhenContextSetParamValuesIsCalledWithLessValuesThanEchoMaxParam(t *testing.T) { + e := New() + r := e.router + + v0 := e.Group("/:version") + v0.GET("/admin", func(c Context) error { + c.SetParamNames("version") + c.SetParamValues("v1") + return nil + }) + + v0.GET("/images/view/:id", handlerHelper("iv", 1)) + v0.GET("/images/:id", handlerHelper("i", 1)) + v0.GET("/view/*", handlerHelper("v", 1)) + + //If this API is called before the next two one panic the other loops ( of course without my fix ;) ) + c := e.NewContext(nil, nil) + r.Find(http.MethodGet, "/v1/admin", c) + c.Handler()(c) + assert.Equal(t, "v1", c.Param("version")) + + //panic + c = e.NewContext(nil, nil) + r.Find(http.MethodGet, "/v1/view/same-data", c) + c.Handler()(c) + assert.Equal(t, "same-data", c.Param("*")) + assert.Equal(t, 1, c.Get("v")) + + //looping + c = e.NewContext(nil, nil) + r.Find(http.MethodGet, "/v1/images/view", c) + c.Handler()(c) + assert.Equal(t, "view", c.Param("id")) + assert.Equal(t, 1, c.Get("i")) +} + func benchmarkRouterRoutes(b *testing.B, routes []*Route) { e := New() r := e.router From 44b4054b9e2f8d0fe613fc921505860290506a4e Mon Sep 17 00:00:00 2001 From: santosh653 <70637961+santosh653@users.noreply.github.com> Date: Mon, 2 Nov 2020 10:10:16 -0500 Subject: [PATCH 052/446] Update .travis.yml adding power support --- .travis.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.travis.yml b/.travis.yml index ef826e952..67d45ad78 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,3 +1,7 @@ +arch: + - amd64 + - ppc64le + language: go go: - 1.14.x From cdd946aaa0aecbba17ac64c6924cc15359ca81cb Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Thu, 5 Nov 2020 03:37:15 +0100 Subject: [PATCH 053/446] Fix DefaultHTTPErrorHandler with Debug=true (#1477) --- echo.go | 10 ++++++---- echo_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/echo.go b/echo.go index 128f84fd2..29b88b706 100644 --- a/echo.go +++ b/echo.go @@ -362,10 +362,12 @@ func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { // Issue #1426 code := he.Code message := he.Message - if e.Debug { - message = err.Error() - } else if m, ok := message.(string); ok { - message = Map{"message": m} + if m, ok := he.Message.(string); ok { + if e.Debug { + message = Map{"message": m, "error": err.Error()} + } else { + message = Map{"message": m} + } } // Send response diff --git a/echo_test.go b/echo_test.go index e1706eff7..0368dbd7a 100644 --- a/echo_test.go +++ b/echo_test.go @@ -568,6 +568,49 @@ func TestHTTPError(t *testing.T) { }) } +func TestDefaultHTTPErrorHandler(t *testing.T) { + e := New() + e.Debug = true + e.Any("/plain", func(c Context) error { + return errors.New("An error occurred") + }) + e.Any("/badrequest", func(c Context) error { + return NewHTTPError(http.StatusBadRequest, "Invalid request") + }) + e.Any("/servererror", func(c Context) error { + return NewHTTPError(http.StatusInternalServerError, map[string]interface{}{ + "code": 33, + "message": "Something bad happened", + "error": "stackinfo", + }) + }) + // With Debug=true plain response contains error message + c, b := request(http.MethodGet, "/plain", e) + assert.Equal(t, http.StatusInternalServerError, c) + assert.Equal(t, "{\n \"error\": \"An error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", b) + // and special handling for HTTPError + c, b = request(http.MethodGet, "/badrequest", e) + assert.Equal(t, http.StatusBadRequest, c) + assert.Equal(t, "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n", b) + // complex errors are serialized to pretty JSON + c, b = request(http.MethodGet, "/servererror", e) + assert.Equal(t, http.StatusInternalServerError, c) + assert.Equal(t, "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", b) + + e.Debug = false + // With Debug=false the error response is shortened + c, b = request(http.MethodGet, "/plain", e) + assert.Equal(t, http.StatusInternalServerError, c) + assert.Equal(t, "{\"message\":\"Internal Server Error\"}\n", b) + c, b = request(http.MethodGet, "/badrequest", e) + assert.Equal(t, http.StatusBadRequest, c) + assert.Equal(t, "{\"message\":\"Invalid request\"}\n", b) + // No difference for error response with non plain string errors + c, b = request(http.MethodGet, "/servererror", e) + assert.Equal(t, http.StatusInternalServerError, c) + assert.Equal(t, "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n", b) +} + func TestEchoClose(t *testing.T) { e := New() errCh := make(chan error) From 4727bc6e997162b72631eb427873456118452c8e Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Thu, 5 Nov 2020 04:15:32 +0000 Subject: [PATCH 054/446] Adding Echo#ListenerNetwork as configuration Now Echo could be configured to Listen on tcp supported networks of net.Listen Go standard library (tcp, tcp4, tcp6) --- echo.go | 22 +++++++++++------- echo_test.go | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 8 deletions(-) diff --git a/echo.go b/echo.go index 128f84fd2..524bd9d92 100644 --- a/echo.go +++ b/echo.go @@ -92,6 +92,7 @@ type ( Renderer Renderer Logger Logger IPExtractor IPExtractor + ListenerNetwork string } // Route contains a handler and information for matching against requests. @@ -281,6 +282,7 @@ var ( ErrInvalidRedirectCode = errors.New("invalid redirect status code") ErrCookieNotFound = errors.New("cookie not found") ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") + ErrInvalidListenerNetwork = errors.New("invalid listener network") ) // Error handlers @@ -302,9 +304,10 @@ func New() (e *Echo) { AutoTLSManager: autocert.Manager{ Prompt: autocert.AcceptTOS, }, - Logger: log.New("echo"), - colorer: color.New(), - maxParam: new(int), + Logger: log.New("echo"), + colorer: color.New(), + maxParam: new(int), + ListenerNetwork: "tcp", } e.Server.Handler = e e.TLSServer.Handler = e @@ -712,7 +715,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) { if s.TLSConfig == nil { if e.Listener == nil { - e.Listener, err = newListener(s.Addr) + e.Listener, err = newListener(s.Addr, e.ListenerNetwork) if err != nil { return err } @@ -723,7 +726,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) { return s.Serve(e.Listener) } if e.TLSListener == nil { - l, err := newListener(s.Addr) + l, err := newListener(s.Addr, e.ListenerNetwork) if err != nil { return err } @@ -752,7 +755,7 @@ func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { } if e.Listener == nil { - e.Listener, err = newListener(s.Addr) + e.Listener, err = newListener(s.Addr, e.ListenerNetwork) if err != nil { return err } @@ -873,8 +876,11 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { return } -func newListener(address string) (*tcpKeepAliveListener, error) { - l, err := net.Listen("tcp", address) +func newListener(address, network string) (*tcpKeepAliveListener, error) { + if network != "tcp" && network != "tcp4" && network != "tcp6" { + return nil, ErrInvalidListenerNetwork + } + l, err := net.Listen(network, address) if err != nil { return nil, err } diff --git a/echo_test.go b/echo_test.go index e1706eff7..6f8130091 100644 --- a/echo_test.go +++ b/echo_test.go @@ -4,6 +4,7 @@ import ( "bytes" stdContext "context" "errors" + "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -609,3 +610,66 @@ func TestEchoShutdown(t *testing.T) { err := <-errCh assert.Equal(t, err.Error(), "http: Server closed") } + +var listenerNetworkTests = []struct { + test string + network string + address string +}{ + {"tcp ipv4 address", "tcp", "127.0.0.1:1323"}, + {"tcp ipv6 address", "tcp", "[::1]:1323"}, + {"tcp4 ipv4 address", "tcp4", "127.0.0.1:1323"}, + {"tcp6 ipv6 address", "tcp6", "[::1]:1323"}, +} + +func TestEchoListenerNetwork(t *testing.T) { + for _, tt := range listenerNetworkTests { + t.Run(tt.test, func(t *testing.T) { + e := New() + e.ListenerNetwork = tt.network + + // HandlerFunc + e.GET("/ok", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + + errCh := make(chan error) + + go func() { + errCh <- e.Start(tt.address) + }() + + time.Sleep(200 * time.Millisecond) + + if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil { + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + if body, err := ioutil.ReadAll(resp.Body); err == nil { + assert.Equal(t, "OK", string(body)) + } else { + assert.Fail(t, err.Error()) + } + + } else { + assert.Fail(t, err.Error()) + } + + if err := e.Close(); err != nil { + t.Fatal(err) + } + }) + } +} + +func TestEchoListenerNetworkInvalid(t *testing.T) { + e := New() + e.ListenerNetwork = "unix" + + // HandlerFunc + e.GET("/ok", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + + assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323")) +} From 871ed9c68d47e89a992679d4e70aa6ca62b0ad41 Mon Sep 17 00:00:00 2001 From: Ulas Akdeniz Date: Fri, 6 Nov 2020 01:15:40 +0100 Subject: [PATCH 055/446] Fix incorrect CORS headers - Fix empty Access-Control-Allow-Origin - Set CORS headers only if request Origin is existing and allowed - Increase middleware test coverage --- middleware/cors.go | 23 ++++++- middleware/cors_test.go | 145 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 162 insertions(+), 6 deletions(-) diff --git a/middleware/cors.go b/middleware/cors.go index c263f7319..07df0e57e 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -102,6 +102,17 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { origin := req.Header.Get(echo.HeaderOrigin) allowOrigin := "" + preflight := req.Method == http.MethodOptions + res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) + + // No Origin provided + if origin == "" { + if !preflight { + return next(c) + } + return c.NoContent(http.StatusNoContent) + } + // Check allowed origins for _, o := range config.AllowOrigins { if o == "*" && config.AllowCredentials { @@ -138,9 +149,16 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } } + // Origin not allowed + if allowOrigin == "" { + if !preflight { + return next(c) + } + return c.NoContent(http.StatusNoContent) + } + // Simple request - if req.Method != http.MethodOptions { - res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) + if !preflight { res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) if config.AllowCredentials { res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") @@ -152,7 +170,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } // Preflight request - res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) diff --git a/middleware/cors_test.go b/middleware/cors_test.go index ca922321c..fc34694db 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -17,19 +17,31 @@ func TestCORS(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := CORS()(echo.NotFoundHandler) + req.Header.Set(echo.HeaderOrigin, "localhost") h(c) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + // Wildcard AllowedOrigin with no Origin header in request + req = httptest.NewRequest(http.MethodGet, "/", nil) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h = CORS()(echo.NotFoundHandler) + h(c) + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + // Allow origins req = httptest.NewRequest(http.MethodGet, "/", nil) rec = httptest.NewRecorder() c = e.NewContext(req, rec) h = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"localhost"}, + AllowOrigins: []string{"localhost"}, + AllowCredentials: true, + MaxAge: 3600, })(echo.NotFoundHandler) req.Header.Set(echo.HeaderOrigin, "localhost") h(c) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) // Preflight request req = httptest.NewRequest(http.MethodOptions, "/", nil) @@ -67,6 +79,22 @@ func TestCORS(t *testing.T) { assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge)) + // Preflight request with Access-Control-Request-Headers + req = httptest.NewRequest(http.MethodOptions, "/", nil) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + req.Header.Set(echo.HeaderOrigin, "localhost") + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header") + cors = CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"*"}, + }) + h = cors(echo.NotFoundHandler) + h(c) + assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders)) + assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) + // Preflight request with `AllowOrigins` which allow all subdomains with * req = httptest.NewRequest(http.MethodOptions, "/", nil) rec = httptest.NewRecorder() @@ -126,7 +154,7 @@ func Test_allowOriginScheme(t *testing.T) { if tt.expected { assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } else { - assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) } } } @@ -217,7 +245,118 @@ func Test_allowOriginSubdomain(t *testing.T) { if tt.expected { assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } else { - assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + } + } +} + +func TestCorsHeaders(t *testing.T) { + tests := []struct { + domain, allowedOrigin, method string + expected bool + }{ + { + domain: "", // Request does not have Origin header + allowedOrigin: "*", + method: http.MethodGet, + expected: false, + }, + { + domain: "http://example.com", + allowedOrigin: "*", + method: http.MethodGet, + expected: true, + }, + { + domain: "", // Request does not have Origin header + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: false, + }, + { + domain: "http://bar.com", + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: false, + }, + { + domain: "http://example.com", + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: true, + }, + { + domain: "", // Request does not have Origin header + allowedOrigin: "*", + method: http.MethodOptions, + expected: false, + }, + { + domain: "http://example.com", + allowedOrigin: "*", + method: http.MethodOptions, + expected: true, + }, + { + domain: "", // Request does not have Origin header + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: false, + }, + { + domain: "http://bar.com", + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: false, + }, + { + domain: "http://example.com", + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: true, + }, + } + + e := echo.New() + for _, tt := range tests { + req := httptest.NewRequest(tt.method, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if tt.domain != "" { + req.Header.Set(echo.HeaderOrigin, tt.domain) + } + cors := CORSWithConfig(CORSConfig{ + AllowOrigins: []string{tt.allowedOrigin}, + //AllowCredentials: true, + //MaxAge: 3600, + }) + h := cors(echo.NotFoundHandler) + h(c) + + assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary)) + + expectedAllowOrigin := "" + if tt.allowedOrigin == "*" { + expectedAllowOrigin = "*" + } else { + expectedAllowOrigin = tt.domain + } + + switch { + case tt.expected && tt.method == http.MethodOptions: + assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods) + assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary])) + case tt.expected && tt.method == http.MethodGet: + assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin + default: + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin + } + + if tt.method == http.MethodOptions { + assert.Equal(t, http.StatusNoContent, rec.Code) } } } From d2b8a7fb450c6b9150fbad5034bb965d262fd757 Mon Sep 17 00:00:00 2001 From: pwli Date: Fri, 6 Nov 2020 22:21:05 +0800 Subject: [PATCH 056/446] Fix Static files route not working --- echo.go | 1 + 1 file changed, 1 insertion(+) diff --git a/echo.go b/echo.go index 29b88b706..db64e1c02 100644 --- a/echo.go +++ b/echo.go @@ -498,6 +498,7 @@ func (common) static(prefix, root string, get func(string, HandlerFunc, ...Middl } return c.File(name) } + get(prefix, h) if prefix == "/" { return get(prefix+"*", h) } From 13374d1daa4b551556fff632fd32a5188ddee6a2 Mon Sep 17 00:00:00 2001 From: pwli Date: Fri, 6 Nov 2020 22:25:00 +0800 Subject: [PATCH 057/446] add tests for Echo.Static() --- echo_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/echo_test.go b/echo_test.go index 0368dbd7a..ac4001222 100644 --- a/echo_test.go +++ b/echo_test.go @@ -86,6 +86,14 @@ func TestEchoStatic(t *testing.T) { e.ServeHTTP(rec, req) assert.Equal(http.StatusMovedPermanently, rec.Code) assert.Equal("/folder/", rec.HeaderMap["Location"][0]) + + // Directory Redirect with non-root path + e.Static("/static", "_fixture") + req = httptest.NewRequest(http.MethodGet, "/static", nil) + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(http.StatusMovedPermanently, rec.Code) + assert.Equal("/static/", rec.HeaderMap["Location"][0]) // Directory with index.html e.Static("/", "_fixture") @@ -100,6 +108,40 @@ func TestEchoStatic(t *testing.T) { } +func TestEchoStaticRedirectIndex(t *testing.T) { + assert := assert.New(t) + e := New() + + // HandlerFunc + e.Static("/static", "_fixture") + + errCh := make(chan error) + + go func() { + errCh <- e.Start("127.0.0.1:1323") + }() + + time.Sleep(200 * time.Millisecond) + + if resp, err := http.Get("http://127.0.0.1:1323/static"); err == nil { + defer resp.Body.Close() + assert.Equal(http.StatusOK, resp.StatusCode) + + if body, err := ioutil.ReadAll(resp.Body); err == nil { + assert.Equal(true, strings.HasPrefix(string(body), "")) + } else { + assert.Fail(err.Error()) + } + + } else { + assert.Fail(err.Error()) + } + + if err := e.Close(); err != nil { + t.Fatal(err) + } +} + func TestEchoFile(t *testing.T) { e := New() e.File("/walle", "_fixture/images/walle.png") From ac54e132e409e86ff3a6ec61f7ae3a04a20a8243 Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Sat, 7 Nov 2020 03:52:35 +0000 Subject: [PATCH 058/446] Adding sync.Pool to Compress Middleware Adding a sync.Pool for the *gzip.Writer reduces the allocations of the Compress middleware in 50% and gives an increase on execution speed of a 85% This fix #1643 --- middleware/compress.go | 26 +++++++++++++++++++++---- middleware/compress_test.go | 38 +++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/middleware/compress.go b/middleware/compress.go index dd97d983d..e4f9fc514 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "strings" + "sync" "github.com/labstack/echo/v4" ) @@ -58,6 +59,8 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { config.Level = DefaultGzipConfig.Level } + pool := gzipPool(config) + return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if config.Skipper(c) { @@ -68,11 +71,13 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) { res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 - rw := res.Writer - w, err := gzip.NewWriterLevel(rw, config.Level) - if err != nil { - return err + i := pool.Get() + w, ok := i.(*gzip.Writer) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) } + rw := res.Writer + w.Reset(rw) defer func() { if res.Size == 0 { if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { @@ -85,6 +90,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { w.Reset(ioutil.Discard) } w.Close() + pool.Put(w) }() grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw} res.Writer = grw @@ -126,3 +132,15 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { } return http.ErrNotSupported } + +func gzipPool(config GzipConfig) sync.Pool { + return sync.Pool{ + New: func() interface{} { + w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level) + if err != nil { + return err + } + return w + }, + } +} diff --git a/middleware/compress_test.go b/middleware/compress_test.go index ac5b6c3bb..d16ffca43 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -120,6 +120,22 @@ func TestGzipErrorReturned(t *testing.T) { assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) } +func TestGzipErrorReturnedInvalidConfig(t *testing.T) { + e := echo.New() + // Invalid level + e.Use(GzipWithConfig(GzipConfig{Level: 12})) + e.GET("/", func(c echo.Context) error { + c.Response().Write([]byte("test")) + return nil + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, rec.Body.String(), "gzip") +} + // Issue #806 func TestGzipWithStatic(t *testing.T) { e := echo.New() @@ -146,3 +162,25 @@ func TestGzipWithStatic(t *testing.T) { } } } + +func BenchmarkGzip(b *testing.B) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + + h := Gzip()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // Gzip + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h(c) + } +} From f72eaa4253ad6a8cd6b124cb9b53d537dbb2d91a Mon Sep 17 00:00:00 2001 From: Segev Finer Date: Sun, 8 Nov 2020 16:33:35 +0200 Subject: [PATCH 059/446] Remove group.Use registering Any routes that break other routes Fixes #1657 --- group.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/group.go b/group.go index 426bef9eb..d239fb581 100644 --- a/group.go +++ b/group.go @@ -23,10 +23,6 @@ func (g *Group) Use(middleware ...MiddlewareFunc) { if len(g.middleware) == 0 { return } - // Allow all requests to reach the group as they might get dropped if router - // doesn't find a match, making none of the group middleware process. - g.Any("", NotFoundHandler) - g.Any("/*", NotFoundHandler) } // CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. From 7a1126fb16449f3142e918ff546a2d611c48f4e3 Mon Sep 17 00:00:00 2001 From: Segev Finer Date: Tue, 10 Nov 2020 19:50:32 +0200 Subject: [PATCH 060/446] Add a test --- group_test.go | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/group_test.go b/group_test.go index c51fd91eb..d4a6846f5 100644 --- a/group_test.go +++ b/group_test.go @@ -119,3 +119,37 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { assert.Equal(t, "/*", m) } + +func TestMultipleGroupSamePathMiddleware(t *testing.T) { + // Ensure multiple groups with the same path do not clobber previous routes or mixup middlewares + e := New() + m1 := func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + c.Set("middleware", "m1") + return next(c) + } + } + m2 := func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + c.Set("middleware", "m2") + return next(c) + } + } + h := func(c Context) error { + return c.String(http.StatusOK, c.Get("middleware").(string)) + } + + g1 := e.Group("/group", m1) + { + g1.GET("", h) + } + g2 := e.Group("/group", m2) + { + g2.GET("/other", h) + } + + _, m := request(http.MethodGet, "/group", e) + assert.Equal(t, "m1", m) + _, m = request(http.MethodGet, "/group/other", e) + assert.Equal(t, "m2", m) +} From 31599cf1f49b12288bf82302f7907ddfc8c9b50e Mon Sep 17 00:00:00 2001 From: Florian Polster Date: Wed, 11 Nov 2020 10:19:05 +0100 Subject: [PATCH 061/446] Workflow also run on changes to go.mod, _fixture, .github changes --- .github/workflows/echo.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 128940cac..38596ab7d 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -6,11 +6,17 @@ on: - master paths: - '**.go' + - 'go.*' + - '_fixture/**' + - '.github/**' pull_request: branches: - master paths: - '**.go' + - 'go.*' + - '_fixture/**' + - '.github/**' env: GO111MODULE: on From ce646ae65ef37e82cbad4135a19c834b1534a034 Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Thu, 12 Nov 2020 00:03:58 -0300 Subject: [PATCH 062/446] Update README.md with an updated Benchmark There is also a related [PR](https://github.com/vishr/web-framework-benchmark/pull/3) to update the benchmark code --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 03ad4dca3..06f8d9d17 100644 --- a/README.md +++ b/README.md @@ -42,11 +42,12 @@ For older versions, please use the latest v3 tag. ## Benchmarks -Date: 2018/03/15
+Date: 2020/11/11
Source: https://github.com/vishr/web-framework-benchmark
Lower is better! - + + ## [Guide](https://echo.labstack.com/guide) From d385a92e513e8514d09d15c759c537e29d7f0d06 Mon Sep 17 00:00:00 2001 From: Ajitem Sahasrabuddhe Date: Fri, 13 Nov 2020 17:30:47 +0530 Subject: [PATCH 063/446] add support for Go 1.15 & drop support for Go 1.12 --- .github/workflows/echo.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 38596ab7d..1bcb5cf96 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -27,7 +27,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - go: [1.12, 1.13, 1.14] + go: [1.13, 1.14, 1.15] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: From b47042f385c522d60c5bead297959d472677c526 Mon Sep 17 00:00:00 2001 From: Arun Gopalpuri Date: Thu, 19 Nov 2020 20:44:00 -0800 Subject: [PATCH 064/446] adding decompress gzipped request middleware --- .github/workflows/echo.yml | 13 +-- middleware/decompress.go | 58 +++++++++++++ middleware/decompress_test.go | 148 ++++++++++++++++++++++++++++++++++ 3 files changed, 213 insertions(+), 6 deletions(-) create mode 100644 middleware/decompress.go create mode 100644 middleware/decompress_test.go diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 38596ab7d..c4fae7735 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -18,10 +18,6 @@ on: - '_fixture/**' - '.github/**' -env: - GO111MODULE: on - GOPROXY: https://proxy.golang.org - jobs: test: strategy: @@ -38,10 +34,15 @@ jobs: - name: Set GOPATH and PATH run: | - echo "::set-env name=GOPATH::$(dirname $GITHUB_WORKSPACE)" - echo "::add-path::$(dirname $GITHUB_WORKSPACE)/bin" + echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV + echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH shell: bash + - name: Set build variables + run: | + echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV + echo "GO111MODULE=on" >> $GITHUB_ENV + - name: Checkout Code uses: actions/checkout@v1 with: diff --git a/middleware/decompress.go b/middleware/decompress.go new file mode 100644 index 000000000..99eaf066d --- /dev/null +++ b/middleware/decompress.go @@ -0,0 +1,58 @@ +package middleware + +import ( + "bytes" + "compress/gzip" + "github.com/labstack/echo/v4" + "io" + "io/ioutil" +) + +type ( + // DecompressConfig defines the config for Decompress middleware. + DecompressConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + } +) + +//GZIPEncoding content-encoding header if set to "gzip", decompress body contents. +const GZIPEncoding string = "gzip" + +var ( + //DefaultDecompressConfig defines the config for decompress middleware + DefaultDecompressConfig = DecompressConfig{Skipper: DefaultSkipper} +) + +//Decompress decompresses request body based if content encoding type is set to "gzip" with default config +func Decompress() echo.MiddlewareFunc { + return DecompressWithConfig(DefaultDecompressConfig) +} + +//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config +func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + switch c.Request().Header.Get(echo.HeaderContentEncoding) { + case GZIPEncoding: + gr, err := gzip.NewReader(c.Request().Body) + if err != nil { + if err == io.EOF { //ignore if body is empty + return next(c) + } + return err + } + defer gr.Close() + var buf bytes.Buffer + io.Copy(&buf, gr) + r := ioutil.NopCloser(&buf) + defer r.Close() + c.Request().Body = r + } + return next(c) + } + } +} diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go new file mode 100644 index 000000000..772c14f6d --- /dev/null +++ b/middleware/decompress_test.go @@ -0,0 +1,148 @@ +package middleware + +import ( + "bytes" + "compress/gzip" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestDecompress(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Skip if no Content-Encoding header + h := Decompress()(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + h(c) + + assert := assert.New(t) + assert.Equal("test", rec.Body.String()) + + // Decompress + body := `{"name": "echo"}` + gz, _ := gzipString(body) + req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h(c) + assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := ioutil.ReadAll(req.Body) + assert.NoError(err) + assert.Equal(body, string(b)) +} + +func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { + e := echo.New() + body := `{"name":"echo"}` + gz, _ := gzipString(body) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := ioutil.ReadAll(req.Body) + assert.NoError(t, err) + assert.NotEqual(t, b, body) + assert.Equal(t, b, gz) +} + +func TestDecompressNoContent(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := Decompress()(func(c echo.Context) error { + return c.NoContent(http.StatusNoContent) + }) + if assert.NoError(t, h(c)) { + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) + assert.Equal(t, 0, len(rec.Body.Bytes())) + } +} + +func TestDecompressErrorReturned(t *testing.T) { + e := echo.New() + e.Use(Decompress()) + e.GET("/", func(c echo.Context) error { + return echo.ErrNotFound + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotFound, rec.Code) + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) +} + +func TestDecompressSkipper(t *testing.T) { + e := echo.New() + e.Use(DecompressWithConfig(DecompressConfig{ + Skipper: func(c echo.Context) bool { + return c.Request().URL.Path == "/skip" + }, + })) + body := `{"name": "echo"}` + req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8) + reqBody, err := ioutil.ReadAll(c.Request().Body) + assert.NoError(t, err) + assert.Equal(t, body, string(reqBody)) +} + +func BenchmarkDecompress(b *testing.B) { + e := echo.New() + body := `{"name": "echo"}` + gz, _ := gzipString(body) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + + h := Decompress()(func(c echo.Context) error { + c.Response().Write([]byte(body)) // For Content-Type sniffing + return nil + }) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // Decompress + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h(c) + } +} + +func gzipString(body string) ([]byte, error) { + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + + _, err := gz.Write([]byte(body)) + if err != nil { + return nil, err + } + + if err := gz.Close(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} From fdbcc2f94e632b7b54e88f86e879d401bcbbc1a7 Mon Sep 17 00:00:00 2001 From: Ajitem Sahasrabuddhe Date: Fri, 20 Nov 2020 17:11:08 +0530 Subject: [PATCH 065/446] add support for go 1.12 --- .github/workflows/echo.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 1bcb5cf96..13b53db6e 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -27,7 +27,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - go: [1.13, 1.14, 1.15] + go: [1.12, 1.13, 1.14, 1.15] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: From 3a6100bebebc7760fb434947daf4e89b47523420 Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Sat, 21 Nov 2020 02:48:16 +0000 Subject: [PATCH 066/446] Improving routing performance and benchmark suite Before this commit, all the node types were added to the same list of children nodes. Taking in consideration that only one Param and Any type of node could exist per node, two new node struct field were added to hold the references to those kind of nodes. This avoid the need to iterate through all the Static type nodes just to find one Param or Any type node. Those iterations could be performed multiple times in the same iteration of Router#Find. Removing the route comments of the Router benchmark tests. Updating the Router benchmarks tests to find the routes defined to each particular benchmark. Before, all the benchmarks tried to find only the GitHub API. Adding new router benchmarks to measure when the Router try to find routes that are not registered. --- router.go | 113 +++++++++++++---------- router_test.go | 245 ++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 271 insertions(+), 87 deletions(-) diff --git a/router.go b/router.go index ed728d6a2..d045d5cdd 100644 --- a/router.go +++ b/router.go @@ -14,14 +14,16 @@ type ( echo *Echo } node struct { - kind kind - label byte - prefix string - parent *node - children children - ppath string - pnames []string - methodHandler *methodHandler + kind kind + label byte + prefix string + parent *node + staticChildrens children + ppath string + pnames []string + methodHandler *methodHandler + paramChildren *node + anyChildren *node } kind uint8 children []*node @@ -44,6 +46,9 @@ const ( skind kind = iota pkind akind + + paramLabel = byte(':') + anyLabel = byte('*') ) // NewRouter returns a new Router instance. @@ -134,23 +139,32 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string } } else if l < pl { // Split node - n := newNode(cn.kind, cn.prefix[l:], cn, cn.children, cn.methodHandler, cn.ppath, cn.pnames) + n := newNode(cn.kind, cn.prefix[l:], cn, cn.staticChildrens, cn.methodHandler, cn.ppath, cn.pnames, cn.paramChildren, cn.anyChildren) // Update parent path for all children to new node - for _, child := range cn.children { + for _, child := range cn.staticChildrens { child.parent = n } + if cn.paramChildren != nil { + cn.paramChildren.parent = n + } + if cn.anyChildren != nil { + cn.anyChildren.parent = n + } // Reset parent node cn.kind = skind cn.label = cn.prefix[0] cn.prefix = cn.prefix[:l] - cn.children = nil + cn.staticChildrens = nil cn.methodHandler = new(methodHandler) cn.ppath = "" cn.pnames = nil + cn.paramChildren = nil + cn.anyChildren = nil - cn.addChild(n) + // Only Static children could reach here + cn.addStaticChild(n) if l == sl { // At parent node @@ -160,9 +174,10 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string cn.pnames = pnames } else { // Create child node - n = newNode(t, search[l:], cn, nil, new(methodHandler), ppath, pnames) + n = newNode(t, search[l:], cn, nil, new(methodHandler), ppath, pnames, nil, nil) n.addHandler(method, h) - cn.addChild(n) + // Only Static children could reach here + cn.addStaticChild(n) } } else if l < sl { search = search[l:] @@ -173,9 +188,16 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string continue } // Create child node - n := newNode(t, search, cn, nil, new(methodHandler), ppath, pnames) + n := newNode(t, search, cn, nil, new(methodHandler), ppath, pnames, nil, nil) n.addHandler(method, h) - cn.addChild(n) + switch t { + case skind: + cn.addStaticChild(n) + case pkind: + cn.paramChildren = n + case akind: + cn.anyChildren = n + } } else { // Node already exists if h != nil { @@ -190,26 +212,28 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string } } -func newNode(t kind, pre string, p *node, c children, mh *methodHandler, ppath string, pnames []string) *node { +func newNode(t kind, pre string, p *node, sc children, mh *methodHandler, ppath string, pnames []string, paramChildren, anyChildren *node) *node { return &node{ - kind: t, - label: pre[0], - prefix: pre, - parent: p, - children: c, - ppath: ppath, - pnames: pnames, - methodHandler: mh, + kind: t, + label: pre[0], + prefix: pre, + parent: p, + staticChildrens: sc, + ppath: ppath, + pnames: pnames, + methodHandler: mh, + paramChildren: paramChildren, + anyChildren: anyChildren, } } -func (n *node) addChild(c *node) { - n.children = append(n.children, c) +func (n *node) addStaticChild(c *node) { + n.staticChildrens = append(n.staticChildrens, c) } -func (n *node) findChild(l byte, t kind) *node { - for _, c := range n.children { - if c.label == l && c.kind == t { +func (n *node) findStaticChild(l byte) *node { + for _, c := range n.staticChildrens { + if c.label == l { return c } } @@ -217,19 +241,16 @@ func (n *node) findChild(l byte, t kind) *node { } func (n *node) findChildWithLabel(l byte) *node { - for _, c := range n.children { + for _, c := range n.staticChildrens { if c.label == l { return c } } - return nil -} - -func (n *node) findChildByKind(t kind) *node { - for _, c := range n.children { - if c.kind == t { - return c - } + if l == paramLabel { + return n.paramChildren + } + if l == anyLabel { + return n.anyChildren } return nil } @@ -356,7 +377,7 @@ func (r *Router) Find(method, path string, c Context) { // Attempt to go back up the tree on no matching prefix or no remaining search if l != pl || search == "" { // Handle special case of trailing slash route with existing any route (see #1526) - if path[len(path)-1] == '/' && cn.findChildByKind(akind) != nil { + if path[len(path)-1] == '/' && cn.anyChildren != nil { goto Any } if nn == nil { // Issue #1348 @@ -372,7 +393,7 @@ func (r *Router) Find(method, path string, c Context) { } // Static node - if child = cn.findChild(search[0], skind); child != nil { + if child = cn.findStaticChild(search[0]); child != nil { // Save next if cn.prefix[len(cn.prefix)-1] == '/' { // Issue #623 nk = pkind @@ -385,7 +406,7 @@ func (r *Router) Find(method, path string, c Context) { Param: // Param node - if child = cn.findChildByKind(pkind); child != nil { + if child = cn.paramChildren; child != nil { // Issue #378 if len(pvalues) == n { continue @@ -410,7 +431,7 @@ func (r *Router) Find(method, path string, c Context) { Any: // Any node - if cn = cn.findChildByKind(akind); cn != nil { + if cn = cn.anyChildren; cn != nil { // If any node is found, use remaining path for pvalues pvalues[len(cn.pnames)-1] = search break @@ -424,7 +445,7 @@ func (r *Router) Find(method, path string, c Context) { search = ns np := nn.parent // Consider param route one level up only - if cn = nn.findChildByKind(pkind); cn != nil { + if cn = nn.paramChildren; cn != nil { pos := strings.IndexByte(ns, '/') if pos == -1 { // If no slash is remaining in search string set param value @@ -441,7 +462,7 @@ func (r *Router) Find(method, path string, c Context) { // No param route found, try to resolve nearest any route for { np = nn.parent - if cn = nn.findChildByKind(akind); cn != nil { + if cn = nn.anyChildren; cn != nil { break } if np == nil { @@ -472,7 +493,7 @@ func (r *Router) Find(method, path string, c Context) { // Dig further for any, might have an empty value for *, e.g. // serving a directory. Issue #207. - if cn = cn.findChildByKind(akind); cn == nil { + if cn = cn.anyChildren; cn == nil { return } if h := cn.findHandler(method); h != nil { diff --git a/router_test.go b/router_test.go index 0e883233b..ac34547fb 100644 --- a/router_test.go +++ b/router_test.go @@ -175,8 +175,10 @@ var ( {"GET", "/authorizations", ""}, {"GET", "/authorizations/:id", ""}, {"POST", "/authorizations", ""}, - //{"PUT", "/authorizations/clients/:client_id", ""}, - //{"PATCH", "/authorizations/:id", ""}, + + {"PUT", "/authorizations/clients/:client_id", ""}, + {"PATCH", "/authorizations/:id", ""}, + {"DELETE", "/authorizations/:id", ""}, {"GET", "/applications/:client_id/tokens/:access_token", ""}, {"DELETE", "/applications/:client_id/tokens", ""}, @@ -198,7 +200,9 @@ var ( {"PUT", "/notifications", ""}, {"PUT", "/repos/:owner/:repo/notifications", ""}, {"GET", "/notifications/threads/:id", ""}, - //{"PATCH", "/notifications/threads/:id", ""}, + + {"PATCH", "/notifications/threads/:id", ""}, + {"GET", "/notifications/threads/:id/subscription", ""}, {"PUT", "/notifications/threads/:id/subscription", ""}, {"DELETE", "/notifications/threads/:id/subscription", ""}, @@ -221,11 +225,15 @@ var ( // Gists {"GET", "/users/:user/gists", ""}, {"GET", "/gists", ""}, - //{"GET", "/gists/public", ""}, - //{"GET", "/gists/starred", ""}, + + {"GET", "/gists/public", ""}, + {"GET", "/gists/starred", ""}, + {"GET", "/gists/:id", ""}, {"POST", "/gists", ""}, - //{"PATCH", "/gists/:id", ""}, + + {"PATCH", "/gists/:id", ""}, + {"PUT", "/gists/:id/star", ""}, {"DELETE", "/gists/:id/star", ""}, {"GET", "/gists/:id/star", ""}, @@ -237,11 +245,15 @@ var ( {"POST", "/repos/:owner/:repo/git/blobs", ""}, {"GET", "/repos/:owner/:repo/git/commits/:sha", ""}, {"POST", "/repos/:owner/:repo/git/commits", ""}, - //{"GET", "/repos/:owner/:repo/git/refs/*ref", ""}, + + {"GET", "/repos/:owner/:repo/git/refs/*ref", ""}, + {"GET", "/repos/:owner/:repo/git/refs", ""}, {"POST", "/repos/:owner/:repo/git/refs", ""}, - //{"PATCH", "/repos/:owner/:repo/git/refs/*ref", ""}, - //{"DELETE", "/repos/:owner/:repo/git/refs/*ref", ""}, + + {"PATCH", "/repos/:owner/:repo/git/refs/*ref", ""}, + {"DELETE", "/repos/:owner/:repo/git/refs/*ref", ""}, + {"GET", "/repos/:owner/:repo/git/tags/:sha", ""}, {"POST", "/repos/:owner/:repo/git/tags", ""}, {"GET", "/repos/:owner/:repo/git/trees/:sha", ""}, @@ -254,22 +266,32 @@ var ( {"GET", "/repos/:owner/:repo/issues", ""}, {"GET", "/repos/:owner/:repo/issues/:number", ""}, {"POST", "/repos/:owner/:repo/issues", ""}, - //{"PATCH", "/repos/:owner/:repo/issues/:number", ""}, + + {"PATCH", "/repos/:owner/:repo/issues/:number", ""}, + {"GET", "/repos/:owner/:repo/assignees", ""}, {"GET", "/repos/:owner/:repo/assignees/:assignee", ""}, {"GET", "/repos/:owner/:repo/issues/:number/comments", ""}, - //{"GET", "/repos/:owner/:repo/issues/comments", ""}, - //{"GET", "/repos/:owner/:repo/issues/comments/:id", ""}, + + {"GET", "/repos/:owner/:repo/issues/comments", ""}, + {"GET", "/repos/:owner/:repo/issues/comments/:id", ""}, + {"POST", "/repos/:owner/:repo/issues/:number/comments", ""}, - //{"PATCH", "/repos/:owner/:repo/issues/comments/:id", ""}, - //{"DELETE", "/repos/:owner/:repo/issues/comments/:id", ""}, + + {"PATCH", "/repos/:owner/:repo/issues/comments/:id", ""}, + {"DELETE", "/repos/:owner/:repo/issues/comments/:id", ""}, + {"GET", "/repos/:owner/:repo/issues/:number/events", ""}, - //{"GET", "/repos/:owner/:repo/issues/events", ""}, - //{"GET", "/repos/:owner/:repo/issues/events/:id", ""}, + + {"GET", "/repos/:owner/:repo/issues/events", ""}, + {"GET", "/repos/:owner/:repo/issues/events/:id", ""}, + {"GET", "/repos/:owner/:repo/labels", ""}, {"GET", "/repos/:owner/:repo/labels/:name", ""}, {"POST", "/repos/:owner/:repo/labels", ""}, - //{"PATCH", "/repos/:owner/:repo/labels/:name", ""}, + + {"PATCH", "/repos/:owner/:repo/labels/:name", ""}, + {"DELETE", "/repos/:owner/:repo/labels/:name", ""}, {"GET", "/repos/:owner/:repo/issues/:number/labels", ""}, {"POST", "/repos/:owner/:repo/issues/:number/labels", ""}, @@ -280,7 +302,9 @@ var ( {"GET", "/repos/:owner/:repo/milestones", ""}, {"GET", "/repos/:owner/:repo/milestones/:number", ""}, {"POST", "/repos/:owner/:repo/milestones", ""}, - //{"PATCH", "/repos/:owner/:repo/milestones/:number", ""}, + + {"PATCH", "/repos/:owner/:repo/milestones/:number", ""}, + {"DELETE", "/repos/:owner/:repo/milestones/:number", ""}, // Miscellaneous @@ -296,7 +320,9 @@ var ( {"GET", "/users/:user/orgs", ""}, {"GET", "/user/orgs", ""}, {"GET", "/orgs/:org", ""}, - //{"PATCH", "/orgs/:org", ""}, + + {"PATCH", "/orgs/:org", ""}, + {"GET", "/orgs/:org/members", ""}, {"GET", "/orgs/:org/members/:user", ""}, {"DELETE", "/orgs/:org/members/:user", ""}, @@ -307,7 +333,9 @@ var ( {"GET", "/orgs/:org/teams", ""}, {"GET", "/teams/:id", ""}, {"POST", "/orgs/:org/teams", ""}, - //{"PATCH", "/teams/:id", ""}, + + {"PATCH", "/teams/:id", ""}, + {"DELETE", "/teams/:id", ""}, {"GET", "/teams/:id/members", ""}, {"GET", "/teams/:id/members/:user", ""}, @@ -323,17 +351,22 @@ var ( {"GET", "/repos/:owner/:repo/pulls", ""}, {"GET", "/repos/:owner/:repo/pulls/:number", ""}, {"POST", "/repos/:owner/:repo/pulls", ""}, - //{"PATCH", "/repos/:owner/:repo/pulls/:number", ""}, + + {"PATCH", "/repos/:owner/:repo/pulls/:number", ""}, + {"GET", "/repos/:owner/:repo/pulls/:number/commits", ""}, {"GET", "/repos/:owner/:repo/pulls/:number/files", ""}, {"GET", "/repos/:owner/:repo/pulls/:number/merge", ""}, {"PUT", "/repos/:owner/:repo/pulls/:number/merge", ""}, {"GET", "/repos/:owner/:repo/pulls/:number/comments", ""}, - //{"GET", "/repos/:owner/:repo/pulls/comments", ""}, - //{"GET", "/repos/:owner/:repo/pulls/comments/:number", ""}, + + {"GET", "/repos/:owner/:repo/pulls/comments", ""}, + {"GET", "/repos/:owner/:repo/pulls/comments/:number", ""}, + {"PUT", "/repos/:owner/:repo/pulls/:number/comments", ""}, - //{"PATCH", "/repos/:owner/:repo/pulls/comments/:number", ""}, - //{"DELETE", "/repos/:owner/:repo/pulls/comments/:number", ""}, + + {"PATCH", "/repos/:owner/:repo/pulls/comments/:number", ""}, + {"DELETE", "/repos/:owner/:repo/pulls/comments/:number", ""}, // Repositories {"GET", "/user/repos", ""}, @@ -343,7 +376,9 @@ var ( {"POST", "/user/repos", ""}, {"POST", "/orgs/:org/repos", ""}, {"GET", "/repos/:owner/:repo", ""}, - //{"PATCH", "/repos/:owner/:repo", ""}, + + {"PATCH", "/repos/:owner/:repo", ""}, + {"GET", "/repos/:owner/:repo/contributors", ""}, {"GET", "/repos/:owner/:repo/languages", ""}, {"GET", "/repos/:owner/:repo/teams", ""}, @@ -359,19 +394,26 @@ var ( {"GET", "/repos/:owner/:repo/commits/:sha/comments", ""}, {"POST", "/repos/:owner/:repo/commits/:sha/comments", ""}, {"GET", "/repos/:owner/:repo/comments/:id", ""}, - //{"PATCH", "/repos/:owner/:repo/comments/:id", ""}, + + {"PATCH", "/repos/:owner/:repo/comments/:id", ""}, + {"DELETE", "/repos/:owner/:repo/comments/:id", ""}, {"GET", "/repos/:owner/:repo/commits", ""}, {"GET", "/repos/:owner/:repo/commits/:sha", ""}, {"GET", "/repos/:owner/:repo/readme", ""}, + //{"GET", "/repos/:owner/:repo/contents/*path", ""}, //{"PUT", "/repos/:owner/:repo/contents/*path", ""}, //{"DELETE", "/repos/:owner/:repo/contents/*path", ""}, - //{"GET", "/repos/:owner/:repo/:archive_format/:ref", ""}, + + {"GET", "/repos/:owner/:repo/:archive_format/:ref", ""}, + {"GET", "/repos/:owner/:repo/keys", ""}, {"GET", "/repos/:owner/:repo/keys/:id", ""}, {"POST", "/repos/:owner/:repo/keys", ""}, - //{"PATCH", "/repos/:owner/:repo/keys/:id", ""}, + + {"PATCH", "/repos/:owner/:repo/keys/:id", ""}, + {"DELETE", "/repos/:owner/:repo/keys/:id", ""}, {"GET", "/repos/:owner/:repo/downloads", ""}, {"GET", "/repos/:owner/:repo/downloads/:id", ""}, @@ -381,14 +423,18 @@ var ( {"GET", "/repos/:owner/:repo/hooks", ""}, {"GET", "/repos/:owner/:repo/hooks/:id", ""}, {"POST", "/repos/:owner/:repo/hooks", ""}, - //{"PATCH", "/repos/:owner/:repo/hooks/:id", ""}, + + {"PATCH", "/repos/:owner/:repo/hooks/:id", ""}, + {"POST", "/repos/:owner/:repo/hooks/:id/tests", ""}, {"DELETE", "/repos/:owner/:repo/hooks/:id", ""}, {"POST", "/repos/:owner/:repo/merges", ""}, {"GET", "/repos/:owner/:repo/releases", ""}, {"GET", "/repos/:owner/:repo/releases/:id", ""}, {"POST", "/repos/:owner/:repo/releases", ""}, - //{"PATCH", "/repos/:owner/:repo/releases/:id", ""}, + + {"PATCH", "/repos/:owner/:repo/releases/:id", ""}, + {"DELETE", "/repos/:owner/:repo/releases/:id", ""}, {"GET", "/repos/:owner/:repo/releases/:id/assets", ""}, {"GET", "/repos/:owner/:repo/stats/contributors", ""}, @@ -412,7 +458,9 @@ var ( // Users {"GET", "/users/:user", ""}, {"GET", "/user", ""}, - //{"PATCH", "/user", ""}, + + {"PATCH", "/user", ""}, + {"GET", "/users", ""}, {"GET", "/user/emails", ""}, {"POST", "/user/emails", ""}, @@ -429,7 +477,9 @@ var ( {"GET", "/user/keys", ""}, {"GET", "/user/keys/:id", ""}, {"POST", "/user/keys", ""}, - //{"PATCH", "/user/keys/:id", ""}, + + {"PATCH", "/user/keys/:id", ""}, + {"DELETE", "/user/keys/:id", ""}, } @@ -500,6 +550,88 @@ var ( {"DELETE", "/moments/:id", ""}, } + paramAndAnyAPI = []*Route{ + {"GET", "/root/:first/foo/*", ""}, + {"GET", "/root/:first/:second/*", ""}, + {"GET", "/root/:first/bar/:second/*", ""}, + {"GET", "/root/:first/qux/:second/:third/:fourth", ""}, + {"GET", "/root/:first/qux/:second/:third/:fourth/*", ""}, + {"GET", "/root/*", ""}, + + {"POST", "/root/:first/foo/*", ""}, + {"POST", "/root/:first/:second/*", ""}, + {"POST", "/root/:first/bar/:second/*", ""}, + {"POST", "/root/:first/qux/:second/:third/:fourth", ""}, + {"POST", "/root/:first/qux/:second/:third/:fourth/*", ""}, + {"POST", "/root/*", ""}, + + {"PUT", "/root/:first/foo/*", ""}, + {"PUT", "/root/:first/:second/*", ""}, + {"PUT", "/root/:first/bar/:second/*", ""}, + {"PUT", "/root/:first/qux/:second/:third/:fourth", ""}, + {"PUT", "/root/:first/qux/:second/:third/:fourth/*", ""}, + {"PUT", "/root/*", ""}, + + {"DELETE", "/root/:first/foo/*", ""}, + {"DELETE", "/root/:first/:second/*", ""}, + {"DELETE", "/root/:first/bar/:second/*", ""}, + {"DELETE", "/root/:first/qux/:second/:third/:fourth", ""}, + {"DELETE", "/root/:first/qux/:second/:third/:fourth/*", ""}, + {"DELETE", "/root/*", ""}, + } + + paramAndAnyAPIToFind = []*Route{ + {"GET", "/root/one/foo/after/the/asterisk", ""}, + {"GET", "/root/one/foo/path/after/the/asterisk", ""}, + {"GET", "/root/one/two/path/after/the/asterisk", ""}, + {"GET", "/root/one/bar/two/after/the/asterisk", ""}, + {"GET", "/root/one/qux/two/three/four", ""}, + {"GET", "/root/one/qux/two/three/four/after/the/asterisk", ""}, + + {"POST", "/root/one/foo/after/the/asterisk", ""}, + {"POST", "/root/one/foo/path/after/the/asterisk", ""}, + {"POST", "/root/one/two/path/after/the/asterisk", ""}, + {"POST", "/root/one/bar/two/after/the/asterisk", ""}, + {"POST", "/root/one/qux/two/three/four", ""}, + {"POST", "/root/one/qux/two/three/four/after/the/asterisk", ""}, + + {"PUT", "/root/one/foo/after/the/asterisk", ""}, + {"PUT", "/root/one/foo/path/after/the/asterisk", ""}, + {"PUT", "/root/one/two/path/after/the/asterisk", ""}, + {"PUT", "/root/one/bar/two/after/the/asterisk", ""}, + {"PUT", "/root/one/qux/two/three/four", ""}, + {"PUT", "/root/one/qux/two/three/four/after/the/asterisk", ""}, + + {"DELETE", "/root/one/foo/after/the/asterisk", ""}, + {"DELETE", "/root/one/foo/path/after/the/asterisk", ""}, + {"DELETE", "/root/one/two/path/after/the/asterisk", ""}, + {"DELETE", "/root/one/bar/two/after/the/asterisk", ""}, + {"DELETE", "/root/one/qux/two/three/four", ""}, + {"DELETE", "/root/one/qux/two/three/four/after/the/asterisk", ""}, + } + + missesAPI = []*Route{ + {"GET", "/missOne", ""}, + {"GET", "/miss/two", ""}, + {"GET", "/miss/three/levels", ""}, + {"GET", "/miss/four/levels/nooo", ""}, + + {"POST", "/missOne", ""}, + {"POST", "/miss/two", ""}, + {"POST", "/miss/three/levels", ""}, + {"POST", "/miss/four/levels/nooo", ""}, + + {"PUT", "/missOne", ""}, + {"PUT", "/miss/two", ""}, + {"PUT", "/miss/three/levels", ""}, + {"PUT", "/miss/four/levels/nooo", ""}, + + {"DELETE", "/missOne", ""}, + {"DELETE", "/miss/two", ""}, + {"DELETE", "/miss/three/levels", ""}, + {"DELETE", "/miss/four/levels/nooo", ""}, + } + // handlerHelper created a function that will set a context key for assertion handlerHelper = func(key string, value int) func(c Context) error { return func(c Context) error { @@ -1298,7 +1430,7 @@ func TestRouterParam1466(t *testing.T) { assert.Equal(t, 0, c.response.Status) } -func benchmarkRouterRoutes(b *testing.B, routes []*Route) { +func benchmarkRouterRoutes(b *testing.B, routes []*Route, routesToFind []*Route) { e := New() r := e.router b.ReportAllocs() @@ -1310,9 +1442,12 @@ func benchmarkRouterRoutes(b *testing.B, routes []*Route) { }) } + // Routes adding are performed just once, so it doesn't make sense to see that in the benchmark + b.ResetTimer() + // Find routes for i := 0; i < b.N; i++ { - for _, route := range gitHubAPI { + for _, route := range routesToFind { c := e.pool.Get().(*context) r.Find(route.Method, route.Path, c) e.pool.Put(c) @@ -1321,28 +1456,56 @@ func benchmarkRouterRoutes(b *testing.B, routes []*Route) { } func BenchmarkRouterStaticRoutes(b *testing.B) { - benchmarkRouterRoutes(b, staticRoutes) + benchmarkRouterRoutes(b, staticRoutes, staticRoutes) +} + +func BenchmarkRouterStaticRoutesMisses(b *testing.B) { + benchmarkRouterRoutes(b, staticRoutes, missesAPI) } func BenchmarkRouterGitHubAPI(b *testing.B) { - benchmarkRouterRoutes(b, gitHubAPI) + benchmarkRouterRoutes(b, gitHubAPI, gitHubAPI) +} + +func BenchmarkRouterGitHubAPIMisses(b *testing.B) { + benchmarkRouterRoutes(b, gitHubAPI, missesAPI) } func BenchmarkRouterParseAPI(b *testing.B) { - benchmarkRouterRoutes(b, parseAPI) + benchmarkRouterRoutes(b, parseAPI, parseAPI) +} + +func BenchmarkRouterParseAPIMisses(b *testing.B) { + benchmarkRouterRoutes(b, parseAPI, missesAPI) } func BenchmarkRouterGooglePlusAPI(b *testing.B) { - benchmarkRouterRoutes(b, googlePlusAPI) + benchmarkRouterRoutes(b, googlePlusAPI, googlePlusAPI) +} + +func BenchmarkRouterGooglePlusAPIMisses(b *testing.B) { + benchmarkRouterRoutes(b, googlePlusAPI, missesAPI) +} + +func BenchmarkRouterParamsAndAnyAPI(b *testing.B) { + benchmarkRouterRoutes(b, paramAndAnyAPI, paramAndAnyAPIToFind) } func (n *node) printTree(pfx string, tail bool) { p := prefix(tail, pfx, "└── ", "├── ") fmt.Printf("%s%s, %p: type=%d, parent=%p, handler=%v, pnames=%v\n", p, n.prefix, n, n.kind, n.parent, n.methodHandler, n.pnames) - children := n.children - l := len(children) p = prefix(tail, pfx, " ", "│ ") + + children := n.staticChildrens + l := len(children) + + if n.paramChildren != nil { + n.paramChildren.printTree(p, n.anyChildren == nil && l == 0) + } + if n.anyChildren != nil { + n.anyChildren.printTree(p, l == 0) + } for i := 0; i < l-1; i++ { children[i].printTree(p, false) } From f1a4cb42e486dc3d91b8c4ec3cce52e979c28a29 Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Mon, 23 Nov 2020 00:31:00 -0300 Subject: [PATCH 067/446] Update README.me --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 06f8d9d17..deba54f40 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,8 @@ Lower is better! +The benchmarks above were run on an Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz + ## [Guide](https://echo.labstack.com/guide) ### Installation From 5b9bbbd356e44b2f2d3608ae0669491c38f31897 Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Mon, 23 Nov 2020 05:11:17 +0000 Subject: [PATCH 068/446] Adding GitHub action to compare benchmarks The GitHub action runs all the benchmarks for the target branch, and the compares those values with the benchmarks results for the PR new code. --- .github/workflows/echo.yml | 52 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 1a0f549cf..839530e28 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -62,3 +62,55 @@ jobs: with: token: fail_ci_if_error: false + benchmark: + needs: test + strategy: + matrix: + os: [ubuntu-latest] + go: [1.15] + name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }} + runs-on: ${{ matrix.os }} + steps: + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v1 + with: + go-version: ${{ matrix.go }} + + - name: Set GOPATH and PATH + run: | + echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV + echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH + shell: bash + + - name: Set build variables + run: | + echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV + echo "GO111MODULE=on" >> $GITHUB_ENV + + - name: Checkout Code (Previous) + uses: actions/checkout@v2 + with: + ref: ${{ github.base_ref }} + path: previous + + - name: Checkout Code (New) + uses: actions/checkout@v2 + with: + path: new + + - name: Install Dependencies + run: go get -v golang.org/x/perf/cmd/benchstat + + - name: Run Benchmark (Previous) + run: | + cd previous + go test -run="-" -bench=".*" -count=5 ./... > benchmark.txt + + - name: Run Benchmark (New) + run: | + cd new + go test -run="-" -bench=".*" -count=5 ./... > benchmark.txt + + - name: Run Benchstat + run: | + benchstat previous/benchmark.txt new/benchmark.txt From 5f1aa1bc0730107300de6f6eee0bb71177baa1f3 Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Tue, 24 Nov 2020 03:24:27 +0000 Subject: [PATCH 069/446] Fixing Echo#Reverse for Any type routes Fixes #1690 --- echo.go | 2 +- echo_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/echo.go b/echo.go index 29b88b706..c43998016 100644 --- a/echo.go +++ b/echo.go @@ -572,7 +572,7 @@ func (e *Echo) Reverse(name string, params ...interface{}) string { for _, r := range e.router.routes { if r.Name == name { for i, l := 0, len(r.Path); i < l; i++ { - if r.Path[i] == ':' && n < ln { + if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln { for ; i < l && r.Path[i] != '/'; i++ { } uri.WriteString(fmt.Sprintf("%v", params[n])) diff --git a/echo_test.go b/echo_test.go index 0368dbd7a..71dc1ac07 100644 --- a/echo_test.go +++ b/echo_test.go @@ -277,10 +277,12 @@ func TestEchoURL(t *testing.T) { e := New() static := func(Context) error { return nil } getUser := func(Context) error { return nil } + getAny := func(Context) error { return nil } getFile := func(Context) error { return nil } e.GET("/static/file", static) e.GET("/users/:id", getUser) + e.GET("/documents/*", getAny) g := e.Group("/group") g.GET("/users/:uid/files/:fid", getFile) @@ -289,6 +291,9 @@ func TestEchoURL(t *testing.T) { assert.Equal("/static/file", e.URL(static)) assert.Equal("/users/:id", e.URL(getUser)) assert.Equal("/users/1", e.URL(getUser, "1")) + assert.Equal("/users/1", e.URL(getUser, "1")) + assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt")) + assert.Equal("/documents/*", e.URL(getAny)) assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1")) assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1")) } @@ -652,3 +657,28 @@ func TestEchoShutdown(t *testing.T) { err := <-errCh assert.Equal(t, err.Error(), "http: Server closed") } + +func TestEchoReverse(t *testing.T) { + assert := assert.New(t) + + e := New() + dummyHandler := func(Context) error { return nil } + + e.GET("/static", dummyHandler).Name = "/static" + e.GET("/static/*", dummyHandler).Name = "/static/*" + e.GET("/params/:foo", dummyHandler).Name = "/params/:foo" + e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux" + e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*" + + assert.Equal("/static", e.Reverse("/static")) + assert.Equal("/static", e.Reverse("/static", "missing param")) + assert.Equal("/static/*", e.Reverse("/static/*")) + assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt")) + + assert.Equal("/params/:foo", e.Reverse("/params/:foo")) + assert.Equal("/params/one", e.Reverse("/params/:foo", "one")) + assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux")) + assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one")) + assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two")) + assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) +} From 26ab188922a69f822833a460f364a06f56404513 Mon Sep 17 00:00:00 2001 From: Pierre Rousset Date: Fri, 9 Oct 2020 18:07:29 +0900 Subject: [PATCH 070/446] CORS: add an optional custom function to validate the origin --- middleware/cors.go | 71 +++++++++++++++++++++++++---------------- middleware/cors_test.go | 47 +++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 27 deletions(-) diff --git a/middleware/cors.go b/middleware/cors.go index 07df0e57e..c1e22e4e6 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -19,6 +19,13 @@ type ( // Optional. Default value []string{"*"}. AllowOrigins []string `yaml:"allow_origins"` + // AllowOriginFunc is a custom function to validate the origin. It takes the + // origin as an argument and returns true if allowed or false otherwise. If + // an error is returned, it is returned by the handler. If this option is + // set, AllowOrigins is ignored. + // Optional. + AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` + // AllowMethods defines a list methods allowed when accessing the resource. // This is used in response to a preflight request. // Optional. Default value DefaultCORSConfig.AllowMethods. @@ -113,40 +120,50 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { return c.NoContent(http.StatusNoContent) } - // Check allowed origins - for _, o := range config.AllowOrigins { - if o == "*" && config.AllowCredentials { - allowOrigin = origin - break - } - if o == "*" || o == origin { - allowOrigin = o - break - } - if matchSubdomain(origin, o) { - allowOrigin = origin - break - } - } - - // Check allowed origin patterns - for _, re := range allowOriginPatterns { - if allowOrigin == "" { - didx := strings.Index(origin, "://") - if didx == -1 { - continue + if config.AllowOriginFunc == nil { + // Check allowed origins + for _, o := range config.AllowOrigins { + if o == "*" && config.AllowCredentials { + allowOrigin = origin + break } - domAuth := origin[didx+3:] - // to avoid regex cost by invalid long domain - if len(domAuth) > 253 { + if o == "*" || o == origin { + allowOrigin = o break } - - if match, _ := regexp.MatchString(re, origin); match { + if matchSubdomain(origin, o) { allowOrigin = origin break } } + + // Check allowed origin patterns + for _, re := range allowOriginPatterns { + if allowOrigin == "" { + didx := strings.Index(origin, "://") + if didx == -1 { + continue + } + domAuth := origin[didx+3:] + // to avoid regex cost by invalid long domain + if len(domAuth) > 253 { + break + } + + if match, _ := regexp.MatchString(re, origin); match { + allowOrigin = origin + break + } + } + } + } else { + allowed, err := config.AllowOriginFunc(origin) + if err != nil { + return err + } + if allowed { + allowOrigin = origin + } } // Origin not allowed diff --git a/middleware/cors_test.go b/middleware/cors_test.go index fc34694db..717abe498 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "net/http" "net/http/httptest" "testing" @@ -360,3 +361,49 @@ func TestCorsHeaders(t *testing.T) { } } } + +func Test_allowOriginFunc(t *testing.T) { + returnTrue := func(origin string) (bool, error) { + return true, nil + } + returnFalse := func(origin string) (bool, error) { + return false, nil + } + returnError := func(origin string) (bool, error) { + return true, errors.New("this is a test error") + } + + allowOriginFuncs := []func(origin string) (bool, error){ + returnTrue, + returnFalse, + returnError, + } + + const origin = "http://example.com" + + e := echo.New() + for _, allowOriginFunc := range allowOriginFuncs { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + req.Header.Set(echo.HeaderOrigin, origin) + cors := CORSWithConfig(CORSConfig{ + AllowOriginFunc: allowOriginFunc, + }) + h := cors(echo.NotFoundHandler) + err := h(c) + + expected, expectedErr := allowOriginFunc(origin) + if expectedErr != nil { + assert.Equal(t, expectedErr, err) + assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + continue + } + + if expected { + assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } else { + assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } + } +} From e6f24aa8b1cb1263d462d4d4a1126827c3d8e7f7 Mon Sep 17 00:00:00 2001 From: Pierre Rousset Date: Mon, 16 Nov 2020 12:53:49 +0900 Subject: [PATCH 071/446] Addressed PR feedback --- middleware/cors.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/middleware/cors.go b/middleware/cors.go index c1e22e4e6..d6ef89644 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -120,7 +120,15 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { return c.NoContent(http.StatusNoContent) } - if config.AllowOriginFunc == nil { + if config.AllowOriginFunc != nil { + allowed, err := config.AllowOriginFunc(origin) + if err != nil { + return err + } + if allowed { + allowOrigin = origin + } + } else { // Check allowed origins for _, o := range config.AllowOrigins { if o == "*" && config.AllowCredentials { @@ -156,14 +164,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } } } - } else { - allowed, err := config.AllowOriginFunc(origin) - if err != nil { - return err - } - if allowed { - allowOrigin = origin - } } // Origin not allowed From 14e020bc07c84ed217fe18a07de5821b0e2152af Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Fri, 27 Nov 2020 03:01:04 +0000 Subject: [PATCH 072/446] Adding sync.Pool to Decompress middleware Fixing a http.Request.Body leak on the decompress middleware that were not properly Close Removing the defer on the call to gzip.Reader, because that reader is already exausted after the call to io.Copy --- middleware/compress.go | 4 ++-- middleware/decompress.go | 49 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/middleware/compress.go b/middleware/compress.go index e4f9fc514..6ae197453 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -59,7 +59,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { config.Level = DefaultGzipConfig.Level } - pool := gzipPool(config) + pool := gzipCompressPool(config) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -133,7 +133,7 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { return http.ErrNotSupported } -func gzipPool(config GzipConfig) sync.Pool { +func gzipCompressPool(config GzipConfig) sync.Pool { return sync.Pool{ New: func() interface{} { w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level) diff --git a/middleware/decompress.go b/middleware/decompress.go index 99eaf066d..3785ab0f2 100644 --- a/middleware/decompress.go +++ b/middleware/decompress.go @@ -3,9 +3,12 @@ package middleware import ( "bytes" "compress/gzip" - "github.com/labstack/echo/v4" "io" "io/ioutil" + "net/http" + "sync" + + "github.com/labstack/echo/v4" ) type ( @@ -32,27 +35,63 @@ func Decompress() echo.MiddlewareFunc { //DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { + pool := gzipDecompressPool() return func(c echo.Context) error { if config.Skipper(c) { return next(c) } switch c.Request().Header.Get(echo.HeaderContentEncoding) { case GZIPEncoding: - gr, err := gzip.NewReader(c.Request().Body) - if err != nil { + b := c.Request().Body + + i := pool.Get() + gr, ok := i.(*gzip.Reader) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) + } + + if err := gr.Reset(b); err != nil { + pool.Put(gr) if err == io.EOF { //ignore if body is empty return next(c) } return err } - defer gr.Close() var buf bytes.Buffer io.Copy(&buf, gr) + + gr.Close() + pool.Put(gr) + + b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here + r := ioutil.NopCloser(&buf) - defer r.Close() c.Request().Body = r } return next(c) } } } + +func gzipDecompressPool() sync.Pool { + return sync.Pool{ + New: func() interface{} { + // create with an empty reader (but with GZIP header) + w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed) + if err != nil { + return err + } + + b := new(bytes.Buffer) + w.Reset(b) + w.Flush() + w.Close() + + r, err := gzip.NewReader(bytes.NewReader(b.Bytes())) + if err != nil { + return err + } + return r + }, + } +} From 2386e17b21ae319c0f8044326b2c65c15f0cccd2 Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Sat, 28 Nov 2020 02:03:54 +0000 Subject: [PATCH 073/446] Increasing Decompress Middleware coverage --- middleware/decompress.go | 73 +++++++++++++++++++++++------------ middleware/decompress_test.go | 61 +++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 25 deletions(-) diff --git a/middleware/decompress.go b/middleware/decompress.go index 3785ab0f2..c046359a2 100644 --- a/middleware/decompress.go +++ b/middleware/decompress.go @@ -16,17 +16,55 @@ type ( DecompressConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper + + // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers + GzipDecompressPool Decompressor } ) //GZIPEncoding content-encoding header if set to "gzip", decompress body contents. const GZIPEncoding string = "gzip" +// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers +type Decompressor interface { + gzipDecompressPool() sync.Pool +} + var ( //DefaultDecompressConfig defines the config for decompress middleware - DefaultDecompressConfig = DecompressConfig{Skipper: DefaultSkipper} + DefaultDecompressConfig = DecompressConfig{ + Skipper: DefaultSkipper, + GzipDecompressPool: &DefaultGzipDecompressPool{}, + } ) +// DefaultGzipDecompressPool is the default implementation of Decompressor interface +type DefaultGzipDecompressPool struct { +} + +func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { + return sync.Pool{ + New: func() interface{} { + // create with an empty reader (but with GZIP header) + w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed) + if err != nil { + return err + } + + b := new(bytes.Buffer) + w.Reset(b) + w.Flush() + w.Close() + + r, err := gzip.NewReader(bytes.NewReader(b.Bytes())) + if err != nil { + return err + } + return r + }, + } +} + //Decompress decompresses request body based if content encoding type is set to "gzip" with default config func Decompress() echo.MiddlewareFunc { return DecompressWithConfig(DefaultDecompressConfig) @@ -34,8 +72,16 @@ func Decompress() echo.MiddlewareFunc { //DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultGzipConfig.Skipper + } + if config.GzipDecompressPool == nil { + config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool + } + return func(next echo.HandlerFunc) echo.HandlerFunc { - pool := gzipDecompressPool() + pool := config.GzipDecompressPool.gzipDecompressPool() return func(c echo.Context) error { if config.Skipper(c) { return next(c) @@ -72,26 +118,3 @@ func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { } } } - -func gzipDecompressPool() sync.Pool { - return sync.Pool{ - New: func() interface{} { - // create with an empty reader (but with GZIP header) - w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed) - if err != nil { - return err - } - - b := new(bytes.Buffer) - w.Reset(b) - w.Flush() - w.Close() - - r, err := gzip.NewReader(bytes.NewReader(b.Bytes())) - if err != nil { - return err - } - return r - }, - } -} diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 772c14f6d..51fa6b0f1 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -3,10 +3,12 @@ package middleware import ( "bytes" "compress/gzip" + "errors" "io/ioutil" "net/http" "net/http/httptest" "strings" + "sync" "testing" "github.com/labstack/echo/v4" @@ -43,6 +45,35 @@ func TestDecompress(t *testing.T) { assert.Equal(body, string(b)) } +func TestDecompressDefaultConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + h(c) + + assert := assert.New(t) + assert.Equal("test", rec.Body.String()) + + // Decompress + body := `{"name": "echo"}` + gz, _ := gzipString(body) + req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + h(c) + assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := ioutil.ReadAll(req.Body) + assert.NoError(err) + assert.Equal(body, string(b)) +} + func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { e := echo.New() body := `{"name":"echo"}` @@ -108,6 +139,36 @@ func TestDecompressSkipper(t *testing.T) { assert.Equal(t, body, string(reqBody)) } +type TestDecompressPoolWithError struct { +} + +func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool { + return sync.Pool{ + New: func() interface{} { + return errors.New("pool error") + }, + } +} + +func TestDecompressPoolError(t *testing.T) { + e := echo.New() + e.Use(DecompressWithConfig(DecompressConfig{ + Skipper: DefaultSkipper, + GzipDecompressPool: &TestDecompressPoolWithError{}, + })) + body := `{"name": "echo"}` + req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + reqBody, err := ioutil.ReadAll(c.Request().Body) + assert.NoError(t, err) + assert.Equal(t, body, string(reqBody)) + assert.Equal(t, rec.Code, http.StatusInternalServerError) +} + func BenchmarkDecompress(b *testing.B) { e := echo.New() body := `{"name": "echo"}` From 3206527cfecb5de721f7ab18146052e4d841c736 Mon Sep 17 00:00:00 2001 From: Nenad Lukic Date: Mon, 30 Nov 2020 19:06:00 +0100 Subject: [PATCH 074/446] Adds IgnoreBase parameter to static middleware Adds IgnoreBase parameter to static middleware to support the use case of nested route groups --- _fixture/_fixture/README.md | 1 + middleware/static.go | 15 +++++++++++++++ middleware/static_test.go | 23 +++++++++++++++++++++++ 3 files changed, 39 insertions(+) create mode 100644 _fixture/_fixture/README.md diff --git a/_fixture/_fixture/README.md b/_fixture/_fixture/README.md new file mode 100644 index 000000000..21a785851 --- /dev/null +++ b/_fixture/_fixture/README.md @@ -0,0 +1 @@ +This directory is used for the static middleware test \ No newline at end of file diff --git a/middleware/static.go b/middleware/static.go index bc2087a77..58b7890a4 100644 --- a/middleware/static.go +++ b/middleware/static.go @@ -36,6 +36,12 @@ type ( // Enable directory browsing. // Optional. Default value false. Browse bool `yaml:"browse"` + + // Enable ignoring of the base of the URL path. + // Example: when assigning a static middleware to a non root path group, + // the filesystem path is not doubled + // Optional. Default value false. + IgnoreBase bool `yaml:"ignoreBase"` } ) @@ -163,6 +169,15 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { } name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security + if config.IgnoreBase { + routePath := path.Base(strings.TrimRight(c.Path(), "/*")) + baseURLPath := path.Base(p) + if baseURLPath == routePath { + i := strings.LastIndex(name, routePath) + name = name[:i] + strings.Replace(name[i:], routePath, "", 1) + } + } + fi, err := os.Stat(name) if err != nil { if os.IsNotExist(err) { diff --git a/middleware/static_test.go b/middleware/static_test.go index 0d695d3db..56e93958e 100644 --- a/middleware/static_test.go +++ b/middleware/static_test.go @@ -67,4 +67,27 @@ func TestStatic(t *testing.T) { assert.Equal(http.StatusOK, rec.Code) assert.Contains(rec.Body.String(), "cert.pem") } + + // IgnoreBase + req = httptest.NewRequest(http.MethodGet, "/_fixture", nil) + rec = httptest.NewRecorder() + config.Root = "../_fixture" + config.IgnoreBase = true + static = StaticWithConfig(config) + c.Echo().Group("_fixture", static) + e.ServeHTTP(rec, req) + + assert.Equal(http.StatusOK, rec.Code) + assert.Equal(rec.Header().Get(echo.HeaderContentLength), "122") + + req = httptest.NewRequest(http.MethodGet, "/_fixture", nil) + rec = httptest.NewRecorder() + config.Root = "../_fixture" + config.IgnoreBase = false + static = StaticWithConfig(config) + c.Echo().Group("_fixture", static) + e.ServeHTTP(rec, req) + + assert.Equal(http.StatusOK, rec.Code) + assert.Contains(rec.Body.String(), "..\\_fixture\\_fixture") } From 364b7e6eca5483f127440a215c42936086a12bd3 Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Tue, 1 Dec 2020 04:39:06 +0000 Subject: [PATCH 075/446] Increasing number of benchmarks on GitHub action Now the number of times that the benchmarks are run before being compared is 8 on the GitHub action. --- .github/workflows/echo.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 839530e28..df3a1d70c 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -104,12 +104,12 @@ jobs: - name: Run Benchmark (Previous) run: | cd previous - go test -run="-" -bench=".*" -count=5 ./... > benchmark.txt + go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt - name: Run Benchmark (New) run: | cd new - go test -run="-" -bench=".*" -count=5 ./... > benchmark.txt + go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt - name: Run Benchstat run: | From 99d5a070979d1902c9ccd62f075ed5e04ce225ef Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Tue, 1 Dec 2020 05:00:19 +0000 Subject: [PATCH 076/446] Adding Codecov configuration Adding a 1% of threshold for coverage diffs --- codecov.yml | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 codecov.yml diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 000000000..0fa3a3f18 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,11 @@ +coverage: + status: + project: + default: + threshold: 1% + patch: + default: + threshold: 1% + +comment: + require_changes: true \ No newline at end of file From 61514f1c847d8fe466e2bf7edf4db413e1c884e3 Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Tue, 1 Dec 2020 05:06:31 +0000 Subject: [PATCH 077/446] Changes on codecov.yml will trigger GitHub Actions --- .github/workflows/echo.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 839530e28..b852fae4f 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -9,6 +9,7 @@ on: - 'go.*' - '_fixture/**' - '.github/**' + - 'codecov.yml' pull_request: branches: - master @@ -17,6 +18,7 @@ on: - 'go.*' - '_fixture/**' - '.github/**' + - 'codecov.yml' jobs: test: From 2152e4e87205ec2d403379758525c64692993da4 Mon Sep 17 00:00:00 2001 From: rkfg Date: Tue, 1 Dec 2020 09:51:20 +0300 Subject: [PATCH 078/446] Support form fields in jwt middleware --- middleware/jwt.go | 14 ++++++++++++++ middleware/jwt_test.go | 41 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/middleware/jwt.go b/middleware/jwt.go index bab00c9f8..da00ea56b 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -57,6 +57,7 @@ type ( // - "query:" // - "param:" // - "cookie:" + // - "form:" TokenLookup string // AuthScheme to be used in the Authorization header. @@ -167,6 +168,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { extractor = jwtFromParam(parts[1]) case "cookie": extractor = jwtFromCookie(parts[1]) + case "form": + extractor = jwtFromForm(parts[1]) } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -266,3 +269,14 @@ func jwtFromCookie(name string) jwtExtractor { return cookie.Value, nil } } + +// jwtFromForm returns a `jwtExtractor` that extracts token from the form field. +func jwtFromForm(name string) jwtExtractor { + return func(c echo.Context) (string, error) { + field := c.FormValue(name) + if field == "" { + return "", ErrJWTMissing + } + return field, nil + } +} diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index ce44f9c9c..205721aec 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -3,6 +3,8 @@ package middleware import ( "net/http" "net/http/httptest" + "net/url" + "strings" "testing" "github.com/dgrijalva/jwt-go" @@ -75,6 +77,7 @@ func TestJWT(t *testing.T) { reqURL string // "/" if empty hdrAuth string hdrCookie string // test.Request doesn't provide SetCookie(); use name=val + formValues map[string]string info string }{ { @@ -192,12 +195,48 @@ func TestJWT(t *testing.T) { expErrCode: http.StatusBadRequest, info: "Empty cookie", }, + { + config: JWTConfig{ + SigningKey: validKey, + TokenLookup: "form:jwt", + }, + formValues: map[string]string{"jwt": token}, + info: "Valid form method", + }, + { + config: JWTConfig{ + SigningKey: validKey, + TokenLookup: "form:jwt", + }, + expErrCode: http.StatusUnauthorized, + formValues: map[string]string{"jwt": "invalid"}, + info: "Invalid token with form method", + }, + { + config: JWTConfig{ + SigningKey: validKey, + TokenLookup: "form:jwt", + }, + expErrCode: http.StatusBadRequest, + info: "Empty form field", + }, } { if tc.reqURL == "" { tc.reqURL = "/" } - req := httptest.NewRequest(http.MethodGet, tc.reqURL, nil) + var req *http.Request + if len(tc.formValues) > 0 { + form := url.Values{} + for k, v := range tc.formValues { + form.Set(k, v) + } + req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode())) + req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded") + req.ParseForm() + } else { + req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil) + } res := httptest.NewRecorder() req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) req.Header.Set(echo.HeaderCookie, tc.hdrCookie) From 571661692f99e4e77daeb85d4166b481e3ddf8cb Mon Sep 17 00:00:00 2001 From: Nenad Lukic Date: Tue, 1 Dec 2020 09:03:00 +0100 Subject: [PATCH 079/446] Uses filepath.Join instead of hardcoded separator for static middleware test --- middleware/static_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/middleware/static_test.go b/middleware/static_test.go index 56e93958e..407dd15ce 100644 --- a/middleware/static_test.go +++ b/middleware/static_test.go @@ -3,6 +3,7 @@ package middleware import ( "net/http" "net/http/httptest" + "path/filepath" "testing" "github.com/labstack/echo/v4" @@ -89,5 +90,5 @@ func TestStatic(t *testing.T) { e.ServeHTTP(rec, req) assert.Equal(http.StatusOK, rec.Code) - assert.Contains(rec.Body.String(), "..\\_fixture\\_fixture") + assert.Contains(rec.Body.String(), filepath.Join("..", "_fixture", "_fixture")) } From dc147d9b974021cd79f98caf2f1159a66819533a Mon Sep 17 00:00:00 2001 From: Vadim Sabirov Date: Thu, 3 Dec 2020 10:21:31 +0300 Subject: [PATCH 080/446] Fix #1523 by adding secure cookie if SameSite mode is None --- middleware/csrf.go | 4 ++-- middleware/csrf_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/middleware/csrf.go b/middleware/csrf.go index ec348ce1b..7804997d4 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -110,8 +110,8 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.CookieMaxAge == 0 { config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge } - if config.CookieSameSite == 0 { - config.CookieSameSite = http.SameSiteDefaultMode + if config.CookieSameSite == http.SameSiteNoneMode { + config.CookieSecure = true } // Initialize diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 5a3b49b7e..af1d26394 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -1,6 +1,7 @@ package middleware import ( + "fmt" "net/http" "net/http/httptest" "net/url" @@ -117,3 +118,43 @@ func TestCSRFWithoutSameSiteMode(t *testing.T) { assert.NoError(t, r) assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"]) } + +func TestCSRFWithSameSiteDefaultMode(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + csrf := CSRFWithConfig(CSRFConfig{ + CookieSameSite: http.SameSiteDefaultMode, + }) + + h := csrf(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + r := h(c) + assert.NoError(t, r) + fmt.Println(rec.Header()["Set-Cookie"]) + assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"]) +} + +func TestCSRFWithSameSiteModeNone(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + csrf := CSRFWithConfig(CSRFConfig{ + CookieSameSite: http.SameSiteNoneMode, + }) + + h := csrf(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + r := h(c) + assert.NoError(t, r) + assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"]) + assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"]) +} From 0406abe066d756b85766e3d7eb110a2cbe157c67 Mon Sep 17 00:00:00 2001 From: Rashad Ansari Date: Wed, 2 Dec 2020 15:43:42 +0330 Subject: [PATCH 081/446] Add the ability to change the status code using Response beforeFuncs --- response.go | 4 ++-- response_test.go | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/response.go b/response.go index ca7405c5d..84f7c9e7e 100644 --- a/response.go +++ b/response.go @@ -56,11 +56,11 @@ func (r *Response) WriteHeader(code int) { r.echo.Logger.Warn("response already committed") return } + r.Status = code for _, fn := range r.beforeFuncs { fn() } - r.Status = code - r.Writer.WriteHeader(code) + r.Writer.WriteHeader(r.Status) r.Committed = true } diff --git a/response_test.go b/response_test.go index 7a9c51c66..d95e079f9 100644 --- a/response_test.go +++ b/response_test.go @@ -56,3 +56,19 @@ func TestResponse_Flush(t *testing.T) { res.Flush() assert.True(t, rec.Flushed) } + +func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + res := &Response{echo: e, Writer: rec} + + res.Before(func() { + if 200 < res.Status && res.Status < 300 { + res.Status = 200 + } + }) + + res.WriteHeader(209) + + assert.Equal(t, http.StatusOK, rec.Code) +} From 6caec3032be056516e30440d1f5bdb5514ad8ad5 Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Mon, 7 Dec 2020 11:23:29 +0100 Subject: [PATCH 082/446] Make our stalebot more relaxed * Use newly added `stale` label for marking for auto-closing * Ignore issues marked as bug or enhancement for stale marking * Give more time for auto-closing of stale issue (30d instead of 7d) --- .github/stale.yml | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/.github/stale.yml b/.github/stale.yml index d9f656321..04dd169cd 100644 --- a/.github/stale.yml +++ b/.github/stale.yml @@ -1,17 +1,19 @@ # Number of days of inactivity before an issue becomes stale daysUntilStale: 60 # Number of days of inactivity before a stale issue is closed -daysUntilClose: 7 +daysUntilClose: 30 # Issues with these labels will never be considered stale exemptLabels: - pinned - security + - bug + - enhancement # Label to use when marking an issue as stale -staleLabel: wontfix +staleLabel: stale # Comment to post when marking an issue as stale. Set to `false` to disable markComment: > This issue has been automatically marked as stale because it has not had - recent activity. It will be closed if no further activity occurs. Thank you - for your contributions. + recent activity. It will be closed within a month if no further activity occurs. + Thank you for your contributions. # Comment to post when closing a stale issue. Set to `false` to disable -closeComment: false \ No newline at end of file +closeComment: false From c171855555477efb011521a37594ef3d81721a21 Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Thu, 10 Dec 2020 03:41:25 +0000 Subject: [PATCH 083/446] Reverting changes on go.sum --- go.sum | 2 -- 1 file changed, 2 deletions(-) diff --git a/go.sum b/go.sum index 187f309f9..58c80c831 100644 --- a/go.sum +++ b/go.sum @@ -15,7 +15,6 @@ github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHX github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= @@ -47,7 +46,6 @@ golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e h1:FDhOuMEY4JVRztM/gsbk+IKUQ8kj74bxZrgw87eMMVc= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From bd5810f5b515f115f97413311d405fbe97752dfe Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Thu, 12 Nov 2020 12:28:45 +0200 Subject: [PATCH 084/446] separate methods to bind only query params, path params, request body --- bind.go | 37 ++++++- bind_test.go | 302 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 333 insertions(+), 6 deletions(-) diff --git a/bind.go b/bind.go index f89147435..c7be242b1 100644 --- a/bind.go +++ b/bind.go @@ -30,10 +30,8 @@ type ( } ) -// Bind implements the `Binder#Bind` function. -func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { - req := c.Request() - +// BindPathParams binds path params to bindable object +func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { names := c.ParamNames() values := c.ParamValues() params := map[string][]string{} @@ -43,12 +41,28 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { if err := b.bindData(i, params, "param"); err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } - if err = b.bindData(i, c.QueryParams(), "query"); err != nil { + return nil +} + +// BindQueryParams binds query params to bindable object +func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error { + if err := b.bindData(i, c.QueryParams(), "query"); err != nil { return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } + return nil +} + +// BindBody binds request body contents to bindable object +// NB: then binding forms take note that this implementation uses standard library form parsing +// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm +// See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm +// See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm +func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { + req := c.Request() if req.ContentLength == 0 { return } + ctype := req.Header.Get(HeaderContentType) switch { case strings.HasPrefix(ctype, MIMEApplicationJSON): @@ -80,7 +94,18 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { default: return ErrUnsupportedMediaType } - return + return nil +} + +// Bind implements the `Binder#Bind` function. +func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { + if err := b.BindPathParams(c, i); err != nil { + return err + } + if err = b.BindQueryParams(c, i); err != nil { + return err + } + return b.BindBody(c, i) } func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag string) error { diff --git a/bind_test.go b/bind_test.go index b9fb9de3c..60c2f9e0a 100644 --- a/bind_test.go +++ b/bind_test.go @@ -553,3 +553,305 @@ func testBindError(assert *assert.Assertions, r io.Reader, ctype string, expecte } } } + +func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { + // tests to check binding behaviour when multiple sources path params, query params and request body are in use + // binding is done in steps and one source could overwrite previous source binded data + // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed + + type Node struct { + ID int `json:"id"` + Node string `json:"node"` + } + + var testCases = []struct { + name string + givenURL string + givenContent io.Reader + givenMethod string + whenBindTarget interface{} + whenNoPathParams bool + expect interface{} + expectError string + }{ + { + name: "ok, POST bind to struct with: path param + query param + empty body", + givenMethod: http.MethodPost, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{"id": 1}`), + expect: &Node{ID: 1, Node: "xxx"}, // in current implementation query params has higher priority than path params + }, + { + name: "ok, POST bind to struct with: path param + empty body", + givenMethod: http.MethodPost, + givenURL: "/api/real_node/endpoint", + givenContent: strings.NewReader(`{"id": 1}`), + expect: &Node{ID: 1, Node: "real_node"}, + }, + { + name: "ok, POST bind to struct with path + query + body = body has priority", + givenMethod: http.MethodPost, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), + expect: &Node{ID: 1, Node: "zzz"}, // field value from content has higher priority + }, + { + name: "nok, POST body bind failure", + givenMethod: http.MethodPost, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{`), + expect: &Node{ID: 0, Node: "xxx"}, // query binding has already modified bind target + expectError: "code=400, message=unexpected EOF, internal=unexpected EOF", + }, + { + name: "nok, GET body bind failure - trying to bind json array to struct", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`[{"id": 1}]`), + expect: &Node{ID: 0, Node: "xxx"}, // query binding has already modified bind target + expectError: "code=400, message=Unmarshal type error: expected=echo.Node, got=array, field=, offset=1, internal=json: cannot unmarshal array into Go value of type echo.Node", + }, + { // binding query params interferes with body. b.BindBody() should be used to bind only body to slice + name: "nok, GET query params bind failure - trying to bind json array to slice", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`[{"id": 1}]`), + whenNoPathParams: true, + whenBindTarget: &[]Node{}, + expect: &[]Node{}, + expectError: "code=400, message=binding element must be a struct, internal=binding element must be a struct", + }, + { // binding path params interferes with body. b.BindBody() should be used to bind only body to slice + name: "nok, GET path params bind failure - trying to bind json array to slice", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`[{"id": 1}]`), + whenBindTarget: &[]Node{}, + expect: &[]Node{}, + expectError: "code=400, message=binding element must be a struct, internal=binding element must be a struct", + }, + { + name: "ok, GET body bind json array to slice", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint", + givenContent: strings.NewReader(`[{"id": 1}]`), + whenNoPathParams: true, + whenBindTarget: &[]Node{}, + expect: &[]Node{{ID: 1, Node: ""}}, + expectError: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + // assume route we are testing is "/api/:node/endpoint?some_query_params=here" + req := httptest.NewRequest(tc.givenMethod, tc.givenURL, tc.givenContent) + req.Header.Set(HeaderContentType, MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + if !tc.whenNoPathParams { + c.SetParamNames("node") + c.SetParamValues("real_node") + } + + var bindTarget interface{} + if tc.whenBindTarget != nil { + bindTarget = tc.whenBindTarget + } else { + bindTarget = &Node{} + } + b := new(DefaultBinder) + + err := b.Bind(bindTarget, c) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, bindTarget) + }) + } +} + +func TestDefaultBinder_BindBody(t *testing.T) { + // tests to check binding behaviour when multiple sources path params, query params and request body are in use + // generally when binding from request body - URL and path params are ignored - unless form is being binded. + // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed + + type Node struct { + ID int `json:"id" xml:"id"` + Node string `json:"node" xml:"node"` + } + type Nodes struct { + Nodes []Node `xml:"node" form:"node"` + } + + var testCases = []struct { + name string + givenURL string + givenContent io.Reader + givenMethod string + givenContentType string + whenNoPathParams bool + whenBindTarget interface{} + expect interface{} + expectError string + }{ + { + name: "ok, JSON POST bind to struct with: path + query + empty field in body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationJSON, + givenContent: strings.NewReader(`{"id": 1}`), + expect: &Node{ID: 1, Node: ""}, // path params or query params should not interfere with body + }, + { + name: "ok, JSON POST bind to struct with: path + query + body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationJSON, + givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), + expect: &Node{ID: 1, Node: "zzz"}, // field value from content has higher priority + }, + { + name: "ok, JSON POST body bind json array to slice (has matching path/query params)", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationJSON, + givenContent: strings.NewReader(`[{"id": 1}]`), + whenNoPathParams: true, + whenBindTarget: &[]Node{}, + expect: &[]Node{{ID: 1, Node: ""}}, + expectError: "", + }, + { // rare case as GET is not usually used to send request body + name: "ok, JSON GET bind to struct with: path + query + empty field in body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodGet, + givenContentType: MIMEApplicationJSON, + givenContent: strings.NewReader(`{"id": 1}`), + expect: &Node{ID: 1, Node: ""}, // path params or query params should not interfere with body + }, + { // rare case as GET is not usually used to send request body + name: "ok, JSON GET bind to struct with: path + query + body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodGet, + givenContentType: MIMEApplicationJSON, + givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), + expect: &Node{ID: 1, Node: "zzz"}, // field value from content has higher priority + }, + { + name: "nok, JSON POST body bind failure", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationJSON, + givenContent: strings.NewReader(`{`), + expect: &Node{ID: 0, Node: ""}, + expectError: "code=400, message=unexpected EOF, internal=unexpected EOF", + }, + { + name: "ok, XML POST bind to struct with: path + query + empty body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationXML, + givenContent: strings.NewReader(`1yyy`), + expect: &Node{ID: 1, Node: "yyy"}, + }, + { + name: "ok, XML POST bind array to slice with: path + query + body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationXML, + givenContent: strings.NewReader(`1yyy`), + whenBindTarget: &Nodes{}, + expect: &Nodes{Nodes: []Node{{ID: 1, Node: "yyy"}}}, + }, + { + name: "nok, XML POST bind failure", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationXML, + givenContent: strings.NewReader(`<`), + expect: &Node{ID: 0, Node: ""}, + expectError: "code=400, message=Syntax error: line=1, error=XML syntax error on line 1: unexpected EOF, internal=XML syntax error on line 1: unexpected EOF", + }, + { + name: "ok, FORM POST bind to struct with: path + query + empty body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationForm, + givenContent: strings.NewReader(`id=1&node=yyy`), + expect: &Node{ID: 1, Node: "yyy"}, + }, + { + // NB: form values are taken from BOTH body and query for POST/PUT/PATCH by standard library implementation + // See: https://golang.org/pkg/net/http/#Request.ParseForm + name: "ok, FORM POST bind to struct with: path + query + empty field in body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationForm, + givenContent: strings.NewReader(`id=1`), + expect: &Node{ID: 1, Node: "xxx"}, + }, + { + // NB: form values are taken from query by standard library implementation + // See: https://golang.org/pkg/net/http/#Request.ParseForm + name: "ok, FORM GET bind to struct with: path + query + empty field in body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodGet, + givenContentType: MIMEApplicationForm, + givenContent: strings.NewReader(`id=1`), + expect: &Node{ID: 0, Node: "xxx"}, // 'xxx' is taken from URL and body is not used with GET by implementation + }, + { + name: "nok, unsupported content type", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMETextPlain, + givenContent: strings.NewReader(``), + expect: &Node{ID: 0, Node: ""}, + expectError: "code=415, message=Unsupported Media Type", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + // assume route we are testing is "/api/:node/endpoint?some_query_params=here" + req := httptest.NewRequest(tc.givenMethod, tc.givenURL, tc.givenContent) + switch tc.givenContentType { + case MIMEApplicationXML: + req.Header.Set(HeaderContentType, MIMEApplicationXML) + case MIMEApplicationForm: + req.Header.Set(HeaderContentType, MIMEApplicationForm) + case MIMEApplicationJSON: + req.Header.Set(HeaderContentType, MIMEApplicationJSON) + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + if !tc.whenNoPathParams { + c.SetParamNames("node") + c.SetParamValues("real_node") + } + + var bindTarget interface{} + if tc.whenBindTarget != nil { + bindTarget = tc.whenBindTarget + } else { + bindTarget = &Node{} + } + b := new(DefaultBinder) + + err := b.BindBody(c, bindTarget) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, bindTarget) + }) + } +} From 1c720597bbad6cab3b64c4a6915f0857258ef67e Mon Sep 17 00:00:00 2001 From: iambenkay Date: Mon, 14 Dec 2020 19:06:25 +0100 Subject: [PATCH 085/446] adds test for request id - remain unchanged if provided --- middleware/request_id_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go index 30eecdef9..86eec8c3b 100644 --- a/middleware/request_id_test.go +++ b/middleware/request_id_test.go @@ -31,3 +31,20 @@ func TestRequestID(t *testing.T) { h(c) assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") } + +func TestRequestID_IDNotAltered(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRequestID, "") + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid := RequestIDWithConfig(RequestIDConfig{}) + h := rid(handler) + _ = h(c) + assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "") +} From 1beaf09740e7e3ed4f6bfdd753121c03946a1142 Mon Sep 17 00:00:00 2001 From: little-cui Date: Sun, 13 Dec 2020 21:49:11 +0800 Subject: [PATCH 086/446] Bug Fix: Directory Traversal --- echo.go | 3 +- echo_test.go | 134 +++++++++++++++++++++++++++++++------------ middleware/static.go | 2 +- 3 files changed, 99 insertions(+), 40 deletions(-) diff --git a/echo.go b/echo.go index 381604180..d284ff396 100644 --- a/echo.go +++ b/echo.go @@ -49,7 +49,6 @@ import ( "net/http" "net/url" "os" - "path" "path/filepath" "reflect" "runtime" @@ -486,7 +485,7 @@ func (common) static(prefix, root string, get func(string, HandlerFunc, ...Middl return err } - name := filepath.Join(root, path.Clean("/"+p)) // "/"+ for security + name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security fi, err := os.Stat(name) if err != nil { // The access path does not exist diff --git a/echo_test.go b/echo_test.go index bf3ecfe91..a6071e12a 100644 --- a/echo_test.go +++ b/echo_test.go @@ -60,45 +60,105 @@ func TestEcho(t *testing.T) { } func TestEchoStatic(t *testing.T) { - e := New() - - assert := assert.New(t) - - // OK - e.Static("/images", "_fixture/images") - c, b := request(http.MethodGet, "/images/walle.png", e) - assert.Equal(http.StatusOK, c) - assert.NotEmpty(b) - - // No file - e.Static("/images", "_fixture/scripts") - c, _ = request(http.MethodGet, "/images/bolt.png", e) - assert.Equal(http.StatusNotFound, c) - - // Directory - e.Static("/images", "_fixture/images") - c, _ = request(http.MethodGet, "/images/", e) - assert.Equal(http.StatusNotFound, c) - - // Directory Redirect - e.Static("/", "_fixture") - req := httptest.NewRequest(http.MethodGet, "/folder", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(http.StatusMovedPermanently, rec.Code) - assert.Equal("/folder/", rec.HeaderMap["Location"][0]) - - // Directory with index.html - e.Static("/", "_fixture") - c, r := request(http.MethodGet, "/", e) - assert.Equal(http.StatusOK, c) - assert.Equal(true, strings.HasPrefix(r, "")) + var testCases = []struct { + name string + givenPrefix string + givenRoot string + whenURL string + expectStatus int + expectHeaderLocation string + expectBodyStartsWith string + }{ + { + name: "ok", + givenPrefix: "/images", + givenRoot: "_fixture/images", + whenURL: "/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "No file", + givenPrefix: "/images", + givenRoot: "_fixture/scripts", + whenURL: "/images/bolt.png", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory", + givenPrefix: "/images", + givenRoot: "_fixture/images", + whenURL: "/images/", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory Redirect", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/folder", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory with index.html", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Sub-directory with index.html", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/folder/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "do not allow directory traversal (backslash - windows separator)", + givenPrefix: "/", + givenRoot: "_fixture/", + whenURL: `/..\\middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "do not allow directory traversal (slash - unix separator)", + givenPrefix: "/", + givenRoot: "_fixture/", + whenURL: `/../middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } - // Sub-directory with index.html - c, r = request(http.MethodGet, "/folder/", e) - assert.Equal(http.StatusOK, c) - assert.Equal(true, strings.HasPrefix(r, "")) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Static(tc.givenPrefix, tc.givenRoot) + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expectStatus, rec.Code) + body := rec.Body.String() + if tc.expectBodyStartsWith != "" { + assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) + } else { + assert.Equal(t, "", body) + } + if tc.expectHeaderLocation != "" { + assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) + } else { + _, ok := rec.Result().Header["Location"] + assert.False(t, ok) + } + }) + } } func TestEchoFile(t *testing.T) { diff --git a/middleware/static.go b/middleware/static.go index 58b7890a4..ae79cb5fa 100644 --- a/middleware/static.go +++ b/middleware/static.go @@ -167,7 +167,7 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { if err != nil { return } - name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security + name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security if config.IgnoreBase { routePath := path.Base(strings.TrimRight(c.Path(), "/*")) From 2374af470cba277b8e66bb9bc65c6816f9e9f6c5 Mon Sep 17 00:00:00 2001 From: pwli Date: Wed, 16 Dec 2020 09:37:26 +0800 Subject: [PATCH 087/446] Update echo_test.go fix typo --- echo_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/echo_test.go b/echo_test.go index 82ccad0ce..895ea1825 100644 --- a/echo_test.go +++ b/echo_test.go @@ -99,16 +99,16 @@ func TestEchoStatic(t *testing.T) { givenRoot: "_fixture", whenURL: "/folder", expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/static/", + expectHeaderLocation: "/folder/", expectBodyStartsWith: "", }, { name: "Directory Redirect with non-root path", givenPrefix: "/static", givenRoot: "_fixture", - whenURL: "/folder", + whenURL: "/static", expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/folder/", + expectHeaderLocation: "/static/", expectBodyStartsWith: "", }, { From 628a2df08cc5bf6b8e9036b4c08ccd02bdf53da7 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Thu, 17 Dec 2020 02:01:57 +0200 Subject: [PATCH 088/446] Revert "Add a test" This reverts commit 7a1126fb --- group_test.go | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/group_test.go b/group_test.go index d4a6846f5..c51fd91eb 100644 --- a/group_test.go +++ b/group_test.go @@ -119,37 +119,3 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { assert.Equal(t, "/*", m) } - -func TestMultipleGroupSamePathMiddleware(t *testing.T) { - // Ensure multiple groups with the same path do not clobber previous routes or mixup middlewares - e := New() - m1 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - c.Set("middleware", "m1") - return next(c) - } - } - m2 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - c.Set("middleware", "m2") - return next(c) - } - } - h := func(c Context) error { - return c.String(http.StatusOK, c.Get("middleware").(string)) - } - - g1 := e.Group("/group", m1) - { - g1.GET("", h) - } - g2 := e.Group("/group", m2) - { - g2.GET("/other", h) - } - - _, m := request(http.MethodGet, "/group", e) - assert.Equal(t, "m1", m) - _, m = request(http.MethodGet, "/group/other", e) - assert.Equal(t, "m2", m) -} From 655596b1b9312274c18b9fbbeb61fa51a2c5c39f Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Thu, 17 Dec 2020 02:01:59 +0200 Subject: [PATCH 089/446] Revert "Remove group.Use registering Any routes that break other routes" This reverts commit f72eaa42 --- group.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/group.go b/group.go index d239fb581..426bef9eb 100644 --- a/group.go +++ b/group.go @@ -23,6 +23,10 @@ func (g *Group) Use(middleware ...MiddlewareFunc) { if len(g.middleware) == 0 { return } + // Allow all requests to reach the group as they might get dropped if router + // doesn't find a match, making none of the group middleware process. + g.Any("", NotFoundHandler) + g.Any("/*", NotFoundHandler) } // CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. From 547ca5ca1e49e1dc24f73ab446fc738c2a636301 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Thu, 17 Dec 2020 02:20:26 +0200 Subject: [PATCH 090/446] reverts #1671 changes --- echo.go | 1 - echo_test.go | 43 ------------------------------------------- 2 files changed, 44 deletions(-) diff --git a/echo.go b/echo.go index 4b0c785a5..d284ff396 100644 --- a/echo.go +++ b/echo.go @@ -500,7 +500,6 @@ func (common) static(prefix, root string, get func(string, HandlerFunc, ...Middl } return c.File(name) } - get(prefix, h) if prefix == "/" { return get(prefix+"*", h) } diff --git a/echo_test.go b/echo_test.go index 895ea1825..a6071e12a 100644 --- a/echo_test.go +++ b/echo_test.go @@ -102,15 +102,6 @@ func TestEchoStatic(t *testing.T) { expectHeaderLocation: "/folder/", expectBodyStartsWith: "", }, - { - name: "Directory Redirect with non-root path", - givenPrefix: "/static", - givenRoot: "_fixture", - whenURL: "/static", - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/static/", - expectBodyStartsWith: "", - }, { name: "Directory with index.html", givenPrefix: "/", @@ -170,40 +161,6 @@ func TestEchoStatic(t *testing.T) { } } -func TestEchoStaticRedirectIndex(t *testing.T) { - assert := assert.New(t) - e := New() - - // HandlerFunc - e.Static("/static", "_fixture") - - errCh := make(chan error) - - go func() { - errCh <- e.Start("127.0.0.1:1323") - }() - - time.Sleep(200 * time.Millisecond) - - if resp, err := http.Get("http://127.0.0.1:1323/static"); err == nil { - defer resp.Body.Close() - assert.Equal(http.StatusOK, resp.StatusCode) - - if body, err := ioutil.ReadAll(resp.Body); err == nil { - assert.Equal(true, strings.HasPrefix(string(body), "")) - } else { - assert.Fail(err.Error()) - } - - } else { - assert.Fail(err.Error()) - } - - if err := e.Close(); err != nil { - t.Fatal(err) - } -} - func TestEchoFile(t *testing.T) { e := New() e.File("/walle", "_fixture/images/walle.png") From e4fe8c836751596f048afb85ee6cd60a0ccf0811 Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Thu, 17 Dec 2020 02:07:41 +0100 Subject: [PATCH 091/446] Fix failing tests on systems not supporting IPv6 --- echo_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/echo_test.go b/echo_test.go index a6071e12a..29edca107 100644 --- a/echo_test.go +++ b/echo_test.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io/ioutil" + "net" "net/http" "net/http/httptest" "reflect" @@ -730,8 +731,24 @@ var listenerNetworkTests = []struct { {"tcp6 ipv6 address", "tcp6", "[::1]:1323"}, } +func supportsIPv6() bool { + addrs, _ := net.InterfaceAddrs() + for _, addr := range addrs { + // Check if any interface has local IPv6 assigned + if strings.Contains(addr.String(), "::1") { + return true + } + } + return false +} + func TestEchoListenerNetwork(t *testing.T) { + hasIPv6 := supportsIPv6() for _, tt := range listenerNetworkTests { + if !hasIPv6 && strings.Contains(tt.address, "::") { + t.Skip("Skipping testing IPv6 for " + tt.address + ", not available") + continue + } t.Run(tt.test, func(t *testing.T) { e := New() e.ListenerNetwork = tt.network From 4d626c210d3946814a30d545adf9b8f2296686a7 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 20 Dec 2020 11:05:42 +0200 Subject: [PATCH 092/446] c.Bind() uses query params only for GET or DELETE methods. This restores pre v.4.1.11 behavior. --- bind.go | 12 ++++++-- bind_test.go | 80 +++++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 73 insertions(+), 19 deletions(-) diff --git a/bind.go b/bind.go index c7be242b1..acd2beda2 100644 --- a/bind.go +++ b/bind.go @@ -98,12 +98,20 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { } // Bind implements the `Binder#Bind` function. +// Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous +// step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams. func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { if err := b.BindPathParams(c, i); err != nil { return err } - if err = b.BindQueryParams(c, i); err != nil { - return err + // Issue #1670 - Query params are binded only for GET/DELETE and NOT for usual request with body (POST/PUT/PATCH) + // Reasoning here is that parameters in query and bind destination struct could have UNEXPECTED matches and results due that. + // i.e. is `&id=1&lang=en` from URL same as `{"id":100,"lang":"de"}` request body and which one should have priority when binding. + // This HTTP method check restores pre v4.1.11 behavior and avoids different problems when query is mixed with body + if c.Request().Method == http.MethodGet || c.Request().Method == http.MethodDelete { + if err = b.BindQueryParams(c, i); err != nil { + return err + } } return b.BindBody(c, i) } diff --git a/bind_test.go b/bind_test.go index 60c2f9e0a..345fbdf10 100644 --- a/bind_test.go +++ b/bind_test.go @@ -559,7 +559,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { // binding is done in steps and one source could overwrite previous source binded data // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed - type Node struct { + type Opts struct { ID int `json:"id"` Node string `json:"node"` } @@ -575,41 +575,77 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { expectError string }{ { - name: "ok, POST bind to struct with: path param + query param + empty body", + name: "ok, POST bind to struct with: path param + query param + body", givenMethod: http.MethodPost, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{"id": 1}`), - expect: &Node{ID: 1, Node: "xxx"}, // in current implementation query params has higher priority than path params + expect: &Opts{ID: 1, Node: "node_from_path"}, // query params are not used, node is filled from path + }, + { + name: "ok, PUT bind to struct with: path param + query param + body", + givenMethod: http.MethodPut, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{"id": 1}`), + expect: &Opts{ID: 1, Node: "node_from_path"}, // query params are not used + }, + { + name: "ok, GET bind to struct with: path param + query param + body", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{"id": 1}`), + expect: &Opts{ID: 1, Node: "xxx"}, // query overwrites previous path value + }, + { + name: "ok, GET bind to struct with: path param + query param + body", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), + expect: &Opts{ID: 1, Node: "zzz"}, // body is binded last and overwrites previous (path,query) values + }, + { + name: "ok, DELETE bind to struct with: path param + query param + body", + givenMethod: http.MethodDelete, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), + expect: &Opts{ID: 1, Node: "zzz"}, // for DELETE body is binded after query params }, { - name: "ok, POST bind to struct with: path param + empty body", + name: "ok, POST bind to struct with: path param + body", givenMethod: http.MethodPost, givenURL: "/api/real_node/endpoint", givenContent: strings.NewReader(`{"id": 1}`), - expect: &Node{ID: 1, Node: "real_node"}, + expect: &Opts{ID: 1, Node: "node_from_path"}, }, { name: "ok, POST bind to struct with path + query + body = body has priority", givenMethod: http.MethodPost, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), - expect: &Node{ID: 1, Node: "zzz"}, // field value from content has higher priority + expect: &Opts{ID: 1, Node: "zzz"}, // field value from content has higher priority }, { name: "nok, POST body bind failure", givenMethod: http.MethodPost, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{`), - expect: &Node{ID: 0, Node: "xxx"}, // query binding has already modified bind target + expect: &Opts{ID: 0, Node: "node_from_path"}, // query binding has already modified bind target expectError: "code=400, message=unexpected EOF, internal=unexpected EOF", }, + { + name: "nok, GET with body bind failure when types are not convertible", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint?id=nope", + givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), + expect: &Opts{ID: 0, Node: "node_from_path"}, // path params binding has already modified bind target + expectError: "code=400, message=strconv.ParseInt: parsing \"nope\": invalid syntax, internal=strconv.ParseInt: parsing \"nope\": invalid syntax", + }, { name: "nok, GET body bind failure - trying to bind json array to struct", givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), - expect: &Node{ID: 0, Node: "xxx"}, // query binding has already modified bind target - expectError: "code=400, message=Unmarshal type error: expected=echo.Node, got=array, field=, offset=1, internal=json: cannot unmarshal array into Go value of type echo.Node", + expect: &Opts{ID: 0, Node: "xxx"}, // query binding has already modified bind target + expectError: "code=400, message=Unmarshal type error: expected=echo.Opts, got=array, field=, offset=1, internal=json: cannot unmarshal array into Go value of type echo.Opts", }, { // binding query params interferes with body. b.BindBody() should be used to bind only body to slice name: "nok, GET query params bind failure - trying to bind json array to slice", @@ -617,17 +653,27 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), whenNoPathParams: true, - whenBindTarget: &[]Node{}, - expect: &[]Node{}, + whenBindTarget: &[]Opts{}, + expect: &[]Opts{}, expectError: "code=400, message=binding element must be a struct, internal=binding element must be a struct", }, + { // binding query params interferes with body. b.BindBody() should be used to bind only body to slice + name: "ok, POST binding to slice should not be affected query params types", + givenMethod: http.MethodPost, + givenURL: "/api/real_node/endpoint?id=nope&node=xxx", + givenContent: strings.NewReader(`[{"id": 1}]`), + whenNoPathParams: true, + whenBindTarget: &[]Opts{}, + expect: &[]Opts{{ID: 1}}, + expectError: "", + }, { // binding path params interferes with body. b.BindBody() should be used to bind only body to slice name: "nok, GET path params bind failure - trying to bind json array to slice", givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), - whenBindTarget: &[]Node{}, - expect: &[]Node{}, + whenBindTarget: &[]Opts{}, + expect: &[]Opts{}, expectError: "code=400, message=binding element must be a struct, internal=binding element must be a struct", }, { @@ -636,8 +682,8 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { givenURL: "/api/real_node/endpoint", givenContent: strings.NewReader(`[{"id": 1}]`), whenNoPathParams: true, - whenBindTarget: &[]Node{}, - expect: &[]Node{{ID: 1, Node: ""}}, + whenBindTarget: &[]Opts{}, + expect: &[]Opts{{ID: 1, Node: ""}}, expectError: "", }, } @@ -653,14 +699,14 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { if !tc.whenNoPathParams { c.SetParamNames("node") - c.SetParamValues("real_node") + c.SetParamValues("node_from_path") } var bindTarget interface{} if tc.whenBindTarget != nil { bindTarget = tc.whenBindTarget } else { - bindTarget = &Node{} + bindTarget = &Opts{} } b := new(DefaultBinder) From 65ea019530a59f93f470f43d735fbc285a06db1e Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 20 Dec 2020 11:06:39 +0200 Subject: [PATCH 093/446] makefile targets to help local development/testing --- Makefile | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/Makefile b/Makefile index dfcb6c02b..c369913a6 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,27 @@ +PKG := "github.com/labstack/echo" +PKG_LIST := $(shell go list ${PKG}/...) + tag: @git tag `grep -P '^\tversion = ' echo.go|cut -f2 -d'"'` @git tag|grep -v ^v + +.DEFAULT_GOAL := check +check: lint vet race ## Check project + +init: + @go get -u golang.org/x/lint/golint + +lint: ## Lint the files + @golint -set_exit_status ${PKG_LIST} + +vet: ## Vet the files + @go vet ${PKG_LIST} + +test: ## Run tests + @go test -short ${PKG_LIST} + +race: ## Run tests with data race detector + @go test -race ${PKG_LIST} + +help: ## Display this help screen + @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' From 734e313f711bd06067759bcfcfb2ba73c3a4dde5 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Tue, 29 Dec 2020 11:46:09 +0200 Subject: [PATCH 094/446] refactor Echo server startup to allow data race free access to listener address --- echo.go | 84 +++++++++++-- echo_test.go | 350 ++++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 381 insertions(+), 53 deletions(-) diff --git a/echo.go b/echo.go index d284ff396..4c4e7d478 100644 --- a/echo.go +++ b/echo.go @@ -67,6 +67,9 @@ type ( // Echo is the top-level framework instance. Echo struct { common + // startupMu is mutex to lock Echo instance access during server configuration and startup. Useful for to get + // listener address info (on which interface/port was listener binded) without having data races. + startupMu sync.RWMutex StdLogger *stdLog.Logger colorer *color.Color premiddleware []MiddlewareFunc @@ -643,21 +646,30 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Start starts an HTTP server. func (e *Echo) Start(address string) error { + e.startupMu.Lock() e.Server.Addr = address - return e.StartServer(e.Server) + if err := e.configureServer(e.Server); err != nil { + e.startupMu.Unlock() + return err + } + e.startupMu.Unlock() + return e.serve() } // StartTLS starts an HTTPS server. // If `certFile` or `keyFile` is `string` the values are treated as file paths. // If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is. func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err error) { + e.startupMu.Lock() var cert []byte if cert, err = filepathOrContent(certFile); err != nil { + e.startupMu.Unlock() return } var key []byte if key, err = filepathOrContent(keyFile); err != nil { + e.startupMu.Unlock() return } @@ -665,10 +677,17 @@ func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err erro s.TLSConfig = new(tls.Config) s.TLSConfig.Certificates = make([]tls.Certificate, 1) if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil { + e.startupMu.Unlock() return } - return e.startTLS(address) + e.configureTLS(address) + if err := e.configureServer(s); err != nil { + e.startupMu.Unlock() + return err + } + e.startupMu.Unlock() + return s.Serve(e.TLSListener) } func filepathOrContent(fileOrContent interface{}) (content []byte, err error) { @@ -684,24 +703,41 @@ func filepathOrContent(fileOrContent interface{}) (content []byte, err error) { // StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org. func (e *Echo) StartAutoTLS(address string) error { + e.startupMu.Lock() s := e.TLSServer s.TLSConfig = new(tls.Config) s.TLSConfig.GetCertificate = e.AutoTLSManager.GetCertificate s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, acme.ALPNProto) - return e.startTLS(address) + + e.configureTLS(address) + if err := e.configureServer(s); err != nil { + e.startupMu.Unlock() + return err + } + e.startupMu.Unlock() + return s.Serve(e.TLSListener) } -func (e *Echo) startTLS(address string) error { +func (e *Echo) configureTLS(address string) { s := e.TLSServer s.Addr = address if !e.DisableHTTP2 { s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2") } - return e.StartServer(e.TLSServer) } // StartServer starts a custom http server. func (e *Echo) StartServer(s *http.Server) (err error) { + e.startupMu.Lock() + if err := e.configureServer(s); err != nil { + e.startupMu.Unlock() + return err + } + e.startupMu.Unlock() + return e.serve() +} + +func (e *Echo) configureServer(s *http.Server) (err error) { // Setup e.colorer.SetOutput(e.Logger.Output()) s.ErrorLog = e.StdLogger @@ -724,7 +760,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) { if !e.HidePort { e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) } - return s.Serve(e.Listener) + return nil } if e.TLSListener == nil { l, err := newListener(s.Addr, e.ListenerNetwork) @@ -736,11 +772,39 @@ func (e *Echo) StartServer(s *http.Server) (err error) { if !e.HidePort { e.colorer.Printf("⇨ https server started on %s\n", e.colorer.Green(e.TLSListener.Addr())) } - return s.Serve(e.TLSListener) + return nil +} + +func (e *Echo) serve() error { + if e.TLSListener != nil { + return e.Server.Serve(e.TLSListener) + } + return e.Server.Serve(e.Listener) +} + +// ListenerAddr returns net.Addr for Listener +func (e *Echo) ListenerAddr() net.Addr { + e.startupMu.RLock() + defer e.startupMu.RUnlock() + if e.Listener == nil { + return nil + } + return e.Listener.Addr() +} + +// TLSListenerAddr returns net.Addr for TLSListener +func (e *Echo) TLSListenerAddr() net.Addr { + e.startupMu.RLock() + defer e.startupMu.RUnlock() + if e.TLSListener == nil { + return nil + } + return e.TLSListener.Addr() } // StartH2CServer starts a custom http/2 server with h2c (HTTP/2 Cleartext). func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { + e.startupMu.Lock() // Setup s := e.Server s.Addr = address @@ -758,18 +822,22 @@ func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { if e.Listener == nil { e.Listener, err = newListener(s.Addr, e.ListenerNetwork) if err != nil { + e.startupMu.Unlock() return err } } if !e.HidePort { e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) } + e.startupMu.Unlock() return s.Serve(e.Listener) } // Close immediately stops the server. // It internally calls `http.Server#Close()`. func (e *Echo) Close() error { + e.startupMu.Lock() + defer e.startupMu.Unlock() if err := e.TLSServer.Close(); err != nil { return err } @@ -779,6 +847,8 @@ func (e *Echo) Close() error { // Shutdown stops the server gracefully. // It internally calls `http.Server#Shutdown()`. func (e *Echo) Shutdown(ctx stdContext.Context) error { + e.startupMu.Lock() + defer e.startupMu.Unlock() if err := e.TLSServer.Shutdown(ctx); err != nil { return err } diff --git a/echo_test.go b/echo_test.go index 29edca107..7f3597428 100644 --- a/echo_test.go +++ b/echo_test.go @@ -3,12 +3,14 @@ package echo import ( "bytes" stdContext "context" + "crypto/tls" "errors" "fmt" "io/ioutil" "net" "net/http" "net/http/httptest" + "os" "reflect" "strings" "testing" @@ -485,26 +487,125 @@ func TestEchoContext(t *testing.T) { e.ReleaseContext(c) } -func TestEchoStart(t *testing.T) { - e := New() - go func() { - assert.NoError(t, e.Start(":0")) - }() - time.Sleep(200 * time.Millisecond) +func waitForServerStart(e *Echo, errChan <-chan error, isTLS bool) error { + ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) + defer cancel() + + ticker := time.NewTicker(5 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + var addr net.Addr + if isTLS { + addr = e.TLSListenerAddr() + } else { + addr = e.ListenerAddr() + } + if addr != nil && strings.Contains(addr.String(), ":") { + return nil // was started + } + case err := <-errChan: + if err == http.ErrServerClosed { + return nil + } + return err + } + } } -func TestEchoStartTLS(t *testing.T) { +func TestEchoStart(t *testing.T) { e := New() + errChan := make(chan error) + go func() { - err := e.StartTLS(":0", "_fixture/certs/cert.pem", "_fixture/certs/key.pem") - // Prevent the test to fail after closing the servers - if err != http.ErrServerClosed { - assert.NoError(t, err) + err := e.Start(":0") + if err != nil { + errChan <- err } }() - time.Sleep(200 * time.Millisecond) - e.Close() + err := waitForServerStart(e, errChan, false) + assert.NoError(t, err) + + assert.NoError(t, e.Close()) +} + +func TestEcho_StartTLS(t *testing.T) { + var testCases = []struct { + name string + addr string + certFile string + keyFile string + expectError string + }{ + { + name: "ok", + addr: ":0", + }, + { + name: "nok, invalid certFile", + addr: ":0", + certFile: "not existing", + expectError: "open not existing: no such file or directory", + }, + { + name: "nok, invalid keyFile", + addr: ":0", + keyFile: "not existing", + expectError: "open not existing: no such file or directory", + }, + { + name: "nok, failed to create cert out of certFile and keyFile", + addr: ":0", + keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key + expectError: "tls: found a certificate rather than a key in the PEM for the private key", + }, + { + name: "nok, invalid tls address", + addr: "nope", + expectError: "listen tcp: address nope: missing port in address", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + errChan := make(chan error) + + go func() { + certFile := "_fixture/certs/cert.pem" + if tc.certFile != "" { + certFile = tc.certFile + } + keyFile := "_fixture/certs/key.pem" + if tc.keyFile != "" { + keyFile = tc.keyFile + } + + err := e.StartTLS(tc.addr, certFile, keyFile) + if err != nil { + errChan <- err + } + }() + + err := waitForServerStart(e, errChan, true) + if tc.expectError != "" { + if _, ok := err.(*os.PathError); ok { + assert.Error(t, err) // error messages for unix and windows are different. so test only error type here + } else { + assert.EqualError(t, err, tc.expectError) + } + } else { + assert.NoError(t, err) + } + + assert.NoError(t, e.Close()) + }) + } } func TestEchoStartTLSByteString(t *testing.T) { @@ -557,47 +658,103 @@ func TestEchoStartTLSByteString(t *testing.T) { e := New() e.HideBanner = true + errChan := make(chan error, 0) + go func() { - err := e.StartTLS(":0", test.cert, test.key) - if test.expectedErr != nil { - require.EqualError(t, err, test.expectedErr.Error()) - } else if err != http.ErrServerClosed { // Prevent the test to fail after closing the servers - require.NoError(t, err) - } + errChan <- e.StartTLS(":0", test.cert, test.key) }() - time.Sleep(200 * time.Millisecond) - require.NoError(t, e.Close()) + err := waitForServerStart(e, errChan, true) + if test.expectedErr != nil { + assert.EqualError(t, err, test.expectedErr.Error()) + } else { + assert.NoError(t, err) + } + + assert.NoError(t, e.Close()) }) } } -func TestEchoStartAutoTLS(t *testing.T) { - e := New() - errChan := make(chan error, 0) +func TestEcho_StartAutoTLS(t *testing.T) { + var testCases = []struct { + name string + addr string + expectError string + }{ + { + name: "ok", + addr: ":0", + }, + { + name: "nok, invalid address", + addr: "nope", + expectError: "listen tcp: address nope: missing port in address", + }, + } - go func() { - errChan <- e.StartAutoTLS(":0") - }() - time.Sleep(200 * time.Millisecond) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + errChan := make(chan error, 0) + + go func() { + errChan <- e.StartAutoTLS(tc.addr) + }() + + err := waitForServerStart(e, errChan, true) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } - select { - case err := <-errChan: - assert.NoError(t, err) - default: - assert.NoError(t, e.Close()) + assert.NoError(t, e.Close()) + }) } } -func TestEchoStartH2CServer(t *testing.T) { - e := New() - e.Debug = true - h2s := &http2.Server{} +func TestEcho_StartH2CServer(t *testing.T) { + var testCases = []struct { + name string + addr string + expectError string + }{ + { + name: "ok", + addr: ":0", + }, + { + name: "nok, invalid address", + addr: "nope", + expectError: "listen tcp: address nope: missing port in address", + }, + } - go func() { - assert.NoError(t, e.StartH2CServer(":0", h2s)) - }() - time.Sleep(200 * time.Millisecond) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Debug = true + h2s := &http2.Server{} + + errChan := make(chan error) + go func() { + err := e.StartH2CServer(tc.addr, h2s) + if err != nil { + errChan <- err + } + }() + + err := waitForServerStart(e, errChan, false) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + assert.NoError(t, e.Close()) + }) + } } func testMethod(t *testing.T, method, path string, e *Echo) { @@ -686,7 +843,8 @@ func TestEchoClose(t *testing.T) { errCh <- e.Start(":0") }() - time.Sleep(200 * time.Millisecond) + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) if err := e.Close(); err != nil { t.Fatal(err) @@ -694,7 +852,7 @@ func TestEchoClose(t *testing.T) { assert.NoError(t, e.Close()) - err := <-errCh + err = <-errCh assert.Equal(t, err.Error(), "http: Server closed") } @@ -706,7 +864,8 @@ func TestEchoShutdown(t *testing.T) { errCh <- e.Start(":0") }() - time.Sleep(200 * time.Millisecond) + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) if err := e.Close(); err != nil { t.Fatal(err) @@ -716,7 +875,7 @@ func TestEchoShutdown(t *testing.T) { defer cancel() assert.NoError(t, e.Shutdown(ctx)) - err := <-errCh + err = <-errCh assert.Equal(t, err.Error(), "http: Server closed") } @@ -764,7 +923,8 @@ func TestEchoListenerNetwork(t *testing.T) { errCh <- e.Start(tt.address) }() - time.Sleep(200 * time.Millisecond) + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil { defer resp.Body.Close() @@ -823,3 +983,101 @@ func TestEchoReverse(t *testing.T) { assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two")) assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) } + +func TestEcho_ListenerAddr(t *testing.T) { + e := New() + + addr := e.ListenerAddr() + assert.Nil(t, addr) + + errCh := make(chan error) + go func() { + errCh <- e.Start(":0") + }() + + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) +} + +func TestEcho_TLSListenerAddr(t *testing.T) { + cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") + require.NoError(t, err) + key, err := ioutil.ReadFile("_fixture/certs/key.pem") + require.NoError(t, err) + + e := New() + + addr := e.TLSListenerAddr() + assert.Nil(t, addr) + + errCh := make(chan error) + go func() { + errCh <- e.StartTLS(":0", cert, key) + }() + + err = waitForServerStart(e, errCh, true) + assert.NoError(t, err) +} + +func TestEcho_StartServer(t *testing.T) { + cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") + require.NoError(t, err) + key, err := ioutil.ReadFile("_fixture/certs/key.pem") + require.NoError(t, err) + certs, err := tls.X509KeyPair(cert, key) + require.NoError(t, err) + + var testCases = []struct { + name string + addr string + TLSConfig *tls.Config + expectError string + }{ + { + name: "ok", + addr: ":0", + }, + { + name: "ok, start with TLS", + addr: ":0", + TLSConfig: &tls.Config{Certificates: []tls.Certificate{certs}}, + }, + { + name: "nok, invalid address", + addr: "nope", + expectError: "listen tcp: address nope: missing port in address", + }, + { + name: "nok, invalid tls address", + addr: "nope", + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + expectError: "listen tcp: address nope: missing port in address", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Debug = true + + server := new(http.Server) + server.Addr = tc.addr + if tc.TLSConfig != nil { + server.TLSConfig = tc.TLSConfig + } + + errCh := make(chan error) + go func() { + errCh <- e.StartServer(server) + }() + + err := waitForServerStart(e, errCh, tc.TLSConfig != nil) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + assert.NoError(t, e.Close()) + }) + } +} From d18c0409378bf531fafe7e3a94c8da4a424f4e20 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Tue, 29 Dec 2020 13:25:43 +0200 Subject: [PATCH 095/446] rename mutex --- echo.go | 56 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/echo.go b/echo.go index 4c4e7d478..64e64c2c3 100644 --- a/echo.go +++ b/echo.go @@ -67,9 +67,9 @@ type ( // Echo is the top-level framework instance. Echo struct { common - // startupMu is mutex to lock Echo instance access during server configuration and startup. Useful for to get + // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get // listener address info (on which interface/port was listener binded) without having data races. - startupMu sync.RWMutex + startupMutex sync.RWMutex StdLogger *stdLog.Logger colorer *color.Color premiddleware []MiddlewareFunc @@ -646,13 +646,13 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Start starts an HTTP server. func (e *Echo) Start(address string) error { - e.startupMu.Lock() + e.startupMutex.Lock() e.Server.Addr = address if err := e.configureServer(e.Server); err != nil { - e.startupMu.Unlock() + e.startupMutex.Unlock() return err } - e.startupMu.Unlock() + e.startupMutex.Unlock() return e.serve() } @@ -660,16 +660,16 @@ func (e *Echo) Start(address string) error { // If `certFile` or `keyFile` is `string` the values are treated as file paths. // If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is. func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err error) { - e.startupMu.Lock() + e.startupMutex.Lock() var cert []byte if cert, err = filepathOrContent(certFile); err != nil { - e.startupMu.Unlock() + e.startupMutex.Unlock() return } var key []byte if key, err = filepathOrContent(keyFile); err != nil { - e.startupMu.Unlock() + e.startupMutex.Unlock() return } @@ -677,16 +677,16 @@ func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err erro s.TLSConfig = new(tls.Config) s.TLSConfig.Certificates = make([]tls.Certificate, 1) if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil { - e.startupMu.Unlock() + e.startupMutex.Unlock() return } e.configureTLS(address) if err := e.configureServer(s); err != nil { - e.startupMu.Unlock() + e.startupMutex.Unlock() return err } - e.startupMu.Unlock() + e.startupMutex.Unlock() return s.Serve(e.TLSListener) } @@ -703,7 +703,7 @@ func filepathOrContent(fileOrContent interface{}) (content []byte, err error) { // StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org. func (e *Echo) StartAutoTLS(address string) error { - e.startupMu.Lock() + e.startupMutex.Lock() s := e.TLSServer s.TLSConfig = new(tls.Config) s.TLSConfig.GetCertificate = e.AutoTLSManager.GetCertificate @@ -711,10 +711,10 @@ func (e *Echo) StartAutoTLS(address string) error { e.configureTLS(address) if err := e.configureServer(s); err != nil { - e.startupMu.Unlock() + e.startupMutex.Unlock() return err } - e.startupMu.Unlock() + e.startupMutex.Unlock() return s.Serve(e.TLSListener) } @@ -728,12 +728,12 @@ func (e *Echo) configureTLS(address string) { // StartServer starts a custom http server. func (e *Echo) StartServer(s *http.Server) (err error) { - e.startupMu.Lock() + e.startupMutex.Lock() if err := e.configureServer(s); err != nil { - e.startupMu.Unlock() + e.startupMutex.Unlock() return err } - e.startupMu.Unlock() + e.startupMutex.Unlock() return e.serve() } @@ -784,8 +784,8 @@ func (e *Echo) serve() error { // ListenerAddr returns net.Addr for Listener func (e *Echo) ListenerAddr() net.Addr { - e.startupMu.RLock() - defer e.startupMu.RUnlock() + e.startupMutex.RLock() + defer e.startupMutex.RUnlock() if e.Listener == nil { return nil } @@ -794,8 +794,8 @@ func (e *Echo) ListenerAddr() net.Addr { // TLSListenerAddr returns net.Addr for TLSListener func (e *Echo) TLSListenerAddr() net.Addr { - e.startupMu.RLock() - defer e.startupMu.RUnlock() + e.startupMutex.RLock() + defer e.startupMutex.RUnlock() if e.TLSListener == nil { return nil } @@ -804,7 +804,7 @@ func (e *Echo) TLSListenerAddr() net.Addr { // StartH2CServer starts a custom http/2 server with h2c (HTTP/2 Cleartext). func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { - e.startupMu.Lock() + e.startupMutex.Lock() // Setup s := e.Server s.Addr = address @@ -822,22 +822,22 @@ func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { if e.Listener == nil { e.Listener, err = newListener(s.Addr, e.ListenerNetwork) if err != nil { - e.startupMu.Unlock() + e.startupMutex.Unlock() return err } } if !e.HidePort { e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) } - e.startupMu.Unlock() + e.startupMutex.Unlock() return s.Serve(e.Listener) } // Close immediately stops the server. // It internally calls `http.Server#Close()`. func (e *Echo) Close() error { - e.startupMu.Lock() - defer e.startupMu.Unlock() + e.startupMutex.Lock() + defer e.startupMutex.Unlock() if err := e.TLSServer.Close(); err != nil { return err } @@ -847,8 +847,8 @@ func (e *Echo) Close() error { // Shutdown stops the server gracefully. // It internally calls `http.Server#Shutdown()`. func (e *Echo) Shutdown(ctx stdContext.Context) error { - e.startupMu.Lock() - defer e.startupMu.Unlock() + e.startupMutex.Lock() + defer e.startupMutex.Unlock() if err := e.TLSServer.Shutdown(ctx); err != nil { return err } From 21f77872028d51bbca05017ed830aa2624a922b6 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Tue, 29 Dec 2020 23:45:27 +0200 Subject: [PATCH 096/446] refactor static middleware tests not to use previous case state --- middleware/static_test.go | 193 +++++++++++++++++++++++--------------- 1 file changed, 116 insertions(+), 77 deletions(-) diff --git a/middleware/static_test.go b/middleware/static_test.go index 407dd15ce..3e6ca5601 100644 --- a/middleware/static_test.go +++ b/middleware/static_test.go @@ -1,94 +1,133 @@ package middleware import ( + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" - "path/filepath" "testing" - - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" ) func TestStatic(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - config := StaticConfig{ - Root: "../_fixture", + var testCases = []struct { + name string + givenConfig *StaticConfig + givenAttachedToGroup string + whenURL string + expectContains string + expectLength string + expectCode int + }{ + { + name: "ok, serve index with Echo message", + whenURL: "/", + expectCode: http.StatusOK, + expectContains: "Echo", + }, + { + name: "ok, serve file from subdirectory", + whenURL: "/images/walle.png", + expectCode: http.StatusOK, + expectLength: "219885", + }, + { + name: "ok, when html5 mode serve index for any static file that does not exist", + givenConfig: &StaticConfig{ + Root: "../_fixture", + HTML5: true, + }, + whenURL: "/random", + expectCode: http.StatusOK, + expectContains: "Echo", + }, + { + name: "ok, serve index as directory index listing files directory", + givenConfig: &StaticConfig{ + Root: "../_fixture/certs", + Browse: true, + }, + whenURL: "/", + expectCode: http.StatusOK, + expectContains: "cert.pem", + }, + { + name: "ok, serve directory index with IgnoreBase and browse", + givenConfig: &StaticConfig{ + Root: "../_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored + IgnoreBase: true, + Browse: true, + }, + givenAttachedToGroup: "/_fixture", + whenURL: "/_fixture/", + expectCode: http.StatusOK, + expectContains: `README.md`, + }, + { + name: "ok, serve file with IgnoreBase", + givenConfig: &StaticConfig{ + Root: "../_fixture/_fixture/", // <-- last `_fixture/` is overlapping with group path and needs to be ignored + IgnoreBase: true, + Browse: true, + }, + givenAttachedToGroup: "/_fixture", + whenURL: "/_fixture/README.md", + expectCode: http.StatusOK, + expectContains: "This directory is used for the static middleware test", + }, + { + name: "nok, file not found", + whenURL: "/none", + expectCode: http.StatusNotFound, + expectContains: "{\"message\":\"Not Found\"}\n", + }, + { + name: "nok, do not allow directory traversal (backslash - windows separator)", + whenURL: `/..\\middleware/basic_auth.go`, + expectCode: http.StatusNotFound, + expectContains: "{\"message\":\"Not Found\"}\n", + }, + { + name: "nok,do not allow directory traversal (slash - unix separator)", + whenURL: `/../middleware/basic_auth.go`, + expectCode: http.StatusNotFound, + expectContains: "{\"message\":\"Not Found\"}\n", + }, } - // Directory - h := StaticWithConfig(config)(echo.NotFoundHandler) - - assert := assert.New(t) - - if assert.NoError(h(c)) { - assert.Contains(rec.Body.String(), "Echo") - } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() - // File found - req = httptest.NewRequest(http.MethodGet, "/images/walle.png", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - if assert.NoError(h(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(rec.Header().Get(echo.HeaderContentLength), "219885") - } + config := StaticConfig{Root: "../_fixture"} + if tc.givenConfig != nil { + config = *tc.givenConfig + } + middlewareFunc := StaticWithConfig(config) + if tc.givenAttachedToGroup != "" { + // middleware is attached to group + subGroup := e.Group(tc.givenAttachedToGroup, middlewareFunc) + // group without http handlers (routes) does not do anything. + // Request is matched against http handlers (routes) that have group middleware attached to them + subGroup.GET("", echo.NotFoundHandler) + subGroup.GET("/*", echo.NotFoundHandler) + } else { + // middleware is on root level + e.Use(middlewareFunc) + } - // File not found - req = httptest.NewRequest(http.MethodGet, "/none", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - he := h(c).(*echo.HTTPError) - assert.Equal(http.StatusNotFound, he.Code) + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() - // HTML5 - req = httptest.NewRequest(http.MethodGet, "/random", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - config.HTML5 = true - static := StaticWithConfig(config) - h = static(echo.NotFoundHandler) - if assert.NoError(h(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Contains(rec.Body.String(), "Echo") - } + e.ServeHTTP(rec, req) - // Browse - req = httptest.NewRequest(http.MethodGet, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - config.Root = "../_fixture/certs" - config.Browse = true - static = StaticWithConfig(config) - h = static(echo.NotFoundHandler) - if assert.NoError(h(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Contains(rec.Body.String(), "cert.pem") + assert.Equal(t, tc.expectCode, rec.Code) + if tc.expectContains != "" { + responseBody := rec.Body.String() + assert.Contains(t, responseBody, tc.expectContains) + } + if tc.expectLength != "" { + assert.Equal(t, rec.Header().Get(echo.HeaderContentLength), tc.expectLength) + } + }) } - - // IgnoreBase - req = httptest.NewRequest(http.MethodGet, "/_fixture", nil) - rec = httptest.NewRecorder() - config.Root = "../_fixture" - config.IgnoreBase = true - static = StaticWithConfig(config) - c.Echo().Group("_fixture", static) - e.ServeHTTP(rec, req) - - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(rec.Header().Get(echo.HeaderContentLength), "122") - - req = httptest.NewRequest(http.MethodGet, "/_fixture", nil) - rec = httptest.NewRecorder() - config.Root = "../_fixture" - config.IgnoreBase = false - static = StaticWithConfig(config) - c.Echo().Group("_fixture", static) - e.ServeHTTP(rec, req) - - assert.Equal(http.StatusOK, rec.Code) - assert.Contains(rec.Body.String(), filepath.Join("..", "_fixture", "_fixture")) } From 716eb183296d83a1c90afb1adc961aba0492f4e0 Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Sun, 3 Jan 2021 00:25:29 +0100 Subject: [PATCH 097/446] Handle static routes with trailing slash (#1747) - Fix Static file route not working without trailing slash - Add tests for static middleware with/without trailing slash - Add tests for static middleware under group Co-authored-by: pwli --- echo.go | 11 ++- echo_test.go | 76 +++++++++++++++++++ middleware/static_test.go | 150 +++++++++++++++++++++++++++++++++++++- 3 files changed, 233 insertions(+), 4 deletions(-) diff --git a/echo.go b/echo.go index 64e64c2c3..6db485d10 100644 --- a/echo.go +++ b/echo.go @@ -503,8 +503,15 @@ func (common) static(prefix, root string, get func(string, HandlerFunc, ...Middl } return c.File(name) } - if prefix == "/" { - return get(prefix+"*", h) + // Handle added routes based on trailing slash: + // /prefix => exact route "/prefix" + any route "/prefix/*" + // /prefix/ => only any route "/prefix/*" + if prefix != "" { + if prefix[len(prefix)-1] == '/' { + // Only add any route for intentional trailing slash + return get(prefix+"*", h) + } + get(prefix, h) } return get(prefix+"/*", h) } diff --git a/echo_test.go b/echo_test.go index 7f3597428..781b901fa 100644 --- a/echo_test.go +++ b/echo_test.go @@ -105,6 +105,32 @@ func TestEchoStatic(t *testing.T) { expectHeaderLocation: "/folder/", expectBodyStartsWith: "", }, + { + name: "Directory Redirect with non-root path", + givenPrefix: "/static", + givenRoot: "_fixture", + whenURL: "/static", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/static/", + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory 404 (request URL without slash)", + givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" + givenRoot: "_fixture", + whenURL: "/folder", // no trailing slash + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Prefixed directory redirect (without slash redirect to slash)", + givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* + givenRoot: "_fixture", + whenURL: "/folder", // no trailing slash + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/folder/", + expectBodyStartsWith: "", + }, { name: "Directory with index.html", givenPrefix: "/", @@ -113,6 +139,22 @@ func TestEchoStatic(t *testing.T) { expectStatus: http.StatusOK, expectBodyStartsWith: "", }, + { + name: "Prefixed directory with index.html (prefix ending with slash)", + givenPrefix: "/assets/", + givenRoot: "_fixture", + whenURL: "/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending without slash)", + givenPrefix: "/assets", + givenRoot: "_fixture", + whenURL: "/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, { name: "Sub-directory with index.html", givenPrefix: "/", @@ -164,6 +206,40 @@ func TestEchoStatic(t *testing.T) { } } +func TestEchoStaticRedirectIndex(t *testing.T) { + assert := assert.New(t) + e := New() + + // HandlerFunc + e.Static("/static", "_fixture") + + errCh := make(chan error) + + go func() { + errCh <- e.Start("127.0.0.1:1323") + }() + + time.Sleep(200 * time.Millisecond) + + if resp, err := http.Get("http://127.0.0.1:1323/static"); err == nil { + defer resp.Body.Close() + assert.Equal(http.StatusOK, resp.StatusCode) + + if body, err := ioutil.ReadAll(resp.Body); err == nil { + assert.Equal(true, strings.HasPrefix(string(body), "")) + } else { + assert.Fail(err.Error()) + } + + } else { + assert.Fail(err.Error()) + } + + if err := e.Close(); err != nil { + t.Fatal(err) + } +} + func TestEchoFile(t *testing.T) { e := New() e.File("/walle", "_fixture/images/walle.png") diff --git a/middleware/static_test.go b/middleware/static_test.go index 3e6ca5601..8c0c97ded 100644 --- a/middleware/static_test.go +++ b/middleware/static_test.go @@ -1,11 +1,13 @@ package middleware import ( - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" + "strings" "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" ) func TestStatic(t *testing.T) { @@ -131,3 +133,147 @@ func TestStatic(t *testing.T) { }) } } + +func TestStatic_GroupWithStatic(t *testing.T) { + var testCases = []struct { + name string + givenGroup string + givenPrefix string + givenRoot string + whenURL string + expectStatus int + expectHeaderLocation string + expectBodyStartsWith string + }{ + { + name: "ok", + givenPrefix: "/images", + givenRoot: "../_fixture/images", + whenURL: "/group/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "No file", + givenPrefix: "/images", + givenRoot: "../_fixture/scripts", + whenURL: "/group/images/bolt.png", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory not found (no trailing slash)", + givenPrefix: "/images", + givenRoot: "../_fixture/images", + whenURL: "/group/images/", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory redirect", + givenPrefix: "/", + givenRoot: "../_fixture", + whenURL: "/group/folder", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/group/folder/", + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory 404 (request URL without slash)", + givenGroup: "_fixture", + givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" + givenRoot: "../_fixture", + whenURL: "/_fixture/folder", // no trailing slash + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Prefixed directory redirect (without slash redirect to slash)", + givenGroup: "_fixture", + givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* + givenRoot: "../_fixture", + whenURL: "/_fixture/folder", // no trailing slash + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/_fixture/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory with index.html", + givenPrefix: "/", + givenRoot: "../_fixture", + whenURL: "/group/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending with slash)", + givenPrefix: "/assets/", + givenRoot: "../_fixture", + whenURL: "/group/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending without slash)", + givenPrefix: "/assets", + givenRoot: "../_fixture", + whenURL: "/group/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Sub-directory with index.html", + givenPrefix: "/", + givenRoot: "../_fixture", + whenURL: "/group/folder/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "do not allow directory traversal (backslash - windows separator)", + givenPrefix: "/", + givenRoot: "../_fixture/", + whenURL: `/group/..\\middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "do not allow directory traversal (slash - unix separator)", + givenPrefix: "/", + givenRoot: "../_fixture/", + whenURL: `/group/../middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + group := "/group" + if tc.givenGroup != "" { + group = tc.givenGroup + } + g := e.Group(group) + g.Static(tc.givenPrefix, tc.givenRoot) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expectStatus, rec.Code) + body := rec.Body.String() + if tc.expectBodyStartsWith != "" { + assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) + } else { + assert.Equal(t, "", body) + } + + if tc.expectHeaderLocation != "" { + assert.Equal(t, tc.expectHeaderLocation, rec.Header().Get(echo.HeaderLocation)) + } else { + _, ok := rec.Result().Header[echo.HeaderLocation] + assert.False(t, ok) + } + }) + } +} From 4310e90d588a77833f6a0251b3dccffe8ba20d99 Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Sun, 3 Jan 2021 01:09:18 +0100 Subject: [PATCH 098/446] Support Go 1.12 for http.SameSiteNoneMode --- middleware/csrf_samesite.go | 12 ++++++++++++ middleware/csrf_samesite_1.12.go | 12 ++++++++++++ middleware/csrf_test.go | 2 +- 3 files changed, 25 insertions(+), 1 deletion(-) create mode 100644 middleware/csrf_samesite.go create mode 100644 middleware/csrf_samesite_1.12.go diff --git a/middleware/csrf_samesite.go b/middleware/csrf_samesite.go new file mode 100644 index 000000000..507f9c35f --- /dev/null +++ b/middleware/csrf_samesite.go @@ -0,0 +1,12 @@ +// +build !go1.12 + +package middleware + +import ( + "net/http" +) + +const ( + // SameSiteNoneMode required to be redefined for Go 1.12 support (see #1524) + SameSiteNoneMode http.SameSite = http.SameSiteNoneMode +) diff --git a/middleware/csrf_samesite_1.12.go b/middleware/csrf_samesite_1.12.go new file mode 100644 index 000000000..5e76113a4 --- /dev/null +++ b/middleware/csrf_samesite_1.12.go @@ -0,0 +1,12 @@ +// +build go1.12 + +package middleware + +import ( + "net/http" +) + +const ( + // SameSiteNoneMode required to be redefined for Go 1.12 support (see #1524) + SameSiteNoneMode http.SameSite = 4 +) diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index af1d26394..51fc66e1e 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -146,7 +146,7 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) { c := e.NewContext(req, rec) csrf := CSRFWithConfig(CSRFConfig{ - CookieSameSite: http.SameSiteNoneMode, + CookieSameSite: SameSiteNoneMode, }) h := csrf(func(c echo.Context) error { From c7c792d3bde296e0e4b15a72639d2968ed6b1d57 Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Sun, 3 Jan 2021 01:45:58 +0100 Subject: [PATCH 099/446] Fix CSRF tests for Go 1.12 --- middleware/csrf.go | 2 +- middleware/csrf_samesite.go | 2 +- middleware/csrf_samesite_1.12.go | 2 +- middleware/csrf_samesite_test.go | 33 ++++++++++++++++++++++++++++++++ middleware/csrf_test.go | 20 ------------------- 5 files changed, 36 insertions(+), 23 deletions(-) create mode 100644 middleware/csrf_samesite_test.go diff --git a/middleware/csrf.go b/middleware/csrf.go index 7804997d4..60f809a04 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -110,7 +110,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.CookieMaxAge == 0 { config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge } - if config.CookieSameSite == http.SameSiteNoneMode { + if config.CookieSameSite == SameSiteNoneMode { config.CookieSecure = true } diff --git a/middleware/csrf_samesite.go b/middleware/csrf_samesite.go index 507f9c35f..9a27dc431 100644 --- a/middleware/csrf_samesite.go +++ b/middleware/csrf_samesite.go @@ -1,4 +1,4 @@ -// +build !go1.12 +// +build go1.13 package middleware diff --git a/middleware/csrf_samesite_1.12.go b/middleware/csrf_samesite_1.12.go index 5e76113a4..22076dd6a 100644 --- a/middleware/csrf_samesite_1.12.go +++ b/middleware/csrf_samesite_1.12.go @@ -1,4 +1,4 @@ -// +build go1.12 +// +build !go1.13 package middleware diff --git a/middleware/csrf_samesite_test.go b/middleware/csrf_samesite_test.go new file mode 100644 index 000000000..26c5bc455 --- /dev/null +++ b/middleware/csrf_samesite_test.go @@ -0,0 +1,33 @@ +// +build go1.13 + +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +// Test for SameSiteModeNone moved to separate file for Go 1.12 support +func TestCSRFWithSameSiteModeNone(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + csrf := CSRFWithConfig(CSRFConfig{ + CookieSameSite: SameSiteNoneMode, + }) + + h := csrf(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + r := h(c) + assert.NoError(t, r) + assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"]) + assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"]) +} diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 51fc66e1e..ebe4dbcde 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -138,23 +138,3 @@ func TestCSRFWithSameSiteDefaultMode(t *testing.T) { fmt.Println(rec.Header()["Set-Cookie"]) assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"]) } - -func TestCSRFWithSameSiteModeNone(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - csrf := CSRFWithConfig(CSRFConfig{ - CookieSameSite: SameSiteNoneMode, - }) - - h := csrf(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - - r := h(c) - assert.NoError(t, r) - assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"]) - assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"]) -} From a9df83037d950bf9ea442806501047eaa2b0f00e Mon Sep 17 00:00:00 2001 From: stffabi Date: Sun, 3 Jan 2021 19:35:00 +0100 Subject: [PATCH 100/446] Do not handle special trailing slash case for partial prefix (#1741) * Add tests for issue #1739 * Handle special trailing slash case only for a matching prefix Only handle the special trailing slash case if the whole prefix matches to avoid matching a wrong route for overlapping prefixes, e.g. /users/* for the path /users_prefix/ where the route is only a partial prefix of the requested path. --- router.go | 8 ++++---- router_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/router.go b/router.go index 749dbf4f6..5010659a6 100644 --- a/router.go +++ b/router.go @@ -372,14 +372,14 @@ func (r *Router) Find(method, path string, c Context) { if search == "" && (nn == nil || cn.parent == nil || cn.ppath != "") { break } + // Handle special case of trailing slash route with existing any route (see #1526) + if search == "" && path[len(path)-1] == '/' && cn.anyChildren != nil { + goto Any + } } // Attempt to go back up the tree on no matching prefix or no remaining search if l != pl || search == "" { - // Handle special case of trailing slash route with existing any route (see #1526) - if path[len(path)-1] == '/' && cn.anyChildren != nil { - goto Any - } if nn == nil { // Issue #1348 return // Not found } diff --git a/router_test.go b/router_test.go index aafc622cb..a5e53c05b 100644 --- a/router_test.go +++ b/router_test.go @@ -750,6 +750,47 @@ func TestRouterMatchAny(t *testing.T) { assert.Equal(t, "joe", c.Param("*")) } +// Issue #1739 +func TestRouterMatchAnyPrefixIssue(t *testing.T) { + e := New() + r := e.router + + // Routes + r.Add(http.MethodGet, "/*", func(c Context) error { + c.Set("path", c.Path()) + return nil + }) + r.Add(http.MethodGet, "/users/*", func(c Context) error { + c.Set("path", c.Path()) + return nil + }) + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/", c) + c.handler(c) + assert.Equal(t, "/*", c.Get("path")) + assert.Equal(t, "", c.Param("*")) + + r.Find(http.MethodGet, "/users", c) + c.handler(c) + assert.Equal(t, "/*", c.Get("path")) + assert.Equal(t, "users", c.Param("*")) + + r.Find(http.MethodGet, "/users/", c) + c.handler(c) + assert.Equal(t, "/users/*", c.Get("path")) + assert.Equal(t, "", c.Param("*")) + + r.Find(http.MethodGet, "/users_prefix", c) + c.handler(c) + assert.Equal(t, "/*", c.Get("path")) + assert.Equal(t, "users_prefix", c.Param("*")) + + r.Find(http.MethodGet, "/users_prefix/", c) + c.handler(c) + assert.Equal(t, "/*", c.Get("path")) + assert.Equal(t, "users_prefix/", c.Param("*")) +} + // TestRouterMatchAnySlash shall verify finding the best route // for any routes with trailing slash requests func TestRouterMatchAnySlash(t *testing.T) { From f7180796583e230f720933e106cba692ffdc2d1b Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Sun, 3 Jan 2021 11:09:17 -0800 Subject: [PATCH 101/446] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index deba54f40..4dec531a2 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ [![Build Status](http://img.shields.io/travis/labstack/echo.svg?style=flat-square)](https://travis-ci.org/labstack/echo) [![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo) [![Join the chat at https://gitter.im/labstack/echo](https://img.shields.io/badge/gitter-join%20chat-brightgreen.svg?style=flat-square)](https://gitter.im/labstack/echo) -[![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://forum.labstack.com) +[![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://github.com/labstack/echo/discussions) [![Twitter](https://img.shields.io/badge/twitter-@labstack-55acee.svg?style=flat-square)](https://twitter.com/labstack) [![License](http://img.shields.io/badge/license-mit-blue.svg?style=flat-square)](https://raw.githubusercontent.com/labstack/echo/master/LICENSE) @@ -94,7 +94,7 @@ func hello(c echo.Context) error { ## Help -- [Forum](https://forum.labstack.com) +- [Forum](https://github.com/labstack/echo/discussions) - [Chat](https://gitter.im/labstack/echo) ## Contribute From 02ed3f3126cf847e77764d9c75835709dc60d677 Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 5 Jan 2021 12:04:24 +0200 Subject: [PATCH 102/446] Fix #1729 Binding query/path params and form fields to struct only works for explicit tags (#1734) * Binding query/path params and form fields to struct only works for fields that have explicit TAG defined on struct * remove unnecessary benchmark after change because it is not valid test anymore --- bind.go | 16 +++++----- bind_test.go | 88 +++++++++++++++++++++++++++++++++++----------------- 2 files changed, 69 insertions(+), 35 deletions(-) diff --git a/bind.go b/bind.go index acd2beda2..16c3b7adf 100644 --- a/bind.go +++ b/bind.go @@ -116,12 +116,13 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { return b.BindBody(c, i) } -func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag string) error { - if ptr == nil || len(data) == 0 { +// bindData will bind data ONLY fields in destination struct that have EXPLICIT tag +func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string) error { + if destination == nil || len(data) == 0 { return nil } - typ := reflect.TypeOf(ptr).Elem() - val := reflect.ValueOf(ptr).Elem() + typ := reflect.TypeOf(destination).Elem() + val := reflect.ValueOf(destination).Elem() // Map if typ.Kind() == reflect.Map { @@ -146,14 +147,15 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag inputFieldName := typeField.Tag.Get(tag) if inputFieldName == "" { - inputFieldName = typeField.Name - // If tag is nil, we inspect if the field is a struct. + // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags). + // structs that implement BindUnmarshaler are binded only when they have explicit tag if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil { return err } - continue } + // does not have explicit tag and is not an ordinary struct - so move to next field + continue } inputValue, exists := data[inputFieldName] diff --git a/bind_test.go b/bind_test.go index 345fbdf10..e8868b35b 100644 --- a/bind_test.go +++ b/bind_test.go @@ -160,6 +160,31 @@ var values = map[string][]string{ "ST": {"bar"}, } +func TestToMultipleFields(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?id=1&ID=2", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + type Root struct { + ID int64 `query:"id"` + Child2 struct { + ID int64 + } + Child1 struct { + ID int64 `query:"id"` + } + } + + u := new(Root) + err := c.Bind(u) + if assert.NoError(t, err) { + assert.Equal(t, int64(1), u.ID) // perfectly reasonable + assert.Equal(t, int64(1), u.Child1.ID) // untagged struct containing tagged field gets filled (by tag) + assert.Equal(t, int64(0), u.Child2.ID) // untagged struct containing untagged field should not be bind + } +} + func TestBindJSON(t *testing.T) { assert := assert.New(t) testBindOkay(assert, strings.NewReader(userJSON), MIMEApplicationJSON) @@ -238,10 +263,13 @@ func TestBindUnmarshalParam(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { - T Timestamp `query:"ts"` - TA []Timestamp `query:"ta"` - SA StringArray `query:"sa"` - ST Struct + T Timestamp `query:"ts"` + TA []Timestamp `query:"ta"` + SA StringArray `query:"sa"` + ST Struct + StWithTag struct { + Foo string `query:"st"` + } }{} err := c.Bind(&result) ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)) @@ -252,7 +280,8 @@ func TestBindUnmarshalParam(t *testing.T) { assert.Equal(ts, result.T) assert.Equal(StringArray([]string{"one", "two", "three"}), result.SA) assert.Equal([]Timestamp{ts, ts}, result.TA) - assert.Equal(Struct{"baz"}, result.ST) + assert.Equal(Struct{""}, result.ST) // child struct does not have a field with matching tag + assert.Equal("baz", result.StWithTag.Foo) // child struct has field with matching tag } } @@ -274,7 +303,7 @@ func TestBindUnmarshalText(t *testing.T) { assert.Equal(t, ts, result.T) assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA) assert.Equal(t, []time.Time{ts, ts}, result.TA) - assert.Equal(t, Struct{"baz"}, result.ST) + assert.Equal(t, Struct{""}, result.ST) // field in child struct does not have tag } } @@ -323,11 +352,27 @@ func TestBindUnsupportedMediaType(t *testing.T) { } func TestBindbindData(t *testing.T) { - assert := assert.New(t) + a := assert.New(t) ts := new(bindTestStruct) b := new(DefaultBinder) - b.bindData(ts, values, "form") - assertBindTestStruct(assert, ts) + err := b.bindData(ts, values, "form") + a.NoError(err) + + a.Equal(0, ts.I) + a.Equal(int8(0), ts.I8) + a.Equal(int16(0), ts.I16) + a.Equal(int32(0), ts.I32) + a.Equal(int64(0), ts.I64) + a.Equal(uint(0), ts.UI) + a.Equal(uint8(0), ts.UI8) + a.Equal(uint16(0), ts.UI16) + a.Equal(uint32(0), ts.UI32) + a.Equal(uint64(0), ts.UI64) + a.Equal(false, ts.B) + a.Equal(float32(0), ts.F32) + a.Equal(float64(0), ts.F64) + a.Equal("", ts.S) + a.Equal("", ts.cantSet) } func TestBindParam(t *testing.T) { @@ -470,20 +515,6 @@ func TestBindSetFields(t *testing.T) { } } -func BenchmarkBindbindData(b *testing.B) { - b.ReportAllocs() - assert := assert.New(b) - ts := new(bindTestStruct) - binder := new(DefaultBinder) - var err error - b.ResetTimer() - for i := 0; i < b.N; i++ { - err = binder.bindData(ts, values, "form") - } - assert.NoError(err) - assertBindTestStruct(assert, ts) -} - func BenchmarkBindbindDataWithTags(b *testing.B) { b.ReportAllocs() assert := assert.New(b) @@ -560,8 +591,9 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed type Opts struct { - ID int `json:"id"` - Node string `json:"node"` + ID int `json:"id" form:"id" query:"id"` + Node string `json:"node" form:"node" query:"node" param:"node"` + Lang string } var testCases = []struct { @@ -727,8 +759,8 @@ func TestDefaultBinder_BindBody(t *testing.T) { // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed type Node struct { - ID int `json:"id" xml:"id"` - Node string `json:"node" xml:"node"` + ID int `json:"id" xml:"id" form:"id" query:"id"` + Node string `json:"node" xml:"node" form:"node" query:"node" param:"node"` } type Nodes struct { Nodes []Node `xml:"node" form:"node"` @@ -824,7 +856,7 @@ func TestDefaultBinder_BindBody(t *testing.T) { expectError: "code=400, message=Syntax error: line=1, error=XML syntax error on line 1: unexpected EOF, internal=XML syntax error on line 1: unexpected EOF", }, { - name: "ok, FORM POST bind to struct with: path + query + empty body", + name: "ok, FORM POST bind to struct with: path + query + body", givenURL: "/api/real_node/endpoint?node=xxx", givenMethod: http.MethodPost, givenContentType: MIMEApplicationForm, From 67263b5e456480224b2e4bb1e96b75d6a8ef54a5 Mon Sep 17 00:00:00 2001 From: Ilija Matoski Date: Tue, 5 Jan 2021 11:14:51 +0100 Subject: [PATCH 103/446] Timeout middleware implementation for go1.13+ (#1743) Co-authored-by: Ilija Matoski --- middleware/timeout.go | 81 +++++++++++++++++ middleware/timeout_test.go | 177 +++++++++++++++++++++++++++++++++++++ 2 files changed, 258 insertions(+) create mode 100644 middleware/timeout.go create mode 100644 middleware/timeout_test.go diff --git a/middleware/timeout.go b/middleware/timeout.go new file mode 100644 index 000000000..d146541e6 --- /dev/null +++ b/middleware/timeout.go @@ -0,0 +1,81 @@ +// +build go1.13 + +package middleware + +import ( + "context" + "github.com/labstack/echo/v4" + "time" +) + +type ( + // TimeoutConfig defines the config for Timeout middleware. + TimeoutConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + // ErrorHandler defines a function which is executed for a timeout + // It can be used to define a custom timeout error + ErrorHandler TimeoutErrorHandlerWithContext + // Timeout configures a timeout for the middleware, defaults to 0 for no timeout + Timeout time.Duration + } + + // TimeoutErrorHandlerWithContext is an error handler that is used with the timeout middleware so we can + // handle the error as we see fit + TimeoutErrorHandlerWithContext func(error, echo.Context) error +) + +var ( + // DefaultTimeoutConfig is the default Timeout middleware config. + DefaultTimeoutConfig = TimeoutConfig{ + Skipper: DefaultSkipper, + Timeout: 0, + ErrorHandler: nil, + } +) + +// Timeout returns a middleware which recovers from panics anywhere in the chain +// and handles the control to the centralized HTTPErrorHandler. +func Timeout() echo.MiddlewareFunc { + return TimeoutWithConfig(DefaultTimeoutConfig) +} + +// TimeoutWithConfig returns a Timeout middleware with config. +// See: `Timeout()`. +func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultTimeoutConfig.Skipper + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) || config.Timeout == 0 { + return next(c) + } + + ctx, cancel := context.WithTimeout(c.Request().Context(), config.Timeout) + defer cancel() + + // this does a deep clone of the context, wondering if there is a better way to do this? + c.SetRequest(c.Request().Clone(ctx)) + + done := make(chan error, 1) + go func() { + // This goroutine will keep running even if this middleware times out and + // will be stopped when ctx.Done() is called down the next(c) call chain + done <- next(c) + }() + + select { + case <-ctx.Done(): + if config.ErrorHandler != nil { + return config.ErrorHandler(ctx.Err(), c) + } + return ctx.Err() + case err := <-done: + return err + } + } + } +} diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go new file mode 100644 index 000000000..c0e945933 --- /dev/null +++ b/middleware/timeout_test.go @@ -0,0 +1,177 @@ +// +build go1.13 + +package middleware + +import ( + "context" + "errors" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "strings" + "testing" + "time" +) + +func TestTimeoutSkipper(t *testing.T) { + t.Parallel() + m := TimeoutWithConfig(TimeoutConfig{ + Skipper: func(context echo.Context) bool { + return true + }, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String()) + return nil + })(c) + + assert.NoError(t, err) +} + +func TestTimeoutWithTimeout0(t *testing.T) { + t.Parallel() + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: 0, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String()) + return nil + })(c) + + assert.NoError(t, err) +} + +func TestTimeoutIsCancelable(t *testing.T) { + t.Parallel() + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: time.Minute, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + assert.EqualValues(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String()) + return nil + })(c) + + assert.NoError(t, err) +} + +func TestTimeoutErrorOutInHandler(t *testing.T) { + t.Parallel() + m := Timeout() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + return errors.New("err") + })(c) + + assert.Error(t, err) +} + +func TestTimeoutTimesOutAfterPredefinedTimeoutWithErrorHandler(t *testing.T) { + t.Parallel() + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: time.Second, + ErrorHandler: func(err error, e echo.Context) error { + assert.EqualError(t, err, context.DeadlineExceeded.Error()) + return errors.New("err") + }, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + time.Sleep(time.Minute) + return nil + })(c) + + assert.EqualError(t, err, errors.New("err").Error()) +} + +func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) { + t.Parallel() + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: time.Second, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + time.Sleep(time.Minute) + return nil + })(c) + + assert.EqualError(t, err, context.DeadlineExceeded.Error()) +} + +func TestTimeoutTestRequestClone(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode())) + req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"}) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + + m := TimeoutWithConfig(TimeoutConfig{ + // Timeout has to be defined or the whole flow for timeout middleware will be skipped + Timeout: time.Second, + }) + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + // Cookie test + cookie, err := c.Request().Cookie("cookie") + if assert.NoError(t, err) { + assert.EqualValues(t, "cookie", cookie.Name) + assert.EqualValues(t, "value", cookie.Value) + } + + // Form values + if assert.NoError(t, c.Request().ParseForm()) { + assert.EqualValues(t, "value", c.Request().FormValue("form")) + } + + // Query string + assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0]) + return nil + })(c) + + assert.NoError(t, err) + +} From 9b0e63046b1341c54145293225f614047967091f Mon Sep 17 00:00:00 2001 From: Martti T Date: Fri, 8 Jan 2021 01:43:38 +0200 Subject: [PATCH 104/446] Fluent Binder for Query/Path/Form binding (#1717) (#1736) * Fluent Binder for Query/Path/Form binding. * CI: report coverage for latest go (1.15) version * improve docs, remove uncommented code * separate unixtime with sec and nanosec precision binding --- .github/workflows/echo.yml | 2 +- Makefile | 7 + binder.go | 1234 ++++++++++++++++ binder_external_test.go | 130 ++ binder_go1.15_test.go | 265 ++++ binder_test.go | 2757 ++++++++++++++++++++++++++++++++++++ 6 files changed, 4394 insertions(+), 1 deletion(-) create mode 100644 binder.go create mode 100644 binder_external_test.go create mode 100644 binder_go1.15_test.go create mode 100644 binder_test.go diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 2aec272d5..fb8c50205 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -59,7 +59,7 @@ jobs: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... - name: Upload coverage to Codecov - if: success() && matrix.go == 1.13 && matrix.os == 'ubuntu-latest' + if: success() && matrix.go == 1.15 && matrix.os == 'ubuntu-latest' uses: codecov/codecov-action@v1 with: token: diff --git a/Makefile b/Makefile index c369913a6..bedb8bd25 100644 --- a/Makefile +++ b/Makefile @@ -23,5 +23,12 @@ test: ## Run tests race: ## Run tests with data race detector @go test -race ${PKG_LIST} +benchmark: ## Run benchmarks + @go test -run="-" -bench=".*" ${PKG_LIST} + help: ## Display this help screen @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +goversion ?= "1.12" +test_version: ## Run tests inside Docker with given version (defaults to 1.12 oldest supported). Example: make test_version goversion=1.13 + @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make check" diff --git a/binder.go b/binder.go new file mode 100644 index 000000000..9f0ca654e --- /dev/null +++ b/binder.go @@ -0,0 +1,1234 @@ +package echo + +import ( + "fmt" + "net/http" + "strconv" + "strings" + "time" +) + +/** + Following functions provide handful of methods for binding to Go native types from request query or path parameters. + * QueryParamsBinder(c) - binds query parameters (source URL) + * PathParamsBinder(c) - binds path parameters (source URL) + * FormFieldBinder(c) - binds form fields (source URL + body) + + Example: + ```go + var length int64 + err := echo.QueryParamsBinder(c).Int64("length", &length).BindError() + ``` + + For every supported type there are following methods: + * ("param", &destination) - if parameter value exists then binds it to given destination of that type i.e Int64(...). + * Must("param", &destination) - parameter value is required to exist, binds it to given destination of that type i.e MustInt64(...). + * s("param", &destination) - (for slices) if parameter values exists then binds it to given destination of that type i.e Int64s(...). + * Musts("param", &destination) - (for slices) parameter value is required to exist, binds it to given destination of that type i.e MustInt64s(...). + + for some slice types `BindWithDelimiter("param", &dest, ",")` supports splitting parameter values before type conversion is done + i.e. URL `/api/search?id=1,2,3&id=1` can be bind to `[]int64{1,2,3,1}` + + `FailFast` flags binder to stop binding after first bind error during binder call chain. Enabled by default. + `BindError()` returns first bind error from binder and resets errors in binder. Useful along with `FailFast()` method + to do binding and returns on first problem + `BindErrors()` returns all bind errors from binder and resets errors in binder. + + Types that are supported: + * bool + * float32 + * float64 + * int + * int8 + * int16 + * int32 + * int64 + * uint + * uint8/byte (does not support `bytes()`. Use BindUnmarshaler/CustomFunc to convert value from base64 etc to []byte{}) + * uint16 + * uint32 + * uint64 + * string + * time + * duration + * BindUnmarshaler() interface + * UnixTime() - converts unix time (integer) to time.Time + * UnixTimeNano() - converts unix time with nano second precision (integer) to time.Time + * CustomFunc() - callback function for your custom conversion logic. Signature `func(values []string) []error` +*/ + +// BindingError represents an error that occurred while binding request data. +type BindingError struct { + // Field is the field name where value binding failed + Field string `json:"field"` + // Values of parameter that failed to bind. + Values []string `json:"-"` + *HTTPError +} + +// NewBindingError creates new instance of binding error +func NewBindingError(sourceParam string, values []string, message interface{}, internalError error) error { + return &BindingError{ + Field: sourceParam, + Values: values, + HTTPError: &HTTPError{ + Code: http.StatusBadRequest, + Message: message, + Internal: internalError, + }, + } +} + +// Error returns error message +func (be *BindingError) Error() string { + return fmt.Sprintf("%s, field=%s", be.HTTPError.Error(), be.Field) +} + +// ValueBinder provides utility methods for binding query or path parameter to various Go built-in types +type ValueBinder struct { + // failFast is flag for binding methods to return without attempting to bind when previous binding already failed + failFast bool + errors []error + + // ValueFunc is used to get single parameter (first) value from request + ValueFunc func(sourceParam string) string + // ValuesFunc is used to get all values for parameter from request. i.e. `/api/search?ids=1&ids=2` + ValuesFunc func(sourceParam string) []string + // ErrorFunc is used to create errors. Allows you to use your own error type, that for example marshals to your specific json response + ErrorFunc func(sourceParam string, values []string, message interface{}, internalError error) error +} + +// QueryParamsBinder creates query parameter value binder +func QueryParamsBinder(c Context) *ValueBinder { + return &ValueBinder{ + failFast: true, + ValueFunc: func(sourceParam string) string { + return c.QueryParam(sourceParam) + }, + ValuesFunc: func(sourceParam string) []string { + values, ok := c.QueryParams()[sourceParam] + if !ok { + return nil + } + return values + }, + ErrorFunc: NewBindingError, + } +} + +// PathParamsBinder creates path parameter value binder +func PathParamsBinder(c Context) *ValueBinder { + return &ValueBinder{ + failFast: true, + ValueFunc: func(sourceParam string) string { + return c.Param(sourceParam) + }, + ValuesFunc: func(sourceParam string) []string { + // path parameter should not have multiple values so getting values does not make sense but lets not error out here + value := c.Param(sourceParam) + if value == "" { + return nil + } + return []string{value} + }, + ErrorFunc: NewBindingError, + } +} + +// FormFieldBinder creates form field value binder +// For all requests, FormFieldBinder parses the raw query from the URL and uses query params as form fields +// +// For POST, PUT, and PATCH requests, it also reads the request body, parses it +// as a form and uses query params as form fields. Request body parameters take precedence over URL query +// string values in r.Form. +// +// NB: when binding forms take note that this implementation uses standard library form parsing +// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm +// See https://golang.org/pkg/net/http/#Request.ParseForm +func FormFieldBinder(c Context) *ValueBinder { + vb := &ValueBinder{ + failFast: true, + ValueFunc: func(sourceParam string) string { + return c.Request().FormValue(sourceParam) + }, + ErrorFunc: NewBindingError, + } + vb.ValuesFunc = func(sourceParam string) []string { + if c.Request().Form == nil { + // this is same as `Request().FormValue()` does internally + _ = c.Request().ParseMultipartForm(32 << 20) + } + values, ok := c.Request().Form[sourceParam] + if !ok { + return nil + } + return values + } + + return vb +} + +// FailFast set internal flag to indicate if binding methods will return early (without binding) when previous bind failed +// NB: call this method before any other binding methods as it modifies binding methods behaviour +func (b *ValueBinder) FailFast(value bool) *ValueBinder { + b.failFast = value + return b +} + +func (b *ValueBinder) setError(err error) { + if b.errors == nil { + b.errors = []error{err} + return + } + b.errors = append(b.errors, err) +} + +// BindError returns first seen bind error and resets/empties binder errors for further calls +func (b *ValueBinder) BindError() error { + if b.errors == nil { + return nil + } + err := b.errors[0] + b.errors = nil // reset errors so next chain will start from zero + return err +} + +// BindErrors returns all bind errors and resets/empties binder errors for further calls +func (b *ValueBinder) BindErrors() []error { + if b.errors == nil { + return nil + } + errors := b.errors + b.errors = nil // reset errors so next chain will start from zero + return errors +} + +// CustomFunc binds parameter values with Func. Func is called only when parameter values exist. +func (b *ValueBinder) CustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder { + return b.customFunc(sourceParam, customFunc, false) +} + +// MustCustomFunc requires parameter values to exist to be bind with Func. Returns error when value does not exist. +func (b *ValueBinder) MustCustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder { + return b.customFunc(sourceParam, customFunc, true) +} + +func (b *ValueBinder) customFunc(sourceParam string, customFunc func(values []string) []error, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + values := b.ValuesFunc(sourceParam) + if len(values) == 0 { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + if errs := customFunc(values); errs != nil { + b.errors = append(b.errors, errs...) + } + return b +} + +// String binds parameter to string variable +func (b *ValueBinder) String(sourceParam string, dest *string) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + return b + } + *dest = value + return b +} + +// MustString requires parameter value to exist to be bind to string variable. Returns error when value does not exist +func (b *ValueBinder) MustString(sourceParam string, dest *string) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil)) + return b + } + *dest = value + return b +} + +// Strings binds parameter values to slice of string +func (b *ValueBinder) Strings(sourceParam string, dest *[]string) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValuesFunc(sourceParam) + if value == nil { + return b + } + *dest = value + return b +} + +// MustStrings requires parameter values to exist to be bind to slice of string variables. Returns error when value does not exist +func (b *ValueBinder) MustStrings(sourceParam string, dest *[]string) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValuesFunc(sourceParam) + if value == nil { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + return b + } + *dest = value + return b +} + +// BindUnmarshaler binds parameter to destination implementing BindUnmarshaler interface +func (b *ValueBinder) BindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + return b + } + + if err := dest.UnmarshalParam(tmp); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to BindUnmarshaler interface", err)) + } + return b +} + +// MustBindUnmarshaler requires parameter value to exist to be bind to destination implementing BindUnmarshaler interface. +// Returns error when value does not exist +func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil)) + return b + } + + if err := dest.UnmarshalParam(value); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to BindUnmarshaler interface", err)) + } + return b +} + +// BindWithDelimiter binds parameter to destination by suitable conversion function. +// Delimiter is used before conversion to split parameter value to separate values +func (b *ValueBinder) BindWithDelimiter(sourceParam string, dest interface{}, delimiter string) *ValueBinder { + return b.bindWithDelimiter(sourceParam, dest, delimiter, false) +} + +// MustBindWithDelimiter requires parameter value to exist to be bind destination by suitable conversion function. +// Delimiter is used before conversion to split parameter value to separate values +func (b *ValueBinder) MustBindWithDelimiter(sourceParam string, dest interface{}, delimiter string) *ValueBinder { + return b.bindWithDelimiter(sourceParam, dest, delimiter, true) +} + +func (b *ValueBinder) bindWithDelimiter(sourceParam string, dest interface{}, delimiter string, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + values := b.ValuesFunc(sourceParam) + if len(values) == 0 { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + tmpValues := make([]string, 0, len(values)) + for _, v := range values { + tmpValues = append(tmpValues, strings.Split(v, delimiter)...) + } + + switch d := dest.(type) { + case *[]string: + *d = tmpValues + return b + case *[]bool: + return b.bools(sourceParam, tmpValues, d) + case *[]int64, *[]int32, *[]int16, *[]int8, *[]int: + return b.ints(sourceParam, tmpValues, d) + case *[]uint64, *[]uint32, *[]uint16, *[]uint8, *[]uint: // *[]byte is same as *[]uint8 + return b.uints(sourceParam, tmpValues, d) + case *[]float64, *[]float32: + return b.floats(sourceParam, tmpValues, d) + case *[]time.Duration: + return b.durations(sourceParam, tmpValues, d) + default: + // support only cases when destination is slice + // does not support time.Time as it needs argument (layout) for parsing or BindUnmarshaler + b.setError(b.ErrorFunc(sourceParam, []string{}, "unsupported bind type", nil)) + return b + } +} + +// Int64 binds parameter to int64 variable +func (b *ValueBinder) Int64(sourceParam string, dest *int64) *ValueBinder { + return b.intValue(sourceParam, dest, 64, false) +} + +// MustInt64 requires parameter value to exist to be bind to int64 variable. Returns error when value does not exist +func (b *ValueBinder) MustInt64(sourceParam string, dest *int64) *ValueBinder { + return b.intValue(sourceParam, dest, 64, true) +} + +// Int32 binds parameter to int32 variable +func (b *ValueBinder) Int32(sourceParam string, dest *int32) *ValueBinder { + return b.intValue(sourceParam, dest, 32, false) +} + +// MustInt32 requires parameter value to exist to be bind to int32 variable. Returns error when value does not exist +func (b *ValueBinder) MustInt32(sourceParam string, dest *int32) *ValueBinder { + return b.intValue(sourceParam, dest, 32, true) +} + +// Int16 binds parameter to int16 variable +func (b *ValueBinder) Int16(sourceParam string, dest *int16) *ValueBinder { + return b.intValue(sourceParam, dest, 16, false) +} + +// MustInt16 requires parameter value to exist to be bind to int16 variable. Returns error when value does not exist +func (b *ValueBinder) MustInt16(sourceParam string, dest *int16) *ValueBinder { + return b.intValue(sourceParam, dest, 16, true) +} + +// Int8 binds parameter to int8 variable +func (b *ValueBinder) Int8(sourceParam string, dest *int8) *ValueBinder { + return b.intValue(sourceParam, dest, 8, false) +} + +// MustInt8 requires parameter value to exist to be bind to int8 variable. Returns error when value does not exist +func (b *ValueBinder) MustInt8(sourceParam string, dest *int8) *ValueBinder { + return b.intValue(sourceParam, dest, 8, true) +} + +// Int binds parameter to int variable +func (b *ValueBinder) Int(sourceParam string, dest *int) *ValueBinder { + return b.intValue(sourceParam, dest, 0, false) +} + +// MustInt requires parameter value to exist to be bind to int variable. Returns error when value does not exist +func (b *ValueBinder) MustInt(sourceParam string, dest *int) *ValueBinder { + return b.intValue(sourceParam, dest, 0, true) +} + +func (b *ValueBinder) intValue(sourceParam string, dest interface{}, bitSize int, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + + return b.int(sourceParam, value, dest, bitSize) +} + +func (b *ValueBinder) int(sourceParam string, value string, dest interface{}, bitSize int) *ValueBinder { + n, err := strconv.ParseInt(value, 10, bitSize) + if err != nil { + if bitSize == 0 { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to int", err)) + } else { + b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to int%v", bitSize), err)) + } + return b + } + + switch d := dest.(type) { + case *int64: + *d = n + case *int32: + *d = int32(n) + case *int16: + *d = int16(n) + case *int8: + *d = int8(n) + case *int: + *d = int(n) + } + return b +} + +func (b *ValueBinder) intsValue(sourceParam string, dest interface{}, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + values := b.ValuesFunc(sourceParam) + if len(values) == 0 { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, values, "required field value is empty", nil)) + } + return b + } + return b.ints(sourceParam, values, dest) +} + +func (b *ValueBinder) ints(sourceParam string, values []string, dest interface{}) *ValueBinder { + switch d := dest.(type) { + case *[]int64: + tmp := make([]int64, len(values)) + for i, v := range values { + b.int(sourceParam, v, &tmp[i], 64) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]int32: + tmp := make([]int32, len(values)) + for i, v := range values { + b.int(sourceParam, v, &tmp[i], 32) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]int16: + tmp := make([]int16, len(values)) + for i, v := range values { + b.int(sourceParam, v, &tmp[i], 16) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]int8: + tmp := make([]int8, len(values)) + for i, v := range values { + b.int(sourceParam, v, &tmp[i], 8) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]int: + tmp := make([]int, len(values)) + for i, v := range values { + b.int(sourceParam, v, &tmp[i], 0) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + } + return b +} + +// Int64s binds parameter to slice of int64 +func (b *ValueBinder) Int64s(sourceParam string, dest *[]int64) *ValueBinder { + return b.intsValue(sourceParam, dest, false) +} + +// MustInt64s requires parameter value to exist to be bind to int64 slice variable. Returns error when value does not exist +func (b *ValueBinder) MustInt64s(sourceParam string, dest *[]int64) *ValueBinder { + return b.intsValue(sourceParam, dest, true) +} + +// Int32s binds parameter to slice of int32 +func (b *ValueBinder) Int32s(sourceParam string, dest *[]int32) *ValueBinder { + return b.intsValue(sourceParam, dest, false) +} + +// MustInt32s requires parameter value to exist to be bind to int32 slice variable. Returns error when value does not exist +func (b *ValueBinder) MustInt32s(sourceParam string, dest *[]int32) *ValueBinder { + return b.intsValue(sourceParam, dest, true) +} + +// Int16s binds parameter to slice of int16 +func (b *ValueBinder) Int16s(sourceParam string, dest *[]int16) *ValueBinder { + return b.intsValue(sourceParam, dest, false) +} + +// MustInt16s requires parameter value to exist to be bind to int16 slice variable. Returns error when value does not exist +func (b *ValueBinder) MustInt16s(sourceParam string, dest *[]int16) *ValueBinder { + return b.intsValue(sourceParam, dest, true) +} + +// Int8s binds parameter to slice of int8 +func (b *ValueBinder) Int8s(sourceParam string, dest *[]int8) *ValueBinder { + return b.intsValue(sourceParam, dest, false) +} + +// MustInt8s requires parameter value to exist to be bind to int8 slice variable. Returns error when value does not exist +func (b *ValueBinder) MustInt8s(sourceParam string, dest *[]int8) *ValueBinder { + return b.intsValue(sourceParam, dest, true) +} + +// Ints binds parameter to slice of int +func (b *ValueBinder) Ints(sourceParam string, dest *[]int) *ValueBinder { + return b.intsValue(sourceParam, dest, false) +} + +// MustInts requires parameter value to exist to be bind to int slice variable. Returns error when value does not exist +func (b *ValueBinder) MustInts(sourceParam string, dest *[]int) *ValueBinder { + return b.intsValue(sourceParam, dest, true) +} + +// Uint64 binds parameter to uint64 variable +func (b *ValueBinder) Uint64(sourceParam string, dest *uint64) *ValueBinder { + return b.uintValue(sourceParam, dest, 64, false) +} + +// MustUint64 requires parameter value to exist to be bind to uint64 variable. Returns error when value does not exist +func (b *ValueBinder) MustUint64(sourceParam string, dest *uint64) *ValueBinder { + return b.uintValue(sourceParam, dest, 64, true) +} + +// Uint32 binds parameter to uint32 variable +func (b *ValueBinder) Uint32(sourceParam string, dest *uint32) *ValueBinder { + return b.uintValue(sourceParam, dest, 32, false) +} + +// MustUint32 requires parameter value to exist to be bind to uint32 variable. Returns error when value does not exist +func (b *ValueBinder) MustUint32(sourceParam string, dest *uint32) *ValueBinder { + return b.uintValue(sourceParam, dest, 32, true) +} + +// Uint16 binds parameter to uint16 variable +func (b *ValueBinder) Uint16(sourceParam string, dest *uint16) *ValueBinder { + return b.uintValue(sourceParam, dest, 16, false) +} + +// MustUint16 requires parameter value to exist to be bind to uint16 variable. Returns error when value does not exist +func (b *ValueBinder) MustUint16(sourceParam string, dest *uint16) *ValueBinder { + return b.uintValue(sourceParam, dest, 16, true) +} + +// Uint8 binds parameter to uint8 variable +func (b *ValueBinder) Uint8(sourceParam string, dest *uint8) *ValueBinder { + return b.uintValue(sourceParam, dest, 8, false) +} + +// MustUint8 requires parameter value to exist to be bind to uint8 variable. Returns error when value does not exist +func (b *ValueBinder) MustUint8(sourceParam string, dest *uint8) *ValueBinder { + return b.uintValue(sourceParam, dest, 8, true) +} + +// Byte binds parameter to byte variable +func (b *ValueBinder) Byte(sourceParam string, dest *byte) *ValueBinder { + return b.uintValue(sourceParam, dest, 8, false) +} + +// MustByte requires parameter value to exist to be bind to byte variable. Returns error when value does not exist +func (b *ValueBinder) MustByte(sourceParam string, dest *byte) *ValueBinder { + return b.uintValue(sourceParam, dest, 8, true) +} + +// Uint binds parameter to uint variable +func (b *ValueBinder) Uint(sourceParam string, dest *uint) *ValueBinder { + return b.uintValue(sourceParam, dest, 0, false) +} + +// MustUint requires parameter value to exist to be bind to uint variable. Returns error when value does not exist +func (b *ValueBinder) MustUint(sourceParam string, dest *uint) *ValueBinder { + return b.uintValue(sourceParam, dest, 0, true) +} + +func (b *ValueBinder) uintValue(sourceParam string, dest interface{}, bitSize int, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + + return b.uint(sourceParam, value, dest, bitSize) +} + +func (b *ValueBinder) uint(sourceParam string, value string, dest interface{}, bitSize int) *ValueBinder { + n, err := strconv.ParseUint(value, 10, bitSize) + if err != nil { + if bitSize == 0 { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to uint", err)) + } else { + b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to uint%v", bitSize), err)) + } + return b + } + + switch d := dest.(type) { + case *uint64: + *d = n + case *uint32: + *d = uint32(n) + case *uint16: + *d = uint16(n) + case *uint8: // byte is alias to uint8 + *d = uint8(n) + case *uint: + *d = uint(n) + } + return b +} + +func (b *ValueBinder) uintsValue(sourceParam string, dest interface{}, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + values := b.ValuesFunc(sourceParam) + if len(values) == 0 { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, values, "required field value is empty", nil)) + } + return b + } + return b.uints(sourceParam, values, dest) +} + +func (b *ValueBinder) uints(sourceParam string, values []string, dest interface{}) *ValueBinder { + switch d := dest.(type) { + case *[]uint64: + tmp := make([]uint64, len(values)) + for i, v := range values { + b.uint(sourceParam, v, &tmp[i], 64) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]uint32: + tmp := make([]uint32, len(values)) + for i, v := range values { + b.uint(sourceParam, v, &tmp[i], 32) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]uint16: + tmp := make([]uint16, len(values)) + for i, v := range values { + b.uint(sourceParam, v, &tmp[i], 16) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]uint8: // byte is alias to uint8 + tmp := make([]uint8, len(values)) + for i, v := range values { + b.uint(sourceParam, v, &tmp[i], 8) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]uint: + tmp := make([]uint, len(values)) + for i, v := range values { + b.uint(sourceParam, v, &tmp[i], 0) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + } + return b +} + +// Uint64s binds parameter to slice of uint64 +func (b *ValueBinder) Uint64s(sourceParam string, dest *[]uint64) *ValueBinder { + return b.uintsValue(sourceParam, dest, false) +} + +// MustUint64s requires parameter value to exist to be bind to uint64 slice variable. Returns error when value does not exist +func (b *ValueBinder) MustUint64s(sourceParam string, dest *[]uint64) *ValueBinder { + return b.uintsValue(sourceParam, dest, true) +} + +// Uint32s binds parameter to slice of uint32 +func (b *ValueBinder) Uint32s(sourceParam string, dest *[]uint32) *ValueBinder { + return b.uintsValue(sourceParam, dest, false) +} + +// MustUint32s requires parameter value to exist to be bind to uint32 slice variable. Returns error when value does not exist +func (b *ValueBinder) MustUint32s(sourceParam string, dest *[]uint32) *ValueBinder { + return b.uintsValue(sourceParam, dest, true) +} + +// Uint16s binds parameter to slice of uint16 +func (b *ValueBinder) Uint16s(sourceParam string, dest *[]uint16) *ValueBinder { + return b.uintsValue(sourceParam, dest, false) +} + +// MustUint16s requires parameter value to exist to be bind to uint16 slice variable. Returns error when value does not exist +func (b *ValueBinder) MustUint16s(sourceParam string, dest *[]uint16) *ValueBinder { + return b.uintsValue(sourceParam, dest, true) +} + +// Uint8s binds parameter to slice of uint8 +func (b *ValueBinder) Uint8s(sourceParam string, dest *[]uint8) *ValueBinder { + return b.uintsValue(sourceParam, dest, false) +} + +// MustUint8s requires parameter value to exist to be bind to uint8 slice variable. Returns error when value does not exist +func (b *ValueBinder) MustUint8s(sourceParam string, dest *[]uint8) *ValueBinder { + return b.uintsValue(sourceParam, dest, true) +} + +// Uints binds parameter to slice of uint +func (b *ValueBinder) Uints(sourceParam string, dest *[]uint) *ValueBinder { + return b.uintsValue(sourceParam, dest, false) +} + +// MustUints requires parameter value to exist to be bind to uint slice variable. Returns error when value does not exist +func (b *ValueBinder) MustUints(sourceParam string, dest *[]uint) *ValueBinder { + return b.uintsValue(sourceParam, dest, true) +} + +// Bool binds parameter to bool variable +func (b *ValueBinder) Bool(sourceParam string, dest *bool) *ValueBinder { + return b.boolValue(sourceParam, dest, false) +} + +// MustBool requires parameter value to exist to be bind to bool variable. Returns error when value does not exist +func (b *ValueBinder) MustBool(sourceParam string, dest *bool) *ValueBinder { + return b.boolValue(sourceParam, dest, true) +} + +func (b *ValueBinder) boolValue(sourceParam string, dest *bool, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + return b.bool(sourceParam, value, dest) +} + +func (b *ValueBinder) bool(sourceParam string, value string, dest *bool) *ValueBinder { + n, err := strconv.ParseBool(value) + if err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to bool", err)) + return b + } + + *dest = n + return b +} + +func (b *ValueBinder) boolsValue(sourceParam string, dest *[]bool, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + values := b.ValuesFunc(sourceParam) + if len(values) == 0 { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + return b.bools(sourceParam, values, dest) +} + +func (b *ValueBinder) bools(sourceParam string, values []string, dest *[]bool) *ValueBinder { + tmp := make([]bool, len(values)) + for i, v := range values { + b.bool(sourceParam, v, &tmp[i]) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *dest = tmp + } + return b +} + +// Bools binds parameter values to slice of bool variables +func (b *ValueBinder) Bools(sourceParam string, dest *[]bool) *ValueBinder { + return b.boolsValue(sourceParam, dest, false) +} + +// MustBools requires parameter values to exist to be bind to slice of bool variables. Returns error when values does not exist +func (b *ValueBinder) MustBools(sourceParam string, dest *[]bool) *ValueBinder { + return b.boolsValue(sourceParam, dest, true) +} + +// Float64 binds parameter to float64 variable +func (b *ValueBinder) Float64(sourceParam string, dest *float64) *ValueBinder { + return b.floatValue(sourceParam, dest, 64, false) +} + +// MustFloat64 requires parameter value to exist to be bind to float64 variable. Returns error when value does not exist +func (b *ValueBinder) MustFloat64(sourceParam string, dest *float64) *ValueBinder { + return b.floatValue(sourceParam, dest, 64, true) +} + +// Float32 binds parameter to float32 variable +func (b *ValueBinder) Float32(sourceParam string, dest *float32) *ValueBinder { + return b.floatValue(sourceParam, dest, 32, false) +} + +// MustFloat32 requires parameter value to exist to be bind to float32 variable. Returns error when value does not exist +func (b *ValueBinder) MustFloat32(sourceParam string, dest *float32) *ValueBinder { + return b.floatValue(sourceParam, dest, 32, true) +} + +func (b *ValueBinder) floatValue(sourceParam string, dest interface{}, bitSize int, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + + return b.float(sourceParam, value, dest, bitSize) +} + +func (b *ValueBinder) float(sourceParam string, value string, dest interface{}, bitSize int) *ValueBinder { + n, err := strconv.ParseFloat(value, bitSize) + if err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to float%v", bitSize), err)) + return b + } + + switch d := dest.(type) { + case *float64: + *d = n + case *float32: + *d = float32(n) + } + return b +} + +func (b *ValueBinder) floatsValue(sourceParam string, dest interface{}, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + values := b.ValuesFunc(sourceParam) + if len(values) == 0 { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + return b.floats(sourceParam, values, dest) +} + +func (b *ValueBinder) floats(sourceParam string, values []string, dest interface{}) *ValueBinder { + switch d := dest.(type) { + case *[]float64: + tmp := make([]float64, len(values)) + for i, v := range values { + b.float(sourceParam, v, &tmp[i], 64) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]float32: + tmp := make([]float32, len(values)) + for i, v := range values { + b.float(sourceParam, v, &tmp[i], 32) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + } + return b +} + +// Float64s binds parameter values to slice of float64 variables +func (b *ValueBinder) Float64s(sourceParam string, dest *[]float64) *ValueBinder { + return b.floatsValue(sourceParam, dest, false) +} + +// MustFloat64s requires parameter values to exist to be bind to slice of float64 variables. Returns error when values does not exist +func (b *ValueBinder) MustFloat64s(sourceParam string, dest *[]float64) *ValueBinder { + return b.floatsValue(sourceParam, dest, true) +} + +// Float32s binds parameter values to slice of float32 variables +func (b *ValueBinder) Float32s(sourceParam string, dest *[]float32) *ValueBinder { + return b.floatsValue(sourceParam, dest, false) +} + +// MustFloat32s requires parameter values to exist to be bind to slice of float32 variables. Returns error when values does not exist +func (b *ValueBinder) MustFloat32s(sourceParam string, dest *[]float32) *ValueBinder { + return b.floatsValue(sourceParam, dest, true) +} + +// Time binds parameter to time.Time variable +func (b *ValueBinder) Time(sourceParam string, dest *time.Time, layout string) *ValueBinder { + return b.time(sourceParam, dest, layout, false) +} + +// MustTime requires parameter value to exist to be bind to time.Time variable. Returns error when value does not exist +func (b *ValueBinder) MustTime(sourceParam string, dest *time.Time, layout string) *ValueBinder { + return b.time(sourceParam, dest, layout, true) +} + +func (b *ValueBinder) time(sourceParam string, dest *time.Time, layout string, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil)) + } + return b + } + t, err := time.Parse(layout, value) + if err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Time", err)) + return b + } + *dest = t + return b +} + +// Times binds parameter values to slice of time.Time variables +func (b *ValueBinder) Times(sourceParam string, dest *[]time.Time, layout string) *ValueBinder { + return b.times(sourceParam, dest, layout, false) +} + +// MustTimes requires parameter values to exist to be bind to slice of time.Time variables. Returns error when values does not exist +func (b *ValueBinder) MustTimes(sourceParam string, dest *[]time.Time, layout string) *ValueBinder { + return b.times(sourceParam, dest, layout, true) +} + +func (b *ValueBinder) times(sourceParam string, dest *[]time.Time, layout string, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + values := b.ValuesFunc(sourceParam) + if len(values) == 0 { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + + tmp := make([]time.Time, len(values)) + for i, v := range values { + t, err := time.Parse(layout, v) + if err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{v}, "failed to bind field value to Time", err)) + if b.failFast { + return b + } + continue + } + tmp[i] = t + } + if b.errors == nil { + *dest = tmp + } + return b +} + +// Duration binds parameter to time.Duration variable +func (b *ValueBinder) Duration(sourceParam string, dest *time.Duration) *ValueBinder { + return b.duration(sourceParam, dest, false) +} + +// MustDuration requires parameter value to exist to be bind to time.Duration variable. Returns error when value does not exist +func (b *ValueBinder) MustDuration(sourceParam string, dest *time.Duration) *ValueBinder { + return b.duration(sourceParam, dest, true) +} + +func (b *ValueBinder) duration(sourceParam string, dest *time.Duration, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil)) + } + return b + } + t, err := time.ParseDuration(value) + if err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Duration", err)) + return b + } + *dest = t + return b +} + +// Durations binds parameter values to slice of time.Duration variables +func (b *ValueBinder) Durations(sourceParam string, dest *[]time.Duration) *ValueBinder { + return b.durationsValue(sourceParam, dest, false) +} + +// MustDurations requires parameter values to exist to be bind to slice of time.Duration variables. Returns error when values does not exist +func (b *ValueBinder) MustDurations(sourceParam string, dest *[]time.Duration) *ValueBinder { + return b.durationsValue(sourceParam, dest, true) +} + +func (b *ValueBinder) durationsValue(sourceParam string, dest *[]time.Duration, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + values := b.ValuesFunc(sourceParam) + if len(values) == 0 { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + return b.durations(sourceParam, values, dest) +} + +func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]time.Duration) *ValueBinder { + tmp := make([]time.Duration, len(values)) + for i, v := range values { + t, err := time.ParseDuration(v) + if err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{v}, "failed to bind field value to Duration", err)) + if b.failFast { + return b + } + continue + } + tmp[i] = t + } + if b.errors == nil { + *dest = tmp + } + return b +} + +// UnixTime binds parameter to time.Time variable (in local Time corresponding to the given Unix time). +// +// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 +// +// Note: +// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, false, false) +} + +// MustUnixTime requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding +// to the given Unix time). Returns error when value does not exist. +// +// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 +// +// Note: +// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, true, false) +} + +// UnixTimeNano binds parameter to time.Time variable (in local Time corresponding to the given Unix time in nano second precision). +// +// Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 +// Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00 +// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00 +// +// Note: +// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. +func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, false, true) +} + +// MustUnixTimeNano requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding +// to the given Unix time value in nano second precision). Returns error when value does not exist. +// +// Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 +// Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00 +// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00 +// +// Note: +// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. +func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, true, true) +} + +func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, isNano bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil)) + } + return b + } + + n, err := strconv.ParseInt(value, 10, 64) + if err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Time", err)) + return b + } + + if isNano { + *dest = time.Unix(0, n) + } else { + *dest = time.Unix(n, 0) + } + return b +} diff --git a/binder_external_test.go b/binder_external_test.go new file mode 100644 index 000000000..f1aecb52b --- /dev/null +++ b/binder_external_test.go @@ -0,0 +1,130 @@ +// run tests as external package to get real feel for API +package echo_test + +import ( + "encoding/base64" + "fmt" + "github.com/labstack/echo/v4" + "log" + "net/http" + "net/http/httptest" +) + +func ExampleValueBinder_BindErrors() { + // example route function that binds query params to different destinations and returns all bind errors in one go + routeFunc := func(c echo.Context) error { + var opts struct { + Active bool + IDs []int64 + } + length := int64(50) // default length is 50 + + b := echo.QueryParamsBinder(c) + + errs := b.Int64("length", &length). + Int64s("ids", &opts.IDs). + Bool("active", &opts.Active). + BindErrors() // returns all errors + if errs != nil { + for _, err := range errs { + bErr := err.(*echo.BindingError) + log.Printf("in case you want to access what field: %s values: %v failed", bErr.Field, bErr.Values) + } + return fmt.Errorf("%v fields failed to bind", len(errs)) + } + fmt.Printf("active = %v, length = %v, ids = %v", opts.Active, length, opts.IDs) + + return c.JSON(http.StatusOK, opts) + } + + e := echo.New() + c := e.NewContext( + httptest.NewRequest(http.MethodGet, "/api/endpoint?active=true&length=25&ids=1&ids=2&ids=3", nil), + httptest.NewRecorder(), + ) + + _ = routeFunc(c) + + // Output: active = true, length = 25, ids = [1 2 3] +} + +func ExampleValueBinder_BindError() { + // example route function that binds query params to different destinations and stops binding on first bind error + failFastRouteFunc := func(c echo.Context) error { + var opts struct { + Active bool + IDs []int64 + } + length := int64(50) // default length is 50 + + // create binder that stops binding at first error + b := echo.QueryParamsBinder(c) + + err := b.Int64("length", &length). + Int64s("ids", &opts.IDs). + Bool("active", &opts.Active). + BindError() // returns first binding error + if err != nil { + bErr := err.(*echo.BindingError) + return fmt.Errorf("my own custom error for field: %s values: %v", bErr.Field, bErr.Values) + } + fmt.Printf("active = %v, length = %v, ids = %v\n", opts.Active, length, opts.IDs) + + return c.JSON(http.StatusOK, opts) + } + + e := echo.New() + c := e.NewContext( + httptest.NewRequest(http.MethodGet, "/api/endpoint?active=true&length=25&ids=1&ids=2&ids=3", nil), + httptest.NewRecorder(), + ) + + _ = failFastRouteFunc(c) + + // Output: active = true, length = 25, ids = [1 2 3] +} + +func ExampleValueBinder_CustomFunc() { + // example route function that binds query params using custom function closure + routeFunc := func(c echo.Context) error { + length := int64(50) // default length is 50 + var binary []byte + + b := echo.QueryParamsBinder(c) + errs := b.Int64("length", &length). + CustomFunc("base64", func(values []string) []error { + if len(values) == 0 { + return nil + } + decoded, err := base64.URLEncoding.DecodeString(values[0]) + if err != nil { + // in this example we use only first param value but url could contain multiple params in reality and + // therefore in theory produce multiple binding errors + return []error{echo.NewBindingError("base64", values[0:1], "failed to decode base64", err)} + } + binary = decoded + return nil + }). + BindErrors() // returns all errors + + if errs != nil { + for _, err := range errs { + bErr := err.(*echo.BindingError) + log.Printf("in case you want to access what field: %s values: %v failed", bErr.Field, bErr.Values) + } + return fmt.Errorf("%v fields failed to bind", len(errs)) + } + fmt.Printf("length = %v, base64 = %s", length, binary) + + return c.JSON(http.StatusOK, "ok") + } + + e := echo.New() + c := e.NewContext( + httptest.NewRequest(http.MethodGet, "/api/endpoint?length=25&base64=SGVsbG8gV29ybGQ%3D", nil), + httptest.NewRecorder(), + ) + _ = routeFunc(c) + + // Output: length = 25, base64 = Hello World +} diff --git a/binder_go1.15_test.go b/binder_go1.15_test.go new file mode 100644 index 000000000..018628c3a --- /dev/null +++ b/binder_go1.15_test.go @@ -0,0 +1,265 @@ +// +build go1.15 + +package echo + +/** + Since version 1.15 time.Time and time.Duration error message pattern has changed (values are wrapped now in \"\") + So pre 1.15 these tests fail with similar error: + + expected: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param" + actual : "code=400, message=failed to bind field value to Duration, internal=time: invalid duration nope, field=param" +*/ + +import ( + "errors" + "github.com/stretchr/testify/assert" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func createTestContext15(URL string, body io.Reader, pathParams map[string]string) Context { + e := New() + req := httptest.NewRequest(http.MethodGet, URL, body) + if body != nil { + req.Header.Set(HeaderContentType, MIMEApplicationJSON) + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + if len(pathParams) > 0 { + names := make([]string, 0) + values := make([]string, 0) + for name, value := range pathParams { + names = append(names, name) + values = append(values, value) + } + c.SetParamNames(names...) + c.SetParamValues(values...) + } + + return c +} + +func TestValueBinder_TimeError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + whenLayout string + expectValue time.Time + expectError string + }{ + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext15(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustTime("param", &dest, tc.whenLayout).BindError() + } else { + err = b.Time("param", &dest, tc.whenLayout).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_TimesError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + whenLayout string + expectValue []time.Time + expectError string + }{ + { + name: "nok, fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext15(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + layout := time.RFC3339 + if tc.whenLayout != "" { + layout = tc.whenLayout + } + + var dest []time.Time + var err error + if tc.whenMust { + err = b.MustTimes("param", &dest, layout).BindError() + } else { + err = b.Times("param", &dest, layout).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_DurationError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue time.Duration + expectError string + }{ + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: 0, + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: 0, + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext15(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest time.Duration + var err error + if tc.whenMust { + err = b.MustDuration("param", &dest).BindError() + } else { + err = b.Duration("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_DurationsError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue []time.Duration + expectError string + }{ + { + name: "nok, fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, internal=time: missing unit in duration \"1\", field=param", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext15(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + var dest []time.Duration + var err error + if tc.whenMust { + err = b.MustDurations("param", &dest).BindError() + } else { + err = b.Durations("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/binder_test.go b/binder_test.go new file mode 100644 index 000000000..946906a96 --- /dev/null +++ b/binder_test.go @@ -0,0 +1,2757 @@ +// run tests as external package to get real feel for API +package echo + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/stretchr/testify/assert" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + "time" +) + +func createTestContext(URL string, body io.Reader, pathParams map[string]string) Context { + e := New() + req := httptest.NewRequest(http.MethodGet, URL, body) + if body != nil { + req.Header.Set(HeaderContentType, MIMEApplicationJSON) + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + if len(pathParams) > 0 { + names := make([]string, 0) + values := make([]string, 0) + for name, value := range pathParams { + names = append(names, name) + values = append(values, value) + } + c.SetParamNames(names...) + c.SetParamValues(values...) + } + + return c +} + +func TestBindingError_Error(t *testing.T) { + err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error")) + assert.EqualError(t, err, `code=400, message=bind failed, internal=internal error, field=id`) + + bErr := err.(*BindingError) + assert.Equal(t, 400, bErr.Code) + assert.Equal(t, "bind failed", bErr.Message) + assert.Equal(t, errors.New("internal error"), bErr.Internal) + + assert.Equal(t, "id", bErr.Field) + assert.Equal(t, []string{"1", "nope"}, bErr.Values) +} + +func TestBindingError_ErrorJSON(t *testing.T) { + err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error")) + + resp, err := json.Marshal(err) + + assert.Equal(t, `{"field":"id","message":"bind failed"}`, string(resp)) +} + +func TestPathParamsBinder(t *testing.T) { + c := createTestContext("/api/user/999", nil, map[string]string{ + "id": "1", + "nr": "2", + "slice": "3", + }) + b := PathParamsBinder(c) + + id := int64(99) + nr := int64(88) + var slice = make([]int64, 0) + var notExisting = make([]int64, 0) + err := b.Int64("id", &id). + Int64("nr", &nr). + Int64s("slice", &slice). + Int64s("not_existing", ¬Existing). + BindError() + + assert.NoError(t, err) + assert.Equal(t, int64(1), id) + assert.Equal(t, int64(2), nr) + assert.Equal(t, []int64{3}, slice) // binding params to slice does not make sense but it should not panic either + assert.Equal(t, []int64{}, notExisting) // binding params to slice does not make sense but it should not panic either +} + +func TestQueryParamsBinder_FailFast(t *testing.T) { + var testCases = []struct { + name string + whenURL string + givenFailFast bool + expectError []string + }{ + { + name: "ok, FailFast=true stops at first error", + whenURL: "/api/user/999?nr=en&id=nope", + givenFailFast: true, + expectError: []string{ + `code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing "nope": invalid syntax, field=id`, + }, + }, + { + name: "ok, FailFast=false encounters all errors", + whenURL: "/api/user/999?nr=en&id=nope", + givenFailFast: false, + expectError: []string{ + `code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing "nope": invalid syntax, field=id`, + `code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing "en": invalid syntax, field=nr`, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, map[string]string{"id": "999"}) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + id := int64(99) + nr := int64(88) + errs := b.Int64("id", &id). + Int64("nr", &nr). + BindErrors() + + assert.Len(t, errs, len(tc.expectError)) + for _, err := range errs { + assert.Contains(t, tc.expectError, err.Error()) + } + }) + } +} + +func TestFormFieldBinder(t *testing.T) { + e := New() + body := `texta=foo&slice=5` + req := httptest.NewRequest(http.MethodPost, "/api/search?id=1&nr=2&slice=3&slice=4", strings.NewReader(body)) + req.Header.Set(HeaderContentLength, strconv.Itoa(len(body))) + req.Header.Set(HeaderContentType, MIMEApplicationForm) + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + b := FormFieldBinder(c) + + var texta string + id := int64(99) + nr := int64(88) + var slice = make([]int64, 0) + var notExisting = make([]int64, 0) + err := b. + Int64s("slice", &slice). + Int64("id", &id). + Int64("nr", &nr). + String("texta", &texta). + Int64s("notExisting", ¬Existing). + BindError() + + assert.NoError(t, err) + assert.Equal(t, "foo", texta) + assert.Equal(t, int64(1), id) + assert.Equal(t, int64(2), nr) + assert.Equal(t, []int64{5, 3, 4}, slice) + assert.Equal(t, []int64{}, notExisting) +} + +func TestValueBinder_errorStopsBinding(t *testing.T) { + // this test documents "feature" that binding multiple params can change destination if it was binded before + // failing parameter binding + + c := createTestContext("/api/user/999?id=1&nr=nope", nil, nil) + b := QueryParamsBinder(c) + + id := int64(99) // will be changed before nr binding fails + nr := int64(88) // will not be changed + err := b.Int64("id", &id). + Int64("nr", &nr). + BindError() + + assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=nr") + assert.Equal(t, int64(1), id) + assert.Equal(t, int64(88), nr) +} + +func TestValueBinder_BindError(t *testing.T) { + c := createTestContext("/api/user/999?nr=en&id=nope", nil, nil) + b := QueryParamsBinder(c) + + id := int64(99) + nr := int64(88) + err := b.Int64("id", &id). + Int64("nr", &nr). + BindError() + + assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=id") + assert.Nil(t, b.errors) + assert.Nil(t, b.BindError()) +} + +func TestValueBinder_GetValues(t *testing.T) { + var testCases = []struct { + name string + whenValuesFunc func(sourceParam string) []string + expect []int64 + expectError string + }{ + { + name: "ok, default implementation", + expect: []int64{1, 101}, + }, + { + name: "ok, values returns nil", + whenValuesFunc: func(sourceParam string) []string { + return nil + }, + expect: []int64(nil), + }, + { + name: "ok, values returns empty slice", + whenValuesFunc: func(sourceParam string) []string { + return []string{} + }, + expect: []int64(nil), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext("/search?nr=en&id=1&id=101", nil, nil) + b := QueryParamsBinder(c) + if tc.whenValuesFunc != nil { + b.ValuesFunc = tc.whenValuesFunc + } + + var IDs []int64 + err := b.Int64s("id", &IDs).BindError() + + assert.Equal(t, tc.expect, IDs) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_CustomFuncWithError(t *testing.T) { + c := createTestContext("/search?nr=en&id=1&id=101", nil, nil) + b := QueryParamsBinder(c) + + id := int64(99) + givenCustomFunc := func(values []string) []error { + assert.Equal(t, []string{"1", "101"}, values) + + return []error{ + errors.New("first error"), + errors.New("second error"), + } + } + err := b.CustomFunc("id", givenCustomFunc).BindError() + + assert.Equal(t, int64(99), id) + assert.EqualError(t, err, "first error") +} + +func TestValueBinder_CustomFunc(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenFuncErrors []error + whenURL string + expectParamValues []string + expectValue interface{} + expectErrors []string + }{ + { + name: "ok, binds value", + whenURL: "/search?nr=en&id=1&id=100", + expectParamValues: []string{"1", "100"}, + expectValue: int64(1000), + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nr=en", + expectParamValues: []string{}, + expectValue: int64(99), + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?nr=en&id=1&id=100", + expectParamValues: []string{"1", "100"}, + expectValue: int64(99), + expectErrors: []string{"previous error"}, + }, + { + name: "nok, func returns errors", + givenFuncErrors: []error{ + errors.New("first error"), + errors.New("second error"), + }, + whenURL: "/search?nr=en&id=1&id=100", + expectParamValues: []string{"1", "100"}, + expectValue: int64(99), + expectErrors: []string{"first error", "second error"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + id := int64(99) + givenCustomFunc := func(values []string) []error { + assert.Equal(t, tc.expectParamValues, values) + if tc.givenFuncErrors == nil { + id = 1000 // emulated conversion and setting value + return nil + } + return tc.givenFuncErrors + } + errs := b.CustomFunc("id", givenCustomFunc).BindErrors() + + assert.Equal(t, tc.expectValue, id) + if tc.expectErrors != nil { + assert.Len(t, errs, len(tc.expectErrors)) + for _, err := range errs { + assert.Contains(t, tc.expectErrors, err.Error()) + } + } else { + assert.Nil(t, errs) + } + }) + } +} + +func TestValueBinder_MustCustomFunc(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenFuncErrors []error + whenURL string + expectParamValues []string + expectValue interface{} + expectErrors []string + }{ + { + name: "ok, binds value", + whenURL: "/search?nr=en&id=1&id=100", + expectParamValues: []string{"1", "100"}, + expectValue: int64(1000), + }, + { + name: "nok, params values empty, returns error, value is not changed", + whenURL: "/search?nr=en", + expectParamValues: []string{}, + expectValue: int64(99), + expectErrors: []string{"code=400, message=required field value is empty, field=id"}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?nr=en&id=1&id=100", + expectParamValues: []string{"1", "100"}, + expectValue: int64(99), + expectErrors: []string{"previous error"}, + }, + { + name: "nok, func returns errors", + givenFuncErrors: []error{ + errors.New("first error"), + errors.New("second error"), + }, + whenURL: "/search?nr=en&id=1&id=100", + expectParamValues: []string{"1", "100"}, + expectValue: int64(99), + expectErrors: []string{"first error", "second error"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + id := int64(99) + givenCustomFunc := func(values []string) []error { + assert.Equal(t, tc.expectParamValues, values) + if tc.givenFuncErrors == nil { + id = 1000 // emulated conversion and setting value + return nil + } + return tc.givenFuncErrors + } + errs := b.MustCustomFunc("id", givenCustomFunc).BindErrors() + + assert.Equal(t, tc.expectValue, id) + if tc.expectErrors != nil { + assert.Len(t, errs, len(tc.expectErrors)) + for _, err := range errs { + assert.Contains(t, tc.expectErrors, err.Error()) + } + } else { + assert.Nil(t, errs) + } + }) + } +} + +func TestValueBinder_String(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue string + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=en¶m=de", + expectValue: "en", + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nr=en", + expectValue: "default", + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?nr=en&id=1&id=100", + expectValue: "default", + expectError: "previous error", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=en¶m=de", + expectValue: "en", + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nr=en", + expectValue: "default", + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?nr=en&id=1&id=100", + expectValue: "default", + expectError: "previous error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := "default" + var err error + if tc.whenMust { + err = b.MustString("param", &dest).BindError() + } else { + err = b.String("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Strings(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue []string + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=en¶m=de", + expectValue: []string{"en", "de"}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nr=en", + expectValue: []string{"default"}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?nr=en&id=1&id=100", + expectValue: []string{"default"}, + expectError: "previous error", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=en¶m=de", + expectValue: []string{"en", "de"}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nr=en", + expectValue: []string{"default"}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?nr=en&id=1&id=100", + expectValue: []string{"default"}, + expectError: "previous error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := []string{"default"} + var err error + if tc.whenMust { + err = b.MustStrings("param", &dest).BindError() + } else { + err = b.Strings("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Int64_intValue(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue int64 + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=1¶m=100", + expectValue: 1, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: 99, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: 99, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: 99, + expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: 1, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: 99, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: 99, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: 99, + expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := int64(99) + var err error + if tc.whenMust { + err = b.MustInt64("param", &dest).BindError() + } else { + err = b.Int64("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Int_errorMessage(t *testing.T) { + // int/uint (without byte size) has a little bit different error message so test these separately + c := createTestContext("/search?param=nope", nil, nil) + b := QueryParamsBinder(c).FailFast(false) + + destInt := 99 + destUint := uint(98) + errs := b.Int("param", &destInt).Uint("param", &destUint).BindErrors() + + assert.Equal(t, 99, destInt) + assert.Equal(t, uint(98), destUint) + assert.EqualError(t, errs[0], `code=400, message=failed to bind field value to int, internal=strconv.ParseInt: parsing "nope": invalid syntax, field=param`) + assert.EqualError(t, errs[1], `code=400, message=failed to bind field value to uint, internal=strconv.ParseUint: parsing "nope": invalid syntax, field=param`) +} + +func TestValueBinder_Uint64_uintValue(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue uint64 + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=1¶m=100", + expectValue: 1, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: 99, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: 99, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: 99, + expectError: "code=400, message=failed to bind field value to uint64, internal=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: 1, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: 99, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: 99, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: 99, + expectError: "code=400, message=failed to bind field value to uint64, internal=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := uint64(99) + var err error + if tc.whenMust { + err = b.MustUint64("param", &dest).BindError() + } else { + err = b.Uint64("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Int_Types(t *testing.T) { + type target struct { + int64 int64 + mustInt64 int64 + uint64 uint64 + mustUint64 uint64 + + int32 int32 + mustInt32 int32 + uint32 uint32 + mustUint32 uint32 + + int16 int16 + mustInt16 int16 + uint16 uint16 + mustUint16 uint16 + + int8 int8 + mustInt8 int8 + uint8 uint8 + mustUint8 uint8 + + byte byte + mustByte byte + + int int + mustInt int + uint uint + mustUint uint + } + types := []string{ + "int64=1", + "mustInt64=2", + "uint64=3", + "mustUint64=4", + + "int32=5", + "mustInt32=6", + "uint32=7", + "mustUint32=8", + + "int16=9", + "mustInt16=10", + "uint16=11", + "mustUint16=12", + + "int8=13", + "mustInt8=14", + "uint8=15", + "mustUint8=16", + + "byte=17", + "mustByte=18", + + "int=19", + "mustInt=20", + "uint=21", + "mustUint=22", + } + c := createTestContext("/search?"+strings.Join(types, "&"), nil, nil) + b := QueryParamsBinder(c) + + dest := target{} + err := b. + Int64("int64", &dest.int64). + MustInt64("mustInt64", &dest.mustInt64). + Uint64("uint64", &dest.uint64). + MustUint64("mustUint64", &dest.mustUint64). + Int32("int32", &dest.int32). + MustInt32("mustInt32", &dest.mustInt32). + Uint32("uint32", &dest.uint32). + MustUint32("mustUint32", &dest.mustUint32). + Int16("int16", &dest.int16). + MustInt16("mustInt16", &dest.mustInt16). + Uint16("uint16", &dest.uint16). + MustUint16("mustUint16", &dest.mustUint16). + Int8("int8", &dest.int8). + MustInt8("mustInt8", &dest.mustInt8). + Uint8("uint8", &dest.uint8). + MustUint8("mustUint8", &dest.mustUint8). + Byte("byte", &dest.byte). + MustByte("mustByte", &dest.mustByte). + Int("int", &dest.int). + MustInt("mustInt", &dest.mustInt). + Uint("uint", &dest.uint). + MustUint("mustUint", &dest.mustUint). + BindError() + + assert.NoError(t, err) + assert.Equal(t, int64(1), dest.int64) + assert.Equal(t, int64(2), dest.mustInt64) + assert.Equal(t, uint64(3), dest.uint64) + assert.Equal(t, uint64(4), dest.mustUint64) + + assert.Equal(t, int32(5), dest.int32) + assert.Equal(t, int32(6), dest.mustInt32) + assert.Equal(t, uint32(7), dest.uint32) + assert.Equal(t, uint32(8), dest.mustUint32) + + assert.Equal(t, int16(9), dest.int16) + assert.Equal(t, int16(10), dest.mustInt16) + assert.Equal(t, uint16(11), dest.uint16) + assert.Equal(t, uint16(12), dest.mustUint16) + + assert.Equal(t, int8(13), dest.int8) + assert.Equal(t, int8(14), dest.mustInt8) + assert.Equal(t, uint8(15), dest.uint8) + assert.Equal(t, uint8(16), dest.mustUint8) + + assert.Equal(t, uint8(17), dest.byte) + assert.Equal(t, uint8(18), dest.mustByte) + + assert.Equal(t, 19, dest.int) + assert.Equal(t, 20, dest.mustInt) + assert.Equal(t, uint(21), dest.uint) + assert.Equal(t, uint(22), dest.mustUint) +} + +func TestValueBinder_Int64s_intsValue(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue []int64 + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=1¶m=2¶m=1", + expectValue: []int64{1, 2, 1}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: []int64{99}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []int64{99}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []int64{99}, + expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1¶m=2¶m=1", + expectValue: []int64{1, 2, 1}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: []int64{99}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: []int64{99}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []int64{99}, + expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := []int64{99} // when values are set with bind - contents before bind is gone + var err error + if tc.whenMust { + err = b.MustInt64s("param", &dest).BindError() + } else { + err = b.Int64s("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Uint64s_uintsValue(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue []uint64 + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=1¶m=2¶m=1", + expectValue: []uint64{1, 2, 1}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: []uint64{99}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []uint64{99}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []uint64{99}, + expectError: "code=400, message=failed to bind field value to uint64, internal=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1¶m=2¶m=1", + expectValue: []uint64{1, 2, 1}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: []uint64{99}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: []uint64{99}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []uint64{99}, + expectError: "code=400, message=failed to bind field value to uint64, internal=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := []uint64{99} // when values are set with bind - contents before bind is gone + var err error + if tc.whenMust { + err = b.MustUint64s("param", &dest).BindError() + } else { + err = b.Uint64s("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Ints_Types(t *testing.T) { + type target struct { + int64 []int64 + mustInt64 []int64 + uint64 []uint64 + mustUint64 []uint64 + + int32 []int32 + mustInt32 []int32 + uint32 []uint32 + mustUint32 []uint32 + + int16 []int16 + mustInt16 []int16 + uint16 []uint16 + mustUint16 []uint16 + + int8 []int8 + mustInt8 []int8 + uint8 []uint8 + mustUint8 []uint8 + + int []int + mustInt []int + uint []uint + mustUint []uint + } + types := []string{ + "int64=1", + "mustInt64=2", + "uint64=3", + "mustUint64=4", + + "int32=5", + "mustInt32=6", + "uint32=7", + "mustUint32=8", + + "int16=9", + "mustInt16=10", + "uint16=11", + "mustUint16=12", + + "int8=13", + "mustInt8=14", + "uint8=15", + "mustUint8=16", + + "int=19", + "mustInt=20", + "uint=21", + "mustUint=22", + } + url := "/search?" + for _, v := range types { + url = url + "&" + v + "&" + v + } + c := createTestContext(url, nil, nil) + b := QueryParamsBinder(c) + + dest := target{} + err := b. + Int64s("int64", &dest.int64). + MustInt64s("mustInt64", &dest.mustInt64). + Uint64s("uint64", &dest.uint64). + MustUint64s("mustUint64", &dest.mustUint64). + Int32s("int32", &dest.int32). + MustInt32s("mustInt32", &dest.mustInt32). + Uint32s("uint32", &dest.uint32). + MustUint32s("mustUint32", &dest.mustUint32). + Int16s("int16", &dest.int16). + MustInt16s("mustInt16", &dest.mustInt16). + Uint16s("uint16", &dest.uint16). + MustUint16s("mustUint16", &dest.mustUint16). + Int8s("int8", &dest.int8). + MustInt8s("mustInt8", &dest.mustInt8). + Uint8s("uint8", &dest.uint8). + MustUint8s("mustUint8", &dest.mustUint8). + Ints("int", &dest.int). + MustInts("mustInt", &dest.mustInt). + Uints("uint", &dest.uint). + MustUints("mustUint", &dest.mustUint). + BindError() + + assert.NoError(t, err) + assert.Equal(t, []int64{1, 1}, dest.int64) + assert.Equal(t, []int64{2, 2}, dest.mustInt64) + assert.Equal(t, []uint64{3, 3}, dest.uint64) + assert.Equal(t, []uint64{4, 4}, dest.mustUint64) + + assert.Equal(t, []int32{5, 5}, dest.int32) + assert.Equal(t, []int32{6, 6}, dest.mustInt32) + assert.Equal(t, []uint32{7, 7}, dest.uint32) + assert.Equal(t, []uint32{8, 8}, dest.mustUint32) + + assert.Equal(t, []int16{9, 9}, dest.int16) + assert.Equal(t, []int16{10, 10}, dest.mustInt16) + assert.Equal(t, []uint16{11, 11}, dest.uint16) + assert.Equal(t, []uint16{12, 12}, dest.mustUint16) + + assert.Equal(t, []int8{13, 13}, dest.int8) + assert.Equal(t, []int8{14, 14}, dest.mustInt8) + assert.Equal(t, []uint8{15, 15}, dest.uint8) + assert.Equal(t, []uint8{16, 16}, dest.mustUint8) + + assert.Equal(t, []int{19, 19}, dest.int) + assert.Equal(t, []int{20, 20}, dest.mustInt) + assert.Equal(t, []uint{21, 21}, dest.uint) + assert.Equal(t, []uint{22, 22}, dest.mustUint) +} + +func TestValueBinder_Ints_Types_FailFast(t *testing.T) { + // FailFast() should stop parsing and return early + errTmpl := "code=400, message=failed to bind field value to %v, internal=strconv.Parse%v: parsing \"nope\": invalid syntax, field=param" + c := createTestContext("/search?param=1¶m=nope¶m=2", nil, nil) + + var dest64 []int64 + err := QueryParamsBinder(c).FailFast(true).Int64s("param", &dest64).BindError() + assert.Equal(t, []int64(nil), dest64) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int64", "Int")) + + var dest32 []int32 + err = QueryParamsBinder(c).FailFast(true).Int32s("param", &dest32).BindError() + assert.Equal(t, []int32(nil), dest32) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int32", "Int")) + + var dest16 []int16 + err = QueryParamsBinder(c).FailFast(true).Int16s("param", &dest16).BindError() + assert.Equal(t, []int16(nil), dest16) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int16", "Int")) + + var dest8 []int8 + err = QueryParamsBinder(c).FailFast(true).Int8s("param", &dest8).BindError() + assert.Equal(t, []int8(nil), dest8) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int8", "Int")) + + var dest []int + err = QueryParamsBinder(c).FailFast(true).Ints("param", &dest).BindError() + assert.Equal(t, []int(nil), dest) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int", "Int")) + + var destu64 []uint64 + err = QueryParamsBinder(c).FailFast(true).Uint64s("param", &destu64).BindError() + assert.Equal(t, []uint64(nil), destu64) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint64", "Uint")) + + var destu32 []uint32 + err = QueryParamsBinder(c).FailFast(true).Uint32s("param", &destu32).BindError() + assert.Equal(t, []uint32(nil), destu32) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint32", "Uint")) + + var destu16 []uint16 + err = QueryParamsBinder(c).FailFast(true).Uint16s("param", &destu16).BindError() + assert.Equal(t, []uint16(nil), destu16) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint16", "Uint")) + + var destu8 []uint8 + err = QueryParamsBinder(c).FailFast(true).Uint8s("param", &destu8).BindError() + assert.Equal(t, []uint8(nil), destu8) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint8", "Uint")) + + var destu []uint + err = QueryParamsBinder(c).FailFast(true).Uints("param", &destu).BindError() + assert.Equal(t, []uint(nil), destu) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint", "Uint")) +} + +func TestValueBinder_Bool(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue bool + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=true¶m=1", + expectValue: true, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: false, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: false, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: false, + expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: true, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: false, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: false, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: false, + expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := false + var err error + if tc.whenMust { + err = b.MustBool("param", &dest).BindError() + } else { + err = b.Bool("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Bools(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue []bool + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=true¶m=false¶m=1¶m=0", + expectValue: []bool{true, false, true, false}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: []bool(nil), + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenURL: "/search?param=1¶m=100", + expectValue: []bool(nil), + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=true¶m=nope¶m=100", + expectValue: []bool(nil), + expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "nok, conversion fails fast, value is not changed", + givenFailFast: true, + whenURL: "/search?param=true¶m=nope¶m=100", + expectValue: []bool(nil), + expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=true¶m=false¶m=1¶m=0", + expectValue: []bool{true, false, true, false}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: []bool(nil), + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: []bool(nil), + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []bool(nil), + expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + var dest []bool + var err error + if tc.whenMust { + err = b.MustBools("param", &dest).BindError() + } else { + err = b.Bools("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Float64(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue float64 + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=4.3¶m=1", + expectValue: 4.3, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: 1.123, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: 1.123, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: 1.123, + expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=4.3¶m=100", + expectValue: 4.3, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: 1.123, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: 1.123, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: 1.123, + expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := 1.123 + var err error + if tc.whenMust { + err = b.MustFloat64("param", &dest).BindError() + } else { + err = b.Float64("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Float64s(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue []float64 + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=4.3¶m=0", + expectValue: []float64{4.3, 0}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: []float64(nil), + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenURL: "/search?param=1¶m=100", + expectValue: []float64(nil), + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []float64(nil), + expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "nok, conversion fails fast, value is not changed", + givenFailFast: true, + whenURL: "/search?param=0¶m=nope¶m=100", + expectValue: []float64(nil), + expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=4.3¶m=0", + expectValue: []float64{4.3, 0}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: []float64(nil), + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: []float64(nil), + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []float64(nil), + expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + var dest []float64 + var err error + if tc.whenMust { + err = b.MustFloat64s("param", &dest).BindError() + } else { + err = b.Float64s("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Float32(t *testing.T) { + var testCases = []struct { + name string + givenNoFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue float32 + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=4.3¶m=1", + expectValue: 4.3, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: 1.123, + }, + { + name: "nok, previous errors fail fast without binding value", + givenNoFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: 1.123, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: 1.123, + expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=4.3¶m=100", + expectValue: 4.3, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: 1.123, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenNoFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: 1.123, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: 1.123, + expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenNoFailFast) + if tc.givenNoFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := float32(1.123) + var err error + if tc.whenMust { + err = b.MustFloat32("param", &dest).BindError() + } else { + err = b.Float32("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Float32s(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue []float32 + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=4.3¶m=0", + expectValue: []float32{4.3, 0}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: []float32(nil), + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenURL: "/search?param=1¶m=100", + expectValue: []float32(nil), + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []float32(nil), + expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "nok, conversion fails fast, value is not changed", + givenFailFast: true, + whenURL: "/search?param=0¶m=nope¶m=100", + expectValue: []float32(nil), + expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=4.3¶m=0", + expectValue: []float32{4.3, 0}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: []float32(nil), + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: []float32(nil), + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []float32(nil), + expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + var dest []float32 + var err error + if tc.whenMust { + err = b.MustFloat32s("param", &dest).BindError() + } else { + err = b.Float32s("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Time(t *testing.T) { + exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00") + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + whenLayout string + expectValue time.Time + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", + whenLayout: time.RFC3339, + expectValue: exampleTime, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: time.Time{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", + whenLayout: time.RFC3339, + expectValue: exampleTime, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: time.Time{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustTime("param", &dest, tc.whenLayout).BindError() + } else { + err = b.Time("param", &dest, tc.whenLayout).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Times(t *testing.T) { + exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00") + exampleTime2, _ := time.Parse(time.RFC3339, "2000-01-02T09:45:31+00:00") + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + whenLayout string + expectValue []time.Time + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", + whenLayout: time.RFC3339, + expectValue: []time.Time{exampleTime, exampleTime2}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: []time.Time(nil), + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Time(nil), + expectError: "previous error", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", + whenLayout: time.RFC3339, + expectValue: []time.Time{exampleTime, exampleTime2}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: []time.Time(nil), + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Time(nil), + expectError: "previous error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + layout := time.RFC3339 + if tc.whenLayout != "" { + layout = tc.whenLayout + } + + var dest []time.Time + var err error + if tc.whenMust { + err = b.MustTimes("param", &dest, layout).BindError() + } else { + err = b.Times("param", &dest, layout).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Duration(t *testing.T) { + example := 42 * time.Second + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue time.Duration + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=42s¶m=1ms", + expectValue: example, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: 0, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: 0, + expectError: "previous error", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=42s¶m=1ms", + expectValue: example, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: 0, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: 0, + expectError: "previous error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest time.Duration + var err error + if tc.whenMust { + err = b.MustDuration("param", &dest).BindError() + } else { + err = b.Duration("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Durations(t *testing.T) { + exampleDuration := 42 * time.Second + exampleDuration2 := 1 * time.Millisecond + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue []time.Duration + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=42s¶m=1ms", + expectValue: []time.Duration{exampleDuration, exampleDuration2}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: []time.Duration(nil), + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Duration(nil), + expectError: "previous error", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=42s¶m=1ms", + expectValue: []time.Duration{exampleDuration, exampleDuration2}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: []time.Duration(nil), + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Duration(nil), + expectError: "previous error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + var dest []time.Duration + var err error + if tc.whenMust { + err = b.MustDurations("param", &dest).BindError() + } else { + err = b.Durations("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_BindUnmarshaler(t *testing.T) { + exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00") + + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue Timestamp + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", + expectValue: Timestamp(exampleTime), + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: Timestamp{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: Timestamp{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: Timestamp{}, + expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", + expectValue: Timestamp(exampleTime), + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: Timestamp{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: Timestamp{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: Timestamp{}, + expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest Timestamp + var err error + if tc.whenMust { + err = b.MustBindUnmarshaler("param", &dest).BindError() + } else { + err = b.BindUnmarshaler("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_BindWithDelimiter_types(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expect interface{} + }{ + { + name: "ok, strings", + expect: []string{"1", "2", "1"}, + }, + { + name: "ok, int64", + expect: []int64{1, 2, 1}, + }, + { + name: "ok, int32", + expect: []int32{1, 2, 1}, + }, + { + name: "ok, int16", + expect: []int16{1, 2, 1}, + }, + { + name: "ok, int8", + expect: []int8{1, 2, 1}, + }, + { + name: "ok, int", + expect: []int{1, 2, 1}, + }, + { + name: "ok, uint64", + expect: []uint64{1, 2, 1}, + }, + { + name: "ok, uint32", + expect: []uint32{1, 2, 1}, + }, + { + name: "ok, uint16", + expect: []uint16{1, 2, 1}, + }, + { + name: "ok, uint8", + expect: []uint8{1, 2, 1}, + }, + { + name: "ok, uint", + expect: []uint{1, 2, 1}, + }, + { + name: "ok, float64", + expect: []float64{1, 2, 1}, + }, + { + name: "ok, float32", + expect: []float32{1, 2, 1}, + }, + { + name: "ok, bool", + whenURL: "/search?param=1,false¶m=true", + expect: []bool{true, false, true}, + }, + { + name: "ok, Duration", + whenURL: "/search?param=1s,42s¶m=1ms", + expect: []time.Duration{1 * time.Second, 42 * time.Second, 1 * time.Millisecond}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + URL := "/search?param=1,2¶m=1" + if tc.whenURL != "" { + URL = tc.whenURL + } + c := createTestContext(URL, nil, nil) + b := QueryParamsBinder(c) + + switch tc.expect.(type) { + case []string: + var dest []string + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []int64: + var dest []int64 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []int32: + var dest []int32 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []int16: + var dest []int16 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []int8: + var dest []int8 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []int: + var dest []int + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []uint64: + var dest []uint64 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []uint32: + var dest []uint32 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []uint16: + var dest []uint16 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []uint8: + var dest []uint8 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []uint: + var dest []uint + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []float64: + var dest []float64 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []float32: + var dest []float32 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []bool: + var dest []bool + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []time.Duration: + var dest []time.Duration + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + default: + assert.Fail(t, "invalid type") + } + }) + } +} + +func TestValueBinder_BindWithDelimiter(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue []int64 + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=1,2¶m=1", + expectValue: []int64{1, 2, 1}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: []int64(nil), + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []int64(nil), + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []int64(nil), + expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1,2¶m=1", + expectValue: []int64{1, 2, 1}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: []int64(nil), + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: []int64(nil), + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []int64(nil), + expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest []int64 + var err error + if tc.whenMust { + err = b.MustBindWithDelimiter("param", &dest, ",").BindError() + } else { + err = b.BindWithDelimiter("param", &dest, ",").BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestBindWithDelimiter_invalidType(t *testing.T) { + c := createTestContext("/search?param=1¶m=100", nil, nil) + b := QueryParamsBinder(c) + + var dest []BindUnmarshaler + err := b.BindWithDelimiter("param", &dest, ",").BindError() + assert.Equal(t, []BindUnmarshaler(nil), dest) + assert.EqualError(t, err, "code=400, message=unsupported bind type, field=param") +} + +func TestValueBinder_UnixTime(t *testing.T) { + exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43+00:00") // => 1609180603 + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue time.Time + expectError string + }{ + { + name: "ok, binds value, unix time in seconds", + whenURL: "/search?param=1609180603¶m=1609180604", + expectValue: exampleTime, + }, + { + name: "ok, binds value, unix time over int32 value", + whenURL: "/search?param=2147483648¶m=1609180604", + expectValue: time.Unix(2147483648, 0), + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: time.Time{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1609180603¶m=1609180604", + expectValue: exampleTime, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: time.Time{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustUnixTime("param", &dest).BindError() + } else { + err = b.UnixTime("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano()) + assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC)) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_UnixTimeNano(t *testing.T) { + exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43.000000000+00:00") // => 1609180603 + exampleTimeNano, _ := time.Parse(time.RFC3339Nano, "2020-12-28T18:36:43.123456789+00:00") // => 1609180603123456789 + exampleTimeNanoBelowSec, _ := time.Parse(time.RFC3339Nano, "1970-01-01T00:00:00.999999999+00:00") + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue time.Time + expectError string + }{ + { + name: "ok, binds value, unix time in nano seconds (sec precision)", + whenURL: "/search?param=1609180603000000000¶m=1609180604", + expectValue: exampleTime, + }, + { + name: "ok, binds value, unix time in nano seconds", + whenURL: "/search?param=1609180603123456789¶m=1609180604", + expectValue: exampleTimeNano, + }, + { + name: "ok, binds value, unix time in nano seconds (below 1 sec)", + whenURL: "/search?param=999999999¶m=1609180604", + expectValue: exampleTimeNanoBelowSec, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: time.Time{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1609180603000000000¶m=1609180604", + expectValue: exampleTime, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: time.Time{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustUnixTimeNano("param", &dest).BindError() + } else { + err = b.UnixTimeNano("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano()) + assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC)) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) { + type Opts struct { + Param int64 `query:"param"` + } + c := createTestContext("/search?param=1¶m=100", nil, nil) + + b.ReportAllocs() + b.ResetTimer() + binder := new(DefaultBinder) + for i := 0; i < b.N; i++ { + var dest Opts + _ = binder.Bind(&dest, c) + } +} + +func BenchmarkValueBinder_BindInt64_single(b *testing.B) { + c := createTestContext("/search?param=1¶m=100", nil, nil) + + b.ReportAllocs() + b.ResetTimer() + type Opts struct { + Param int64 + } + binder := QueryParamsBinder(c) + for i := 0; i < b.N; i++ { + var dest Opts + _ = binder.Int64("param", &dest.Param).BindError() + } +} + +func BenchmarkRawFunc_Int64_single(b *testing.B) { + c := createTestContext("/search?param=1¶m=100", nil, nil) + + rawFunc := func(input string, defaultValue int64) (int64, bool) { + if input == "" { + return defaultValue, true + } + n, err := strconv.Atoi(input) + if err != nil { + return 0, false + } + return int64(n), true + } + + b.ReportAllocs() + b.ResetTimer() + type Opts struct { + Param int64 + } + for i := 0; i < b.N; i++ { + var dest Opts + if n, ok := rawFunc(c.QueryParam("param"), 1); ok { + dest.Param = n + } + } +} + +func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) { + type Opts struct { + Int64 int64 `query:"int64"` + Int32 int32 `query:"int32"` + Int16 int16 `query:"int16"` + Int8 int8 `query:"int8"` + String string `query:"string"` + + Uint64 uint64 `query:"uint64"` + Uint32 uint32 `query:"uint32"` + Uint16 uint16 `query:"uint16"` + Uint8 uint8 `query:"uint8"` + Strings []string `query:"strings"` + } + c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil) + + b.ReportAllocs() + b.ResetTimer() + binder := new(DefaultBinder) + for i := 0; i < b.N; i++ { + var dest Opts + _ = binder.Bind(&dest, c) + if dest.Int64 != 1 { + b.Fatalf("int64!=1") + } + } +} + +func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) { + type Opts struct { + Int64 int64 `query:"int64"` + Int32 int32 `query:"int32"` + Int16 int16 `query:"int16"` + Int8 int8 `query:"int8"` + String string `query:"string"` + + Uint64 uint64 `query:"uint64"` + Uint32 uint32 `query:"uint32"` + Uint16 uint16 `query:"uint16"` + Uint8 uint8 `query:"uint8"` + Strings []string `query:"strings"` + } + c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil) + + b.ReportAllocs() + b.ResetTimer() + binder := QueryParamsBinder(c) + for i := 0; i < b.N; i++ { + var dest Opts + _ = binder. + Int64("int64", &dest.Int64). + Int32("int32", &dest.Int32). + Int16("int16", &dest.Int16). + Int8("int8", &dest.Int8). + String("string", &dest.String). + Uint64("int64", &dest.Uint64). + Uint32("int32", &dest.Uint32). + Uint16("int16", &dest.Uint16). + Uint8("int8", &dest.Uint8). + Strings("strings", &dest.Strings). + BindError() + if dest.Int64 != 1 { + b.Fatalf("int64!=1") + } + } +} From 7c8592a7e05370b21cc02cc2edaee6b2355fd83a Mon Sep 17 00:00:00 2001 From: Benjamin Chibuzor-Orie Date: Fri, 15 Jan 2021 21:53:15 +0100 Subject: [PATCH 105/446] adds middleware for rate limiting (#1724) * adds middleware for rate limiting * added comment for InMemoryStore ShouldAllow * removed redundant mutex declaration * fixed lint issues * removed sleep from tests * improved coverage * refactor: renames Identifiers, includes default SourceFunc * Added last seen stats for visitor * uses http Constants for improved readdability adds default error handler * used other handler apart from default handler to mark custom error handler for rate limiting * split tests into separate blocks added an error pair to IdentifierExtractor Includes deny handler for explicitly denying requests * adds comments for exported members Extractor and ErrorHandler * makes cleanup implementation inhouse * Avoid race for cleanup due to non-atomic access to store.expiresIn * Use a dedicated producer for rate testing * tidy commit * refactors tests, implicitly tests lastSeen property on visitor switches NewRateLimiterMemoryStore constructor to Referential Functions style (Advised by @pafuent) * switches to mock of time module for time based tests tests are now fully deterministic * improved coverage * replaces Rob Pike referential options with more conventional struct configs makes cleanup asynchronous * blocks racy access to lastCleanup * Add benchmark tests for rate limiter * Add rate limiter with sharded memory store * Racy access to store.lastCleanup eliminated Merges in shiny sharded map implementation by @lammel * Remove RateLimiterShradedMemoryStore for now * Make fields for RateLimiterStoreConfig public for external configuration * Improve docs for RateLimiter usage * Fix ErrorHandler vs. DenyHandler usage for rate limiter * Simplify NewRateLimiterMemoryStore * improved coverage * updated errorHandler and denyHandler to use echo.HTTPError * Improve wording for error and comments * Remove duplicate lastSeen marking for Allow * Improve wording for comments * Add disclaimer on perf characteristics of memory store * changes Allow signature on rate limiter to return err too Co-authored-by: Roland Lammel --- .gitignore | 1 + go.mod | 1 + go.sum | 2 + middleware/rate_limiter.go | 268 ++++++++++++++++++ middleware/rate_limiter_test.go | 462 ++++++++++++++++++++++++++++++++ 5 files changed, 734 insertions(+) create mode 100644 middleware/rate_limiter.go create mode 100644 middleware/rate_limiter_test.go diff --git a/.gitignore b/.gitignore index dd74acca4..dbadf3bd0 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ vendor .idea *.iml *.out +.vscode diff --git a/go.mod b/go.mod index 74c6a9abe..877117075 100644 --- a/go.mod +++ b/go.mod @@ -12,4 +12,5 @@ require ( golang.org/x/net v0.0.0-20200822124328-c89045814202 golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6 // indirect golang.org/x/text v0.3.3 // indirect + golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 ) diff --git a/go.sum b/go.sum index 58c80c831..54ba24e67 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= +golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go new file mode 100644 index 000000000..7d1abfcb9 --- /dev/null +++ b/middleware/rate_limiter.go @@ -0,0 +1,268 @@ +package middleware + +import ( + "net/http" + "sync" + "time" + + "github.com/labstack/echo/v4" + "golang.org/x/time/rate" +) + +type ( + // RateLimiterStore is the interface to be implemented by custom stores. + RateLimiterStore interface { + // Stores for the rate limiter have to implement the Allow method + Allow(identifier string) (bool, error) + } +) + +type ( + // RateLimiterConfig defines the configuration for the rate limiter + RateLimiterConfig struct { + Skipper Skipper + BeforeFunc BeforeFunc + // IdentifierExtractor uses echo.Context to extract the identifier for a visitor + IdentifierExtractor Extractor + // Store defines a store for the rate limiter + Store RateLimiterStore + // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error + ErrorHandler func(context echo.Context, err error) error + // DenyHandler provides a handler to be called when RateLimiter denies access + DenyHandler func(context echo.Context, identifier string, err error) error + } + // Extractor is used to extract data from echo.Context + Extractor func(context echo.Context) (string, error) +) + +// errors +var ( + // ErrRateLimitExceeded denotes an error raised when rate limit is exceeded + ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") + // ErrExtractorError denotes an error raised when extractor function is unsuccessful + ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") +) + +// DefaultRateLimiterConfig defines default values for RateLimiterConfig +var DefaultRateLimiterConfig = RateLimiterConfig{ + Skipper: DefaultSkipper, + IdentifierExtractor: func(ctx echo.Context) (string, error) { + id := ctx.RealIP() + return id, nil + }, + ErrorHandler: func(context echo.Context, err error) error { + return &echo.HTTPError{ + Code: ErrExtractorError.Code, + Message: ErrExtractorError.Message, + Internal: err, + } + }, + DenyHandler: func(context echo.Context, identifier string, err error) error { + return &echo.HTTPError{ + Code: ErrRateLimitExceeded.Code, + Message: ErrRateLimitExceeded.Message, + Internal: err, + } + }, +} + +/* +RateLimiter returns a rate limiting middleware + + e := echo.New() + + limiterStore := middleware.NewRateLimiterMemoryStore(20) + + e.GET("/rate-limited", func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }, RateLimiter(limiterStore)) +*/ +func RateLimiter(store RateLimiterStore) echo.MiddlewareFunc { + config := DefaultRateLimiterConfig + config.Store = store + + return RateLimiterWithConfig(config) +} + +/* +RateLimiterWithConfig returns a rate limiting middleware + + e := echo.New() + + config := middleware.RateLimiterConfig{ + Skipper: DefaultSkipper, + Store: middleware.NewRateLimiterMemoryStore( + middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute} + ) + IdentifierExtractor: func(ctx echo.Context) (string, error) { + id := ctx.RealIP() + return id, nil + }, + ErrorHandler: func(context echo.Context, err error) error { + return context.JSON(http.StatusTooManyRequests, nil) + }, + DenyHandler: func(context echo.Context, identifier string) error { + return context.JSON(http.StatusForbidden, nil) + }, + } + + e.GET("/rate-limited", func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }, middleware.RateLimiterWithConfig(config)) +*/ +func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { + if config.Skipper == nil { + config.Skipper = DefaultRateLimiterConfig.Skipper + } + if config.IdentifierExtractor == nil { + config.IdentifierExtractor = DefaultRateLimiterConfig.IdentifierExtractor + } + if config.ErrorHandler == nil { + config.ErrorHandler = DefaultRateLimiterConfig.ErrorHandler + } + if config.DenyHandler == nil { + config.DenyHandler = DefaultRateLimiterConfig.DenyHandler + } + if config.Store == nil { + panic("Store configuration must be provided") + } + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + if config.BeforeFunc != nil { + config.BeforeFunc(c) + } + + identifier, err := config.IdentifierExtractor(c) + if err != nil { + c.Error(config.ErrorHandler(c, err)) + return nil + } + + if allow, err := config.Store.Allow(identifier); !allow { + c.Error(config.DenyHandler(c, identifier, err)) + return nil + } + return next(c) + } + } +} + +type ( + // RateLimiterMemoryStore is the built-in store implementation for RateLimiter + RateLimiterMemoryStore struct { + visitors map[string]*Visitor + mutex sync.Mutex + rate rate.Limit + burst int + expiresIn time.Duration + lastCleanup time.Time + } + // Visitor signifies a unique user's limiter details + Visitor struct { + *rate.Limiter + lastSeen time.Time + } +) + +/* +NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with +the provided rate (as req/s). Burst and ExpiresIn will be set to default values. + +Example (with 20 requests/sec): + + limiterStore := middleware.NewRateLimiterMemoryStore(20) + +*/ +func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) { + return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ + Rate: rate, + }) +} + +/* +NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore +with the provided configuration. Rate must be provided. Burst will be set to the value of +the configured rate if not provided or set to 0. + +The build-in memory store is usually capable for modest loads. For higher loads other +store implementations should be considered. + +Characteristics: +* Concurrency above 100 parallel requests may causes measurable lock contention +* A high number of different IP addresses (above 16000) may be impacted by the internally used Go map +* A high number of requests from a single IP address may cause lock contention + +Example: + + limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig( + middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minutes}, + ) +*/ +func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) { + store = &RateLimiterMemoryStore{} + + store.rate = config.Rate + store.burst = config.Burst + store.expiresIn = config.ExpiresIn + if config.ExpiresIn == 0 { + store.expiresIn = DefaultRateLimiterMemoryStoreConfig.ExpiresIn + } + if config.Burst == 0 { + store.burst = int(config.Rate) + } + store.visitors = make(map[string]*Visitor) + store.lastCleanup = now() + return +} + +// RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore +type RateLimiterMemoryStoreConfig struct { + Rate rate.Limit // Rate of requests allowed to pass as req/s + Burst int // Burst additionally allows a number of requests to pass when rate limit is reached + ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up +} + +// DefaultRateLimiterMemoryStoreConfig provides default configuration values for RateLimiterMemoryStore +var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{ + ExpiresIn: 3 * time.Minute, +} + +// Allow implements RateLimiterStore.Allow +func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { + store.mutex.Lock() + limiter, exists := store.visitors[identifier] + if !exists { + limiter = new(Visitor) + limiter.Limiter = rate.NewLimiter(store.rate, store.burst) + store.visitors[identifier] = limiter + } + limiter.lastSeen = now() + if now().Sub(store.lastCleanup) > store.expiresIn { + store.cleanupStaleVisitors() + } + store.mutex.Unlock() + return limiter.AllowN(now(), 1), nil +} + +/* +cleanupStaleVisitors helps manage the size of the visitors map by removing stale records +of users who haven't visited again after the configured expiry time has elapsed +*/ +func (store *RateLimiterMemoryStore) cleanupStaleVisitors() { + for id, visitor := range store.visitors { + if now().Sub(visitor.lastSeen) > store.expiresIn { + delete(store.visitors, id) + } + } + store.lastCleanup = now() +} + +/* +actual time method which is mocked in test file +*/ +var now = func() time.Time { + return time.Now() +} diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go new file mode 100644 index 000000000..2e57bf175 --- /dev/null +++ b/middleware/rate_limiter_test.go @@ -0,0 +1,462 @@ +package middleware + +import ( + "errors" + "fmt" + "math/rand" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/labstack/echo/v4" + "github.com/labstack/gommon/random" + "github.com/stretchr/testify/assert" + "golang.org/x/time/rate" +) + +func TestRateLimiter(t *testing.T) { + e := echo.New() + + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + mw := RateLimiter(inMemoryStore) + + testCases := []struct { + id string + code int + }{ + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + } + + for _, tc := range testCases { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, tc.id) + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + _ = mw(handler)(c) + assert.Equal(t, tc.code, rec.Code) + } +} + +func TestRateLimiter_panicBehaviour(t *testing.T) { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + assert.Panics(t, func() { + RateLimiter(nil) + }) + + assert.NotPanics(t, func() { + RateLimiter(inMemoryStore) + }) +} + +func TestRateLimiterWithConfig(t *testing.T) { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + e := echo.New() + + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + mw := RateLimiterWithConfig(RateLimiterConfig{ + IdentifierExtractor: func(c echo.Context) (string, error) { + id := c.Request().Header.Get(echo.HeaderXRealIP) + if id == "" { + return "", errors.New("invalid identifier") + } + return id, nil + }, + DenyHandler: func(ctx echo.Context, identifier string, err error) error { + return ctx.JSON(http.StatusForbidden, nil) + }, + ErrorHandler: func(ctx echo.Context, err error) error { + return ctx.JSON(http.StatusBadRequest, nil) + }, + Store: inMemoryStore, + }) + + testCases := []struct { + id string + code int + }{ + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusForbidden}, + {"", http.StatusBadRequest}, + {"127.0.0.1", http.StatusForbidden}, + {"127.0.0.1", http.StatusForbidden}, + } + + for _, tc := range testCases { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, tc.id) + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + _ = mw(handler)(c) + + assert.Equal(t, tc.code, rec.Code) + } +} + +func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + e := echo.New() + + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + mw := RateLimiterWithConfig(RateLimiterConfig{ + IdentifierExtractor: func(c echo.Context) (string, error) { + id := c.Request().Header.Get(echo.HeaderXRealIP) + if id == "" { + return "", errors.New("invalid identifier") + } + return id, nil + }, + Store: inMemoryStore, + }) + + testCases := []struct { + id string + code int + }{ + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"", http.StatusForbidden}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + } + + for _, tc := range testCases { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, tc.id) + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + _ = mw(handler)(c) + + assert.Equal(t, tc.code, rec.Code) + } +} + +func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { + { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + e := echo.New() + + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + mw := RateLimiterWithConfig(RateLimiterConfig{ + Store: inMemoryStore, + }) + + testCases := []struct { + id string + code int + }{ + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + {"127.0.0.1", http.StatusTooManyRequests}, + } + + for _, tc := range testCases { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, tc.id) + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + _ = mw(handler)(c) + + assert.Equal(t, tc.code, rec.Code) + } + } +} + +func TestRateLimiterWithConfig_skipper(t *testing.T) { + e := echo.New() + + var beforeFuncRan bool + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + var inMemoryStore = NewRateLimiterMemoryStore(5) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + mw := RateLimiterWithConfig(RateLimiterConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + BeforeFunc: func(c echo.Context) { + beforeFuncRan = true + }, + Store: inMemoryStore, + IdentifierExtractor: func(ctx echo.Context) (string, error) { + return "127.0.0.1", nil + }, + }) + + _ = mw(handler)(c) + + assert.Equal(t, false, beforeFuncRan) +} + +func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { + e := echo.New() + + var beforeFuncRan bool + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + var inMemoryStore = NewRateLimiterMemoryStore(5) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + mw := RateLimiterWithConfig(RateLimiterConfig{ + Skipper: func(c echo.Context) bool { + return false + }, + BeforeFunc: func(c echo.Context) { + beforeFuncRan = true + }, + Store: inMemoryStore, + IdentifierExtractor: func(ctx echo.Context) (string, error) { + return "127.0.0.1", nil + }, + }) + + _ = mw(handler)(c) + + assert.Equal(t, true, beforeFuncRan) +} + +func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { + e := echo.New() + + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + var beforeRan bool + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + mw := RateLimiterWithConfig(RateLimiterConfig{ + BeforeFunc: func(c echo.Context) { + beforeRan = true + }, + Store: inMemoryStore, + IdentifierExtractor: func(ctx echo.Context) (string, error) { + return "127.0.0.1", nil + }, + }) + + _ = mw(handler)(c) + + assert.Equal(t, true, beforeRan) +} + +func TestRateLimiterMemoryStore_Allow(t *testing.T) { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3, ExpiresIn: 2 * time.Second}) + testCases := []struct { + id string + allowed bool + }{ + {"127.0.0.1", true}, // 0 ms + {"127.0.0.1", true}, // 220 ms burst #2 + {"127.0.0.1", true}, // 440 ms burst #3 + {"127.0.0.1", false}, // 660 ms block + {"127.0.0.1", false}, // 880 ms block + {"127.0.0.1", true}, // 1100 ms next second #1 + {"127.0.0.2", true}, // 1320 ms allow other ip + {"127.0.0.1", false}, // 1540 ms no burst + {"127.0.0.1", false}, // 1760 ms no burst + {"127.0.0.1", false}, // 1980 ms no burst + {"127.0.0.1", true}, // 2200 ms no burst + {"127.0.0.1", false}, // 2420 ms no burst + {"127.0.0.1", false}, // 2640 ms no burst + {"127.0.0.1", false}, // 2860 ms no burst + {"127.0.0.1", true}, // 3080 ms no burst + {"127.0.0.1", false}, // 3300 ms no burst + {"127.0.0.1", false}, // 3520 ms no burst + {"127.0.0.1", false}, // 3740 ms no burst + {"127.0.0.1", false}, // 3960 ms no burst + {"127.0.0.1", true}, // 4180 ms no burst + {"127.0.0.1", false}, // 4400 ms no burst + {"127.0.0.1", false}, // 4620 ms no burst + {"127.0.0.1", false}, // 4840 ms no burst + {"127.0.0.1", true}, // 5060 ms no burst + } + + for i, tc := range testCases { + t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond) + now = func() time.Time { + return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond) + } + allowed, _ := inMemoryStore.Allow(tc.id) + assert.Equal(t, tc.allowed, allowed) + } +} + +func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + now = func() time.Time { + return time.Now() + } + fmt.Println(now()) + inMemoryStore.visitors = map[string]*Visitor{ + "A": { + Limiter: rate.NewLimiter(1, 3), + lastSeen: now(), + }, + "B": { + Limiter: rate.NewLimiter(1, 3), + lastSeen: now().Add(-1 * time.Minute), + }, + "C": { + Limiter: rate.NewLimiter(1, 3), + lastSeen: now().Add(-5 * time.Minute), + }, + "D": { + Limiter: rate.NewLimiter(1, 3), + lastSeen: now().Add(-10 * time.Minute), + }, + } + + inMemoryStore.Allow("D") + inMemoryStore.cleanupStaleVisitors() + + var exists bool + + _, exists = inMemoryStore.visitors["A"] + assert.Equal(t, true, exists) + + _, exists = inMemoryStore.visitors["B"] + assert.Equal(t, true, exists) + + _, exists = inMemoryStore.visitors["C"] + assert.Equal(t, false, exists) + + _, exists = inMemoryStore.visitors["D"] + assert.Equal(t, true, exists) +} + +func TestNewRateLimiterMemoryStore(t *testing.T) { + testCases := []struct { + rate rate.Limit + burst int + expiresIn time.Duration + expectedExpiresIn time.Duration + }{ + {1, 3, 5 * time.Second, 5 * time.Second}, + {2, 4, 0, 3 * time.Minute}, + {1, 5, 10 * time.Minute, 10 * time.Minute}, + {3, 7, 0, 3 * time.Minute}, + } + + for _, tc := range testCases { + store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: tc.rate, Burst: tc.burst, ExpiresIn: tc.expiresIn}) + assert.Equal(t, tc.rate, store.rate) + assert.Equal(t, tc.burst, store.burst) + assert.Equal(t, tc.expectedExpiresIn, store.expiresIn) + } +} + +func generateAddressList(count int) []string { + addrs := make([]string, count) + for i := 0; i < count; i++ { + addrs[i] = random.String(15) + } + return addrs +} + +func run(wg *sync.WaitGroup, store RateLimiterStore, addrs []string, max int, b *testing.B) { + for i := 0; i < b.N; i++ { + store.Allow(addrs[rand.Intn(max)]) + } + wg.Done() +} + +func benchmarkStore(store RateLimiterStore, parallel int, max int, b *testing.B) { + addrs := generateAddressList(max) + wg := &sync.WaitGroup{} + for i := 0; i < parallel; i++ { + wg.Add(1) + go run(wg, store, addrs, max, b) + } + wg.Wait() +} + +const ( + testExpiresIn = 1000 * time.Millisecond +) + +func BenchmarkRateLimiterMemoryStore_1000(b *testing.B) { + var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) + benchmarkStore(store, 10, 1000, b) +} + +func BenchmarkRateLimiterMemoryStore_10000(b *testing.B) { + var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) + benchmarkStore(store, 10, 10000, b) +} + +func BenchmarkRateLimiterMemoryStore_100000(b *testing.B) { + var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) + benchmarkStore(store, 10, 100000, b) +} + +func BenchmarkRateLimiterMemoryStore_conc100_10000(b *testing.B) { + var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) + benchmarkStore(store, 100, 10000, b) +} From 932976ded6e22eb0d80f73c540aef0849b2047af Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Mon, 8 Feb 2021 15:58:55 +0100 Subject: [PATCH 106/446] Support real regex rules for rewrite and proxy middleware (#1767) Support real regex rules for rewrite and proxy middleware (use non-greedy matching by default) Co-authored-by: pwli --- middleware/middleware.go | 2 +- middleware/proxy.go | 22 ++++-- middleware/proxy_test.go | 142 ++++++++++++++++++++++++++----------- middleware/rewrite.go | 24 +++++-- middleware/rewrite_test.go | 47 +++++++++++- 5 files changed, 182 insertions(+), 55 deletions(-) diff --git a/middleware/middleware.go b/middleware/middleware.go index 60834b505..8381e3a5d 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -38,7 +38,7 @@ func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string { rulesRegex := map[*regexp.Regexp]string{} for k, v := range rewrite { k = regexp.QuoteMeta(k) - k = strings.Replace(k, `\*`, "(.*)", -1) + k = strings.Replace(k, `\*`, "(.*?)", -1) if strings.HasPrefix(k, `\^`) { k = strings.Replace(k, `\^`, "^", -1) } diff --git a/middleware/proxy.go b/middleware/proxy.go index 1b972eb16..63eec5a20 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -36,6 +36,13 @@ type ( // "/users/*/orders/*": "/user/$1/order/$2", Rewrite map[string]string + // RegexRewrite defines rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRewrite map[*regexp.Regexp]string + // Context key to store selected ProxyTarget into context. // Optional. Default value "target". ContextKey string @@ -46,8 +53,6 @@ type ( // ModifyResponse defines function to modify response from ProxyTarget. ModifyResponse func(*http.Response) error - - rewriteRegex map[*regexp.Regexp]string } // ProxyTarget defines the upstream target. @@ -206,7 +211,14 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { panic("echo: proxy middleware requires balancer") } - config.rewriteRegex = rewriteRulesRegex(config.Rewrite) + if config.Rewrite != nil { + if config.RegexRewrite == nil { + config.RegexRewrite = make(map[*regexp.Regexp]string) + } + for k, v := range rewriteRulesRegex(config.Rewrite) { + config.RegexRewrite[k] = v + } + } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { @@ -220,7 +232,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { c.Set(config.ContextKey, tgt) // Set rewrite path and raw path - rewritePath(config.rewriteRegex, req) + rewritePath(config.RegexRewrite, req) // Fix header // Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream. @@ -251,5 +263,3 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { } } } - - diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 534e45f44..ec6f1925b 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "regexp" "testing" "github.com/labstack/echo/v4" @@ -83,46 +84,6 @@ func TestProxy(t *testing.T) { body = rec.Body.String() assert.Equal(t, "target 2", body) - // Rewrite - e = echo.New() - e.Use(ProxyWithConfig(ProxyConfig{ - Balancer: rrb, - Rewrite: map[string]string{ - "/old": "/new", - "/api/*": "/$1", - "/js/*": "/public/javascripts/$1", - "/users/*/orders/*": "/user/$1/order/$2", - }, - })) - req.URL, _ = url.Parse("/api/users") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/users", req.URL.EscapedPath()) - assert.Equal(t, http.StatusOK, rec.Code) - req.URL, _ = url.Parse( "/js/main.js") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) - assert.Equal(t, http.StatusOK, rec.Code) - req.URL, _ = url.Parse("/old") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/new", req.URL.EscapedPath()) - assert.Equal(t, http.StatusOK, rec.Code) - req.URL, _ = url.Parse( "/users/jack/orders/1") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) - assert.Equal(t, http.StatusOK, rec.Code) - req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) - assert.Equal(t, http.StatusOK, rec.Code) - req.URL, _ = url.Parse("/api/new users") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/new%20users", req.URL.EscapedPath()) // ModifyResponse e = echo.New() e.Use(ProxyWithConfig(ProxyConfig{ @@ -196,3 +157,104 @@ func TestProxyRealIPHeader(t *testing.T) { assert.Equal(t, tt.extectedXRealIP, req.Header.Get(echo.HeaderXRealIP), "hasRealIPheader: %t / hasIPExtractor: %t", tt.hasRealIPheader, tt.hasIPExtractor) } } + +func TestProxyRewrite(t *testing.T) { + // Setup + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + defer upstream.Close() + url, _ := url.Parse(upstream.URL) + rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + // Rewrite + e := echo.New() + e.Use(ProxyWithConfig(ProxyConfig{ + Balancer: rrb, + Rewrite: map[string]string{ + "/old": "/new", + "/api/*": "/$1", + "/js/*": "/public/javascripts/$1", + "/users/*/orders/*": "/user/$1/order/$2", + }, + })) + req.URL, _ = url.Parse("/api/users") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/users", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) + req.URL, _ = url.Parse("/js/main.js") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) + req.URL, _ = url.Parse("/old") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/new", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) + req.URL, _ = url.Parse("/users/jack/orders/1") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) + req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) + req.URL, _ = url.Parse("/api/new users") + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "/new%20users", req.URL.EscapedPath()) +} + +func TestProxyRewriteRegex(t *testing.T) { + // Setup + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + defer upstream.Close() + url, _ := url.Parse(upstream.URL) + rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + // Rewrite + e := echo.New() + e.Use(ProxyWithConfig(ProxyConfig{ + Balancer: rrb, + Rewrite: map[string]string{ + "^/a/*": "/v1/$1", + "^/b/*/c/*": "/v2/$2/$1", + "^/c/*/*": "/v3/$2", + }, + RegexRewrite: map[*regexp.Regexp]string{ + regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1", + regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1", + }, + })) + + testCases := []struct { + requestPath string + statusCode int + expectPath string + }{ + {"/unmatched", http.StatusOK, "/unmatched"}, + {"/a/test", http.StatusOK, "/v1/test"}, + {"/b/foo/c/bar/baz", http.StatusOK, "/v2/bar/baz/foo"}, + {"/c/ignore/test", http.StatusOK, "/v3/test"}, + {"/c/ignore1/test/this", http.StatusOK, "/v3/test/this"}, + {"/x/ignore/test", http.StatusOK, "/v4/test"}, + {"/y/foo/bar", http.StatusOK, "/v5/bar/foo"}, + } + + + for _, tc := range testCases { + t.Run(tc.requestPath, func(t *testing.T) { + req.URL, _ = url.Parse(tc.requestPath) + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expectPath, req.URL.EscapedPath()) + assert.Equal(t, tc.statusCode, rec.Code) + }) + } +} diff --git a/middleware/rewrite.go b/middleware/rewrite.go index 0965e313f..c05d5d84f 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -1,8 +1,9 @@ package middleware import ( - "github.com/labstack/echo/v4" "regexp" + + "github.com/labstack/echo/v4" ) type ( @@ -21,7 +22,12 @@ type ( // Required. Rules map[string]string `yaml:"rules"` - rulesRegex map[*regexp.Regexp]string + // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"` } ) @@ -45,14 +51,20 @@ func Rewrite(rules map[string]string) echo.MiddlewareFunc { // See: `Rewrite()`. func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { // Defaults - if config.Rules == nil { - panic("echo: rewrite middleware requires url path rewrite rules") + if config.Rules == nil && config.RegexRules == nil { + panic("echo: rewrite middleware requires url path rewrite rules or regex rules") } + if config.Skipper == nil { config.Skipper = DefaultBodyDumpConfig.Skipper } - config.rulesRegex = rewriteRulesRegex(config.Rules) + if config.RegexRules == nil { + config.RegexRules = make(map[*regexp.Regexp]string) + } + for k, v := range rewriteRulesRegex(config.Rules) { + config.RegexRules[k] = v + } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { @@ -62,7 +74,7 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { req := c.Request() // Set rewrite path and raw path - rewritePath(config.rulesRegex, req) + rewritePath(config.RegexRules, req) return next(c) } } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index abf11b2f7..351b7313c 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "regexp" "testing" "github.com/labstack/echo/v4" @@ -55,8 +56,8 @@ func TestEchoRewritePreMiddleware(t *testing.T) { // Rewrite old url to new one e.Pre(Rewrite(map[string]string{ - "/old": "/new", - }, + "/old": "/new", + }, )) // Route @@ -129,3 +130,45 @@ func TestEchoRewriteWithCaret(t *testing.T) { e.ServeHTTP(rec, req) assert.Equal(t, "/v2/abc/test", req.URL.Path) } + +// Verify regex used with rewrite +func TestEchoRewriteWithRegexRules(t *testing.T) { + e := echo.New() + + e.Pre(RewriteWithConfig(RewriteConfig{ + Rules: map[string]string{ + "^/a/*": "/v1/$1", + "^/b/*/c/*": "/v2/$2/$1", + "^/c/*/*": "/v3/$2", + }, + RegexRules: map[*regexp.Regexp]string{ + regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1", + regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1", + }, + })) + + var rec *httptest.ResponseRecorder + var req *http.Request + + testCases := []struct { + requestPath string + expectPath string + }{ + {"/unmatched", "/unmatched"}, + {"/a/test", "/v1/test"}, + {"/b/foo/c/bar/baz", "/v2/bar/baz/foo"}, + {"/c/ignore/test", "/v3/test"}, + {"/c/ignore1/test/this", "/v3/test/this"}, + {"/x/ignore/test", "/v4/test"}, + {"/y/foo/bar", "/v5/bar/foo"}, + } + + for _, tc := range testCases { + t.Run(tc.requestPath, func(t *testing.T) { + req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil) + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expectPath, req.URL.EscapedPath()) + }) + } +} From f09f2bd14e59a8b4aed7bcd1977305ad2bfb57a2 Mon Sep 17 00:00:00 2001 From: Martti T Date: Thu, 11 Feb 2021 15:53:22 +0200 Subject: [PATCH 107/446] Fix open redirect vulnerability with AddTrailingSlashWithConfig and RemoveTrailingSlashWithConfig (#1775,#1771) * fix open redirect vulnerability with AddTrailingSlashWithConfig and RemoveTrailingSlashWithConfig (fix #1771) * rename trimMultipleSlashes to sanitizeURI --- middleware/slash.go | 13 +- middleware/slash_test.go | 342 ++++++++++++++++++++++++++++++--------- 2 files changed, 273 insertions(+), 82 deletions(-) diff --git a/middleware/slash.go b/middleware/slash.go index 0492b334b..4188675b0 100644 --- a/middleware/slash.go +++ b/middleware/slash.go @@ -60,7 +60,7 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc // Redirect if config.RedirectCode != 0 { - return c.Redirect(config.RedirectCode, uri) + return c.Redirect(config.RedirectCode, sanitizeURI(uri)) } // Forward @@ -108,7 +108,7 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu // Redirect if config.RedirectCode != 0 { - return c.Redirect(config.RedirectCode, uri) + return c.Redirect(config.RedirectCode, sanitizeURI(uri)) } // Forward @@ -119,3 +119,12 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu } } } + +func sanitizeURI(uri string) string { + // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri + // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash + if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') { + uri = "/" + strings.TrimLeft(uri, `/\`) + } + return uri +} diff --git a/middleware/slash_test.go b/middleware/slash_test.go index 2a8e9eeaa..ddb071045 100644 --- a/middleware/slash_test.go +++ b/middleware/slash_test.go @@ -9,88 +9,270 @@ import ( "github.com/stretchr/testify/assert" ) +func TestAddTrailingSlashWithConfig(t *testing.T) { + var testCases = []struct { + whenURL string + whenMethod string + expectPath string + expectLocation []string + expectStatus int + }{ + { + whenURL: "/add-slash", + whenMethod: http.MethodGet, + expectPath: "/add-slash", + expectLocation: []string{`/add-slash/`}, + }, + { + whenURL: "/add-slash?key=value", + whenMethod: http.MethodGet, + expectPath: "/add-slash", + expectLocation: []string{`/add-slash/?key=value`}, + }, + { + whenURL: "/", + whenMethod: http.MethodConnect, + expectPath: "/", + expectLocation: nil, + expectStatus: http.StatusOK, + }, + // cases for open redirect vulnerability + { + whenURL: "http://localhost:1323/%5Cexample.com", + expectPath: `/\example.com`, + expectLocation: []string{`/example.com/`}, + }, + { + whenURL: `http://localhost:1323/\example.com`, + expectPath: `/\example.com`, + expectLocation: []string{`/example.com/`}, + }, + { + whenURL: `http://localhost:1323/\\%5C////%5C\\\example.com`, + expectPath: `/\\\////\\\\example.com`, + expectLocation: []string{`/example.com/`}, + }, + { + whenURL: "http://localhost:1323//example.com", + expectPath: `//example.com`, + expectLocation: []string{`/example.com/`}, + }, + { + whenURL: "http://localhost:1323/%5C%5C", + expectPath: `/\\`, + expectLocation: []string{`/`}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + e := echo.New() + + mw := AddTrailingSlashWithConfig(TrailingSlashConfig{ + RedirectCode: http.StatusMovedPermanently, + }) + h := mw(func(c echo.Context) error { + return nil + }) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil) + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, tc.expectPath, req.URL.Path) + assert.Equal(t, tc.expectLocation, rec.Header()[echo.HeaderLocation]) + if tc.expectStatus == 0 { + assert.Equal(t, http.StatusMovedPermanently, rec.Code) + } else { + assert.Equal(t, tc.expectStatus, rec.Code) + } + }) + } +} + func TestAddTrailingSlash(t *testing.T) { - is := assert.New(t) - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/add-slash", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := AddTrailingSlash()(func(c echo.Context) error { - return nil - }) - is.NoError(h(c)) - is.Equal("/add-slash/", req.URL.Path) - is.Equal("/add-slash/", req.RequestURI) - - // Method Connect must not fail: - req = httptest.NewRequest(http.MethodConnect, "", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = AddTrailingSlash()(func(c echo.Context) error { - return nil - }) - is.NoError(h(c)) - is.Equal("/", req.URL.Path) - is.Equal("/", req.RequestURI) - - // With config - req = httptest.NewRequest(http.MethodGet, "/add-slash?key=value", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = AddTrailingSlashWithConfig(TrailingSlashConfig{ - RedirectCode: http.StatusMovedPermanently, - })(func(c echo.Context) error { - return nil - }) - is.NoError(h(c)) - is.Equal(http.StatusMovedPermanently, rec.Code) - is.Equal("/add-slash/?key=value", rec.Header().Get(echo.HeaderLocation)) + var testCases = []struct { + whenURL string + whenMethod string + expectPath string + expectLocation []string + }{ + { + whenURL: "/add-slash", + whenMethod: http.MethodGet, + expectPath: "/add-slash/", + }, + { + whenURL: "/add-slash?key=value", + whenMethod: http.MethodGet, + expectPath: "/add-slash/", + }, + { + whenURL: "/", + whenMethod: http.MethodConnect, + expectPath: "/", + expectLocation: nil, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + e := echo.New() + + h := AddTrailingSlash()(func(c echo.Context) error { + return nil + }) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil) + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, tc.expectPath, req.URL.Path) + assert.Equal(t, []string(nil), rec.Header()[echo.HeaderLocation]) + assert.Equal(t, http.StatusOK, rec.Code) + }) + } +} + +func TestRemoveTrailingSlashWithConfig(t *testing.T) { + var testCases = []struct { + whenURL string + whenMethod string + expectPath string + expectLocation []string + expectStatus int + }{ + { + whenURL: "/remove-slash/", + whenMethod: http.MethodGet, + expectPath: "/remove-slash/", + expectLocation: []string{`/remove-slash`}, + }, + { + whenURL: "/remove-slash/?key=value", + whenMethod: http.MethodGet, + expectPath: "/remove-slash/", + expectLocation: []string{`/remove-slash?key=value`}, + }, + { + whenURL: "/", + whenMethod: http.MethodConnect, + expectPath: "/", + expectLocation: nil, + expectStatus: http.StatusOK, + }, + { + whenURL: "http://localhost", + whenMethod: http.MethodGet, + expectPath: "", + expectLocation: nil, + expectStatus: http.StatusOK, + }, + // cases for open redirect vulnerability + { + whenURL: "http://localhost:1323/%5Cexample.com/", + expectPath: `/\example.com/`, + expectLocation: []string{`/example.com`}, + }, + { + whenURL: `http://localhost:1323/\example.com/`, + expectPath: `/\example.com/`, + expectLocation: []string{`/example.com`}, + }, + { + whenURL: `http://localhost:1323/\\%5C////%5C\\\example.com/`, + expectPath: `/\\\////\\\\example.com/`, + expectLocation: []string{`/example.com`}, + }, + { + whenURL: "http://localhost:1323//example.com/", + expectPath: `//example.com/`, + expectLocation: []string{`/example.com`}, + }, + { + whenURL: "http://localhost:1323/%5C%5C/", + expectPath: `/\\/`, + expectLocation: []string{`/`}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + e := echo.New() + + mw := RemoveTrailingSlashWithConfig(TrailingSlashConfig{ + RedirectCode: http.StatusMovedPermanently, + }) + h := mw(func(c echo.Context) error { + return nil + }) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil) + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, tc.expectPath, req.URL.Path) + assert.Equal(t, tc.expectLocation, rec.Header()[echo.HeaderLocation]) + if tc.expectStatus == 0 { + assert.Equal(t, http.StatusMovedPermanently, rec.Code) + } else { + assert.Equal(t, tc.expectStatus, rec.Code) + } + }) + } } func TestRemoveTrailingSlash(t *testing.T) { - is := assert.New(t) - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/remove-slash/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := RemoveTrailingSlash()(func(c echo.Context) error { - return nil - }) - is.NoError(h(c)) - is.Equal("/remove-slash", req.URL.Path) - is.Equal("/remove-slash", req.RequestURI) - - // Method Connect must not fail: - req = httptest.NewRequest(http.MethodConnect, "", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = RemoveTrailingSlash()(func(c echo.Context) error { - return nil - }) - is.NoError(h(c)) - is.Equal("", req.URL.Path) - is.Equal("", req.RequestURI) - - // With config - req = httptest.NewRequest(http.MethodGet, "/remove-slash/?key=value", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = RemoveTrailingSlashWithConfig(TrailingSlashConfig{ - RedirectCode: http.StatusMovedPermanently, - })(func(c echo.Context) error { - return nil - }) - is.NoError(h(c)) - is.Equal(http.StatusMovedPermanently, rec.Code) - is.Equal("/remove-slash?key=value", rec.Header().Get(echo.HeaderLocation)) - - // With bare URL - req = httptest.NewRequest(http.MethodGet, "http://localhost", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = RemoveTrailingSlash()(func(c echo.Context) error { - return nil - }) - is.NoError(h(c)) - is.Equal("", req.URL.Path) + var testCases = []struct { + whenURL string + whenMethod string + expectPath string + }{ + { + whenURL: "/remove-slash/", + whenMethod: http.MethodGet, + expectPath: "/remove-slash", + }, + { + whenURL: "/remove-slash/?key=value", + whenMethod: http.MethodGet, + expectPath: "/remove-slash", + }, + { + whenURL: "/", + whenMethod: http.MethodConnect, + expectPath: "/", + }, + { + whenURL: "http://localhost", + whenMethod: http.MethodGet, + expectPath: "", + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + e := echo.New() + + h := RemoveTrailingSlash()(func(c echo.Context) error { + return nil + }) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil) + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, tc.expectPath, req.URL.Path) + assert.Equal(t, []string(nil), rec.Header()[echo.HeaderLocation]) + assert.Equal(t, http.StatusOK, rec.Code) + }) + } } From a170896c423a9dab7f3b8f38a6e25a1ef12c6076 Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Thu, 11 Feb 2021 14:54:06 +0100 Subject: [PATCH 108/446] Add CHANGELOG.md for historic tracking of changes (#1764) --- CHANGELOG.md | 73 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..33f5587f8 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,73 @@ +# Changelog + +## v4.2.0 - 2020-02-11 + +**Important notes** + +The behaviour for binding data has been reworked for compatibility with echo before v4.1.11 by +enforcing `explicit tagging` for processing parameters. This **may break** your code if you +expect combined handling of query/path/form params. +Please see the updated documentation for [request](https://echo.labstack.com/guide/request) and [binding](https://echo.labstack.com/guide/request) + +The handling for rewrite rules has been slightly adjusted to expand `*` to a non-greedy `(.*?)` capture group. This is only relevant if multiple asterisks are used in your rules. +Please see [rewrite](https://echo.labstack.com/middleware/rewrite) and [proxy](https://echo.labstack.com/middleware/proxy) for details. + +**Security** + +* Fix directory traversal vulnerability for Windows (#1718, little-cui) +* Fix open redirect vulnerability with trailing slash (#1771,#1775 aldas,GeoffreyFrogeye) + +**Enhancements** + +* Add Echo#ListenerNetwork as configuration (#1667, pafuent) +* Add ability to change the status code using response beforeFuncs (#1706, RashadAnsari) +* Echo server startup to allow data race free access to listener address +* Binder: Restore pre v4.1.11 behaviour for c.Bind() to use query params only for GET or DELETE methods (#1727, aldas) +* Binder: Add separate methods to bind only query params, path params or request body (#1681, aldas) +* Binder: New fluent binder for query/path/form parameter binding (#1717, #1736, aldas) +* Router: Performance improvements for missed routes (#1689, pafuent) +* Router: Improve performance for Real-IP detection using IndexByte instead of Split (#1640, imxyb) +* Middleware: Support real regex rules for rewrite and proxy middleware (#1767) +* Middleware: New rate limiting middleware (#1724, iambenkay) +* Middleware: New timeout middleware implementation for go1.13+ (#1743, ) +* Middleware: Allow regex pattern for CORS middleware (#1623, KlotzAndrew) +* Middleware: Add IgnoreBase parameter to static middleware (#1701, lnenad, iambenkay) +* Middleware: Add an optional custom function to CORS middleware to validate origin (#1651, curvegrid) +* Middleware: Support form fields in JWT middleware (#1704, rkfg) +* Middleware: Use sync.Pool for (de)compress middleware to improve performance (#1699, #1672, pafuent) +* Middleware: Add decompress middleware to support gzip compressed requests (#1687, arun0009) +* Middleware: Add ErrJWTInvalid for JWT middleware (#1627, juanbelieni) +* Middleware: Add SameSite mode for CSRF cookies to support iframes (#1524, pr0head) + +**Fixes** + +* Fix handling of special trailing slash case for partial prefix (#1741, stffabi) +* Fix handling of static routes with trailing slash (#1747) +* Fix Static files route not working (#1671, pwli0755, lammel) +* Fix use of caret(^) in regex for rewrite middleware (#1588, chotow) +* Fix Echo#Reverse for Any type routes (#1695, pafuent) +* Fix Router#Find panic with infinite loop (#1661, pafuent) +* Fix Router#Find panic fails on Param paths (#1659, pafuent) +* Fix DefaultHTTPErrorHandler with Debug=true (#1477, lammel) +* Fix incorrect CORS headers (#1669, ulasakdeniz) +* Fix proxy middleware rewritePath to use url with updated tests (#1630, arun0009) +* Fix rewritePath for proxy middleware to use escaped path in (#1628, arun0009) +* Remove unless defer (#1656, imxyb) + +**General** + +* New maintainers for Echo: Roland Lammel (@lammel) and Pablo Andres Fuente (@pafuent) +* Add GitHub action to compare benchmarks (#1702, pafuent) +* Binding query/path params and form fields to struct only works for explicit tags (#1729,#1734, aldas) +* Add support for Go 1.15 in CI (#1683, asahasrabuddhe) +* Add test for request id to remain unchanged if provided (#1719, iambenkay) +* Refactor echo instance listener access and startup to speed up testing (#1735, aldas) +* Refactor and improve various tests for binding and routing +* Run test workflow only for relevant changes (#1637, #1636, pofl) +* Update .travis.yml (#1662, santosh653) +* Update README.md with an recents framework benchmark (#1679, pafuent) + +This release was made possible by **over 100 commits** from more than **20 contributors**: +asahasrabuddhe, aldas, AndrewKlotz, arun0009, chotow, curvegrid, iambenkay, imxyb, +juanbelieni, lammel, little-cui, lnenad, pafuent, pofl, pr0head, pwli, RashadAnsari, +rkfg, santosh653, segfiner, stffabi, ulasakdeniz From b0f56eaf969bf5f8a6e4b6727f3f5fc7f44f36a0 Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Thu, 11 Feb 2021 19:35:16 +0100 Subject: [PATCH 109/446] Update version to v4.2.0 --- echo.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/echo.go b/echo.go index 6db485d10..7f1c83998 100644 --- a/echo.go +++ b/echo.go @@ -234,7 +234,7 @@ const ( const ( // Version of Echo - Version = "4.1.17" + Version = "4.2.0" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 6a666acd5ce75bfa79b029bab9b6e13ae0e5573b Mon Sep 17 00:00:00 2001 From: Shubhendra Singh Chauhan Date: Fri, 26 Feb 2021 15:34:34 +0530 Subject: [PATCH 110/446] improve code quality (#1792) * Merge variable declaration with assignment * Fix unnecessary typecasting on `bytes.Buffer` * Remove unnecessary wrapping of function call --- binder.go | 12 ++++-------- context_test.go | 12 ++++-------- middleware/logger_test.go | 2 +- middleware/rate_limiter.go | 4 +--- middleware/rate_limiter_test.go | 4 +--- 5 files changed, 11 insertions(+), 23 deletions(-) diff --git a/binder.go b/binder.go index 9f0ca654e..0900ce8dc 100644 --- a/binder.go +++ b/binder.go @@ -101,10 +101,8 @@ type ValueBinder struct { // QueryParamsBinder creates query parameter value binder func QueryParamsBinder(c Context) *ValueBinder { return &ValueBinder{ - failFast: true, - ValueFunc: func(sourceParam string) string { - return c.QueryParam(sourceParam) - }, + failFast: true, + ValueFunc: c.QueryParam, ValuesFunc: func(sourceParam string) []string { values, ok := c.QueryParams()[sourceParam] if !ok { @@ -119,10 +117,8 @@ func QueryParamsBinder(c Context) *ValueBinder { // PathParamsBinder creates path parameter value binder func PathParamsBinder(c Context) *ValueBinder { return &ValueBinder{ - failFast: true, - ValueFunc: func(sourceParam string) string { - return c.Param(sourceParam) - }, + failFast: true, + ValueFunc: c.Param, ValuesFunc: func(sourceParam string) []string { // path parameter should not have multiple values so getting values does not make sense but lets not error out here value := c.Param(sourceParam) diff --git a/context_test.go b/context_test.go index 417d4a749..963c91e04 100644 --- a/context_test.go +++ b/context_test.go @@ -649,8 +649,7 @@ func TestContextRedirect(t *testing.T) { } func TestContextStore(t *testing.T) { - var c Context - c = new(context) + var c Context = new(context) c.Set("name", "Jon Snow") testify.Equal(t, "Jon Snow", c.Get("name")) } @@ -687,8 +686,7 @@ func TestContextHandler(t *testing.T) { } func TestContext_SetHandler(t *testing.T) { - var c Context - c = new(context) + var c Context = new(context) testify.Nil(t, c.Handler()) @@ -701,8 +699,7 @@ func TestContext_SetHandler(t *testing.T) { func TestContext_Path(t *testing.T) { path := "/pa/th" - var c Context - c = new(context) + var c Context = new(context) c.SetPath(path) testify.Equal(t, path, c.Path()) @@ -736,8 +733,7 @@ func TestContext_QueryString(t *testing.T) { } func TestContext_Request(t *testing.T) { - var c Context - c = new(context) + var c Context = new(context) testify.Nil(t, c.Request()) diff --git a/middleware/logger_test.go b/middleware/logger_test.go index b196bc6c8..4d4515b19 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -164,7 +164,7 @@ func TestLoggerCustomTimestamp(t *testing.T) { e.ServeHTTP(rec, req) var objs map[string]*json.RawMessage - if err := json.Unmarshal([]byte(buf.String()), &objs); err != nil { + if err := json.Unmarshal(buf.Bytes(), &objs); err != nil { panic(err) } loggedTime := *(*string)(unsafe.Pointer(objs["time"])) diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index 7d1abfcb9..46a310d96 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -263,6 +263,4 @@ func (store *RateLimiterMemoryStore) cleanupStaleVisitors() { /* actual time method which is mocked in test file */ -var now = func() time.Time { - return time.Now() -} +var now = time.Now diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go index 2e57bf175..89d9a6edc 100644 --- a/middleware/rate_limiter_test.go +++ b/middleware/rate_limiter_test.go @@ -350,9 +350,7 @@ func TestRateLimiterMemoryStore_Allow(t *testing.T) { func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) - now = func() time.Time { - return time.Now() - } + now = time.Now fmt.Println(now()) inMemoryStore.visitors = map[string]*Visitor{ "A": { From 45870c75c3b3fb20edec25cb09dc14ab17df59b9 Mon Sep 17 00:00:00 2001 From: RaviKiran K Date: Sat, 27 Feb 2021 03:22:32 +0530 Subject: [PATCH 111/446] Uses strings.Equalfold (#1790) Changes case insensitive string comparisons to string.EqualFold which performs better than strings.Lower(str) == str comparison --- context.go | 2 +- middleware/basic_auth.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/context.go b/context.go index 380a52e1f..0cee48ce0 100644 --- a/context.go +++ b/context.go @@ -246,7 +246,7 @@ func (c *context) IsTLS() bool { func (c *context) IsWebSocket() bool { upgrade := c.request.Header.Get(HeaderUpgrade) - return strings.ToLower(upgrade) == "websocket" + return strings.EqualFold(upgrade, "websocket") } func (c *context) Scheme() string { diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 76ba24206..8cf1ed9fc 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -73,7 +73,7 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { auth := c.Request().Header.Get(echo.HeaderAuthorization) l := len(basic) - if len(auth) > l+1 && strings.ToLower(auth[:l]) == basic { + if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { b, err := base64.StdEncoding.DecodeString(auth[l+1:]) if err != nil { return err From d9e235416ddaa359a16c84f2f685c863e10f859c Mon Sep 17 00:00:00 2001 From: Leo Takaoka <62293842+Le0tk0k@users.noreply.github.com> Date: Sat, 27 Feb 2021 06:55:00 +0900 Subject: [PATCH 112/446] apply go fmt (#1788) --- middleware/proxy_test.go | 1 - middleware/rewrite_test.go | 16 ++++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index ec6f1925b..eb72f16ee 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -247,7 +247,6 @@ func TestProxyRewriteRegex(t *testing.T) { {"/y/foo/bar", http.StatusOK, "/v5/bar/foo"}, } - for _, tc := range testCases { t.Run(tc.requestPath, func(t *testing.T) { req.URL, _ = url.Parse(tc.requestPath) diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index 351b7313c..84006e32e 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -163,12 +163,12 @@ func TestEchoRewriteWithRegexRules(t *testing.T) { {"/y/foo/bar", "/v5/bar/foo"}, } - for _, tc := range testCases { - t.Run(tc.requestPath, func(t *testing.T) { - req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil) - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, tc.expectPath, req.URL.EscapedPath()) - }) - } + for _, tc := range testCases { + t.Run(tc.requestPath, func(t *testing.T) { + req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil) + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expectPath, req.URL.EscapedPath()) + }) + } } From c79ffed7cee29a66cdca963c5b681db503d8e18e Mon Sep 17 00:00:00 2001 From: Martti T Date: Sun, 28 Feb 2021 20:13:04 +0200 Subject: [PATCH 113/446] Fix Echo.Serve() will not serve on HTTP port correctly when there is already TLSListener set to Echo instance. (#1785) (#1793) --- Makefile | 6 +-- _fixture/certs/README.md | 13 +++++++ _fixture/certs/cert.pem | 44 ++++++++++++++-------- _fixture/certs/key.pem | 79 ++++++++++++++++++++++++++-------------- echo.go | 15 +++----- echo_test.go | 54 +++++++++++++++++++++++++++ 6 files changed, 156 insertions(+), 55 deletions(-) create mode 100644 _fixture/certs/README.md diff --git a/Makefile b/Makefile index bedb8bd25..48061f7e2 100644 --- a/Makefile +++ b/Makefile @@ -29,6 +29,6 @@ benchmark: ## Run benchmarks help: ## Display this help screen @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -goversion ?= "1.12" -test_version: ## Run tests inside Docker with given version (defaults to 1.12 oldest supported). Example: make test_version goversion=1.13 - @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make check" +goversion ?= "1.15" +test_version: ## Run tests inside Docker with given version (defaults to 1.15 oldest supported). Example: make test_version goversion=1.15 + @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" diff --git a/_fixture/certs/README.md b/_fixture/certs/README.md new file mode 100644 index 000000000..e27d4b139 --- /dev/null +++ b/_fixture/certs/README.md @@ -0,0 +1,13 @@ +To generate a valid certificate and private key use the following command: + +```bash +# In OpenSSL ≥ 1.1.1 +openssl req -x509 -newkey rsa:4096 -sha256 -days 9999 -nodes \ + -keyout key.pem -out cert.pem -subj "/CN=localhost" \ + -addext "subjectAltName=DNS:localhost,IP:127.0.0.1,IP:::1" +``` + +To check a certificate use the following command: +```bash +openssl x509 -in cert.pem -text +``` diff --git a/_fixture/certs/cert.pem b/_fixture/certs/cert.pem index c58f13fa6..d88cf3fec 100644 --- a/_fixture/certs/cert.pem +++ b/_fixture/certs/cert.pem @@ -1,18 +1,30 @@ -----BEGIN CERTIFICATE----- -MIIC+TCCAeGgAwIBAgIQe/dw9alKTWAPhsHoLdkn+TANBgkqhkiG9w0BAQsFADAS -MRAwDgYDVQQKEwdBY21lIENvMB4XDTE2MDkyNTAwNDcxN1oXDTE3MDkyNTAwNDcx -N1owEjEQMA4GA1UEChMHQWNtZSBDbzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC -AQoCggEBAL8WwhLGbK8HkiEDKV0JbjtWp3/EWKhKFW3YtKtPfPOgoZejdNn9VE0B -IlQ4rwa1wmsM9NDKC0m60oiNOYeyugx9PoFI3RXzuKVX2x7E5LTW0sv0LC9PCggZ -MZTih1AiYtwJIZl+aK6s4dTb/PUOLDdcRTZTF2egkdAicbUlQT4Kn+A3jHiE+ATC -h3MlV2BHarhAhWb0FrOg2bEtFrMyFDaLbHI7xbj+vB9CkGB9L5tObP2M9lQCxH8d -ElWkJjxg7vdkhJ5+sWNaY80utNipUdVO845tIERwRXRRviFYpOcuNfnJYC9kwRjv -CRanh3epWhG0cFQVV5d45sHf6t5F+jsCAwEAAaNLMEkwDgYDVR0PAQH/BAQDAgWg -MBMGA1UdJQQMMAoGCCsGAQUFBwMBMAwGA1UdEwEB/wQCMAAwFAYDVR0RBA0wC4IJ -bG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IBAQAdd3ZW6R4cImmxIzfoz7Ttq862 -oOiyzFnisCxgNdA78epit49zg0CgF7q9guTEArXJLI+/qnjPPObPOlTlsEyomb2F -UOS+2hn/ZyU5/tUxhkeOBYqdEaryk6zF6vPLUJ5IphJgOg00uIQGL0UvupBLEyIG -Rsa/lKEtW5Z9PbIi9GeVn51U+9VMCYft/T7SDziKl7OcE/qoVh1G0/tTRkAqOqpZ -bzc8ssEhJVNZ/DO+uYHNYf/waB6NjfXQuTegU/SyxnawvQ4oBHIzyuWplGCcTlfT -IXsOQdJo2xuu8807d+rO1FpN8yWi5OF/0sif0RrocSskLAIL/PI1qfWuuPck +MIIFODCCAyCgAwIBAgIUaTvDluaMf+VJgYHQ0HFTS3yuCHYwDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIxMDIyNzIxMzQ0MVoXDTQ4MDcx +NDIxMzQ0MVowFDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEF +AAOCAg8AMIICCgKCAgEAnqyyAAnWFH2TH7Epj5yfZxYrBvizydZe1Wo/1WpGR2IK +QT+qIul5sEKX/ERqEOXsawSrL3fw9cuSM8Z2vD/57ZZdoSR7XIdVaMDEQenJ968a +HObu4D27uBQwIwrM5ELgnd+fC4gis64nIu+2GSfHumZXi7lLW7DbNm8oWkMqI6tY +2s2wx2hwGYNVJrwSn4WGnkzhQ5U5mkcsLELMx7GR0Qnv6P7sNGZVeqMU7awkcSpR +crKR1OUP7XCJkEq83WLHSx50+QZv7LiyDmGnujHevRbdSHlcFfHZtaufYat+qICe +S3XADwRQe/0VSsmja6u3DAHy7VmL8PNisAdkopQZrhiI9OvGrpGZffs9zn+s/jeX +N1bqVDihCMiEjqXMlHx2oj3AXrZTFxb7y7Ap9C07nf70lpxQWW9SjMYRF98JBiHF +eJbQkNVkmz6T8ielQbX0l46F2SGK98oyFCGNIAZBUdj5CcS1E6w/lk4t58/em0k7 +3wFC5qg0g0wfIbNSmxljBNxnaBYUqyaaAJJhpaEoOebm4RYV58hQ0FbMfpnLnSh4 +dYStsk6i1PumWoa7D45DTtxF3kH7TB3YOB5aWaNGAPQC1m4Qcd23YB5Rd/ABirSp +ux6/cFGosjSfJ/G+G0RhNUpmcbDJvFSOhD2WCuieVhCTAzp+VPIA9bSqD+InlT0C +AwEAAaOBgTB/MB0GA1UdDgQWBBQZyM//SvzYKokQZI/0MVGb6PkH+zAfBgNVHSME +GDAWgBQZyM//SvzYKokQZI/0MVGb6PkH+zAPBgNVHRMBAf8EBTADAQH/MCwGA1Ud +EQQlMCOCCWxvY2FsaG9zdIcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG +9w0BAQsFAAOCAgEAKGAJQmQ/KLw8iMb5QsyxxAonVjJ1eDAhNM3GWdHpM0/GFamO +vVtATLQQldwDiZJvrsCQPEc8ctZ2Utvg/StLQ3+rZpsvt0+gcUlLJK61qguwYqb2 ++T7VK5s7V/OyI/tsuboOW50Pka9vQHV+Z0aM06Yu+HNDAq/UTpEOb/3MQvZd6Ooy +PTpZtFb/+5jIQa1dIsfFWmpBxF0+wUd9GEkX3j7nekwoZfJ8Ze4GWYERZbOFpDAQ +rIHdthH5VJztnpQJmaKqzgIOF+Rurwlp5ecSC33xNNjDaYtuf/fiWnoKGhHVSBhT +61+0yxn3rTgh/Dsm95xY00rSX6lmcvI+kRNTUc8GGPz0ajBH6xyY7bNhfMjmnSW/ +C/XTEDbTAhT7ndWC5vvzp7ZU0TvN+WY6A0f2kxSnnrEk6QRUvRtKkjAkmAFz8exi +ttBBW0I3E5HNIC5CYRimq/9z+3clM/P1KbNblwuC65bL+PZ+nzFnn5hFaK9eLPol +OwZQXv7IvAw8GfgLTrEUT7eBCQwe1IqesA7NTxF1BVwmNUb2XamvQZ7ly67QybRw +0uJq80XjpVjBWYTTQy1dsnC2OTKdqGsV9TVIDR+UGfIG9cxL70pEbiSH2AX+IDCy +i3kNIvpXgBliAyOjW6Hj1fv6dNfAat/hqEfnquWkfvcs3HNrG/InwpwNAUs= -----END CERTIFICATE----- diff --git a/_fixture/certs/key.pem b/_fixture/certs/key.pem index 9c75e7ca8..0276c224e 100644 --- a/_fixture/certs/key.pem +++ b/_fixture/certs/key.pem @@ -1,27 +1,52 @@ ------BEGIN RSA PRIVATE KEY----- -MIIEpAIBAAKCAQEAvxbCEsZsrweSIQMpXQluO1anf8RYqEoVbdi0q09886Chl6N0 -2f1UTQEiVDivBrXCawz00MoLSbrSiI05h7K6DH0+gUjdFfO4pVfbHsTktNbSy/Qs -L08KCBkxlOKHUCJi3AkhmX5orqzh1Nv89Q4sN1xFNlMXZ6CR0CJxtSVBPgqf4DeM -eIT4BMKHcyVXYEdquECFZvQWs6DZsS0WszIUNotscjvFuP68H0KQYH0vm05s/Yz2 -VALEfx0SVaQmPGDu92SEnn6xY1pjzS602KlR1U7zjm0gRHBFdFG+IVik5y41+clg -L2TBGO8JFqeHd6laEbRwVBVXl3jmwd/q3kX6OwIDAQABAoIBAQCR69EcAUZxinh+ -mSl3EIKK8atLGCcTrC8dCQU+ZJ7odFuxrnLHHHrJqvoKEpclqprioKw63G8uSGoJ -OL8b7tHAQ8v9ciTSZKE2Mhb0MirsJbgnYzhykAr7EDIanbny6a9Qk/CChFNwQDjc -EXnjsIT3aZC44U7YJXfz1rm6OM7Pjn6z8H4vYGRDOsYkhXvPfnPW8C2LFJVr9nvE -0gIAOVoGejEJrsJVK3Uj/nPcqSQYXmwEmtjtzOw7u6yp1b2VZEK7tR47HwJt6ltG -Z9zhpwhpvdOuXNMqMOYRf9bLBWnSqIlTHOO0UlAnyRCY1HxluZB7ZSg9VnoJDrD7 -w+JqAGnBAoGBAO5qyIzjldwR004YjepmZfuX3PnGLZhzhmTTC7Pl9gqv1TvxfxvD -6yBFL2GrN1IcnrX9Qk2xncUAbpM989MF+EC7I4++1t1I6akUKFEDkfvQwQjCXfPS -Jv2rkwIVSkt8F0X/tOb13OeIiHuFVI/Bb9VoJSP/k4DfPV+/HnwBxvzLAoGBAM0u -b/rYfm5rb20/PKClUs154s0eKSokVogqiJkf+5qLsV+TD50JVZBVw8s4XM79iwQI -PyGY9nI1AvqG7yIzxSy5/Qk1+ZVdVYpmWIO5PnJ8TVraDVhCQ3fVz1uWtcyaqPVr -3QzdyvsEgFUGFItmRdhSvA8RGrpVCHTBzrDj3jpRAoGBAKNaSLS3jkstb3D3w+yR -YliisYX1cfIdXTyhmUgWTKD/3oLmsSdt8iC3JoKt1AaPk3Kv5ojjJG0BIcIC1ZeF -ZJW9Yt0vbXpKZcYyCHmRj6lQW6JLwiG3oH133A62VaQojq2oSONiG4wL8S9oqAqj -B6PZanEiwIaw7hU3FoTylstHAoGAFYvE0pCdZjb98njrgusZcN5VxLhgFj7On2no -AjxrjWUR8TleMF1kkM2Qy+xVQp85U+kRyBNp/cA3WduFjQ/mqrW1LpxuYxL0Ap6Q -uPRg7GDFNr8jG5uJvjHDnpiK6rtq9qqnAczgnc9xMnx699B7kSXO/b4MEnkPdENN -0yF6mqECgYA88UELxbhqMSdG24DX0zHXvkXLIml2JNVb54glFByIIem+acff9oG9 -X5GajlBroPoKk7FgA9ouqcQMH66UnFi6qh07l0J2xb0aXP8yzLAGauVGTTNIQCR4 -VpqyDpjlc1ZqfZWOrvwSrUH1mEkxbeVvQsOUja2Jvu+lc3Zo099ILw== ------END RSA PRIVATE KEY----- +-----BEGIN PRIVATE KEY----- +MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQCerLIACdYUfZMf +sSmPnJ9nFisG+LPJ1l7Vaj/VakZHYgpBP6oi6XmwQpf8RGoQ5exrBKsvd/D1y5Iz +xna8P/ntll2hJHtch1VowMRB6cn3rxoc5u7gPbu4FDAjCszkQuCd358LiCKzrici +77YZJ8e6ZleLuUtbsNs2byhaQyojq1jazbDHaHAZg1UmvBKfhYaeTOFDlTmaRyws +QszHsZHRCe/o/uw0ZlV6oxTtrCRxKlFyspHU5Q/tcImQSrzdYsdLHnT5Bm/suLIO +Yae6Md69Ft1IeVwV8dm1q59hq36ogJ5LdcAPBFB7/RVKyaNrq7cMAfLtWYvw82Kw +B2SilBmuGIj068aukZl9+z3Of6z+N5c3VupUOKEIyISOpcyUfHaiPcBetlMXFvvL +sCn0LTud/vSWnFBZb1KMxhEX3wkGIcV4ltCQ1WSbPpPyJ6VBtfSXjoXZIYr3yjIU +IY0gBkFR2PkJxLUTrD+WTi3nz96bSTvfAULmqDSDTB8hs1KbGWME3GdoFhSrJpoA +kmGloSg55ubhFhXnyFDQVsx+mcudKHh1hK2yTqLU+6ZahrsPjkNO3EXeQftMHdg4 +HlpZo0YA9ALWbhBx3bdgHlF38AGKtKm7Hr9wUaiyNJ8n8b4bRGE1SmZxsMm8VI6E +PZYK6J5WEJMDOn5U8gD1tKoP4ieVPQIDAQABAoICAEHF2CsH6MOpofi7GT08cR7s +I33KTcxWngzc9ATk/qjMTO/rEf1Sxmx3zkR1n3nNtQhPcR5GG43nin0HwWQbKOCB +OeJ4GuKp/o9jiHbCEEQpQyvD1jUBofSV+bYs3e2ogy8t6OGA1tGgWPy0XMlkoff0 +QEnczw3864FO5m0z9h2/Ax//r02ZTw5kUEG0KAwT709jEuVO0AfRhM/8CKKmSola +EyaDtSmrWbdyLlSuzJRUNFrVBno3UTjdM0iqkks6jN3ojBhFwNNhY/1uIXafAXNk +LOnD1JYMIHCb6X809VWnqvYgozIWWb5rlA3iM2mITmId1LLqMYX5fWj2R5LUzSek +H+XG+F9FIouTaL1ACoXr0zyeY5N5YJdyXYa1tThdW+axX9ZrnPgeiQrmxzKPIyb7 +LLlVtNBQUg/t5tX80KyYjkNUu4j3oq/uBYPi0m//ovwMyi9bSbbyPT+cDXuXX5Bc +oY7wyn3evXX0c1R7vdJLZLkLu+ctVex/9hvMjeW/mMasDjLnqY7pF3Skct1SX5N2 +U8YVU9bGvFpLEwM9lmi/T7bcv+zbmGPlfTsZiFrCsixPLn7sX7y5M4L8au8O0jh0 +nHm/8rWVg1Qw0Hobg3tA8FjeMa8Sr2fYmkNLVKFzhuJLxknTJLaUbX5CymNqWP4H +OctvfSY0nSZ1eQpBkQaJAoIBAQDTb/NhYCfaJBLXHVMy/VYd7kWGZ+I87artcE/l +8u0pJ8XOP4kp0otFIumpHUFodysAeP6HrI79MuJB40fy91HzWZC+NrPufFFFuZ0z +Ld1o3Y5nAeoZmMlf1F12Oe3OQZy7nm9eNNkfeoVtKqDv4FhAqk+aoMor86HscKsR +C6HlZFdGc7kX0ylrQAXPq9KLhcvUU9oAUpbqTbhYK83IebRJgFDG45HkVo9SUHpF +dmCFSb91eZpRGpdfNLCuLiSu52TebayaUCnceeAt8SyeiChJ/TwWmRRDJS0QUv6h +s3Wdp+cx9ANoujA4XzAs8Fld5IZ4bcG5jjwD62/tJyWrCC5DAoIBAQDAHfHjrYCK +GHBrMj+MA7cK7fCJUn/iJLSLGgo2ANYF5oq9gaCwHCtKIyB9DN/KiY0JpJ6PWg+Q +9Difq23YXiJjNEBS5EFTu9UwWAr1RhSAegrfHxm0sDbcAx31NtDYvBsADCWQYmzc +KPfBshf5K4g/VCIj2VzC2CE6kNtdhqLU6AV2Pi1Tl1S82xWoAjHy91tDmlFQNWCj +B2ZnZ7tY9zuwDfeBBOVCPHICgl5Q4PrY1KEWEXiNxgbtkNmOPAsY9WSqgOsP9pWK +J924gdCCvovINzZtgRisxKth6Fkhra+VCsheg9SWvgR09Deo6CCoSwYxOSb0cjh2 +oyX5Rb1kJ7Z/AoIBAQCX2iNVoBV/GcFeNXV3fXLH9ESCj0FwuNC1zp/TanDhyerK +gd8k5k2Xzcc66gP73vpHUJ6dGlVni4/r+ivGV9HHkF/f/LGlaiuEhBZel2YY1mZb +nIhg8dZOuNqW+mvMYlsKdHNPmW0GqpwBF0iWfu1jI+4gA7Kvdj6o7RIvH8eaVEJK +GvqoHcP1fvmteJ2yDtmhGMfMy4QPqtnmmS8l+CJ/V2SsMuyorXIpkBsAoFAZ6ilT +WY53CT4F5nWt4v39j7pl9SatfT1TV0SmOjvtb6Rf3zu0jyR6RMzkmHa/839ZRylI +OxPntzDCi7qxy7yjLmlVPJ6RgZGgzwqHrEHlX+65AoIBAQCEzu6d3x5B2N02LZli +eFr8MjqbI64GLiulEY5HgNJzZ8k3cjocJI0Ehj36VIEMaYRXSzbVkIO8SCgwsPiR +n5mUDNX+t441jV62Odbxcc3Qdw226rABieOSupDmKEu92GOt57e8FV5939BOVYhf +FunsJYQoViXbCEAIVYVgJSfBmNfVwuvgonfQyn8xErtm4/pyRGa71PqGGSKAj2Qi +/16CuVUFGtZFsLV76JW8wZqHdI4bTF6TW3cEmaLbwcRGL7W0bMSS13rO8/pBh3QW +PhUxhoGYt6rQHHEBkPa04nXDyZ10QRwgTSGVnBIyMK4KyTpxorm8OI2x7dzdcomX +iCCPAoIBAETwfr2JKPb/AzrKhhbZgU+sLVn3WH/nb68VheNEmGOzsqXaSHCR2NOq +/ow7bawjc8yUIhBRzokR4F/7jGolOmfdq0MYFb6/YokssKfv1ugxBhmvOxpZ6F6E +cERJ8Ex/ffQU053gLR/0ammddVuS1GR5I/jEdP0lJVh0xapoZNUlT5dWYCgo20hY +ZAmKpU+veyUn+5Li0pmm959vnLK5LJzEA5mpz3w1QPPtVwQs05dwmEV3CRAcCeeh +8sXp49WNCSW4I3BxuTZzRV845SGIFhZwgVV42PTp2LPKl2p6E7Bk8xpUCCvBpALp +QmA5yIMx+u2Jpr7fUsXEXEPTEhvjff0= +-----END PRIVATE KEY----- diff --git a/echo.go b/echo.go index 7f1c83998..1074ba492 100644 --- a/echo.go +++ b/echo.go @@ -660,7 +660,7 @@ func (e *Echo) Start(address string) error { return err } e.startupMutex.Unlock() - return e.serve() + return e.Server.Serve(e.Listener) } // StartTLS starts an HTTPS server. @@ -740,8 +740,12 @@ func (e *Echo) StartServer(s *http.Server) (err error) { e.startupMutex.Unlock() return err } + if s.TLSConfig != nil { + e.startupMutex.Unlock() + return s.Serve(e.TLSListener) + } e.startupMutex.Unlock() - return e.serve() + return s.Serve(e.Listener) } func (e *Echo) configureServer(s *http.Server) (err error) { @@ -782,13 +786,6 @@ func (e *Echo) configureServer(s *http.Server) (err error) { return nil } -func (e *Echo) serve() error { - if e.TLSListener != nil { - return e.Server.Serve(e.TLSListener) - } - return e.Server.Serve(e.Listener) -} - // ListenerAddr returns net.Addr for Listener func (e *Echo) ListenerAddr() net.Addr { e.startupMutex.RLock() diff --git a/echo_test.go b/echo_test.go index 781b901fa..07661b9f8 100644 --- a/echo_test.go +++ b/echo_test.go @@ -684,6 +684,60 @@ func TestEcho_StartTLS(t *testing.T) { } } +func TestEchoStartTLSAndStart(t *testing.T) { + // We test if Echo and listeners work correctly when Echo is simultaneously attached to HTTP and HTTPS server + e := New() + e.GET("/", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + + errTLSChan := make(chan error) + go func() { + certFile := "_fixture/certs/cert.pem" + keyFile := "_fixture/certs/key.pem" + err := e.StartTLS("localhost:", certFile, keyFile) + if err != nil { + errTLSChan <- err + } + }() + + err := waitForServerStart(e, errTLSChan, true) + assert.NoError(t, err) + defer func() { + if err := e.Shutdown(stdContext.Background()); err != nil { + t.Error(err) + } + }() + + // check if HTTPS works (note: we are using self signed certs so InsecureSkipVerify=true) + client := &http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }} + res, err := client.Get("https://" + e.TLSListenerAddr().String()) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) + + errChan := make(chan error) + go func() { + err := e.Start("localhost:") + if err != nil { + errChan <- err + } + }() + err = waitForServerStart(e, errChan, false) + assert.NoError(t, err) + + // now we are serving both HTTPS and HTTP listeners. see if HTTP works in addition to HTTPS + res, err = http.Get("http://" + e.ListenerAddr().String()) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) + + // see if HTTPS works after HTTP listener is also added + res, err = client.Get("https://" + e.TLSListenerAddr().String()) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) +} + func TestEchoStartTLSByteString(t *testing.T) { cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") require.NoError(t, err) From b2444d8399562930188e3e16c3f3937d75a78e42 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Mon, 1 Mar 2021 23:45:18 +0200 Subject: [PATCH 114/446] Fix #1794: panics in timeout middleware are not recovered and cause application to crash --- middleware/timeout.go | 12 ++++++++++++ middleware/timeout_test.go | 19 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/middleware/timeout.go b/middleware/timeout.go index d146541e6..4be557f76 100644 --- a/middleware/timeout.go +++ b/middleware/timeout.go @@ -4,6 +4,7 @@ package middleware import ( "context" + "fmt" "github.com/labstack/echo/v4" "time" ) @@ -62,6 +63,17 @@ func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc { done := make(chan error, 1) go func() { + defer func() { + if r := recover(); r != nil { + err, ok := r.(error) + if !ok { + err = fmt.Errorf("panic recovered in timeout middleware: %v", r) + } + c.Logger().Error(err) + done <- err + } + }() + // This goroutine will keep running even if this middleware times out and // will be stopped when ctx.Done() is called down the next(c) call chain done <- next(c) diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index c0e945933..faecc4c53 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -175,3 +175,22 @@ func TestTimeoutTestRequestClone(t *testing.T) { assert.NoError(t, err) } + +func TestTimeoutRecoversPanic(t *testing.T) { + t.Parallel() + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: 25 * time.Millisecond, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + panic("panic in handler") + })(c) + + assert.Error(t, err, "panic recovered in timeout middleware: panic in handler") +} From 6f9b71cd6fcbe51943e123298e1579c13e77b898 Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 2 Mar 2021 20:56:40 +0200 Subject: [PATCH 115/446] Poc router stack backtracking (#1791) Router: PoC stack based backtracking Co-authored-by: stffabi --- router.go | 194 ++++++++++++++++++++----------------------------- router_test.go | 176 +++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 243 insertions(+), 127 deletions(-) diff --git a/router.go b/router.go index 5010659a6..f0e9e51f4 100644 --- a/router.go +++ b/router.go @@ -2,7 +2,6 @@ package echo import ( "net/http" - "strings" ) type ( @@ -334,21 +333,48 @@ func (r *Router) Find(method, path string, c Context) { cn := r.tree // Current node as root var ( - search = path - child *node // Child node - n int // Param counter - nk kind // Next kind - nn *node // Next node - ns string // Next search - pvalues = ctx.pvalues // Use the internal slice so the interface can keep the illusion of a dynamic slice + search = path + searchIndex = 0 + n int // Param counter + pvalues = ctx.pvalues // Use the internal slice so the interface can keep the illusion of a dynamic slice ) - // Search order static > param > any - for { - if search == "" { - break + // Backtracking is needed when a dead end (leaf node) is reached in the router tree. + // To backtrack the current node will be changed to the parent node and the next kind for the + // router logic will be returned based on fromKind or kind of the dead end node (static > param > any). + // For example if there is no static node match we should check parent next sibling by kind (param). + // Backtracking itself does not check if there is a next sibling, this is done by the router logic. + backtrackToNextNodeKind := func(fromKind kind) (nextNodeKind kind, valid bool) { + previous := cn + cn = previous.parent + valid = cn != nil + + // Next node type by priority + // NOTE: With the current implementation we never backtrack from an `any` route, so `previous.kind` is + // always `static` or `any` + // If this is changed then for any route next kind would be `static` and this statement should be changed + nextNodeKind = previous.kind + 1 + + if fromKind == skind { + // when backtracking is done from static kind block we did not change search so nothing to restore + return } + // restore search to value it was before we move to current node we are backtracking from. + if previous.kind == skind { + searchIndex -= len(previous.prefix) + } else { + n-- + // for param/any node.prefix value is always `:` so we can not deduce searchIndex from that and must use pValue + // for that index as it would also contain part of path we cut off before moving into node we are backtracking from + searchIndex -= len(pvalues[n]) + } + search = path[searchIndex:] + return + } + + // Search order static > param > any + for { pl := 0 // Prefix length l := 0 // LCP length @@ -365,60 +391,42 @@ func (r *Router) Find(method, path string, c Context) { } } - if l == pl { - // Continue search - search = search[l:] - // Finish routing if no remaining search and we are on an leaf node - if search == "" && (nn == nil || cn.parent == nil || cn.ppath != "") { - break - } - // Handle special case of trailing slash route with existing any route (see #1526) - if search == "" && path[len(path)-1] == '/' && cn.anyChildren != nil { - goto Any + if l != pl { + // No matching prefix, let's backtrack to the first possible alternative node of the decision path + nk, ok := backtrackToNextNodeKind(skind) + if !ok { + return // No other possibilities on the decision path + } else if nk == pkind { + goto Param + // NOTE: this case (backtracking from static node to previous any node) can not happen by current any matching logic. Any node is end of search currently + //} else if nk == akind { + // goto Any + } else { + // Not found (this should never be possible for static node we are looking currently) + return } } - // Attempt to go back up the tree on no matching prefix or no remaining search - if l != pl || search == "" { - if nn == nil { // Issue #1348 - return // Not found - } - cn = nn - search = ns - if nk == pkind { - goto Param - } else if nk == akind { - goto Any - } + // The full prefix has matched, remove the prefix from the remaining search + search = search[l:] + searchIndex = searchIndex + l + + // Finish routing if no remaining search and we are on an leaf node + if search == "" && cn.ppath != "" { + break } // Static node - if child = cn.findStaticChild(search[0]); child != nil { - // Save next - if cn.prefix[len(cn.prefix)-1] == '/' { // Issue #623 - nk = pkind - nn = cn - ns = search + if search != "" { + if child := cn.findStaticChild(search[0]); child != nil { + cn = child + continue } - cn = child - continue } Param: // Param node - if child = cn.paramChildren; child != nil { - // Issue #378 - if len(pvalues) == n { - continue - } - - // Save next - if cn.prefix[len(cn.prefix)-1] == '/' { // Issue #623 - nk = akind - nn = cn - ns = search - } - + if child := cn.paramChildren; search != "" && child != nil { cn = child i, l := 0, len(search) for ; i < l && search[i] != '/'; i++ { @@ -426,87 +434,39 @@ func (r *Router) Find(method, path string, c Context) { pvalues[n] = search[:i] n++ search = search[i:] + searchIndex = searchIndex + i continue } Any: // Any node - if cn = cn.anyChildren; cn != nil { + if child := cn.anyChildren; child != nil { // If any node is found, use remaining path for pvalues + cn = child pvalues[len(cn.pnames)-1] = search break } - // No node found, continue at stored next node - // or find nearest "any" route - if nn != nil { - // No next node to go down in routing (issue #954) - // Find nearest "any" route going up the routing tree - search = ns - np := nn.parent - // Consider param route one level up only - if cn = nn.paramChildren; cn != nil { - pos := strings.IndexByte(ns, '/') - if pos == -1 { - // If no slash is remaining in search string set param value - if len(cn.pnames) > 0 { - pvalues[len(cn.pnames)-1] = search - } - break - } else if pos > 0 { - // Otherwise continue route processing with restored next node - cn = nn - nn = nil - ns = "" - goto Param - } - } - // No param route found, try to resolve nearest any route - for { - np = nn.parent - if cn = nn.anyChildren; cn != nil { - break - } - if np == nil { - break // no further parent nodes in tree, abort - } - var str strings.Builder - str.WriteString(nn.prefix) - str.WriteString(search) - search = str.String() - nn = np - } - if cn != nil { // use the found "any" route and update path - pvalues[len(cn.pnames)-1] = search - break - } + // Let's backtrack to the first possible alternative node of the decision path + nk, ok := backtrackToNextNodeKind(akind) + if !ok { + return // No other possibilities on the decision path + } else if nk == pkind { + goto Param + } else if nk == akind { + goto Any + } else { + // Not found + return } - return // Not found - } ctx.handler = cn.findHandler(method) ctx.path = cn.ppath ctx.pnames = cn.pnames - // NOTE: Slow zone... if ctx.handler == nil { ctx.handler = cn.checkMethodNotAllowed() - - // Dig further for any, might have an empty value for *, e.g. - // serving a directory. Issue #207. - if cn = cn.anyChildren; cn == nil { - return - } - if h := cn.findHandler(method); h != nil { - ctx.handler = h - } else { - ctx.handler = cn.checkMethodNotAllowed() - } - ctx.path = cn.ppath - ctx.pnames = cn.pnames - pvalues[len(cn.pnames)-1] = "" } - return } diff --git a/router_test.go b/router_test.go index a5e53c05b..ba1890bd1 100644 --- a/router_test.go +++ b/router_test.go @@ -730,26 +730,58 @@ func TestRouterMatchAny(t *testing.T) { r := e.router // Routes - r.Add(http.MethodGet, "/", func(Context) error { - return nil - }) - r.Add(http.MethodGet, "/*", func(Context) error { - return nil - }) - r.Add(http.MethodGet, "/users/*", func(Context) error { - return nil - }) + r.Add(http.MethodGet, "/", handlerHelper("case", 1)) + r.Add(http.MethodGet, "/*", handlerHelper("case", 2)) + r.Add(http.MethodGet, "/users/*", handlerHelper("case", 3)) + c := e.NewContext(nil, nil).(*context) r.Find(http.MethodGet, "/", c) - assert.Equal(t, "", c.Param("*")) + c.handler(c) + + assert.Equal(t, 1, c.Get("case")) + assert.Equal(t, "/", c.Get("path")) r.Find(http.MethodGet, "/download", c) + c.handler(c) + assert.Equal(t, 2, c.Get("case")) + assert.Equal(t, "/*", c.Get("path")) assert.Equal(t, "download", c.Param("*")) r.Find(http.MethodGet, "/users/joe", c) + c.handler(c) + assert.Equal(t, 3, c.Get("case")) + assert.Equal(t, "/users/*", c.Get("path")) assert.Equal(t, "joe", c.Param("*")) } +// NOTE: this is to document current implementation. Last added route with `*` asterisk is always the match and no +// backtracking or more precise matching is done to find more suitable match. +// +// Current behaviour might not be correct or expected. +// But this is where we are without well defined requirements/rules how (multiple) asterisks work in route +func TestRouterAnyMatchesLastAddedAnyRoute(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/users/*", handlerHelper("case", 1)) + r.Add(http.MethodGet, "/users/*/action*", handlerHelper("case", 2)) + + c := e.NewContext(nil, nil).(*context) + + r.Find(http.MethodGet, "/users/xxx/action/sea", c) + c.handler(c) + assert.Equal(t, "/users/*/action*", c.Get("path")) + assert.Equal(t, "xxx/action/sea", c.Param("*")) + + // if we add another route then it is the last added and so it is matched + r.Add(http.MethodGet, "/users/*/action/search", handlerHelper("case", 3)) + + r.Find(http.MethodGet, "/users/xxx/action/sea", c) + c.handler(c) + assert.Equal(t, "/users/*/action/search", c.Get("path")) + assert.Equal(t, "xxx/action/sea", c.Param("*")) +} + // Issue #1739 func TestRouterMatchAnyPrefixIssue(t *testing.T) { e := New() @@ -791,6 +823,130 @@ func TestRouterMatchAnyPrefixIssue(t *testing.T) { assert.Equal(t, "users_prefix/", c.Param("*")) } +func TestRouteMultiLevelBacktracking(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/a/:b/c", handlerHelper("case", 1)) + r.Add(http.MethodGet, "/a/c/d", handlerHelper("case", 2)) + r.Add(http.MethodGet, "/:e/c/f", handlerHelper("case", 3)) + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/a/c/f", c) + + c.handler(c) + assert.Equal(t, 3, c.Get("case")) + assert.Equal(t, "/:e/c/f", c.Get("path")) +} + +// Issue # +func TestRouterBacktrackingFromParam(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/*", handlerHelper("case", 1)) + r.Add(http.MethodGet, "/users/:name/", handlerHelper("case", 2)) + + c := e.NewContext(nil, nil).(*context) + + r.Find(http.MethodGet, "/users/firstname/no-match", c) + c.handler(c) + assert.Equal(t, 1, c.Get("case")) + assert.Equal(t, "/*", c.Get("path")) + assert.Equal(t, "users/firstname/no-match", c.Param("*")) + + r.Find(http.MethodGet, "/users/firstname/", c) + c.handler(c) + assert.Equal(t, 2, c.Get("case")) + assert.Equal(t, "/users/:name/", c.Get("path")) + assert.Equal(t, "firstname", c.Param("name")) +} + +func TestRouterBacktrackingFromParamAny(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/*", handlerHelper("case", 1)) + r.Add(http.MethodGet, "/:name/lastname", handlerHelper("case", 2)) + + c := e.NewContext(nil, nil).(*context) + + r.Find(http.MethodGet, "/firstname/test", c) + c.handler(c) + assert.Equal(t, 1, c.Get("case")) + assert.Equal(t, "/*", c.Get("path")) + assert.Equal(t, "firstname/test", c.Param("*")) + + r.Find(http.MethodGet, "/firstname", c) + c.handler(c) + assert.Equal(t, 1, c.Get("case")) + assert.Equal(t, "/*", c.Get("path")) + assert.Equal(t, "firstname", c.Param("*")) + + r.Find(http.MethodGet, "/firstname/lastname", c) + c.handler(c) + assert.Equal(t, 2, c.Get("case")) + assert.Equal(t, "/:name/lastname", c.Get("path")) + assert.Equal(t, "firstname", c.Param("name")) +} + +func TestRouterBacktrackingFromParamAny2(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/*", handlerHelper("case", 1)) + r.Add(http.MethodGet, "/:name", handlerHelper("case", 2)) + r.Add(http.MethodGet, "/:name/lastname", handlerHelper("case", 3)) + + c := e.NewContext(nil, nil).(*context) + + r.Find(http.MethodGet, "/firstname/test", c) + c.handler(c) + assert.Equal(t, 1, c.Get("case")) + assert.Equal(t, "/*", c.Get("path")) + assert.Equal(t, "firstname/test", c.Param("*")) + + r.Find(http.MethodGet, "/firstname", c) + c.handler(c) + assert.Equal(t, 2, c.Get("case")) + assert.Equal(t, "/:name", c.Get("path")) + assert.Equal(t, "firstname", c.Param("name")) + + r.Find(http.MethodGet, "/firstname/lastname", c) + c.handler(c) + assert.Equal(t, 3, c.Get("case")) + assert.Equal(t, "/:name/lastname", c.Get("path")) + assert.Equal(t, "firstname", c.Param("name")) +} + +func TestRouterAnyCommonPath(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/ab*", handlerHelper("case", 1)) + r.Add(http.MethodGet, "/abcd", handlerHelper("case", 2)) + r.Add(http.MethodGet, "/abcd*", handlerHelper("case", 3)) + + c := e.NewContext(nil, nil).(*context) + + r.Find(http.MethodGet, "/abee", c) + c.handler(c) + assert.Equal(t, 1, c.Get("case")) + assert.Equal(t, "/ab*", c.Get("path")) + assert.Equal(t, "ee", c.Param("*")) + + r.Find(http.MethodGet, "/abcd", c) + c.handler(c) + assert.Equal(t, "/abcd", c.Get("path")) + assert.Equal(t, 2, c.Get("case")) + + r.Find(http.MethodGet, "/abcde", c) + c.handler(c) + assert.Equal(t, 3, c.Get("case")) + assert.Equal(t, "/abcd*", c.Get("path")) + assert.Equal(t, "e", c.Param("*")) +} + // TestRouterMatchAnySlash shall verify finding the best route // for any routes with trailing slash requests func TestRouterMatchAnySlash(t *testing.T) { From 664cf8c1060bb9fe97de6e7f13bd65fbaf292402 Mon Sep 17 00:00:00 2001 From: Martti T Date: Sat, 6 Mar 2021 01:43:59 +0200 Subject: [PATCH 116/446] Refactor router for readability (#1796) * refactor router tests to table driven (this way it is easier to debug test cases with breakpoints) * refactor router variables to be more readable --- router.go | 305 ++++---- router_test.go | 1869 ++++++++++++++++++++++++++++++------------------ 2 files changed, 1328 insertions(+), 846 deletions(-) diff --git a/router.go b/router.go index f0e9e51f4..2dd09fae2 100644 --- a/router.go +++ b/router.go @@ -13,16 +13,16 @@ type ( echo *Echo } node struct { - kind kind - label byte - prefix string - parent *node - staticChildrens children - ppath string - pnames []string - methodHandler *methodHandler - paramChildren *node - anyChildren *node + kind kind + label byte + prefix string + parent *node + staticChildren children + ppath string + pnames []string + methodHandler *methodHandler + paramChild *node + anyChild *node } kind uint8 children []*node @@ -42,9 +42,9 @@ type ( ) const ( - skind kind = iota - pkind - akind + staticKind kind = iota + paramKind + anyKind paramLabel = byte(':') anyLabel = byte('*') @@ -73,137 +73,147 @@ func (r *Router) Add(method, path string, h HandlerFunc) { pnames := []string{} // Param names ppath := path // Pristine path - for i, l := 0, len(path); i < l; i++ { + for i, lcpIndex := 0, len(path); i < lcpIndex; i++ { if path[i] == ':' { j := i + 1 - r.insert(method, path[:i], nil, skind, "", nil) - for ; i < l && path[i] != '/'; i++ { + r.insert(method, path[:i], nil, staticKind, "", nil) + for ; i < lcpIndex && path[i] != '/'; i++ { } pnames = append(pnames, path[j:i]) path = path[:j] + path[i:] - i, l = j, len(path) + i, lcpIndex = j, len(path) - if i == l { - r.insert(method, path[:i], h, pkind, ppath, pnames) + if i == lcpIndex { + r.insert(method, path[:i], h, paramKind, ppath, pnames) } else { - r.insert(method, path[:i], nil, pkind, "", nil) + r.insert(method, path[:i], nil, paramKind, "", nil) } } else if path[i] == '*' { - r.insert(method, path[:i], nil, skind, "", nil) + r.insert(method, path[:i], nil, staticKind, "", nil) pnames = append(pnames, "*") - r.insert(method, path[:i+1], h, akind, ppath, pnames) + r.insert(method, path[:i+1], h, anyKind, ppath, pnames) } } - r.insert(method, path, h, skind, ppath, pnames) + r.insert(method, path, h, staticKind, ppath, pnames) } func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string, pnames []string) { // Adjust max param - l := len(pnames) - if *r.echo.maxParam < l { - *r.echo.maxParam = l + paramLen := len(pnames) + if *r.echo.maxParam < paramLen { + *r.echo.maxParam = paramLen } - cn := r.tree // Current node as root - if cn == nil { + currentNode := r.tree // Current node as root + if currentNode == nil { panic("echo: invalid method") } search := path for { - sl := len(search) - pl := len(cn.prefix) - l := 0 - - // LCP - max := pl - if sl < max { - max = sl + searchLen := len(search) + prefixLen := len(currentNode.prefix) + lcpLen := 0 + + // LCP - Longest Common Prefix (https://en.wikipedia.org/wiki/LCP_array) + max := prefixLen + if searchLen < max { + max = searchLen } - for ; l < max && search[l] == cn.prefix[l]; l++ { + for ; lcpLen < max && search[lcpLen] == currentNode.prefix[lcpLen]; lcpLen++ { } - if l == 0 { + if lcpLen == 0 { // At root node - cn.label = search[0] - cn.prefix = search + currentNode.label = search[0] + currentNode.prefix = search if h != nil { - cn.kind = t - cn.addHandler(method, h) - cn.ppath = ppath - cn.pnames = pnames + currentNode.kind = t + currentNode.addHandler(method, h) + currentNode.ppath = ppath + currentNode.pnames = pnames } - } else if l < pl { + } else if lcpLen < prefixLen { // Split node - n := newNode(cn.kind, cn.prefix[l:], cn, cn.staticChildrens, cn.methodHandler, cn.ppath, cn.pnames, cn.paramChildren, cn.anyChildren) + n := newNode( + currentNode.kind, + currentNode.prefix[lcpLen:], + currentNode, + currentNode.staticChildren, + currentNode.methodHandler, + currentNode.ppath, + currentNode.pnames, + currentNode.paramChild, + currentNode.anyChild, + ) // Update parent path for all children to new node - for _, child := range cn.staticChildrens { + for _, child := range currentNode.staticChildren { child.parent = n } - if cn.paramChildren != nil { - cn.paramChildren.parent = n + if currentNode.paramChild != nil { + currentNode.paramChild.parent = n } - if cn.anyChildren != nil { - cn.anyChildren.parent = n + if currentNode.anyChild != nil { + currentNode.anyChild.parent = n } // Reset parent node - cn.kind = skind - cn.label = cn.prefix[0] - cn.prefix = cn.prefix[:l] - cn.staticChildrens = nil - cn.methodHandler = new(methodHandler) - cn.ppath = "" - cn.pnames = nil - cn.paramChildren = nil - cn.anyChildren = nil + currentNode.kind = staticKind + currentNode.label = currentNode.prefix[0] + currentNode.prefix = currentNode.prefix[:lcpLen] + currentNode.staticChildren = nil + currentNode.methodHandler = new(methodHandler) + currentNode.ppath = "" + currentNode.pnames = nil + currentNode.paramChild = nil + currentNode.anyChild = nil // Only Static children could reach here - cn.addStaticChild(n) + currentNode.addStaticChild(n) - if l == sl { + if lcpLen == searchLen { // At parent node - cn.kind = t - cn.addHandler(method, h) - cn.ppath = ppath - cn.pnames = pnames + currentNode.kind = t + currentNode.addHandler(method, h) + currentNode.ppath = ppath + currentNode.pnames = pnames } else { // Create child node - n = newNode(t, search[l:], cn, nil, new(methodHandler), ppath, pnames, nil, nil) + n = newNode(t, search[lcpLen:], currentNode, nil, new(methodHandler), ppath, pnames, nil, nil) n.addHandler(method, h) // Only Static children could reach here - cn.addStaticChild(n) + currentNode.addStaticChild(n) } - } else if l < sl { - search = search[l:] - c := cn.findChildWithLabel(search[0]) + } else if lcpLen < searchLen { + search = search[lcpLen:] + c := currentNode.findChildWithLabel(search[0]) if c != nil { // Go deeper - cn = c + currentNode = c continue } // Create child node - n := newNode(t, search, cn, nil, new(methodHandler), ppath, pnames, nil, nil) + n := newNode(t, search, currentNode, nil, new(methodHandler), ppath, pnames, nil, nil) n.addHandler(method, h) switch t { - case skind: - cn.addStaticChild(n) - case pkind: - cn.paramChildren = n - case akind: - cn.anyChildren = n + case staticKind: + currentNode.addStaticChild(n) + case paramKind: + currentNode.paramChild = n + case anyKind: + currentNode.anyChild = n } } else { // Node already exists if h != nil { - cn.addHandler(method, h) - cn.ppath = ppath - if len(cn.pnames) == 0 { // Issue #729 - cn.pnames = pnames + currentNode.addHandler(method, h) + currentNode.ppath = ppath + if len(currentNode.pnames) == 0 { // Issue #729 + currentNode.pnames = pnames } } } @@ -213,25 +223,25 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string func newNode(t kind, pre string, p *node, sc children, mh *methodHandler, ppath string, pnames []string, paramChildren, anyChildren *node) *node { return &node{ - kind: t, - label: pre[0], - prefix: pre, - parent: p, - staticChildrens: sc, - ppath: ppath, - pnames: pnames, - methodHandler: mh, - paramChildren: paramChildren, - anyChildren: anyChildren, + kind: t, + label: pre[0], + prefix: pre, + parent: p, + staticChildren: sc, + ppath: ppath, + pnames: pnames, + methodHandler: mh, + paramChild: paramChildren, + anyChild: anyChildren, } } func (n *node) addStaticChild(c *node) { - n.staticChildrens = append(n.staticChildrens, c) + n.staticChildren = append(n.staticChildren, c) } func (n *node) findStaticChild(l byte) *node { - for _, c := range n.staticChildrens { + for _, c := range n.staticChildren { if c.label == l { return c } @@ -240,16 +250,16 @@ func (n *node) findStaticChild(l byte) *node { } func (n *node) findChildWithLabel(l byte) *node { - for _, c := range n.staticChildrens { + for _, c := range n.staticChildren { if c.label == l { return c } } if l == paramLabel { - return n.paramChildren + return n.paramChild } if l == anyLabel { - return n.anyChildren + return n.anyChild } return nil } @@ -330,13 +340,15 @@ func (n *node) checkMethodNotAllowed() HandlerFunc { func (r *Router) Find(method, path string, c Context) { ctx := c.(*context) ctx.path = path - cn := r.tree // Current node as root + currentNode := r.tree // Current node as root var ( + // search stores the remaining path to check for match. By each iteration we move from start of path to end of the path + // and search value gets shorter and shorter. search = path searchIndex = 0 - n int // Param counter - pvalues = ctx.pvalues // Use the internal slice so the interface can keep the illusion of a dynamic slice + paramIndex int // Param counter + paramValues = ctx.pvalues // Use the internal slice so the interface can keep the illusion of a dynamic slice ) // Backtracking is needed when a dead end (leaf node) is reached in the router tree. @@ -345,9 +357,9 @@ func (r *Router) Find(method, path string, c Context) { // For example if there is no static node match we should check parent next sibling by kind (param). // Backtracking itself does not check if there is a next sibling, this is done by the router logic. backtrackToNextNodeKind := func(fromKind kind) (nextNodeKind kind, valid bool) { - previous := cn - cn = previous.parent - valid = cn != nil + previous := currentNode + currentNode = previous.parent + valid = currentNode != nil // Next node type by priority // NOTE: With the current implementation we never backtrack from an `any` route, so `previous.kind` is @@ -355,51 +367,57 @@ func (r *Router) Find(method, path string, c Context) { // If this is changed then for any route next kind would be `static` and this statement should be changed nextNodeKind = previous.kind + 1 - if fromKind == skind { + if fromKind == staticKind { // when backtracking is done from static kind block we did not change search so nothing to restore return } // restore search to value it was before we move to current node we are backtracking from. - if previous.kind == skind { + if previous.kind == staticKind { searchIndex -= len(previous.prefix) } else { - n-- + paramIndex-- // for param/any node.prefix value is always `:` so we can not deduce searchIndex from that and must use pValue // for that index as it would also contain part of path we cut off before moving into node we are backtracking from - searchIndex -= len(pvalues[n]) + searchIndex -= len(paramValues[paramIndex]) } search = path[searchIndex:] return } - // Search order static > param > any + // Router tree is implemented by longest common prefix array (LCP array) https://en.wikipedia.org/wiki/LCP_array + // Tree search is implemented as for loop where one loop iteration is divided into 3 separate blocks + // Each of these blocks checks specific kind of node (static/param/any). Order of blocks reflex their priority in routing. + // Search order/priority is: static > param > any. + // + // Note: backtracking in tree is implemented by replacing/switching currentNode to previous node + // and hoping to (goto statement) next block by priority to check if it is the match. for { - pl := 0 // Prefix length - l := 0 // LCP length + prefixLen := 0 // Prefix length + lcpLen := 0 // LCP (longest common prefix) length - if cn.label != ':' { - sl := len(search) - pl = len(cn.prefix) + if currentNode.kind == staticKind { + searchLen := len(search) + prefixLen = len(currentNode.prefix) - // LCP - max := pl - if sl < max { - max = sl + // LCP - Longest Common Prefix (https://en.wikipedia.org/wiki/LCP_array) + max := prefixLen + if searchLen < max { + max = searchLen } - for ; l < max && search[l] == cn.prefix[l]; l++ { + for ; lcpLen < max && search[lcpLen] == currentNode.prefix[lcpLen]; lcpLen++ { } } - if l != pl { + if lcpLen != prefixLen { // No matching prefix, let's backtrack to the first possible alternative node of the decision path - nk, ok := backtrackToNextNodeKind(skind) + nk, ok := backtrackToNextNodeKind(staticKind) if !ok { return // No other possibilities on the decision path - } else if nk == pkind { + } else if nk == paramKind { goto Param // NOTE: this case (backtracking from static node to previous any node) can not happen by current any matching logic. Any node is end of search currently - //} else if nk == akind { + //} else if nk == anyKind { // goto Any } else { // Not found (this should never be possible for static node we are looking currently) @@ -408,31 +426,32 @@ func (r *Router) Find(method, path string, c Context) { } // The full prefix has matched, remove the prefix from the remaining search - search = search[l:] - searchIndex = searchIndex + l + search = search[lcpLen:] + searchIndex = searchIndex + lcpLen // Finish routing if no remaining search and we are on an leaf node - if search == "" && cn.ppath != "" { + if search == "" && currentNode.ppath != "" { break } // Static node if search != "" { - if child := cn.findStaticChild(search[0]); child != nil { - cn = child + if child := currentNode.findStaticChild(search[0]); child != nil { + currentNode = child continue } } Param: // Param node - if child := cn.paramChildren; search != "" && child != nil { - cn = child + if child := currentNode.paramChild; search != "" && child != nil { + currentNode = child + // FIXME: when param node does not have any children then param node should act similarly to any node - consider all remaining search as match i, l := 0, len(search) for ; i < l && search[i] != '/'; i++ { } - pvalues[n] = search[:i] - n++ + paramValues[paramIndex] = search[:i] + paramIndex++ search = search[i:] searchIndex = searchIndex + i continue @@ -440,20 +459,20 @@ func (r *Router) Find(method, path string, c Context) { Any: // Any node - if child := cn.anyChildren; child != nil { - // If any node is found, use remaining path for pvalues - cn = child - pvalues[len(cn.pnames)-1] = search + if child := currentNode.anyChild; child != nil { + // If any node is found, use remaining path for paramValues + currentNode = child + paramValues[len(currentNode.pnames)-1] = search break } // Let's backtrack to the first possible alternative node of the decision path - nk, ok := backtrackToNextNodeKind(akind) + nk, ok := backtrackToNextNodeKind(anyKind) if !ok { return // No other possibilities on the decision path - } else if nk == pkind { + } else if nk == paramKind { goto Param - } else if nk == akind { + } else if nk == anyKind { goto Any } else { // Not found @@ -461,12 +480,12 @@ func (r *Router) Find(method, path string, c Context) { } } - ctx.handler = cn.findHandler(method) - ctx.path = cn.ppath - ctx.pnames = cn.pnames + ctx.handler = currentNode.findHandler(method) + ctx.path = currentNode.ppath + ctx.pnames = currentNode.pnames if ctx.handler == nil { - ctx.handler = cn.checkMethodNotAllowed() + ctx.handler = currentNode.checkMethodNotAllowed() } return } diff --git a/router_test.go b/router_test.go index ba1890bd1..47e499402 100644 --- a/router_test.go +++ b/router_test.go @@ -640,42 +640,90 @@ var ( return nil } } + handlerFunc = func(c Context) error { + c.Set("path", c.Path()) + return nil + } ) +func checkUnusedParamValues(t *testing.T, c *context, expectParam map[string]string) { + for i, p := range c.pnames { + value := c.pvalues[i] + if value != "" { + if expectParam == nil { + t.Errorf("pValue '%v' is set for param name '%v' but we are not expecting it with expectParam", value, p) + } else { + if _, ok := expectParam[p]; !ok { + t.Errorf("pValue '%v' is set for param name '%v' but we are not expecting it with expectParam", value, p) + } + } + } + } +} + func TestRouterStatic(t *testing.T) { e := New() r := e.router path := "/folders/a/files/echo.gif" - r.Add(http.MethodGet, path, func(c Context) error { - c.Set("path", path) - return nil - }) + r.Add(http.MethodGet, path, handlerFunc) c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, path, c) c.handler(c) + assert.Equal(t, path, c.Get("path")) } func TestRouterParam(t *testing.T) { e := New() r := e.router - r.Add(http.MethodGet, "/users/:id", func(c Context) error { - return nil - }) - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/1", c) - assert.Equal(t, "1", c.Param("id")) + + r.Add(http.MethodGet, "/users/:id", handlerFunc) + + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectParam map[string]string + }{ + { + name: "route /users/1 to /users/:id", + whenURL: "/users/1", + expectRoute: "/users/:id", + expectParam: map[string]string{"id": "1"}, + }, + { // FIXME: this documents current implementation (slash at end is problematic) + name: "route /users/1/ to /users/:id", + whenURL: "/users/1/", + expectRoute: nil, // FIXME: should be "/users/:id", + expectParam: nil, // FIXME: should be map[string]string{"id": "1/"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, tc.whenURL, c) + + c.handler(c) + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func TestRouterTwoParam(t *testing.T) { e := New() r := e.router - r.Add(http.MethodGet, "/users/:uid/files/:fid", func(Context) error { - return nil - }) + r.Add(http.MethodGet, "/users/:uid/files/:fid", handlerFunc) c := e.NewContext(nil, nil).(*context) r.Find(http.MethodGet, "/users/1/files/1", c) + assert.Equal(t, "1", c.Param("uid")) assert.Equal(t, "1", c.Param("fid")) } @@ -685,18 +733,279 @@ func TestRouterParamWithSlash(t *testing.T) { e := New() r := e.router - r.Add(http.MethodGet, "/a/:b/c/d/:e", func(c Context) error { - return nil - }) + r.Add(http.MethodGet, "/a/:b/c/d/:e", handlerFunc) + r.Add(http.MethodGet, "/a/:b/c/:d/:f", handlerFunc) - r.Add(http.MethodGet, "/a/:b/c/:d/:f", func(c Context) error { - return nil - }) + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, "/a/1/c/d/2/3", c) // `2/3` should mapped to path `/a/:b/c/d/:e` and into `:e` + + err := c.handler(c) + assert.Equal(t, nil, c.Get("path")) // FIXME: should be "/a/:b/c/d/:e" + assert.EqualError(t, err, "code=404, message=Not Found") // FIXME: should be .NoError() +} + +// Issue #1754 - router needs to backtrack multiple levels upwards in tree to find the matching route +// route evaluation order +// +// Routes: +// 1) /a/:b/c +// 2) /a/c/d +// 3) /a/c/df +// +// 4) /a/*/f +// 5) /:e/c/f +// +// 6) /* +// +// Searching route for "/a/c/f" should match "/a/*/f" +// When route `4) /a/*/f` is not added then request for "/a/c/f" should match "/:e/c/f" +// +// +----------+ +// +-----+ "/" root +--------------------+--------------------------+ +// | +----------+ | | +// | | | +// +-------v-------+ +---v---------+ +-------v---+ +// | "a/" (static) +---------------+ | ":" (param) | | "*" (any) | +// +-+----------+--+ | +-----------+-+ +-----------+ +// | | | | +// +---------------v+ +-- ---v------+ +------v----+ +-----v-----------+ +// | "c/d" (static) | | ":" (param) | | "*" (any) | | "/c/f" (static) | +// +---------+------+ +--------+----+ +----------++ +-----------------+ +// | | | +// | | | +// +---------v----+ +------v--------+ +------v--------+ +// | "f" (static) | | "/c" (static) | | "/f" (static) | +// +--------------+ +---------------+ +---------------+ +func TestRouteMultiLevelBacktracking(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectParam map[string]string + }{ + { + name: "route /a/c/df to /a/c/df", + whenURL: "/a/c/df", + expectRoute: "/a/c/df", + }, + { + name: "route /a/x/df to /a/:b/c", + whenURL: "/a/x/c", + expectRoute: "/a/:b/c", + expectParam: map[string]string{"b": "x"}, + }, + { + name: "route /a/x/f to /a/*/f", + whenURL: "/a/x/f", + expectRoute: "/a/*/f", + expectParam: map[string]string{"*": "x/f"}, // NOTE: `x` would be probably more suitable + }, + { + name: "route /b/c/f to /:e/c/f", + whenURL: "/b/c/f", + expectRoute: "/:e/c/f", + expectParam: map[string]string{"e": "b"}, + }, + { + name: "route /b/c/c to /*", + whenURL: "/b/c/c", + expectRoute: "/*", + expectParam: map[string]string{"*": "b/c/c"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/a/:b/c", handlerHelper("case", 1)) + r.Add(http.MethodGet, "/a/c/d", handlerHelper("case", 2)) + r.Add(http.MethodGet, "/a/c/df", handlerHelper("case", 3)) + r.Add(http.MethodGet, "/a/*/f", handlerHelper("case", 4)) + r.Add(http.MethodGet, "/:e/c/f", handlerHelper("case", 5)) + r.Add(http.MethodGet, "/*", handlerHelper("case", 6)) + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, tc.whenURL, c) + + c.handler(c) + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } +} + +// Issue #1754 - router needs to backtrack multiple levels upwards in tree to find the matching route +// route evaluation order +// +// Request for "/a/c/f" should match "/:e/c/f" +// +// +-0,7--------+ +// | "/" (root) |----------------------------------+ +// +------------+ | +// | | | +// | | | +// +-1,6-----------+ | | +-8-----------+ +------v----+ +// | "a/" (static) +<--+ +--------->+ ":" (param) | | "*" (any) | +// +---------------+ +-------------+ +-----------+ +// | | | +// +-2--------v-----+ +v-3,5--------+ +-9------v--------+ +// | "c/d" (static) | | ":" (param) | | "/c/f" (static) | +// +----------------+ +-------------+ +-----------------+ +// | +// +-4--v----------+ +// | "/c" (static) | +// +---------------+ +func TestRouteMultiLevelBacktracking2(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/a/:b/c", handlerFunc) + r.Add(http.MethodGet, "/a/c/d", handlerFunc) + r.Add(http.MethodGet, "/a/c/df", handlerFunc) + r.Add(http.MethodGet, "/:e/c/f", handlerFunc) + r.Add(http.MethodGet, "/*", handlerFunc) + + var testCases = []struct { + name string + whenURL string + expectRoute string + expectParam map[string]string + }{ + { + name: "route /a/c/df to /a/c/df", + whenURL: "/a/c/df", + expectRoute: "/a/c/df", + }, + { + name: "route /a/x/df to /a/:b/c", + whenURL: "/a/x/c", + expectRoute: "/a/:b/c", + expectParam: map[string]string{"b": "x"}, + }, + { + name: "route /a/c/f to /:e/c/f", + whenURL: "/a/c/f", + expectRoute: "/:e/c/f", + expectParam: map[string]string{"e": "a"}, + }, + { + name: "route /b/c/f to /:e/c/f", + whenURL: "/b/c/f", + expectRoute: "/:e/c/f", + expectParam: map[string]string{"e": "b"}, + }, + { + name: "route /b/c/c to /*", + whenURL: "/b/c/c", + expectRoute: "/*", + expectParam: map[string]string{"*": "b/c/c"}, + }, + { // this traverses `/a/:b/c` and `/:e/c/f` branches and eventually backtracks to `/*` + name: "route /a/c/cf to /*", + whenURL: "/a/c/cf", + expectRoute: "/*", + expectParam: map[string]string{"*": "a/c/cf"}, + }, + { + name: "route /anyMatch to /*", + whenURL: "/anyMatch", + expectRoute: "/*", + expectParam: map[string]string{"*": "anyMatch"}, + }, + { + name: "route /anyMatch/withSlash to /*", + whenURL: "/anyMatch/withSlash", + expectRoute: "/*", + expectParam: map[string]string{"*": "anyMatch/withSlash"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) + + r.Find(http.MethodGet, tc.whenURL, c) + + c.handler(c) + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } +} + +func TestRouterBacktrackingFromMultipleParamKinds(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/*", handlerFunc) // this can match only path that does not have slash in it + r.Add(http.MethodGet, "/:1/second", handlerFunc) + r.Add(http.MethodGet, "/:1/:2", handlerFunc) // this acts as match ANY for all routes that have at least one slash + r.Add(http.MethodGet, "/:1/:2/third", handlerFunc) + r.Add(http.MethodGet, "/:1/:2/:3/fourth", handlerFunc) + r.Add(http.MethodGet, "/:1/:2/:3/:4/fifth", handlerFunc) c := e.NewContext(nil, nil).(*context) - assert.NotPanics(t, func() { - r.Find(http.MethodGet, "/a/1/c/d/2/3", c) - }) + var testCases = []struct { + name string + whenURL string + expectRoute string + expectParam map[string]string + }{ + { + name: "route /first to /*", + whenURL: "/first", + expectRoute: "/*", + expectParam: map[string]string{"*": "first"}, + }, + { + name: "route /first/second to /:1/second", + whenURL: "/first/second", + expectRoute: "/:1/second", + expectParam: map[string]string{"1": "first"}, + }, + { + name: "route /first/second-new to /:1/:2", + whenURL: "/first/second-new", + expectRoute: "/:1/:2", + expectParam: map[string]string{ + "1": "first", + "2": "second-new", + }, + }, + { // FIXME: should match `/:1/:2` when backtracking in tree. this 1 level backtracking fails even with old implementation + name: "route /first/second/ to /:1/:2", + whenURL: "/first/second/", + expectRoute: "/*", // "/:1/:2", + expectParam: map[string]string{"*": "first/second/"}, // map[string]string{"1": "first", "2": "second/"}, + }, + { // FIXME: should match `/:1/:2`. same backtracking problem. when backtracking is at `/:1/:2` during backtracking this node should be match as it has executable handler + name: "route /first/second/third/fourth/fifth/nope to /:1/:2", + whenURL: "/first/second/third/fourth/fifth/nope", + expectRoute: "/*", // "/:1/:2", + expectParam: map[string]string{"*": "first/second/third/fourth/fifth/nope"}, // map[string]string{"1": "first", "2": "second/third/fourth/fifth/nope"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r.Find(http.MethodGet, tc.whenURL, c) + + c.handler(c) + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } // Issue #1509 @@ -713,16 +1022,37 @@ func TestRouterParamStaticConflict(t *testing.T) { g.GET("/status", handler) g.GET("/:name", handler) - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/g/s", c) - c.handler(c) - assert.Equal(t, "s", c.Param("name")) - assert.Equal(t, "/g/:name", c.Get("path")) + var testCases = []struct { + whenURL string + expectRoute interface{} + expectParam map[string]string + }{ + { + whenURL: "/g/s", + expectRoute: "/g/:name", + expectParam: map[string]string{"name": "s"}, + }, + { + whenURL: "/g/status", + expectRoute: "/g/status", + expectParam: map[string]string{"name": ""}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/g/status", c) - c.handler(c) - assert.Equal(t, "/g/status", c.Get("path")) + r.Find(http.MethodGet, tc.whenURL, c) + err := c.handler(c) + + assert.NoError(t, err) + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func TestRouterMatchAny(t *testing.T) { @@ -730,28 +1060,46 @@ func TestRouterMatchAny(t *testing.T) { r := e.router // Routes - r.Add(http.MethodGet, "/", handlerHelper("case", 1)) - r.Add(http.MethodGet, "/*", handlerHelper("case", 2)) - r.Add(http.MethodGet, "/users/*", handlerHelper("case", 3)) - - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/", c) - c.handler(c) - - assert.Equal(t, 1, c.Get("case")) - assert.Equal(t, "/", c.Get("path")) + r.Add(http.MethodGet, "/", handlerFunc) + r.Add(http.MethodGet, "/*", handlerFunc) + r.Add(http.MethodGet, "/users/*", handlerFunc) + + var testCases = []struct { + whenURL string + expectRoute interface{} + expectParam map[string]string + }{ + { + whenURL: "/", + expectRoute: "/", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/download", + expectRoute: "/*", + expectParam: map[string]string{"*": "download"}, + }, + { + whenURL: "/users/joe", + expectRoute: "/users/*", + expectParam: map[string]string{"*": "joe"}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/download", c) - c.handler(c) - assert.Equal(t, 2, c.Get("case")) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "download", c.Param("*")) + r.Find(http.MethodGet, tc.whenURL, c) + err := c.handler(c) - r.Find(http.MethodGet, "/users/joe", c) - c.handler(c) - assert.Equal(t, 3, c.Get("case")) - assert.Equal(t, "/users/*", c.Get("path")) - assert.Equal(t, "joe", c.Param("*")) + assert.NoError(t, err) + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } // NOTE: this is to document current implementation. Last added route with `*` asterisk is always the match and no @@ -796,155 +1144,53 @@ func TestRouterMatchAnyPrefixIssue(t *testing.T) { c.Set("path", c.Path()) return nil }) - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/", c) - c.handler(c) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "", c.Param("*")) - - r.Find(http.MethodGet, "/users", c) - c.handler(c) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "users", c.Param("*")) - - r.Find(http.MethodGet, "/users/", c) - c.handler(c) - assert.Equal(t, "/users/*", c.Get("path")) - assert.Equal(t, "", c.Param("*")) - - r.Find(http.MethodGet, "/users_prefix", c) - c.handler(c) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "users_prefix", c.Param("*")) - r.Find(http.MethodGet, "/users_prefix/", c) - c.handler(c) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "users_prefix/", c.Param("*")) -} - -func TestRouteMultiLevelBacktracking(t *testing.T) { - e := New() - r := e.router - - r.Add(http.MethodGet, "/a/:b/c", handlerHelper("case", 1)) - r.Add(http.MethodGet, "/a/c/d", handlerHelper("case", 2)) - r.Add(http.MethodGet, "/:e/c/f", handlerHelper("case", 3)) - - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/a/c/f", c) - - c.handler(c) - assert.Equal(t, 3, c.Get("case")) - assert.Equal(t, "/:e/c/f", c.Get("path")) -} - -// Issue # -func TestRouterBacktrackingFromParam(t *testing.T) { - e := New() - r := e.router - - r.Add(http.MethodGet, "/*", handlerHelper("case", 1)) - r.Add(http.MethodGet, "/users/:name/", handlerHelper("case", 2)) - - c := e.NewContext(nil, nil).(*context) - - r.Find(http.MethodGet, "/users/firstname/no-match", c) - c.handler(c) - assert.Equal(t, 1, c.Get("case")) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "users/firstname/no-match", c.Param("*")) - - r.Find(http.MethodGet, "/users/firstname/", c) - c.handler(c) - assert.Equal(t, 2, c.Get("case")) - assert.Equal(t, "/users/:name/", c.Get("path")) - assert.Equal(t, "firstname", c.Param("name")) -} - -func TestRouterBacktrackingFromParamAny(t *testing.T) { - e := New() - r := e.router - - r.Add(http.MethodGet, "/*", handlerHelper("case", 1)) - r.Add(http.MethodGet, "/:name/lastname", handlerHelper("case", 2)) - - c := e.NewContext(nil, nil).(*context) - - r.Find(http.MethodGet, "/firstname/test", c) - c.handler(c) - assert.Equal(t, 1, c.Get("case")) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "firstname/test", c.Param("*")) - - r.Find(http.MethodGet, "/firstname", c) - c.handler(c) - assert.Equal(t, 1, c.Get("case")) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "firstname", c.Param("*")) - - r.Find(http.MethodGet, "/firstname/lastname", c) - c.handler(c) - assert.Equal(t, 2, c.Get("case")) - assert.Equal(t, "/:name/lastname", c.Get("path")) - assert.Equal(t, "firstname", c.Param("name")) -} - -func TestRouterBacktrackingFromParamAny2(t *testing.T) { - e := New() - r := e.router - - r.Add(http.MethodGet, "/*", handlerHelper("case", 1)) - r.Add(http.MethodGet, "/:name", handlerHelper("case", 2)) - r.Add(http.MethodGet, "/:name/lastname", handlerHelper("case", 3)) - - c := e.NewContext(nil, nil).(*context) - - r.Find(http.MethodGet, "/firstname/test", c) - c.handler(c) - assert.Equal(t, 1, c.Get("case")) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "firstname/test", c.Param("*")) - - r.Find(http.MethodGet, "/firstname", c) - c.handler(c) - assert.Equal(t, 2, c.Get("case")) - assert.Equal(t, "/:name", c.Get("path")) - assert.Equal(t, "firstname", c.Param("name")) - - r.Find(http.MethodGet, "/firstname/lastname", c) - c.handler(c) - assert.Equal(t, 3, c.Get("case")) - assert.Equal(t, "/:name/lastname", c.Get("path")) - assert.Equal(t, "firstname", c.Param("name")) -} - -func TestRouterAnyCommonPath(t *testing.T) { - e := New() - r := e.router - - r.Add(http.MethodGet, "/ab*", handlerHelper("case", 1)) - r.Add(http.MethodGet, "/abcd", handlerHelper("case", 2)) - r.Add(http.MethodGet, "/abcd*", handlerHelper("case", 3)) - - c := e.NewContext(nil, nil).(*context) - - r.Find(http.MethodGet, "/abee", c) - c.handler(c) - assert.Equal(t, 1, c.Get("case")) - assert.Equal(t, "/ab*", c.Get("path")) - assert.Equal(t, "ee", c.Param("*")) + var testCases = []struct { + whenURL string + expectRoute interface{} + expectParam map[string]string + }{ + { + whenURL: "/", + expectRoute: "/*", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/users", + expectRoute: "/*", + expectParam: map[string]string{"*": "users"}, + }, + { + whenURL: "/users/", + expectRoute: "/users/*", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/users_prefix", + expectRoute: "/*", + expectParam: map[string]string{"*": "users_prefix"}, + }, + { + whenURL: "/users_prefix/", + expectRoute: "/*", + expectParam: map[string]string{"*": "users_prefix/"}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/abcd", c) - c.handler(c) - assert.Equal(t, "/abcd", c.Get("path")) - assert.Equal(t, 2, c.Get("case")) + r.Find(http.MethodGet, tc.whenURL, c) + err := c.handler(c) - r.Find(http.MethodGet, "/abcde", c) - c.handler(c) - assert.Equal(t, 3, c.Get("case")) - assert.Equal(t, "/abcd*", c.Get("path")) - assert.Equal(t, "e", c.Param("*")) + assert.NoError(t, err) + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } // TestRouterMatchAnySlash shall verify finding the best route @@ -953,168 +1199,226 @@ func TestRouterMatchAnySlash(t *testing.T) { e := New() r := e.router - handler := func(c Context) error { - c.Set("path", c.Path()) - return nil - } - // Routes - r.Add(http.MethodGet, "/users", handler) - r.Add(http.MethodGet, "/users/*", handler) - r.Add(http.MethodGet, "/img/*", handler) - r.Add(http.MethodGet, "/img/load", handler) - r.Add(http.MethodGet, "/img/load/*", handler) - r.Add(http.MethodGet, "/assets/*", handler) - - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/", c) - assert.Equal(t, "", c.Param("*")) - - // Test trailing slash request for simple any route (see #1526) - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/", c) - c.handler(c) - assert.Equal(t, "/users/*", c.Get("path")) - assert.Equal(t, "", c.Param("*")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/joe", c) - c.handler(c) - assert.Equal(t, "/users/*", c.Get("path")) - assert.Equal(t, "joe", c.Param("*")) - - // Test trailing slash request for nested any route (see #1526) - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/img/load", c) - c.handler(c) - assert.Equal(t, "/img/load", c.Get("path")) - assert.Equal(t, "", c.Param("*")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/img/load/", c) - c.handler(c) - assert.Equal(t, "/img/load/*", c.Get("path")) - assert.Equal(t, "", c.Param("*")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/img/load/ben", c) - c.handler(c) - assert.Equal(t, "/img/load/*", c.Get("path")) - assert.Equal(t, "ben", c.Param("*")) - - // Test /assets/* any route - // ... without trailing slash must not match - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/assets", c) - c.handler(c) - assert.Equal(t, nil, c.Get("path")) - assert.Equal(t, "", c.Param("*")) + r.Add(http.MethodGet, "/users", handlerFunc) + r.Add(http.MethodGet, "/users/*", handlerFunc) + r.Add(http.MethodGet, "/img/*", handlerFunc) + r.Add(http.MethodGet, "/img/load", handlerFunc) + r.Add(http.MethodGet, "/img/load/*", handlerFunc) + r.Add(http.MethodGet, "/assets/*", handlerFunc) + + var testCases = []struct { + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + whenURL: "/", + expectRoute: nil, + expectParam: map[string]string{"*": ""}, + expectError: ErrNotFound, + }, + { // Test trailing slash request for simple any route (see #1526) + whenURL: "/users/", + expectRoute: "/users/*", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/users/joe", + expectRoute: "/users/*", + expectParam: map[string]string{"*": "joe"}, + }, + // Test trailing slash request for nested any route (see #1526) + { + whenURL: "/img/load", + expectRoute: "/img/load", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/img/load/", + expectRoute: "/img/load/*", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/img/load/ben", + expectRoute: "/img/load/*", + expectParam: map[string]string{"*": "ben"}, + }, + // Test /assets/* any route + { // ... without trailing slash must not match + whenURL: "/assets", + expectRoute: nil, + expectParam: map[string]string{"*": ""}, + expectError: ErrNotFound, + }, + + { // ... with trailing slash must match + whenURL: "/assets/", + expectRoute: "/assets/*", + expectParam: map[string]string{"*": ""}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - // ... with trailing slash must match - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/assets/", c) - c.handler(c) - assert.Equal(t, "/assets/*", c.Get("path")) - assert.Equal(t, "", c.Param("*")) + r.Find(http.MethodGet, tc.whenURL, c) + err := c.handler(c) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func TestRouterMatchAnyMultiLevel(t *testing.T) { e := New() r := e.router - handler := func(c Context) error { - c.Set("path", c.Path()) - return nil - } // Routes - r.Add(http.MethodGet, "/api/users/jack", handler) - r.Add(http.MethodGet, "/api/users/jill", handler) - r.Add(http.MethodGet, "/api/users/*", handler) - r.Add(http.MethodGet, "/api/*", handler) - r.Add(http.MethodGet, "/other/*", handler) - r.Add(http.MethodGet, "/*", handler) - - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/api/users/jack", c) - c.handler(c) - assert.Equal(t, "/api/users/jack", c.Get("path")) - assert.Equal(t, "", c.Param("*")) - - r.Find(http.MethodGet, "/api/users/jill", c) - c.handler(c) - assert.Equal(t, "/api/users/jill", c.Get("path")) - assert.Equal(t, "", c.Param("*")) - - r.Find(http.MethodGet, "/api/users/joe", c) - c.handler(c) - assert.Equal(t, "/api/users/*", c.Get("path")) - assert.Equal(t, "joe", c.Param("*")) - - r.Find(http.MethodGet, "/api/nousers/joe", c) - c.handler(c) - assert.Equal(t, "/api/*", c.Get("path")) - assert.Equal(t, "nousers/joe", c.Param("*")) + r.Add(http.MethodGet, "/api/users/jack", handlerFunc) + r.Add(http.MethodGet, "/api/users/jill", handlerFunc) + r.Add(http.MethodGet, "/api/users/*", handlerFunc) + r.Add(http.MethodGet, "/api/*", handlerFunc) + r.Add(http.MethodGet, "/other/*", handlerFunc) + r.Add(http.MethodGet, "/*", handlerFunc) + + var testCases = []struct { + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + whenURL: "/api/users/jack", + expectRoute: "/api/users/jack", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/api/users/jill", + expectRoute: "/api/users/jill", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/api/users/joe", + expectRoute: "/api/users/*", + expectParam: map[string]string{"*": "joe"}, + }, + { + whenURL: "/api/nousers/joe", + expectRoute: "/api/*", + expectParam: map[string]string{"*": "nousers/joe"}, + }, + { + whenURL: "/api/none", + expectRoute: "/api/*", + expectParam: map[string]string{"*": "none"}, + }, + { + whenURL: "/api/none", + expectRoute: "/api/*", + expectParam: map[string]string{"*": "none"}, + }, + { + whenURL: "/noapi/users/jim", + expectRoute: "/*", + expectParam: map[string]string{"*": "noapi/users/jim"}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/api/none", c) - c.handler(c) - assert.Equal(t, "/api/*", c.Get("path")) - assert.Equal(t, "none", c.Param("*")) + r.Find(http.MethodGet, tc.whenURL, c) + err := c.handler(c) - r.Find(http.MethodGet, "/noapi/users/jim", c) - c.handler(c) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "noapi/users/jim", c.Param("*")) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func TestRouterMatchAnyMultiLevelWithPost(t *testing.T) { e := New() r := e.router - handler := func(c Context) error { - c.Set("path", c.Path()) - return nil - } // Routes - e.POST("/api/auth/login", handler) - e.POST("/api/auth/forgotPassword", handler) - e.Any("/api/*", handler) - e.Any("/*", handler) - - // POST /api/auth/login shall choose login method - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodPost, "/api/auth/login", c) - c.handler(c) - assert.Equal(t, "/api/auth/login", c.Get("path")) - assert.Equal(t, "", c.Param("*")) - - // GET /api/auth/login shall choose any route - // c = e.NewContext(nil, nil).(*context) - // r.Find(http.MethodGet, "/api/auth/login", c) - // c.handler(c) - // assert.Equal(t, "/api/*", c.Get("path")) - // assert.Equal(t, "auth/login", c.Param("*")) - - // POST /api/auth/logout shall choose nearest any route - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodPost, "/api/auth/logout", c) - c.handler(c) - assert.Equal(t, "/api/*", c.Get("path")) - assert.Equal(t, "auth/logout", c.Param("*")) - - // POST to /api/other/test shall choose nearest any route - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodPost, "/api/other/test", c) - c.handler(c) - assert.Equal(t, "/api/*", c.Get("path")) - assert.Equal(t, "other/test", c.Param("*")) + e.POST("/api/auth/login", handlerFunc) + e.POST("/api/auth/forgotPassword", handlerFunc) + e.Any("/api/*", handlerFunc) + e.Any("/*", handlerFunc) + + var testCases = []struct { + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { // POST /api/auth/login shall choose login method + whenURL: "/api/auth/login", + whenMethod: http.MethodPost, + expectRoute: "/api/auth/login", + expectParam: map[string]string{"*": ""}, + }, + { // POST /api/auth/logout shall choose nearest any route + whenURL: "/api/auth/logout", + whenMethod: http.MethodPost, + expectRoute: "/api/*", + expectParam: map[string]string{"*": "auth/logout"}, + }, + { // POST to /api/other/test shall choose nearest any route + whenURL: "/api/other/test", + whenMethod: http.MethodPost, + expectRoute: "/api/*", + expectParam: map[string]string{"*": "other/test"}, + }, + { // GET to /api/other/test shall choose nearest any route + whenURL: "/api/other/test", + whenMethod: http.MethodGet, + expectRoute: "/api/*", + expectParam: map[string]string{"*": "other/test"}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - // GET to /api/other/test shall choose nearest any route - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/api/other/test", c) - c.handler(c) - assert.Equal(t, "/api/*", c.Get("path")) - assert.Equal(t, "other/test", c.Param("*")) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + r.Find(method, tc.whenURL, c) + err := c.handler(c) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func TestRouterMicroParam(t *testing.T) { @@ -1150,29 +1454,56 @@ func TestRouterMultiRoute(t *testing.T) { r := e.router // Routes - r.Add(http.MethodGet, "/users", func(c Context) error { - c.Set("path", "/users") - return nil - }) - r.Add(http.MethodGet, "/users/:id", func(c Context) error { - return nil - }) - c := e.NewContext(nil, nil).(*context) - - // Route > /users - r.Find(http.MethodGet, "/users", c) - c.handler(c) - assert.Equal(t, "/users", c.Get("path")) + r.Add(http.MethodGet, "/users", handlerFunc) + r.Add(http.MethodGet, "/users/:id", handlerFunc) + + var testCases = []struct { + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + whenURL: "/users", + expectRoute: "/users", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/users/1", + expectRoute: "/users/:id", + expectParam: map[string]string{"id": "1"}, + }, + { + whenURL: "/user", + expectRoute: nil, + expectParam: map[string]string{"*": ""}, + expectError: ErrNotFound, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - // Route > /users/:id - r.Find(http.MethodGet, "/users/1", c) - assert.Equal(t, "1", c.Param("id")) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + r.Find(method, tc.whenURL, c) + err := c.handler(c) - // Route > /user - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/user", c) - he := c.handler(c).(*HTTPError) - assert.Equal(t, http.StatusNotFound, he.Code) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func TestRouterPriority(t *testing.T) { @@ -1180,123 +1511,112 @@ func TestRouterPriority(t *testing.T) { r := e.router // Routes - r.Add(http.MethodGet, "/users", handlerHelper("a", 1)) - r.Add(http.MethodGet, "/users/new", handlerHelper("b", 2)) - r.Add(http.MethodGet, "/users/:id", handlerHelper("c", 3)) - r.Add(http.MethodGet, "/users/dew", handlerHelper("d", 4)) - r.Add(http.MethodGet, "/users/:id/files", handlerHelper("e", 5)) - r.Add(http.MethodGet, "/users/newsee", handlerHelper("f", 6)) - r.Add(http.MethodGet, "/users/*", handlerHelper("g", 7)) - r.Add(http.MethodGet, "/users/new/*", handlerHelper("h", 8)) - r.Add(http.MethodGet, "/*", handlerHelper("i", 9)) - - // Route > /users - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users", c) - c.handler(c) - assert.Equal(t, 1, c.Get("a")) - assert.Equal(t, "/users", c.Get("path")) - - // Route > /users/new - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/new", c) - c.handler(c) - assert.Equal(t, 2, c.Get("b")) - assert.Equal(t, "/users/new", c.Get("path")) - - // Route > /users/:id - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/1", c) - c.handler(c) - assert.Equal(t, 3, c.Get("c")) - assert.Equal(t, "/users/:id", c.Get("path")) - - // Route > /users/dew - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/dew", c) - c.handler(c) - assert.Equal(t, 4, c.Get("d")) - assert.Equal(t, "/users/dew", c.Get("path")) - - // Route > /users/:id/files - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/1/files", c) - c.handler(c) - assert.Equal(t, 5, c.Get("e")) - assert.Equal(t, "/users/:id/files", c.Get("path")) - - // Route > /users/:id - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/news", c) - c.handler(c) - assert.Equal(t, 3, c.Get("c")) - assert.Equal(t, "/users/:id", c.Get("path")) - - // Route > /users/newsee - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/newsee", c) - c.handler(c) - assert.Equal(t, 6, c.Get("f")) - assert.Equal(t, "/users/newsee", c.Get("path")) - - // Route > /users/newsee - r.Find(http.MethodGet, "/users/newsee", c) - c.handler(c) - assert.Equal(t, 6, c.Get("f")) - - // Route > /users/newsee - r.Find(http.MethodGet, "/users/newsee", c) - c.handler(c) - assert.Equal(t, 6, c.Get("f")) - - // Route > /users/* - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/joe/books", c) - c.handler(c) - assert.Equal(t, 7, c.Get("g")) - assert.Equal(t, "/users/*", c.Get("path")) - assert.Equal(t, "joe/books", c.Param("*")) - - // Route > /users/new/* should be matched - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/new/someone", c) - c.handler(c) - assert.Equal(t, 8, c.Get("h")) - assert.Equal(t, "/users/new/*", c.Get("path")) - assert.Equal(t, "someone", c.Param("*")) - - // Route > /users/* should be matched although /users/dew exists - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/dew/someone", c) - c.handler(c) - assert.Equal(t, 7, c.Get("g")) - assert.Equal(t, "/users/*", c.Get("path")) - - assert.Equal(t, "dew/someone", c.Param("*")) - - // Route > /users/* should be matched although /users/dew exists - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/notexists/someone", c) - c.handler(c) - assert.Equal(t, 7, c.Get("g")) - assert.Equal(t, "/users/*", c.Get("path")) - assert.Equal(t, "notexists/someone", c.Param("*")) + r.Add(http.MethodGet, "/users", handlerFunc) + r.Add(http.MethodGet, "/users/new", handlerFunc) + r.Add(http.MethodGet, "/users/:id", handlerFunc) + r.Add(http.MethodGet, "/users/dew", handlerFunc) + r.Add(http.MethodGet, "/users/:id/files", handlerFunc) + r.Add(http.MethodGet, "/users/newsee", handlerFunc) + r.Add(http.MethodGet, "/users/*", handlerFunc) + r.Add(http.MethodGet, "/users/new/*", handlerFunc) + r.Add(http.MethodGet, "/*", handlerFunc) + + var testCases = []struct { + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + whenURL: "/users", + expectRoute: "/users", + }, + { + whenURL: "/users/new", + expectRoute: "/users/new", + }, + { + whenURL: "/users/1", + expectRoute: "/users/:id", + expectParam: map[string]string{"id": "1"}, + }, + { + whenURL: "/users/dew", + expectRoute: "/users/dew", + }, + { + whenURL: "/users/1/files", + expectRoute: "/users/:id/files", + expectParam: map[string]string{"id": "1"}, + }, + { + whenURL: "/users/new", + expectRoute: "/users/new", + }, + { + whenURL: "/users/news", + expectRoute: "/users/:id", + expectParam: map[string]string{"id": "news"}, + }, + { + whenURL: "/users/newsee", + expectRoute: "/users/newsee", + }, + { + whenURL: "/users/joe/books", + expectRoute: "/users/*", + expectParam: map[string]string{"*": "joe/books"}, + }, + { + whenURL: "/users/new/someone", + expectRoute: "/users/new/*", + expectParam: map[string]string{"*": "someone"}, + }, + { + whenURL: "/users/dew/someone", + expectRoute: "/users/*", + expectParam: map[string]string{"*": "dew/someone"}, + }, + { // Route > /users/* should be matched although /users/dew exists + whenURL: "/users/notexists/someone", + expectRoute: "/users/*", + expectParam: map[string]string{"*": "notexists/someone"}, + }, + { + whenURL: "/nousers", + expectRoute: "/*", + expectParam: map[string]string{"*": "nousers"}, + }, + { + whenURL: "/nousers/new", + expectRoute: "/*", + expectParam: map[string]string{"*": "nousers/new"}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - // Route > * - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/nousers", c) - c.handler(c) - assert.Equal(t, 9, c.Get("i")) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "nousers", c.Param("*")) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + r.Find(method, tc.whenURL, c) + err := c.handler(c) - // Route > * - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/nousers/new", c) - c.handler(c) - assert.Equal(t, 9, c.Get("i")) - assert.Equal(t, "/*", c.Get("path")) - assert.Equal(t, "nousers/new", c.Param("*")) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func TestRouterIssue1348(t *testing.T) { @@ -1315,31 +1635,55 @@ func TestRouterIssue1348(t *testing.T) { func TestRouterPriorityNotFound(t *testing.T) { e := New() r := e.router - c := e.NewContext(nil, nil).(*context) // Add - r.Add(http.MethodGet, "/a/foo", func(c Context) error { - c.Set("a", 1) - return nil - }) - r.Add(http.MethodGet, "/a/bar", func(c Context) error { - c.Set("b", 2) - return nil - }) - - // Find - r.Find(http.MethodGet, "/a/foo", c) - c.handler(c) - assert.Equal(t, 1, c.Get("a")) + r.Add(http.MethodGet, "/a/foo", handlerFunc) + r.Add(http.MethodGet, "/a/bar", handlerFunc) + + var testCases = []struct { + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + whenURL: "/a/foo", + expectRoute: "/a/foo", + }, + { + whenURL: "/a/bar", + expectRoute: "/a/bar", + }, + { + whenURL: "/abc/def", + expectRoute: nil, + expectError: ErrNotFound, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/a/bar", c) - c.handler(c) - assert.Equal(t, 2, c.Get("b")) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + r.Find(method, tc.whenURL, c) + err := c.handler(c) - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/abc/def", c) - he := c.handler(c).(*HTTPError) - assert.Equal(t, http.StatusNotFound, he.Code) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func TestRouterParamNames(t *testing.T) { @@ -1347,34 +1691,58 @@ func TestRouterParamNames(t *testing.T) { r := e.router // Routes - r.Add(http.MethodGet, "/users", func(c Context) error { - c.Set("path", "/users") - return nil - }) - r.Add(http.MethodGet, "/users/:id", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/users/:uid/files/:fid", func(c Context) error { - return nil - }) - c := e.NewContext(nil, nil).(*context) - - // Route > /users - r.Find(http.MethodGet, "/users", c) - c.handler(c) - assert.Equal(t, "/users", c.Get("path")) + r.Add(http.MethodGet, "/users", handlerFunc) + r.Add(http.MethodGet, "/users/:id", handlerFunc) + r.Add(http.MethodGet, "/users/:uid/files/:fid", handlerFunc) + + var testCases = []struct { + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + whenURL: "/users", + expectRoute: "/users", + }, + { + whenURL: "/users/1", + expectRoute: "/users/:id", + expectParam: map[string]string{"id": "1"}, + }, + { + whenURL: "/users/1/files/1", + expectRoute: "/users/:uid/files/:fid", + expectParam: map[string]string{ + "uid": "1", + "fid": "1", + }, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - // Route > /users/:id - r.Find(http.MethodGet, "/users/1", c) - assert.Equal(t, "id", c.pnames[0]) - assert.Equal(t, "1", c.Param("id")) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + r.Find(method, tc.whenURL, c) + err := c.handler(c) - // Route > /users/:uid/files/:fid - r.Find(http.MethodGet, "/users/1/files/1", c) - assert.Equal(t, "uid", c.pnames[0]) - assert.Equal(t, "1", c.Param("uid")) - assert.Equal(t, "fid", c.pnames[1]) - assert.Equal(t, "1", c.Param("fid")) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } // Issue #623 and #1406 @@ -1389,47 +1757,69 @@ func TestRouterStaticDynamicConflict(t *testing.T) { r.Add(http.MethodGet, "/server", handlerHelper("c", 3)) r.Add(http.MethodGet, "/", handlerHelper("f", 6)) - c := e.NewContext(nil, nil) - r.Find(http.MethodGet, "/dictionary/skills", c) - c.Handler()(c) - assert.Equal(t, 1, c.Get("a")) - assert.Equal(t, "/dictionary/skills", c.Get("path")) - - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/dictionary/skillsnot", c) - c.Handler()(c) - assert.Equal(t, 2, c.Get("b")) - assert.Equal(t, "/dictionary/:name", c.Get("path")) - - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/dictionary/type", c) - c.Handler()(c) - assert.Equal(t, 2, c.Get("b")) - assert.Equal(t, "/dictionary/:name", c.Get("path")) - - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/server", c) - c.Handler()(c) - assert.Equal(t, 3, c.Get("c")) - assert.Equal(t, "/server", c.Get("path")) - - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/users/new", c) - c.Handler()(c) - assert.Equal(t, 4, c.Get("d")) - assert.Equal(t, "/users/new", c.Get("path")) + var testCases = []struct { + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + whenURL: "/dictionary/skills", + expectRoute: "/dictionary/skills", + expectParam: map[string]string{"*": ""}, + }, + { + whenURL: "/dictionary/skillsnot", + expectRoute: "/dictionary/:name", + expectParam: map[string]string{"name": "skillsnot"}, + }, + { + whenURL: "/dictionary/type", + expectRoute: "/dictionary/:name", + expectParam: map[string]string{"name": "type"}, + }, + { + whenURL: "/server", + expectRoute: "/server", + }, + { + whenURL: "/users/new", + expectRoute: "/users/new", + }, + { + whenURL: "/users/new2", + expectRoute: "/users/:name", + expectParam: map[string]string{"name": "new2"}, + }, + { + whenURL: "/", + expectRoute: "/", + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/users/new2", c) - c.Handler()(c) - assert.Equal(t, 5, c.Get("e")) - assert.Equal(t, "/users/:name", c.Get("path")) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + r.Find(method, tc.whenURL, c) + err := c.handler(c) - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/", c) - c.Handler()(c) - assert.Equal(t, 6, c.Get("f")) - assert.Equal(t, "/", c.Get("path")) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } // Issue #1348 @@ -1438,42 +1828,76 @@ func TestRouterParamBacktraceNotFound(t *testing.T) { r := e.router // Add - r.Add(http.MethodGet, "/:param1", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/:param1/foo", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/:param1/bar", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/:param1/bar/:param2", func(c Context) error { - return nil - }) - - c := e.NewContext(nil, nil).(*context) - - //Find - r.Find(http.MethodGet, "/a", c) - assert.Equal(t, "a", c.Param("param1")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/a/foo", c) - assert.Equal(t, "a", c.Param("param1")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/a/bar", c) - assert.Equal(t, "a", c.Param("param1")) + r.Add(http.MethodGet, "/:param1", handlerFunc) + r.Add(http.MethodGet, "/:param1/foo", handlerFunc) + r.Add(http.MethodGet, "/:param1/bar", handlerFunc) + r.Add(http.MethodGet, "/:param1/bar/:param2", handlerFunc) + + var testCases = []struct { + name string + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + name: "route /a to /:param1", + whenURL: "/a", + expectRoute: "/:param1", + expectParam: map[string]string{"param1": "a"}, + }, + { + name: "route /a/foo to /:param1/foo", + whenURL: "/a/foo", + expectRoute: "/:param1/foo", + expectParam: map[string]string{"param1": "a"}, + }, + { + name: "route /a/bar to /:param1/bar", + whenURL: "/a/bar", + expectRoute: "/:param1/bar", + expectParam: map[string]string{"param1": "a"}, + }, + { + name: "route /a/bar/b to /:param1/bar/:param2", + whenURL: "/a/bar/b", + expectRoute: "/:param1/bar/:param2", + expectParam: map[string]string{ + "param1": "a", + "param2": "b", + }, + }, + { + name: "route /a/bbbbb should return 404", + whenURL: "/a/bbbbb", + expectRoute: nil, + expectError: ErrNotFound, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/a/bar/b", c) - assert.Equal(t, "a", c.Param("param1")) - assert.Equal(t, "b", c.Param("param2")) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + r.Find(method, tc.whenURL, c) + err := c.handler(c) - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/a/bbbbb", c) - he := c.handler(c).(*HTTPError) - assert.Equal(t, http.StatusNotFound, he.Code) + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func testRouterAPI(t *testing.T, api []*Route) { @@ -1487,13 +1911,15 @@ func testRouterAPI(t *testing.T, api []*Route) { } c := e.NewContext(nil, nil).(*context) for _, route := range api { - r.Find(route.Method, route.Path, c) - tokens := strings.Split(route.Path[1:], "/") - for _, token := range tokens { - if token[0] == ':' { - assert.Equal(t, c.Param(token[1:]), token) + t.Run(route.Path, func(t *testing.T) { + r.Find(route.Method, route.Path, c) + tokens := strings.Split(route.Path[1:], "/") + for _, token := range tokens { + if token[0] == ':' { + assert.Equal(t, c.Param(token[1:]), token) + } } - } + }) } } @@ -1552,79 +1978,93 @@ func TestRouterParam1466(t *testing.T) { e := New() r := e.router - r.Add(http.MethodPost, "/users/signup", func(c Context) error { - return nil - }) - r.Add(http.MethodPost, "/users/signup/bulk", func(c Context) error { - return nil - }) - r.Add(http.MethodPost, "/users/survey", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/users/:username", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/interests/:name/users", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/skills/:name/users", func(c Context) error { - return nil - }) + r.Add(http.MethodPost, "/users/signup", handlerFunc) + r.Add(http.MethodPost, "/users/signup/bulk", handlerFunc) + r.Add(http.MethodPost, "/users/survey", handlerFunc) + r.Add(http.MethodGet, "/users/:username", handlerFunc) + r.Add(http.MethodGet, "/interests/:name/users", handlerFunc) + r.Add(http.MethodGet, "/skills/:name/users", handlerFunc) // Additional routes for Issue 1479 - r.Add(http.MethodGet, "/users/:username/likes/projects/ids", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/users/:username/profile", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/users/:username/uploads/:type", func(c Context) error { - return nil - }) - - c := e.NewContext(nil, nil).(*context) - - r.Find(http.MethodGet, "/users/ajitem", c) - assert.Equal(t, "ajitem", c.Param("username")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/sharewithme", c) - assert.Equal(t, "sharewithme", c.Param("username")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/signup", c) - assert.Equal(t, "", c.Param("username")) - // Additional assertions for #1479 - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/sharewithme/likes/projects/ids", c) - assert.Equal(t, "sharewithme", c.Param("username")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/ajitem/likes/projects/ids", c) - assert.Equal(t, "ajitem", c.Param("username")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/sharewithme/profile", c) - assert.Equal(t, "sharewithme", c.Param("username")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/ajitem/profile", c) - assert.Equal(t, "ajitem", c.Param("username")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/sharewithme/uploads/self", c) - assert.Equal(t, "sharewithme", c.Param("username")) - assert.Equal(t, "self", c.Param("type")) - - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/ajitem/uploads/self", c) - assert.Equal(t, "ajitem", c.Param("username")) - assert.Equal(t, "self", c.Param("type")) - - // Issue #1493 - check for routing loop - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/tree/free", c) - assert.Equal(t, "", c.Param("id")) - assert.Equal(t, 0, c.response.Status) + r.Add(http.MethodGet, "/users/:username/likes/projects/ids", handlerFunc) + r.Add(http.MethodGet, "/users/:username/profile", handlerFunc) + r.Add(http.MethodGet, "/users/:username/uploads/:type", handlerFunc) + + var testCases = []struct { + whenURL string + expectRoute interface{} + expectParam map[string]string + }{ + { + whenURL: "/users/ajitem", + expectRoute: "/users/:username", + expectParam: map[string]string{"username": "ajitem"}, + }, + { + whenURL: "/users/sharewithme", + expectRoute: "/users/:username", + expectParam: map[string]string{"username": "sharewithme"}, + }, + { + whenURL: "/users/signup", + expectRoute: nil, // method not found as this route is for POST but request is for GET + expectParam: map[string]string{"username": ""}, + }, + // Additional assertions for #1479 + { + whenURL: "/users/sharewithme/likes/projects/ids", + expectRoute: "/users/:username/likes/projects/ids", + expectParam: map[string]string{"username": "sharewithme"}, + }, + { + whenURL: "/users/ajitem/likes/projects/ids", + expectRoute: "/users/:username/likes/projects/ids", + expectParam: map[string]string{"username": "ajitem"}, + }, + { + whenURL: "/users/sharewithme/profile", + expectRoute: "/users/:username/profile", + expectParam: map[string]string{"username": "sharewithme"}, + }, + { + whenURL: "/users/ajitem/profile", + expectRoute: "/users/:username/profile", + expectParam: map[string]string{"username": "ajitem"}, + }, + { + whenURL: "/users/sharewithme/uploads/self", + expectRoute: "/users/:username/uploads/:type", + expectParam: map[string]string{ + "username": "sharewithme", + "type": "self", + }, + }, + { + whenURL: "/users/ajitem/uploads/self", + expectRoute: "/users/:username/uploads/:type", + expectParam: map[string]string{ + "username": "ajitem", + "type": "self", + }, + }, + { + whenURL: "/users/tree/free", + expectRoute: nil, // not found + expectParam: map[string]string{"id": ""}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) + + r.Find(http.MethodGet, tc.whenURL, c) + c.handler(c) + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } // Issue #1655 @@ -1669,33 +2109,56 @@ func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) { e := New() r := e.router - r.Add(http.MethodGet, "/users/create", handlerHelper("create", 1)) - r.Add(http.MethodGet, "/users/:id/edit", func(c Context) error { - return nil - }) - r.Add(http.MethodGet, "/users/:id/active", func(c Context) error { - return nil - }) - - c := e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/alice/edit", c) - assert.Equal(t, "alice", c.Param("id")) + r.Add(http.MethodGet, "/users/create", handlerFunc) + r.Add(http.MethodGet, "/users/:id/edit", handlerFunc) + r.Add(http.MethodGet, "/users/:id/active", handlerFunc) + + var testCases = []struct { + whenURL string + expectRoute interface{} + expectParam map[string]string + expectStatus int + }{ + { + whenURL: "/users/alice/edit", + expectRoute: "/users/:id/edit", + expectParam: map[string]string{"id": "alice"}, + }, + { + whenURL: "/users/bob/active", + expectRoute: "/users/:id/active", + expectParam: map[string]string{"id": "bob"}, + }, + { + whenURL: "/users/create", + expectRoute: "/users/create", + expectParam: nil, + }, + //This panic before the fix for Issue #1653 + { + whenURL: "/users/createNotFound", + expectStatus: http.StatusNotFound, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/bob/active", c) - assert.Equal(t, "bob", c.Param("id")) + r.Find(http.MethodGet, tc.whenURL, c) + err := c.handler(c) - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/create", c) - c.Handler()(c) - assert.Equal(t, 1, c.Get("create")) - assert.Equal(t, "/users/create", c.Get("path")) - - //This panic before the fix for Issue #1653 - c = e.NewContext(nil, nil).(*context) - r.Find(http.MethodGet, "/users/createNotFound", c) - he := c.Handler()(c).(*HTTPError) - assert.Equal(t, http.StatusNotFound, he.Code) + if tc.expectStatus != 0 { + assert.Error(t, err) + he := err.(*HTTPError) + assert.Equal(t, tc.expectStatus, he.Code) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } } func benchmarkRouterRoutes(b *testing.B, routes []*Route, routesToFind []*Route) { @@ -1765,14 +2228,14 @@ func (n *node) printTree(pfx string, tail bool) { p = prefix(tail, pfx, " ", "│ ") - children := n.staticChildrens + children := n.staticChildren l := len(children) - if n.paramChildren != nil { - n.paramChildren.printTree(p, n.anyChildren == nil && l == 0) + if n.paramChild != nil { + n.paramChild.printTree(p, n.anyChild == nil && l == 0) } - if n.anyChildren != nil { - n.anyChildren.printTree(p, l == 0) + if n.anyChild != nil { + n.anyChild.printTree(p, l == 0) } for i := 0; i < l-1; i++ { children[i].printTree(p, false) From cffd3efa91c8cb95d104ec1cbebd2e1c0a8981c0 Mon Sep 17 00:00:00 2001 From: Seena Fallah Date: Sun, 7 Mar 2021 22:57:01 +0330 Subject: [PATCH 117/446] Avoid context canceled errors (#1789) * Avoid context canceled errors Return 499 Client Closed Request when the client has closed the request before the server could send a response Signed-off-by: Seena Fallah --- middleware/proxy_1_11.go | 24 +++++++++++++++++++++++- middleware/proxy_1_11_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/middleware/proxy_1_11.go b/middleware/proxy_1_11.go index a43927817..17d142d8d 100644 --- a/middleware/proxy_1_11.go +++ b/middleware/proxy_1_11.go @@ -3,13 +3,22 @@ package middleware import ( + "context" "fmt" "net/http" "net/http/httputil" + "strings" "github.com/labstack/echo/v4" ) +// StatusCodeContextCanceled is a custom HTTP status code for situations +// where a client unexpectedly closed the connection to the server. +// As there is no standard error code for "client closed connection", but +// various well-known HTTP clients and server implement this HTTP code we use +// 499 too instead of the more problematic 5xx, which does not allow to detect this situation +const StatusCodeContextCanceled = 499 + func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { proxy := httputil.NewSingleHostReverseProxy(tgt.URL) proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { @@ -17,7 +26,20 @@ func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handle if tgt.Name != "" { desc = fmt.Sprintf("%s(%s)", tgt.Name, tgt.URL.String()) } - c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err))) + // If the client canceled the request (usually by closing the connection), we can report a + // client error (4xx) instead of a server error (5xx) to correctly identify the situation. + // The Go standard library (at of late 2020) wraps the exported, standard + // context.Canceled error with unexported garbage value requiring a substring check, see + // https://github.com/golang/go/blob/6965b01ea248cabb70c3749fd218b36089a21efb/src/net/net.go#L416-L430 + if err == context.Canceled || strings.Contains(err.Error(), "operation was canceled") { + httpError := echo.NewHTTPError(StatusCodeContextCanceled, fmt.Sprintf("client closed connection: %v", err)) + httpError.Internal = err + c.Set("_error", httpError) + } else { + httpError := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err)) + httpError.Internal = err + c.Set("_error", httpError) + } } proxy.Transport = config.Transport proxy.ModifyResponse = config.ModifyResponse diff --git a/middleware/proxy_1_11_test.go b/middleware/proxy_1_11_test.go index 26feaabaa..c3541d5e8 100644 --- a/middleware/proxy_1_11_test.go +++ b/middleware/proxy_1_11_test.go @@ -3,10 +3,13 @@ package middleware import ( + "context" "net/http" "net/http/httptest" "net/url" + "sync" "testing" + "time" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" @@ -51,3 +54,33 @@ func TestProxy_1_11(t *testing.T) { assert.Equal(t, "/api/users", req.URL.Path) assert.Equal(t, http.StatusBadGateway, rec.Code) } + +func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { + var timeoutStop sync.WaitGroup + timeoutStop.Add(1) + HTTPTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + timeoutStop.Wait() // wait until we have canceled the request + w.WriteHeader(http.StatusOK) + })) + defer HTTPTarget.Close() + targetURL, _ := url.Parse(HTTPTarget.URL) + target := &ProxyTarget{ + Name: "target", + URL: targetURL, + } + rb := NewRandomBalancer(nil) + assert.True(t, rb.AddTarget(target)) + e := echo.New() + e.Use(Proxy(rb)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + ctx, cancel := context.WithCancel(req.Context()) + req = req.WithContext(ctx) + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + e.ServeHTTP(rec, req) + timeoutStop.Done() + assert.Equal(t, 499, rec.Code) +} From 5622ecc1808899388fb24aeee70d6ced353a731e Mon Sep 17 00:00:00 2001 From: Martti T Date: Mon, 8 Mar 2021 03:01:02 +0200 Subject: [PATCH 118/446] Fix performance regression caused by path escaping (#1777, #1798, #1799) * Fix performance regression #1777 and avoid double escaping in rewrite/proxy middleware. * Add rewrite test for correct escaping of replacement (#1798) Co-authored-by: Roland Lammel --- echo.go | 16 +++- echo_test.go | 89 ++++++++++++++++++++-- middleware/middleware.go | 25 ++++++- middleware/middleware_test.go | 69 +++++++++++++++++ middleware/proxy_test.go | 128 ++++++++++++++++++++------------ middleware/rewrite_test.go | 134 +++++++++++++++++++++++++++------- 6 files changed, 376 insertions(+), 85 deletions(-) create mode 100644 middleware/middleware_test.go diff --git a/echo.go b/echo.go index 1074ba492..0b49f4112 100644 --- a/echo.go +++ b/echo.go @@ -629,12 +629,12 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { h := NotFoundHandler if e.premiddleware == nil { - e.findRouter(r.Host).Find(r.Method, r.URL.EscapedPath(), c) + e.findRouter(r.Host).Find(r.Method, GetPath(r), c) h = c.Handler() h = applyMiddleware(h, e.middleware...) } else { h = func(c Context) error { - e.findRouter(r.Host).Find(r.Method, r.URL.EscapedPath(), c) + e.findRouter(r.Host).Find(r.Method, GetPath(r), c) h := c.Handler() h = applyMiddleware(h, e.middleware...) return h(c) @@ -909,6 +909,18 @@ func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { } } +// GetPath returns RawPath, if it's empty returns Path from URL +// Difference between RawPath and Path is: +// * Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/. +// * RawPath is an optional field which only gets set if the default encoding is different from Path. +func GetPath(r *http.Request) string { + path := r.URL.RawPath + if path == "" { + path = r.URL.Path + } + return path +} + func (e *Echo) findRouter(host string) *Router { if len(e.routers) > 0 { if r, ok := e.routers[host]; ok { diff --git a/echo_test.go b/echo_test.go index 07661b9f8..35c79cbc0 100644 --- a/echo_test.go +++ b/echo_test.go @@ -468,15 +468,46 @@ func TestEchoRoutes(t *testing.T) { } } -func TestEchoEncodedPath(t *testing.T) { +func TestEchoServeHTTPPathEncoding(t *testing.T) { e := New() + e.GET("/with/slash", func(c Context) error { + return c.String(http.StatusOK, "/with/slash") + }) e.GET("/:id", func(c Context) error { - return c.NoContent(http.StatusOK) + return c.String(http.StatusOK, c.Param("id")) }) - req := httptest.NewRequest(http.MethodGet, "/with%2Fslash", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) + + var testCases = []struct { + name string + whenURL string + expectURL string + expectStatus int + }{ + { + name: "url with encoding is not decoded for routing", + whenURL: "/with%2Fslash", + expectURL: "with%2Fslash", // `%2F` is not decoded to `/` for routing + expectStatus: http.StatusOK, + }, + { + name: "url without encoding is used as is", + whenURL: "/with/slash", + expectURL: "/with/slash", + expectStatus: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectStatus, rec.Code) + assert.Equal(t, tc.expectURL, rec.Body.String()) + }) + } } func TestEchoGroup(t *testing.T) { @@ -1211,3 +1242,49 @@ func TestEcho_StartServer(t *testing.T) { }) } } + +func benchmarkEchoRoutes(b *testing.B, routes []*Route) { + e := New() + req := httptest.NewRequest("GET", "/", nil) + u := req.URL + w := httptest.NewRecorder() + + b.ReportAllocs() + + // Add routes + for _, route := range routes { + e.Add(route.Method, route.Path, func(c Context) error { + return nil + }) + } + + // Find routes + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, route := range routes { + req.Method = route.Method + u.Path = route.Path + e.ServeHTTP(w, req) + } + } +} + +func BenchmarkEchoStaticRoutes(b *testing.B) { + benchmarkEchoRoutes(b, staticRoutes) +} + +func BenchmarkEchoStaticRoutesMisses(b *testing.B) { + benchmarkEchoRoutes(b, staticRoutes) +} + +func BenchmarkEchoGitHubAPI(b *testing.B) { + benchmarkEchoRoutes(b, gitHubAPI) +} + +func BenchmarkEchoGitHubAPIMisses(b *testing.B) { + benchmarkEchoRoutes(b, gitHubAPI) +} + +func BenchmarkEchoParseAPI(b *testing.B) { + benchmarkEchoRoutes(b, parseAPI) +} diff --git a/middleware/middleware.go b/middleware/middleware.go index 8381e3a5d..6bdb0eb79 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -2,6 +2,7 @@ package middleware import ( "net/http" + "net/url" "regexp" "strconv" "strings" @@ -50,10 +51,26 @@ func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string { func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) { for k, v := range rewriteRegex { - replacerRawPath := captureTokens(k, req.URL.EscapedPath()) - if replacerRawPath != nil { - replacerPath := captureTokens(k, req.URL.Path) - req.URL.RawPath, req.URL.Path = replacerRawPath.Replace(v), replacerPath.Replace(v) + rawPath := req.URL.RawPath + if rawPath != "" { + // RawPath is only set when there has been escaping done. In that case Path must be deduced from rewritten RawPath + // because encoded Path could match rules that RawPath did not + if replacer := captureTokens(k, rawPath); replacer != nil { + rawPath = replacer.Replace(v) + + req.URL.RawPath = rawPath + req.URL.Path, _ = url.PathUnescape(rawPath) + + return // rewrite only once + } + + continue + } + + if replacer := captureTokens(k, req.URL.Path); replacer != nil { + req.URL.Path = replacer.Replace(v) + + return // rewrite only once } } } diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go new file mode 100644 index 000000000..bc14c531d --- /dev/null +++ b/middleware/middleware_test.go @@ -0,0 +1,69 @@ +package middleware + +import ( + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "regexp" + "testing" +) + +func TestRewritePath(t *testing.T) { + var testCases = []struct { + whenURL string + expectPath string + expectRawPath string + }{ + { + whenURL: "http://localhost:8080/old", + expectPath: "/new", + expectRawPath: "", + }, + { // encoded `ol%64` (decoded `old`) should not be rewritten to `/new` + whenURL: "/ol%64", // `%64` is decoded `d` + expectPath: "/old", + expectRawPath: "/ol%64", + }, + { + whenURL: "http://localhost:8080/users/+_+/orders/___++++?test=1", + expectPath: "/user/+_+/order/___++++", + expectRawPath: "", + }, + { + whenURL: "http://localhost:8080/users/%20a/orders/%20aa", + expectPath: "/user/ a/order/ aa", + expectRawPath: "", + }, + { + whenURL: "http://localhost:8080/%47%6f%2f", + expectPath: "/Go/", + expectRawPath: "/%47%6f%2f", + }, + { + whenURL: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F", + expectPath: "/user/jill/order/T/cO4lW/t/Vp/", + expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + }, + { // do nothing, replace nothing + whenURL: "http://localhost:8080/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + expectPath: "/user/jill/order/T/cO4lW/t/Vp/", + expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + }, + } + + rules := map[*regexp.Regexp]string{ + regexp.MustCompile("^/old$"): "/new", + regexp.MustCompile("^/users/(.*?)/orders/(.*?)$"): "/user/$1/order/$2", + } + + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + + rewritePath(rules, req) + + assert.Equal(t, tc.expectPath, req.URL.Path) // Path field is stored in decoded form: /%47%6f%2f becomes /Go/. + assert.Equal(t, tc.expectRawPath, req.URL.RawPath) // RawPath, an optional field which only gets set if the default encoding is different from Path. + }) + } +} diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index eb72f16ee..591981e7f 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -75,10 +75,12 @@ func TestProxy(t *testing.T) { rrb := NewRoundRobinBalancer(targets) e = echo.New() e.Use(Proxy(rrb)) + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) body = rec.Body.String() assert.Equal(t, "target 1", body) + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) body = rec.Body.String() @@ -94,6 +96,7 @@ func TestProxy(t *testing.T) { return nil }, })) + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) assert.Equal(t, "modified", rec.Body.String()) @@ -108,6 +111,7 @@ func TestProxy(t *testing.T) { } } rrb1 := NewRoundRobinBalancer(targets) + e = echo.New() e.Use(contextObserver) e.Use(Proxy(rrb1)) @@ -159,54 +163,84 @@ func TestProxyRealIPHeader(t *testing.T) { } func TestProxyRewrite(t *testing.T) { - // Setup - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) - defer upstream.Close() - url, _ := url.Parse(upstream.URL) - rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - - // Rewrite - e := echo.New() - e.Use(ProxyWithConfig(ProxyConfig{ - Balancer: rrb, - Rewrite: map[string]string{ - "/old": "/new", - "/api/*": "/$1", - "/js/*": "/public/javascripts/$1", - "/users/*/orders/*": "/user/$1/order/$2", + var testCases = []struct { + whenPath string + expectProxiedURI string + expectStatus int + }{ + { + whenPath: "/api/users", + expectProxiedURI: "/users", + expectStatus: http.StatusOK, }, - })) - req.URL, _ = url.Parse("/api/users") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/users", req.URL.EscapedPath()) - assert.Equal(t, http.StatusOK, rec.Code) - req.URL, _ = url.Parse("/js/main.js") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) - assert.Equal(t, http.StatusOK, rec.Code) - req.URL, _ = url.Parse("/old") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/new", req.URL.EscapedPath()) - assert.Equal(t, http.StatusOK, rec.Code) - req.URL, _ = url.Parse("/users/jack/orders/1") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) - assert.Equal(t, http.StatusOK, rec.Code) - req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) - assert.Equal(t, http.StatusOK, rec.Code) - req.URL, _ = url.Parse("/api/new users") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/new%20users", req.URL.EscapedPath()) + { + whenPath: "/js/main.js", + expectProxiedURI: "/public/javascripts/main.js", + expectStatus: http.StatusOK, + }, + { + whenPath: "/old", + expectProxiedURI: "/new", + expectStatus: http.StatusOK, + }, + { + whenPath: "/users/jack/orders/1", + expectProxiedURI: "/user/jack/order/1", + expectStatus: http.StatusOK, + }, + { + whenPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + expectProxiedURI: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + expectStatus: http.StatusOK, + }, + { // ` ` (space) is encoded by httpClient to `%20` when doing request to Echo. `%20` should not be double escaped when proxying request + whenPath: "/api/new users", + expectProxiedURI: "/new%20users", + expectStatus: http.StatusOK, + }, + { // query params should be proxied and not be modified + whenPath: "/api/users?limit=10", + expectProxiedURI: "/users?limit=10", + expectStatus: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.whenPath, func(t *testing.T) { + receivedRequestURI := make(chan string, 1) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server + // we need unmodified target to see if we are encoding/decoding the url in addition to rewrite/replace logic + // if original request had `%2F` we should not magically decode it to `/` as it would change what was requested + receivedRequestURI <- r.RequestURI + })) + defer upstream.Close() + serverURL, _ := url.Parse(upstream.URL) + rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: serverURL}}) + + // Rewrite + e := echo.New() + e.Use(ProxyWithConfig(ProxyConfig{ + Balancer: rrb, + Rewrite: map[string]string{ + "/old": "/new", + "/api/*": "/$1", + "/js/*": "/public/javascripts/$1", + "/users/*/orders/*": "/user/$1/order/$2", + }, + })) + + targetURL, _ := serverURL.Parse(tc.whenPath) + req := httptest.NewRequest(http.MethodGet, targetURL.String(), nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectStatus, rec.Code) + actualRequestURI := <-receivedRequestURI + assert.Equal(t, tc.expectProxiedURI, actualRequestURI) + }) + } } func TestProxyRewriteRegex(t *testing.T) { diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index 84006e32e..cff2714d7 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -12,9 +12,9 @@ import ( "github.com/stretchr/testify/assert" ) -//Assert expected with url.EscapedPath method to obtain the path. -func TestRewrite(t *testing.T) { +func TestRewriteAfterRouting(t *testing.T) { e := echo.New() + // middlewares added with `Use()` are executed after routing is done and do not affect which route handler is matched e.Use(RewriteWithConfig(RewriteConfig{ Rules: map[string]string{ "/old": "/new", @@ -23,30 +23,71 @@ func TestRewrite(t *testing.T) { "/users/*/orders/*": "/user/$1/order/$2", }, })) - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - req.URL, _ = url.Parse("/api/users") - e.ServeHTTP(rec, req) - assert.Equal(t, "/users", req.URL.EscapedPath()) - req.URL, _ = url.Parse("/js/main.js") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/public/javascripts/main.js", req.URL.EscapedPath()) - req.URL, _ = url.Parse("/old") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/new", req.URL.EscapedPath()) - req.URL, _ = url.Parse("/users/jack/orders/1") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/user/jack/order/1", req.URL.EscapedPath()) - req.URL, _ = url.Parse("/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F") - rec = httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", req.URL.EscapedPath()) - req.URL, _ = url.Parse("/api/new users") - e.ServeHTTP(rec, req) - assert.Equal(t, "/new%20users", req.URL.EscapedPath()) + e.GET("/public/*", func(c echo.Context) error { + return c.String(http.StatusOK, c.Param("*")) + }) + e.GET("/*", func(c echo.Context) error { + return c.String(http.StatusOK, c.Param("*")) + }) + + var testCases = []struct { + whenPath string + expectRoutePath string + expectRequestPath string + expectRequestRawPath string + }{ + { + whenPath: "/api/users", + expectRoutePath: "api/users", + expectRequestPath: "/users", + expectRequestRawPath: "", + }, + { + whenPath: "/js/main.js", + expectRoutePath: "js/main.js", + expectRequestPath: "/public/javascripts/main.js", + expectRequestRawPath: "", + }, + { + whenPath: "/users/jack/orders/1", + expectRoutePath: "users/jack/orders/1", + expectRequestPath: "/user/jack/order/1", + expectRequestRawPath: "", + }, + { // no rewrite rule matched. already encoded URL should not be double encoded or changed in any way + whenPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + expectRoutePath: "user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + expectRequestPath: "/user/jill/order/T/cO4lW/t/Vp/", // this is equal to `url.Parse(tc.whenPath)` result + expectRequestRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + }, + { // just rewrite but do not touch encoding. already encoded URL should not be double encoded + whenPath: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F", + expectRoutePath: "users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F", + expectRequestPath: "/user/jill/order/T/cO4lW/t/Vp/", // this is equal to `url.Parse(tc.whenPath)` result + expectRequestRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + }, + { // ` ` (space) is encoded by httpClient to `%20` when doing request to Echo. `%20` should not be double escaped or changed in any way when rewriting request + whenPath: "/api/new users", + expectRoutePath: "api/new users", + expectRequestPath: "/new users", + expectRequestRawPath: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.whenPath, func(t *testing.T) { + target, _ := url.Parse(tc.whenPath) + req := httptest.NewRequest(http.MethodGet, target.String(), nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, tc.expectRoutePath, rec.Body.String()) + assert.Equal(t, tc.expectRequestPath, req.URL.Path) + assert.Equal(t, tc.expectRequestRawPath, req.URL.RawPath) + }) + } } // Issue #1086 @@ -55,6 +96,7 @@ func TestEchoRewritePreMiddleware(t *testing.T) { r := e.Router() // Rewrite old url to new one + // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches e.Pre(Rewrite(map[string]string{ "/old": "/new", }, @@ -77,6 +119,7 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { e := echo.New() r := e.Router() + // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches e.Pre(RewriteWithConfig(RewriteConfig{ Rules: map[string]string{ "/api/*/mgmt/proj/*/agt": "/api/$1/hosts/$2", @@ -172,3 +215,42 @@ func TestEchoRewriteWithRegexRules(t *testing.T) { }) } } + +// Ensure correct escaping as defined in replacement (issue #1798) +func TestEchoRewriteReplacementEscaping(t *testing.T) { + e := echo.New() + + e.Pre(RewriteWithConfig(RewriteConfig{ + Rules: map[string]string{ + "^/a/*": "/$1?query=param", + "^/b/*": "/$1;part#one", + }, + RegexRules: map[*regexp.Regexp]string{ + regexp.MustCompile("^/x/(.*)"): "/$1?query=param", + regexp.MustCompile("^/y/(.*)"): "/$1;part#one", + }, + })) + + var rec *httptest.ResponseRecorder + var req *http.Request + + testCases := []struct { + requestPath string + expectPath string + }{ + {"/unmatched", "/unmatched"}, + {"/a/test", "/test?query=param"}, + {"/b/foo/bar", "/foo/bar;part#one"}, + {"/x/test", "/test?query=param"}, + {"/y/foo/bar", "/foo/bar;part#one"}, + } + + for _, tc := range testCases { + t.Run(tc.requestPath, func(t *testing.T) { + req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil) + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expectPath, req.URL.Path) + }) + } +} From d6127fe316464d9a2ae0245682a84fc189f6a676 Mon Sep 17 00:00:00 2001 From: Martti T Date: Mon, 8 Mar 2021 03:13:22 +0200 Subject: [PATCH 119/446] Rework timeout middleware to use http.TimeoutHandler implementation (fix #1761) (#1801) --- middleware/timeout.go | 94 ++++++++++++-------- middleware/timeout_test.go | 175 ++++++++++++++++++++++++------------- 2 files changed, 172 insertions(+), 97 deletions(-) diff --git a/middleware/timeout.go b/middleware/timeout.go index 4be557f76..68f464e40 100644 --- a/middleware/timeout.go +++ b/middleware/timeout.go @@ -4,8 +4,8 @@ package middleware import ( "context" - "fmt" "github.com/labstack/echo/v4" + "net/http" "time" ) @@ -14,16 +14,23 @@ type ( TimeoutConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper - // ErrorHandler defines a function which is executed for a timeout - // It can be used to define a custom timeout error - ErrorHandler TimeoutErrorHandlerWithContext + + // ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code + // It can be used to define a custom timeout error message + ErrorMessage string + + // OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after + // request timeouted and we already had sent the error code (503) and message response to the client. + // NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer + // will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()` + OnTimeoutRouteErrorHandler func(err error, c echo.Context) + // Timeout configures a timeout for the middleware, defaults to 0 for no timeout + // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) + // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output + // difference over 500microseconds (0.5millisecond) response seems to be reliable Timeout time.Duration } - - // TimeoutErrorHandlerWithContext is an error handler that is used with the timeout middleware so we can - // handle the error as we see fit - TimeoutErrorHandlerWithContext func(error, echo.Context) error ) var ( @@ -31,7 +38,7 @@ var ( DefaultTimeoutConfig = TimeoutConfig{ Skipper: DefaultSkipper, Timeout: 0, - ErrorHandler: nil, + ErrorMessage: "", } ) @@ -55,39 +62,50 @@ func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc { return next(c) } - ctx, cancel := context.WithTimeout(c.Request().Context(), config.Timeout) - defer cancel() - - // this does a deep clone of the context, wondering if there is a better way to do this? - c.SetRequest(c.Request().Clone(ctx)) - - done := make(chan error, 1) - go func() { - defer func() { - if r := recover(); r != nil { - err, ok := r.(error) - if !ok { - err = fmt.Errorf("panic recovered in timeout middleware: %v", r) - } - c.Logger().Error(err) - done <- err - } - }() - - // This goroutine will keep running even if this middleware times out and - // will be stopped when ctx.Done() is called down the next(c) call chain - done <- next(c) - }() + handlerWrapper := echoHandlerFuncWrapper{ + ctx: c, + handler: next, + errChan: make(chan error, 1), + errHandler: config.OnTimeoutRouteErrorHandler, + } + handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage) + handler.ServeHTTP(c.Response().Writer, c.Request()) select { - case <-ctx.Done(): - if config.ErrorHandler != nil { - return config.ErrorHandler(ctx.Err(), c) - } - return ctx.Err() - case err := <-done: + case err := <-handlerWrapper.errChan: return err + default: + return nil } } } } + +type echoHandlerFuncWrapper struct { + ctx echo.Context + handler echo.HandlerFunc + errHandler func(err error, c echo.Context) + errChan chan error +} + +func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + // replace writer with TimeoutHandler custom one. This will guarantee that + // `writes by h to its ResponseWriter will return ErrHandlerTimeout.` + originalWriter := t.ctx.Response().Writer + t.ctx.Response().Writer = rw + + err := t.handler(t.ctx) + if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded { + if err != nil && t.errHandler != nil { + t.errHandler(err, t.ctx) + } + return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers + } + // we restore original writer only for cases we did not timeout. On timeout we have already sent response to client + // and should not anymore send additional headers/data + // so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body + t.ctx.Response().Writer = originalWriter + if err != nil { + t.errChan <- err + } +} diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index faecc4c53..af4c62647 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -3,7 +3,6 @@ package middleware import ( - "context" "errors" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" @@ -22,6 +21,7 @@ func TestTimeoutSkipper(t *testing.T) { Skipper: func(context echo.Context) bool { return true }, + Timeout: 1 * time.Nanosecond, }) req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -31,18 +31,17 @@ func TestTimeoutSkipper(t *testing.T) { c := e.NewContext(req, rec) err := m(func(c echo.Context) error { - assert.NotEqual(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String()) - return nil + time.Sleep(25 * time.Microsecond) + return errors.New("response from handler") })(c) - assert.NoError(t, err) + // if not skipped we would have not returned error due context timeout logic + assert.EqualError(t, err, "response from handler") } func TestTimeoutWithTimeout0(t *testing.T) { t.Parallel() - m := TimeoutWithConfig(TimeoutConfig{ - Timeout: 0, - }) + m := Timeout() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() @@ -58,10 +57,11 @@ func TestTimeoutWithTimeout0(t *testing.T) { assert.NoError(t, err) } -func TestTimeoutIsCancelable(t *testing.T) { +func TestTimeoutErrorOutInHandler(t *testing.T) { t.Parallel() m := TimeoutWithConfig(TimeoutConfig{ - Timeout: time.Minute, + // Timeout has to be defined or the whole flow for timeout middleware will be skipped + Timeout: 50 * time.Millisecond, }) req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -70,24 +70,6 @@ func TestTimeoutIsCancelable(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) - err := m(func(c echo.Context) error { - assert.EqualValues(t, "*context.timerCtx", reflect.TypeOf(c.Request().Context()).String()) - return nil - })(c) - - assert.NoError(t, err) -} - -func TestTimeoutErrorOutInHandler(t *testing.T) { - t.Parallel() - m := Timeout() - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - - e := echo.New() - c := e.NewContext(req, rec) - err := m(func(c echo.Context) error { return errors.New("err") })(c) @@ -95,34 +77,15 @@ func TestTimeoutErrorOutInHandler(t *testing.T) { assert.Error(t, err) } -func TestTimeoutTimesOutAfterPredefinedTimeoutWithErrorHandler(t *testing.T) { +func TestTimeoutOnTimeoutRouteErrorHandler(t *testing.T) { t.Parallel() - m := TimeoutWithConfig(TimeoutConfig{ - Timeout: time.Second, - ErrorHandler: func(err error, e echo.Context) error { - assert.EqualError(t, err, context.DeadlineExceeded.Error()) - return errors.New("err") - }, - }) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - - e := echo.New() - c := e.NewContext(req, rec) - err := m(func(c echo.Context) error { - time.Sleep(time.Minute) - return nil - })(c) - - assert.EqualError(t, err, errors.New("err").Error()) -} - -func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) { - t.Parallel() + actualErrChan := make(chan error, 1) m := TimeoutWithConfig(TimeoutConfig{ - Timeout: time.Second, + Timeout: 1 * time.Millisecond, + OnTimeoutRouteErrorHandler: func(err error, c echo.Context) { + actualErrChan <- err + }, }) req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -131,12 +94,16 @@ func TestTimeoutTimesOutAfterPredefinedTimeout(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) + stopChan := make(chan struct{}, 0) err := m(func(c echo.Context) error { - time.Sleep(time.Minute) - return nil + <-stopChan + return errors.New("error in route after timeout") })(c) + stopChan <- struct{}{} + assert.NoError(t, err) - assert.EqualError(t, err, context.DeadlineExceeded.Error()) + actualErr := <-actualErrChan + assert.EqualError(t, actualErr, "error in route after timeout") } func TestTimeoutTestRequestClone(t *testing.T) { @@ -148,7 +115,7 @@ func TestTimeoutTestRequestClone(t *testing.T) { m := TimeoutWithConfig(TimeoutConfig{ // Timeout has to be defined or the whole flow for timeout middleware will be skipped - Timeout: time.Second, + Timeout: 1 * time.Second, }) e := echo.New() @@ -178,8 +145,63 @@ func TestTimeoutTestRequestClone(t *testing.T) { func TestTimeoutRecoversPanic(t *testing.T) { t.Parallel() + e := echo.New() + e.Use(Recover()) // recover middleware will handler our panic + e.Use(TimeoutWithConfig(TimeoutConfig{ + Timeout: 50 * time.Millisecond, + })) + + e.GET("/", func(c echo.Context) error { + panic("panic!!!") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + assert.NotPanics(t, func() { + e.ServeHTTP(rec, req) + }) +} + +func TestTimeoutDataRace(t *testing.T) { + t.Parallel() + + timeout := 1 * time.Millisecond + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: timeout, + ErrorMessage: "Timeout! change me", + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) + // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output + // difference over 500microseconds (0.5millisecond) response seems to be reliable + time.Sleep(timeout) // timeout and handler execution time difference is close to zero + return c.String(http.StatusOK, "Hello, World!") + })(c) + + assert.NoError(t, err) + + if rec.Code == http.StatusServiceUnavailable { + assert.Equal(t, "Timeout! change me", rec.Body.String()) + } else { + assert.Equal(t, "Hello, World!", rec.Body.String()) + } +} + +func TestTimeoutWithErrorMessage(t *testing.T) { + t.Parallel() + + timeout := 1 * time.Millisecond m := TimeoutWithConfig(TimeoutConfig{ - Timeout: 25 * time.Millisecond, + Timeout: timeout, + ErrorMessage: "Timeout! change me", }) req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -188,9 +210,44 @@ func TestTimeoutRecoversPanic(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) + stopChan := make(chan struct{}, 0) err := m(func(c echo.Context) error { - panic("panic in handler") + // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) + // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output + // difference over 500microseconds (0.5millisecond) response seems to be reliable + <-stopChan + return c.String(http.StatusOK, "Hello, World!") })(c) + stopChan <- struct{}{} + + assert.NoError(t, err) + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) + assert.Equal(t, "Timeout! change me", rec.Body.String()) +} + +func TestTimeoutWithDefaultErrorMessage(t *testing.T) { + t.Parallel() - assert.Error(t, err, "panic recovered in timeout middleware: panic in handler") + timeout := 1 * time.Millisecond + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: timeout, + ErrorMessage: "", + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + stopChan := make(chan struct{}, 0) + err := m(func(c echo.Context) error { + <-stopChan + return c.String(http.StatusOK, "Hello, World!") + })(c) + stopChan <- struct{}{} + + assert.NoError(t, err) + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) + assert.Equal(t, `Timeout

Timeout

`, rec.Body.String()) } From a97052edaf781a731903d816c9b271028d709131 Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Mon, 8 Mar 2021 02:33:04 +0100 Subject: [PATCH 120/446] Update version to v4.2.1 --- CHANGELOG.md | 22 ++++++++++++++++++++++ echo.go | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 33f5587f8..b50478830 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,27 @@ # Changelog +## v4.2.1 - 2020-03-08 + +**Important notes** + +Due to a datarace the config parameters for the newly added timeout middleware required a change. +See the [docs](https://echo.labstack.com/middleware/timeout). +A performance regression has been fixed, even bringing better performance than before for some routing scenarios. + +**Fixes** + +* Fix performance regression caused by path escaping (#1777, #1798, #1799, aldas) +* Avoid context canceled errors (#1789, clwluvw) +* Improve router to use on stack backtracking (#1791, aldas, stffabi) +* Fix panic in timeout middleware not being not recovered and cause application crash (#1794, aldas) +* Fix Echo.Serve() not serving on HTTP port correctly when TLSListener is used (#1785, #1793, aldas) +* Apply go fmt (#1788, Le0tk0k) +* Uses strings.Equalfold (#1790, rkilingr) +* Improve code quality (#1792, withshubh) + +This release was made possible by our **contributors**: +aldas, clwluvw, lammel, Le0tk0k, maciej-jezierski, rkilingr, stffabi, withshubh + ## v4.2.0 - 2020-02-11 **Important notes** diff --git a/echo.go b/echo.go index 0b49f4112..3fccaf648 100644 --- a/echo.go +++ b/echo.go @@ -234,7 +234,7 @@ const ( const ( // Version of Echo - Version = "4.2.0" + Version = "4.2.1" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 4c2fd1fb042b122e2f96830ddb58aee6c9f90bf3 Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 9 Mar 2021 14:22:11 +0200 Subject: [PATCH 121/446] Allow proxy middleware to use query part in rewrite (fix #1798) (#1802) --- middleware/middleware.go | 46 ++++++++++++++++++++--------------- middleware/middleware_test.go | 29 +++++++++++++++++++--- middleware/proxy.go | 5 ++-- middleware/proxy_test.go | 27 ++++++++++++++------ middleware/rewrite.go | 6 ++--- middleware/rewrite_test.go | 9 +++++-- 6 files changed, 85 insertions(+), 37 deletions(-) diff --git a/middleware/middleware.go b/middleware/middleware.go index 6bdb0eb79..a7ad73a5c 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -2,7 +2,6 @@ package middleware import ( "net/http" - "net/url" "regexp" "strconv" "strings" @@ -49,30 +48,39 @@ func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string { return rulesRegex } -func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) { - for k, v := range rewriteRegex { - rawPath := req.URL.RawPath - if rawPath != "" { - // RawPath is only set when there has been escaping done. In that case Path must be deduced from rewritten RawPath - // because encoded Path could match rules that RawPath did not - if replacer := captureTokens(k, rawPath); replacer != nil { - rawPath = replacer.Replace(v) - - req.URL.RawPath = rawPath - req.URL.Path, _ = url.PathUnescape(rawPath) - - return // rewrite only once - } +func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error { + if len(rewriteRegex) == 0 { + return nil + } - continue + // Depending how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path. + // We only want to use path part for rewriting and therefore trim prefix if it exists + rawURI := req.RequestURI + if rawURI != "" && rawURI[0] != '/' { + prefix := "" + if req.URL.Scheme != "" { + prefix = req.URL.Scheme + "://" } + if req.URL.Host != "" { + prefix += req.URL.Host // host or host:port + } + if prefix != "" { + rawURI = strings.TrimPrefix(rawURI, prefix) + } + } - if replacer := captureTokens(k, req.URL.Path); replacer != nil { - req.URL.Path = replacer.Replace(v) + for k, v := range rewriteRegex { + if replacer := captureTokens(k, rawURI); replacer != nil { + url, err := req.URL.Parse(replacer.Replace(v)) + if err != nil { + return err + } + req.URL = url - return // rewrite only once + return nil // rewrite only once } } + return nil } // DefaultSkipper returns false which processes the middleware. diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index bc14c531d..44f44142c 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -8,11 +8,13 @@ import ( "testing" ) -func TestRewritePath(t *testing.T) { +func TestRewriteURL(t *testing.T) { var testCases = []struct { whenURL string expectPath string expectRawPath string + expectQuery string + expectErr string }{ { whenURL: "http://localhost:8080/old", @@ -28,6 +30,7 @@ func TestRewritePath(t *testing.T) { whenURL: "http://localhost:8080/users/+_+/orders/___++++?test=1", expectPath: "/user/+_+/order/___++++", expectRawPath: "", + expectQuery: "test=1", }, { whenURL: "http://localhost:8080/users/%20a/orders/%20aa", @@ -35,9 +38,10 @@ func TestRewritePath(t *testing.T) { expectRawPath: "", }, { - whenURL: "http://localhost:8080/%47%6f%2f", + whenURL: "http://localhost:8080/%47%6f%2f?test=1", expectPath: "/Go/", expectRawPath: "/%47%6f%2f", + expectQuery: "test=1", }, { whenURL: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F", @@ -49,21 +53,40 @@ func TestRewritePath(t *testing.T) { expectPath: "/user/jill/order/T/cO4lW/t/Vp/", expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", }, + { + whenURL: "http://localhost:8080/static", + expectPath: "/static/path", + expectRawPath: "", + expectQuery: "role=AUTHOR&limit=1000", + }, + { + whenURL: "/static", + expectPath: "/static/path", + expectRawPath: "", + expectQuery: "role=AUTHOR&limit=1000", + }, } rules := map[*regexp.Regexp]string{ regexp.MustCompile("^/old$"): "/new", regexp.MustCompile("^/users/(.*?)/orders/(.*?)$"): "/user/$1/order/$2", + regexp.MustCompile("^/static$"): "/static/path?role=AUTHOR&limit=1000", } for _, tc := range testCases { t.Run(tc.whenURL, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) - rewritePath(rules, req) + err := rewriteURL(rules, req) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } assert.Equal(t, tc.expectPath, req.URL.Path) // Path field is stored in decoded form: /%47%6f%2f becomes /Go/. assert.Equal(t, tc.expectRawPath, req.URL.RawPath) // RawPath, an optional field which only gets set if the default encoding is different from Path. + assert.Equal(t, tc.expectQuery, req.URL.RawQuery) }) } } diff --git a/middleware/proxy.go b/middleware/proxy.go index 63eec5a20..6f01f3a7c 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -231,8 +231,9 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { tgt := config.Balancer.Next(c) c.Set(config.ContextKey, tgt) - // Set rewrite path and raw path - rewritePath(config.RegexRewrite, req) + if err := rewriteURL(config.RegexRewrite, req); err != nil { + return err + } // Fix header // Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream. diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 591981e7f..93daf735e 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -245,12 +245,16 @@ func TestProxyRewrite(t *testing.T) { func TestProxyRewriteRegex(t *testing.T) { // Setup - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + receivedRequestURI := make(chan string, 1) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server + // we need unmodified target to see if we are encoding/decoding the url in addition to rewrite/replace logic + // if original request had `%2F` we should not magically decode it to `/` as it would change what was requested + receivedRequestURI <- r.RequestURI + })) defer upstream.Close() - url, _ := url.Parse(upstream.URL) - rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() + tmpUrL, _ := url.Parse(upstream.URL) + rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: tmpUrL}}) // Rewrite e := echo.New() @@ -279,14 +283,21 @@ func TestProxyRewriteRegex(t *testing.T) { {"/c/ignore1/test/this", http.StatusOK, "/v3/test/this"}, {"/x/ignore/test", http.StatusOK, "/v4/test"}, {"/y/foo/bar", http.StatusOK, "/v5/bar/foo"}, + // NB: fragment is not added by golang httputil.NewSingleHostReverseProxy implementation + // $2 = `bar?q=1#frag`, $1 = `foo`. replaced uri = `/v5/bar?q=1#frag/foo` but httputil.NewSingleHostReverseProxy does not send `#frag/foo` (currently) + {"/y/foo/bar?q=1#frag", http.StatusOK, "/v5/bar?q=1"}, } for _, tc := range testCases { t.Run(tc.requestPath, func(t *testing.T) { - req.URL, _ = url.Parse(tc.requestPath) - rec = httptest.NewRecorder() + targetURL, _ := url.Parse(tc.requestPath) + req := httptest.NewRequest(http.MethodGet, targetURL.String(), nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) - assert.Equal(t, tc.expectPath, req.URL.EscapedPath()) + + actualRequestURI := <-receivedRequestURI + assert.Equal(t, tc.expectPath, actualRequestURI) assert.Equal(t, tc.statusCode, rec.Code) }) } diff --git a/middleware/rewrite.go b/middleware/rewrite.go index c05d5d84f..e5b0a6b56 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -72,9 +72,9 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { return next(c) } - req := c.Request() - // Set rewrite path and raw path - rewritePath(config.RegexRules, req) + if err := rewriteURL(config.RegexRules, c.Request()); err != nil { + return err + } return next(c) } } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index cff2714d7..0ac04bb2f 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -220,6 +220,8 @@ func TestEchoRewriteWithRegexRules(t *testing.T) { func TestEchoRewriteReplacementEscaping(t *testing.T) { e := echo.New() + // NOTE: these are incorrect regexps as they do not factor in that URI we are replacing could contain ? (query) and # (fragment) parts + // so in reality they append query and fragment part as `$1` matches everything after that prefix e.Pre(RewriteWithConfig(RewriteConfig{ Rules: map[string]string{ "^/a/*": "/$1?query=param", @@ -228,6 +230,7 @@ func TestEchoRewriteReplacementEscaping(t *testing.T) { RegexRules: map[*regexp.Regexp]string{ regexp.MustCompile("^/x/(.*)"): "/$1?query=param", regexp.MustCompile("^/y/(.*)"): "/$1;part#one", + regexp.MustCompile("^/z/(.*)"): "/$1?test=1#escaped%20test", }, })) @@ -236,13 +239,15 @@ func TestEchoRewriteReplacementEscaping(t *testing.T) { testCases := []struct { requestPath string - expectPath string + expect string }{ {"/unmatched", "/unmatched"}, {"/a/test", "/test?query=param"}, {"/b/foo/bar", "/foo/bar;part#one"}, {"/x/test", "/test?query=param"}, {"/y/foo/bar", "/foo/bar;part#one"}, + {"/z/foo/b%20ar", "/foo/b%20ar?test=1#escaped%20test"}, + {"/z/foo/b%20ar?nope=1#yes", "/foo/b%20ar?nope=1#yes?test=1%23escaped%20test"}, // example of appending } for _, tc := range testCases { @@ -250,7 +255,7 @@ func TestEchoRewriteReplacementEscaping(t *testing.T) { req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, tc.expectPath, req.URL.Path) + assert.Equal(t, tc.expect, req.URL.String()) }) } } From dec96f0312ee1e1117366a420c57dd744444d0da Mon Sep 17 00:00:00 2001 From: Martti T Date: Fri, 12 Mar 2021 13:49:09 +0200 Subject: [PATCH 122/446] fix timeout middleware not sending status code when handler returns an error (fix #1804) (#1805) --- middleware/timeout.go | 5 +++++ middleware/timeout_test.go | 26 +++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/middleware/timeout.go b/middleware/timeout.go index 68f464e40..99d436ac2 100644 --- a/middleware/timeout.go +++ b/middleware/timeout.go @@ -106,6 +106,11 @@ func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Reques // so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body t.ctx.Response().Writer = originalWriter if err != nil { + // call global error handler to write error to the client. This is needed or `http.TimeoutHandler` will send status code by itself + // and after that our tries to write status code will not work anymore + t.ctx.Error(err) + // we pass error from handler to middlewares up in handler chain to act on it if needed. But this means that + // global error handler is probably be called twice as `t.ctx.Error` already does that. t.errChan <- err } } diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index af4c62647..8f8fa3049 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -71,10 +71,34 @@ func TestTimeoutErrorOutInHandler(t *testing.T) { c := e.NewContext(req, rec) err := m(func(c echo.Context) error { - return errors.New("err") + return echo.NewHTTPError(http.StatusTeapot, "err") })(c) assert.Error(t, err) + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.Equal(t, "{\"message\":\"err\"}\n", rec.Body.String()) +} + +func TestTimeoutSuccessfulRequest(t *testing.T) { + t.Parallel() + m := TimeoutWithConfig(TimeoutConfig{ + // Timeout has to be defined or the whole flow for timeout middleware will be skipped + Timeout: 50 * time.Millisecond, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + return c.JSON(http.StatusCreated, map[string]string{"data": "ok"}) + })(c) + + assert.NoError(t, err) + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Equal(t, "{\"data\":\"ok\"}\n", rec.Body.String()) } func TestTimeoutOnTimeoutRouteErrorHandler(t *testing.T) { From 67f6346df242da6cf22a7b593b4ec3631063fe35 Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 6 Apr 2021 10:05:33 +0300 Subject: [PATCH 123/446] Fix Bind() when target is array/slice and path/query params complain target not being struct (#1835) For path/query params binding we do not try (silently return) to bind when target is not struct. Recreates PR #1574 and fixes #1565 --- bind.go | 4 +++ bind_test.go | 70 +++++++++++++++++++++++++++++++++++++++------------- echo_test.go | 4 +++ 3 files changed, 61 insertions(+), 17 deletions(-) diff --git a/bind.go b/bind.go index 16c3b7adf..08d398916 100644 --- a/bind.go +++ b/bind.go @@ -134,6 +134,10 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri // !struct if typ.Kind() != reflect.Struct { + if tag == "param" || tag == "query" { + // incompatible type, data is probably to be found in the body + return nil + } return errors.New("binding element must be a struct") } diff --git a/bind_test.go b/bind_test.go index e8868b35b..73398034e 100644 --- a/bind_test.go +++ b/bind_test.go @@ -9,6 +9,7 @@ import ( "mime/multipart" "net/http" "net/http/httptest" + "net/url" "reflect" "strconv" "strings" @@ -187,7 +188,10 @@ func TestToMultipleFields(t *testing.T) { func TestBindJSON(t *testing.T) { assert := assert.New(t) - testBindOkay(assert, strings.NewReader(userJSON), MIMEApplicationJSON) + testBindOkay(assert, strings.NewReader(userJSON), nil, MIMEApplicationJSON) + testBindOkay(assert, strings.NewReader(userJSON), dummyQuery, MIMEApplicationJSON) + testBindArrayOkay(assert, strings.NewReader(usersJSON), nil, MIMEApplicationJSON) + testBindArrayOkay(assert, strings.NewReader(usersJSON), dummyQuery, MIMEApplicationJSON) testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) testBindError(assert, strings.NewReader(userJSONInvalidType), MIMEApplicationJSON, &json.UnmarshalTypeError{}) } @@ -195,11 +199,15 @@ func TestBindJSON(t *testing.T) { func TestBindXML(t *testing.T) { assert := assert.New(t) - testBindOkay(assert, strings.NewReader(userXML), MIMEApplicationXML) + testBindOkay(assert, strings.NewReader(userXML), nil, MIMEApplicationXML) + testBindOkay(assert, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) + testBindArrayOkay(assert, strings.NewReader(userXML), nil, MIMEApplicationXML) + testBindArrayOkay(assert, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationXML, errors.New("")) testBindError(assert, strings.NewReader(userXMLConvertNumberError), MIMEApplicationXML, &strconv.NumError{}) testBindError(assert, strings.NewReader(userXMLUnsupportedTypeError), MIMEApplicationXML, &xml.SyntaxError{}) - testBindOkay(assert, strings.NewReader(userXML), MIMETextXML) + testBindOkay(assert, strings.NewReader(userXML), nil, MIMETextXML) + testBindOkay(assert, strings.NewReader(userXML), dummyQuery, MIMETextXML) testBindError(assert, strings.NewReader(invalidContent), MIMETextXML, errors.New("")) testBindError(assert, strings.NewReader(userXMLConvertNumberError), MIMETextXML, &strconv.NumError{}) testBindError(assert, strings.NewReader(userXMLUnsupportedTypeError), MIMETextXML, &xml.SyntaxError{}) @@ -208,7 +216,8 @@ func TestBindXML(t *testing.T) { func TestBindForm(t *testing.T) { assert := assert.New(t) - testBindOkay(assert, strings.NewReader(userForm), MIMEApplicationForm) + testBindOkay(assert, strings.NewReader(userForm), nil, MIMEApplicationForm) + testBindOkay(assert, strings.NewReader(userForm), dummyQuery, MIMEApplicationForm) e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userForm)) rec := httptest.NewRecorder() @@ -336,14 +345,16 @@ func TestBindUnmarshalTextPtr(t *testing.T) { } func TestBindMultipartForm(t *testing.T) { - body := new(bytes.Buffer) - mw := multipart.NewWriter(body) + bodyBuffer := new(bytes.Buffer) + mw := multipart.NewWriter(bodyBuffer) mw.WriteField("id", "1") mw.WriteField("name", "Jon Snow") mw.Close() + body := bodyBuffer.Bytes() assert := assert.New(t) - testBindOkay(assert, body, mw.FormDataContentType()) + testBindOkay(assert, bytes.NewReader(body), nil, mw.FormDataContentType()) + testBindOkay(assert, bytes.NewReader(body), dummyQuery, mw.FormDataContentType()) } func TestBindUnsupportedMediaType(t *testing.T) { @@ -547,9 +558,13 @@ func assertBindTestStruct(a *assert.Assertions, ts *bindTestStruct) { a.Equal("", ts.GetCantSet()) } -func testBindOkay(assert *assert.Assertions, r io.Reader, ctype string) { +func testBindOkay(assert *assert.Assertions, r io.Reader, query url.Values, ctype string) { e := New() - req := httptest.NewRequest(http.MethodPost, "/", r) + path := "/" + if len(query) > 0 { + path += "?" + query.Encode() + } + req := httptest.NewRequest(http.MethodPost, path, r) rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(HeaderContentType, ctype) @@ -561,6 +576,25 @@ func testBindOkay(assert *assert.Assertions, r io.Reader, ctype string) { } } +func testBindArrayOkay(assert *assert.Assertions, r io.Reader, query url.Values, ctype string) { + e := New() + path := "/" + if len(query) > 0 { + path += "?" + query.Encode() + } + req := httptest.NewRequest(http.MethodPost, path, r) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + req.Header.Set(HeaderContentType, ctype) + u := []user{} + err := c.Bind(&u) + if assert.NoError(err) { + assert.Equal(1, len(u)) + assert.Equal(1, u[0].ID) + assert.Equal("Jon Snow", u[0].Name) + } +} + func testBindError(assert *assert.Assertions, r io.Reader, ctype string, expectedInternal error) { e := New() req := httptest.NewRequest(http.MethodPost, "/", r) @@ -679,15 +713,16 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { expect: &Opts{ID: 0, Node: "xxx"}, // query binding has already modified bind target expectError: "code=400, message=Unmarshal type error: expected=echo.Opts, got=array, field=, offset=1, internal=json: cannot unmarshal array into Go value of type echo.Opts", }, - { // binding query params interferes with body. b.BindBody() should be used to bind only body to slice - name: "nok, GET query params bind failure - trying to bind json array to slice", + { // query param is ignored as we do not know where exactly to bind it in slice + name: "ok, GET bind to struct slice, ignore query param", givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), whenNoPathParams: true, whenBindTarget: &[]Opts{}, - expect: &[]Opts{}, - expectError: "code=400, message=binding element must be a struct, internal=binding element must be a struct", + expect: &[]Opts{ + {ID: 1, Node: ""}, + }, }, { // binding query params interferes with body. b.BindBody() should be used to bind only body to slice name: "ok, POST binding to slice should not be affected query params types", @@ -699,14 +734,15 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { expect: &[]Opts{{ID: 1}}, expectError: "", }, - { // binding path params interferes with body. b.BindBody() should be used to bind only body to slice - name: "nok, GET path params bind failure - trying to bind json array to slice", + { // path param is ignored as we do not know where exactly to bind it in slice + name: "ok, GET bind to struct slice, ignore path param", givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), whenBindTarget: &[]Opts{}, - expect: &[]Opts{}, - expectError: "code=400, message=binding element must be a struct, internal=binding element must be a struct", + expect: &[]Opts{ + {ID: 1, Node: ""}, + }, }, { name: "ok, GET body bind json array to slice", diff --git a/echo_test.go b/echo_test.go index 35c79cbc0..58ecea741 100644 --- a/echo_test.go +++ b/echo_test.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "net/http/httptest" + "net/url" "os" "reflect" "strings" @@ -30,6 +31,7 @@ type ( const ( userJSON = `{"id":1,"name":"Jon Snow"}` + usersJSON = `[{"id":1,"name":"Jon Snow"}]` userXML = `1Jon Snow` userForm = `id=1&name=Jon Snow` invalidContent = "invalid content" @@ -48,6 +50,8 @@ const userXMLPretty = ` Jon Snow ` +var dummyQuery = url.Values{"dummy": []string{"useless"}} + func TestEcho(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) From ae4665cf7a215d14b3ba769bfa355c5420ce10ef Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 6 Apr 2021 10:11:31 +0300 Subject: [PATCH 124/446] Fix panic in redirect middleware on short host name (fix #1811) (#1813) --- middleware/redirect.go | 45 +++--- middleware/redirect_test.go | 263 +++++++++++++++++++++++++++++++----- 2 files changed, 249 insertions(+), 59 deletions(-) diff --git a/middleware/redirect.go b/middleware/redirect.go index 813e5b856..13877db38 100644 --- a/middleware/redirect.go +++ b/middleware/redirect.go @@ -2,6 +2,7 @@ package middleware import ( "net/http" + "strings" "github.com/labstack/echo/v4" ) @@ -40,11 +41,11 @@ func HTTPSRedirect() echo.MiddlewareFunc { // HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config. // See `HTTPSRedirect()`. func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if ok = scheme != "https"; ok { - url = "https://" + host + uri + return redirect(config, func(scheme, host, uri string) (bool, string) { + if scheme != "https" { + return true, "https://" + host + uri } - return + return false, "" }) } @@ -59,11 +60,11 @@ func HTTPSWWWRedirect() echo.MiddlewareFunc { // HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // See `HTTPSWWWRedirect()`. func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if ok = scheme != "https" && host[:4] != www; ok { - url = "https://www." + host + uri + return redirect(config, func(scheme, host, uri string) (bool, string) { + if scheme != "https" && !strings.HasPrefix(host, www) { + return true, "https://www." + host + uri } - return + return false, "" }) } @@ -79,13 +80,11 @@ func HTTPSNonWWWRedirect() echo.MiddlewareFunc { // See `HTTPSNonWWWRedirect()`. func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if ok = scheme != "https"; ok { - if host[:4] == www { - host = host[4:] - } - url = "https://" + host + uri + if scheme != "https" { + host = strings.TrimPrefix(host, www) + return true, "https://" + host + uri } - return + return false, "" }) } @@ -100,11 +99,11 @@ func WWWRedirect() echo.MiddlewareFunc { // WWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // See `WWWRedirect()`. func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if ok = host[:4] != www; ok { - url = scheme + "://www." + host + uri + return redirect(config, func(scheme, host, uri string) (bool, string) { + if !strings.HasPrefix(host, www) { + return true, scheme + "://www." + host + uri } - return + return false, "" }) } @@ -119,17 +118,17 @@ func NonWWWRedirect() echo.MiddlewareFunc { // NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // See `NonWWWRedirect()`. func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if ok = host[:4] == www; ok { - url = scheme + "://" + host[4:] + uri + return redirect(config, func(scheme, host, uri string) (bool, string) { + if strings.HasPrefix(host, www) { + return true, scheme + "://" + host[4:] + uri } - return + return false, "" }) } func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { if config.Skipper == nil { - config.Skipper = DefaultTrailingSlashConfig.Skipper + config.Skipper = DefaultRedirectConfig.Skipper } if config.Code == 0 { config.Code = DefaultRedirectConfig.Code diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go index 082609574..9d1b56205 100644 --- a/middleware/redirect_test.go +++ b/middleware/redirect_test.go @@ -12,62 +12,253 @@ import ( type middlewareGenerator func() echo.MiddlewareFunc func TestRedirectHTTPSRedirect(t *testing.T) { - res := redirectTest(HTTPSRedirect, "labstack.com", nil) - - assert.Equal(t, http.StatusMovedPermanently, res.Code) - assert.Equal(t, "https://labstack.com/", res.Header().Get(echo.HeaderLocation)) -} + var testCases = []struct { + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + whenHost: "labstack.com", + expectLocation: "https://labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "labstack.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + } -func TestHTTPSRedirectBehindTLSTerminationProxy(t *testing.T) { - header := http.Header{} - header.Set(echo.HeaderXForwardedProto, "https") - res := redirectTest(HTTPSRedirect, "labstack.com", header) + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + res := redirectTest(HTTPSRedirect, tc.whenHost, tc.whenHeader) - assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } } func TestRedirectHTTPSWWWRedirect(t *testing.T) { - res := redirectTest(HTTPSWWWRedirect, "labstack.com", nil) - - assert.Equal(t, http.StatusMovedPermanently, res.Code) - assert.Equal(t, "https://www.labstack.com/", res.Header().Get(echo.HeaderLocation)) -} + var testCases = []struct { + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + whenHost: "labstack.com", + expectLocation: "https://www.labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.labstack.com", + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + { + whenHost: "a.com", + expectLocation: "https://www.a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "ip", + expectLocation: "https://www.ip/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "labstack.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + } -func TestRedirectHTTPSWWWRedirectBehindTLSTerminationProxy(t *testing.T) { - header := http.Header{} - header.Set(echo.HeaderXForwardedProto, "https") - res := redirectTest(HTTPSWWWRedirect, "labstack.com", header) + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + res := redirectTest(HTTPSWWWRedirect, tc.whenHost, tc.whenHeader) - assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } } func TestRedirectHTTPSNonWWWRedirect(t *testing.T) { - res := redirectTest(HTTPSNonWWWRedirect, "www.labstack.com", nil) - - assert.Equal(t, http.StatusMovedPermanently, res.Code) - assert.Equal(t, "https://labstack.com/", res.Header().Get(echo.HeaderLocation)) -} + var testCases = []struct { + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + whenHost: "www.labstack.com", + expectLocation: "https://labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "a.com", + expectLocation: "https://a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "ip", + expectLocation: "https://ip/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.labstack.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + } -func TestRedirectHTTPSNonWWWRedirectBehindTLSTerminationProxy(t *testing.T) { - header := http.Header{} - header.Set(echo.HeaderXForwardedProto, "https") - res := redirectTest(HTTPSNonWWWRedirect, "www.labstack.com", header) + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + res := redirectTest(HTTPSNonWWWRedirect, tc.whenHost, tc.whenHeader) - assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } } func TestRedirectWWWRedirect(t *testing.T) { - res := redirectTest(WWWRedirect, "labstack.com", nil) + var testCases = []struct { + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + whenHost: "labstack.com", + expectLocation: "http://www.labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "a.com", + expectLocation: "http://www.a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "ip", + expectLocation: "http://www.ip/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "a.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "https://www.a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.ip", + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + res := redirectTest(WWWRedirect, tc.whenHost, tc.whenHeader) - assert.Equal(t, http.StatusMovedPermanently, res.Code) - assert.Equal(t, "http://www.labstack.com/", res.Header().Get(echo.HeaderLocation)) + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } } func TestRedirectNonWWWRedirect(t *testing.T) { - res := redirectTest(NonWWWRedirect, "www.labstack.com", nil) + var testCases = []struct { + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + whenHost: "www.labstack.com", + expectLocation: "http://labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.a.com", + expectLocation: "http://a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.a.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "https://a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "ip", + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + res := redirectTest(NonWWWRedirect, tc.whenHost, tc.whenHeader) + + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } +} + +func TestNonWWWRedirectWithConfig(t *testing.T) { + var testCases = []struct { + name string + givenCode int + givenSkipFunc func(c echo.Context) bool + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + name: "usual redirect", + whenHost: "www.labstack.com", + expectLocation: "http://labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + name: "redirect is skipped", + givenSkipFunc: func(c echo.Context) bool { + return true // skip always + }, + whenHost: "www.labstack.com", + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + { + name: "redirect with custom status code", + givenCode: http.StatusSeeOther, + whenHost: "www.labstack.com", + expectLocation: "http://labstack.com/", + expectStatusCode: http.StatusSeeOther, + }, + } + + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + middleware := func() echo.MiddlewareFunc { + return NonWWWRedirectWithConfig(RedirectConfig{ + Skipper: tc.givenSkipFunc, + Code: tc.givenCode, + }) + } + res := redirectTest(middleware, tc.whenHost, tc.whenHeader) - assert.Equal(t, http.StatusMovedPermanently, res.Code) - assert.Equal(t, "http://labstack.com/", res.Header().Get(echo.HeaderLocation)) + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } } func redirectTest(fn middlewareGenerator, host string, header http.Header) *httptest.ResponseRecorder { From 10d8c53d55c89ccdc93112f55a680ff49fb320f3 Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 6 Apr 2021 10:12:00 +0300 Subject: [PATCH 125/446] Fix timeout middleware docs (fixes #1816) (#1836) --- middleware/timeout.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/middleware/timeout.go b/middleware/timeout.go index 99d436ac2..5d23ff455 100644 --- a/middleware/timeout.go +++ b/middleware/timeout.go @@ -42,8 +42,8 @@ var ( } ) -// Timeout returns a middleware which recovers from panics anywhere in the chain -// and handles the control to the centralized HTTPErrorHandler. +// Timeout returns a middleware which returns error (503 Service Unavailable error) to client immediately when handler +// call runs for longer than its time limit. NB: timeout does not stop handler execution. func Timeout() echo.MiddlewareFunc { return TimeoutWithConfig(DefaultTimeoutConfig) } From 8da8e161380fd926d4341721f0328f1e94d6d0a2 Mon Sep 17 00:00:00 2001 From: Martti T Date: Wed, 7 Apr 2021 22:45:14 +0300 Subject: [PATCH 126/446] Update version and changelog for 4.2.2 (#1838) --- CHANGELOG.md | 10 ++++++++++ echo.go | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b50478830..c1be77a91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## v4.2.2 - 2020-04-07 + +**Fixes** + +* Allow proxy middleware to use query part in rewrite (#1802) +* Fix timeout middleware not sending status code when handler returns an error (#1805) +* Fix Bind() when target is array/slice and path/query params complains bind target not being struct (#1835) +* Fix panic in redirect middleware on short host name (#1813) +* Fix timeout middleware docs (#1836) + ## v4.2.1 - 2020-03-08 **Important notes** diff --git a/echo.go b/echo.go index 3fccaf648..a24e3977f 100644 --- a/echo.go +++ b/echo.go @@ -234,7 +234,7 @@ const ( const ( // Version of Echo - Version = "4.2.1" + Version = "4.2.2" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From bb7f2223bbddee817a1e8a940d186a7328025f6e Mon Sep 17 00:00:00 2001 From: Martti T Date: Fri, 9 Apr 2021 10:14:23 +0300 Subject: [PATCH 127/446] Update and tidy dependencies (#1841) --- go.mod | 10 +++++----- go.sum | 36 ++++++++++++++---------------------- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/go.mod b/go.mod index 877117075..2510d10c6 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,12 @@ go 1.15 require ( github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/labstack/gommon v0.3.0 - github.com/mattn/go-colorable v0.1.7 // indirect + github.com/mattn/go-colorable v0.1.8 // indirect github.com/stretchr/testify v1.4.0 github.com/valyala/fasttemplate v1.2.1 - golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a - golang.org/x/net v0.0.0-20200822124328-c89045814202 - golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6 // indirect - golang.org/x/text v0.3.3 // indirect + golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 + golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 + golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57 // indirect + golang.org/x/text v0.3.6 // indirect golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 ) diff --git a/go.sum b/go.sum index 54ba24e67..d18f10fb6 100644 --- a/go.sum +++ b/go.sum @@ -4,12 +4,10 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumC github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0= github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= -github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-colorable v0.1.7 h1:bQGKb3vps/j0E9GfJQ03JyhRuxsvdAanXlT9BTw3mdw= -github.com/mattn/go-colorable v0.1.7/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8= +github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.9 h1:d5US/mDsogSGW37IV293h//ZFaeajb69h+EHFsv2xGg= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= @@ -20,32 +18,26 @@ github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJy github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasttemplate v1.0.1 h1:tY9CJiPnMXf1ERmG2EyK7gNUd+c6RKGD0IfU8WdUSz8= github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM= -golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= -golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w= +golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a h1:aYOabOQFp6Vj6W1F80affTUvO9UxmJRx8K0gsfABByQ= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae h1:/WDfKMnPU+m5M4xB+6x4kaepxRw6jWvR5iDRdvjHgy8= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6 h1:DvY3Zkh7KabQE/kfzMvYvKirSiguP9Q/veMtkYyf0o8= -golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57 h1:F5Gozwx4I1xtr/sr/8CFbb57iKi3297KFs0QDbGN60A= +golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= From a4ab482b605fb260385212b9bcbcfcdef00a0c27 Mon Sep 17 00:00:00 2001 From: Martti T Date: Fri, 16 Apr 2021 12:38:12 +0300 Subject: [PATCH 128/446] Add Go 1.16 to CI and drop 1.12 specific code (#1850) * Correct incorrect years in CHANGELOG.md * CI tests with last 4 versions. Remove 1.12 and below specific code * Rename proxy test --- .github/workflows/echo.yml | 8 +-- CHANGELOG.md | 6 +-- echo_go1.13_test.go | 28 ----------- echo_test.go | 17 +++++++ middleware/csrf.go | 2 +- middleware/csrf_samesite.go | 12 ----- middleware/csrf_samesite_1.12.go | 12 ----- middleware/csrf_samesite_test.go | 33 ------------ middleware/csrf_test.go | 20 ++++++++ middleware/proxy.go | 37 ++++++++++++++ middleware/proxy_1_11.go | 47 ----------------- middleware/proxy_1_11_n.go | 14 ------ middleware/proxy_1_11_test.go | 86 -------------------------------- middleware/proxy_test.go | 73 +++++++++++++++++++++++++++ middleware/timeout.go | 2 - middleware/timeout_test.go | 2 - 16 files changed, 156 insertions(+), 243 deletions(-) delete mode 100644 echo_go1.13_test.go delete mode 100644 middleware/csrf_samesite.go delete mode 100644 middleware/csrf_samesite_1.12.go delete mode 100644 middleware/csrf_samesite_test.go delete mode 100644 middleware/proxy_1_11.go delete mode 100644 middleware/proxy_1_11_n.go delete mode 100644 middleware/proxy_1_11_test.go diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index fb8c50205..ec5517561 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -25,7 +25,9 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - go: [1.12, 1.13, 1.14, 1.15] + # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy + # Echo tests with last four major releases + go: [1.13, 1.14, 1.15, 1.16] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: @@ -59,7 +61,7 @@ jobs: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... - name: Upload coverage to Codecov - if: success() && matrix.go == 1.15 && matrix.os == 'ubuntu-latest' + if: success() && matrix.go == 1.16 && matrix.os == 'ubuntu-latest' uses: codecov/codecov-action@v1 with: token: @@ -69,7 +71,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - go: [1.15] + go: [1.16] name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: diff --git a/CHANGELOG.md b/CHANGELOG.md index c1be77a91..a3117b80c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Changelog -## v4.2.2 - 2020-04-07 +## v4.2.2 - 2021-04-07 **Fixes** @@ -10,7 +10,7 @@ * Fix panic in redirect middleware on short host name (#1813) * Fix timeout middleware docs (#1836) -## v4.2.1 - 2020-03-08 +## v4.2.1 - 2021-03-08 **Important notes** @@ -32,7 +32,7 @@ A performance regression has been fixed, even bringing better performance than b This release was made possible by our **contributors**: aldas, clwluvw, lammel, Le0tk0k, maciej-jezierski, rkilingr, stffabi, withshubh -## v4.2.0 - 2020-02-11 +## v4.2.0 - 2021-02-11 **Important notes** diff --git a/echo_go1.13_test.go b/echo_go1.13_test.go deleted file mode 100644 index 3c488bc63..000000000 --- a/echo_go1.13_test.go +++ /dev/null @@ -1,28 +0,0 @@ -// +build go1.13 - -package echo - -import ( - "errors" - "net/http" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestHTTPError_Unwrap(t *testing.T) { - t.Run("non-internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - - assert.Nil(t, errors.Unwrap(err)) - }) - t.Run("internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - err.SetInternal(errors.New("internal error")) - assert.Equal(t, "internal error", errors.Unwrap(err).Error()) - }) -} diff --git a/echo_test.go b/echo_test.go index 58ecea741..ba498831b 100644 --- a/echo_test.go +++ b/echo_test.go @@ -957,6 +957,23 @@ func TestHTTPError(t *testing.T) { }) } +func TestHTTPError_Unwrap(t *testing.T) { + t.Run("non-internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + + assert.Nil(t, errors.Unwrap(err)) + }) + t.Run("internal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + err.SetInternal(errors.New("internal error")) + assert.Equal(t, "internal error", errors.Unwrap(err).Error()) + }) +} + func TestDefaultHTTPErrorHandler(t *testing.T) { e := New() e.Debug = true diff --git a/middleware/csrf.go b/middleware/csrf.go index 60f809a04..7804997d4 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -110,7 +110,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.CookieMaxAge == 0 { config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge } - if config.CookieSameSite == SameSiteNoneMode { + if config.CookieSameSite == http.SameSiteNoneMode { config.CookieSecure = true } diff --git a/middleware/csrf_samesite.go b/middleware/csrf_samesite.go deleted file mode 100644 index 9a27dc431..000000000 --- a/middleware/csrf_samesite.go +++ /dev/null @@ -1,12 +0,0 @@ -// +build go1.13 - -package middleware - -import ( - "net/http" -) - -const ( - // SameSiteNoneMode required to be redefined for Go 1.12 support (see #1524) - SameSiteNoneMode http.SameSite = http.SameSiteNoneMode -) diff --git a/middleware/csrf_samesite_1.12.go b/middleware/csrf_samesite_1.12.go deleted file mode 100644 index 22076dd6a..000000000 --- a/middleware/csrf_samesite_1.12.go +++ /dev/null @@ -1,12 +0,0 @@ -// +build !go1.13 - -package middleware - -import ( - "net/http" -) - -const ( - // SameSiteNoneMode required to be redefined for Go 1.12 support (see #1524) - SameSiteNoneMode http.SameSite = 4 -) diff --git a/middleware/csrf_samesite_test.go b/middleware/csrf_samesite_test.go deleted file mode 100644 index 26c5bc455..000000000 --- a/middleware/csrf_samesite_test.go +++ /dev/null @@ -1,33 +0,0 @@ -// +build go1.13 - -package middleware - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" -) - -// Test for SameSiteModeNone moved to separate file for Go 1.12 support -func TestCSRFWithSameSiteModeNone(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - csrf := CSRFWithConfig(CSRFConfig{ - CookieSameSite: SameSiteNoneMode, - }) - - h := csrf(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - - r := h(c) - assert.NoError(t, r) - assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"]) - assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"]) -} diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index ebe4dbcde..af1d26394 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -138,3 +138,23 @@ func TestCSRFWithSameSiteDefaultMode(t *testing.T) { fmt.Println(rec.Header()["Set-Cookie"]) assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"]) } + +func TestCSRFWithSameSiteModeNone(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + csrf := CSRFWithConfig(CSRFConfig{ + CookieSameSite: http.SameSiteNoneMode, + }) + + h := csrf(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + r := h(c) + assert.NoError(t, r) + assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"]) + assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"]) +} diff --git a/middleware/proxy.go b/middleware/proxy.go index 6f01f3a7c..6cfd6731e 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -1,13 +1,16 @@ package middleware import ( + "context" "fmt" "io" "math/rand" "net" "net/http" + "net/http/httputil" "net/url" "regexp" + "strings" "sync" "sync/atomic" "time" @@ -264,3 +267,37 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { } } } + +// StatusCodeContextCanceled is a custom HTTP status code for situations +// where a client unexpectedly closed the connection to the server. +// As there is no standard error code for "client closed connection", but +// various well-known HTTP clients and server implement this HTTP code we use +// 499 too instead of the more problematic 5xx, which does not allow to detect this situation +const StatusCodeContextCanceled = 499 + +func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { + proxy := httputil.NewSingleHostReverseProxy(tgt.URL) + proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { + desc := tgt.URL.String() + if tgt.Name != "" { + desc = fmt.Sprintf("%s(%s)", tgt.Name, tgt.URL.String()) + } + // If the client canceled the request (usually by closing the connection), we can report a + // client error (4xx) instead of a server error (5xx) to correctly identify the situation. + // The Go standard library (at of late 2020) wraps the exported, standard + // context.Canceled error with unexported garbage value requiring a substring check, see + // https://github.com/golang/go/blob/6965b01ea248cabb70c3749fd218b36089a21efb/src/net/net.go#L416-L430 + if err == context.Canceled || strings.Contains(err.Error(), "operation was canceled") { + httpError := echo.NewHTTPError(StatusCodeContextCanceled, fmt.Sprintf("client closed connection: %v", err)) + httpError.Internal = err + c.Set("_error", httpError) + } else { + httpError := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err)) + httpError.Internal = err + c.Set("_error", httpError) + } + } + proxy.Transport = config.Transport + proxy.ModifyResponse = config.ModifyResponse + return proxy +} diff --git a/middleware/proxy_1_11.go b/middleware/proxy_1_11.go deleted file mode 100644 index 17d142d8d..000000000 --- a/middleware/proxy_1_11.go +++ /dev/null @@ -1,47 +0,0 @@ -// +build go1.11 - -package middleware - -import ( - "context" - "fmt" - "net/http" - "net/http/httputil" - "strings" - - "github.com/labstack/echo/v4" -) - -// StatusCodeContextCanceled is a custom HTTP status code for situations -// where a client unexpectedly closed the connection to the server. -// As there is no standard error code for "client closed connection", but -// various well-known HTTP clients and server implement this HTTP code we use -// 499 too instead of the more problematic 5xx, which does not allow to detect this situation -const StatusCodeContextCanceled = 499 - -func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { - proxy := httputil.NewSingleHostReverseProxy(tgt.URL) - proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { - desc := tgt.URL.String() - if tgt.Name != "" { - desc = fmt.Sprintf("%s(%s)", tgt.Name, tgt.URL.String()) - } - // If the client canceled the request (usually by closing the connection), we can report a - // client error (4xx) instead of a server error (5xx) to correctly identify the situation. - // The Go standard library (at of late 2020) wraps the exported, standard - // context.Canceled error with unexported garbage value requiring a substring check, see - // https://github.com/golang/go/blob/6965b01ea248cabb70c3749fd218b36089a21efb/src/net/net.go#L416-L430 - if err == context.Canceled || strings.Contains(err.Error(), "operation was canceled") { - httpError := echo.NewHTTPError(StatusCodeContextCanceled, fmt.Sprintf("client closed connection: %v", err)) - httpError.Internal = err - c.Set("_error", httpError) - } else { - httpError := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err)) - httpError.Internal = err - c.Set("_error", httpError) - } - } - proxy.Transport = config.Transport - proxy.ModifyResponse = config.ModifyResponse - return proxy -} diff --git a/middleware/proxy_1_11_n.go b/middleware/proxy_1_11_n.go deleted file mode 100644 index 9a78929fe..000000000 --- a/middleware/proxy_1_11_n.go +++ /dev/null @@ -1,14 +0,0 @@ -// +build !go1.11 - -package middleware - -import ( - "net/http" - "net/http/httputil" - - "github.com/labstack/echo/v4" -) - -func proxyHTTP(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { - return httputil.NewSingleHostReverseProxy(t.URL) -} diff --git a/middleware/proxy_1_11_test.go b/middleware/proxy_1_11_test.go deleted file mode 100644 index c3541d5e8..000000000 --- a/middleware/proxy_1_11_test.go +++ /dev/null @@ -1,86 +0,0 @@ -// +build go1.11 - -package middleware - -import ( - "context" - "net/http" - "net/http/httptest" - "net/url" - "sync" - "testing" - "time" - - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" -) - -func TestProxy_1_11(t *testing.T) { - // Setup - url1, _ := url.Parse("http://127.0.0.1:27121") - url2, _ := url.Parse("http://127.0.0.1:27122") - - targets := []*ProxyTarget{ - { - Name: "target 1", - URL: url1, - }, - { - Name: "target 2", - URL: url2, - }, - } - rb := NewRandomBalancer(nil) - // must add targets: - for _, target := range targets { - assert.True(t, rb.AddTarget(target)) - } - - // must ignore duplicates: - for _, target := range targets { - assert.False(t, rb.AddTarget(target)) - } - - // Random - e := echo.New() - e.Use(Proxy(rb)) - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - - // Remote unreachable - rec = httptest.NewRecorder() - req.URL.Path = "/api/users" - e.ServeHTTP(rec, req) - assert.Equal(t, "/api/users", req.URL.Path) - assert.Equal(t, http.StatusBadGateway, rec.Code) -} - -func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { - var timeoutStop sync.WaitGroup - timeoutStop.Add(1) - HTTPTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - timeoutStop.Wait() // wait until we have canceled the request - w.WriteHeader(http.StatusOK) - })) - defer HTTPTarget.Close() - targetURL, _ := url.Parse(HTTPTarget.URL) - target := &ProxyTarget{ - Name: "target", - URL: targetURL, - } - rb := NewRandomBalancer(nil) - assert.True(t, rb.AddTarget(target)) - e := echo.New() - e.Use(Proxy(rb)) - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - ctx, cancel := context.WithCancel(req.Context()) - req = req.WithContext(ctx) - go func() { - time.Sleep(10 * time.Millisecond) - cancel() - }() - e.ServeHTTP(rec, req) - timeoutStop.Done() - assert.Equal(t, 499, rec.Code) -} diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 93daf735e..7939fc5c2 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -2,6 +2,7 @@ package middleware import ( "bytes" + "context" "fmt" "io/ioutil" "net" @@ -9,7 +10,9 @@ import ( "net/http/httptest" "net/url" "regexp" + "sync" "testing" + "time" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" @@ -302,3 +305,73 @@ func TestProxyRewriteRegex(t *testing.T) { }) } } + +func TestProxyError(t *testing.T) { + // Setup + url1, _ := url.Parse("http://127.0.0.1:27121") + url2, _ := url.Parse("http://127.0.0.1:27122") + + targets := []*ProxyTarget{ + { + Name: "target 1", + URL: url1, + }, + { + Name: "target 2", + URL: url2, + }, + } + rb := NewRandomBalancer(nil) + // must add targets: + for _, target := range targets { + assert.True(t, rb.AddTarget(target)) + } + + // must ignore duplicates: + for _, target := range targets { + assert.False(t, rb.AddTarget(target)) + } + + // Random + e := echo.New() + e.Use(Proxy(rb)) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + // Remote unreachable + rec = httptest.NewRecorder() + req.URL.Path = "/api/users" + e.ServeHTTP(rec, req) + assert.Equal(t, "/api/users", req.URL.Path) + assert.Equal(t, http.StatusBadGateway, rec.Code) +} + +func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { + var timeoutStop sync.WaitGroup + timeoutStop.Add(1) + HTTPTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + timeoutStop.Wait() // wait until we have canceled the request + w.WriteHeader(http.StatusOK) + })) + defer HTTPTarget.Close() + targetURL, _ := url.Parse(HTTPTarget.URL) + target := &ProxyTarget{ + Name: "target", + URL: targetURL, + } + rb := NewRandomBalancer(nil) + assert.True(t, rb.AddTarget(target)) + e := echo.New() + e.Use(Proxy(rb)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + ctx, cancel := context.WithCancel(req.Context()) + req = req.WithContext(ctx) + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + e.ServeHTTP(rec, req) + timeoutStop.Done() + assert.Equal(t, 499, rec.Code) +} diff --git a/middleware/timeout.go b/middleware/timeout.go index 5d23ff455..d56e463b0 100644 --- a/middleware/timeout.go +++ b/middleware/timeout.go @@ -1,5 +1,3 @@ -// +build go1.13 - package middleware import ( diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index 8f8fa3049..f9f1826be 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -1,5 +1,3 @@ -// +build go1.13 - package middleware import ( From 3b07058a1d8f440497dc7be1d2ad9a2767c84a57 Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Sat, 17 Apr 2021 12:47:48 -0700 Subject: [PATCH 129/446] Create LICENSE --- LICENSE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LICENSE b/LICENSE index b5b006b4e..c46d0105f 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2017 LabStack +Copyright (c) 2021 LabStack Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal From 643066594d00891e3151c7ed87244bfeddcd57b9 Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 27 Apr 2021 10:55:31 +0300 Subject: [PATCH 130/446] Fix router not matching param route with trailing slash and implement matching by path+method (#1812) * when url ends with slash first param route is the match (fix #1804) * router should check if method is suitable for matching route and if not then continue search in tree (fix #1808) --- context_test.go | 6 ++- router.go | 113 ++++++++++++++++++++++++++++++++++++++++-------- router_test.go | 79 +++++++++++++++++++++++++++++---- 3 files changed, 169 insertions(+), 29 deletions(-) diff --git a/context_test.go b/context_test.go index 963c91e04..2c61ffb3a 100644 --- a/context_test.go +++ b/context_test.go @@ -464,7 +464,9 @@ func TestContextPath(t *testing.T) { e := New() r := e.Router() - r.Add(http.MethodGet, "/users/:id", nil) + handler := func(c Context) error { return c.String(http.StatusOK, "OK") } + + r.Add(http.MethodGet, "/users/:id", handler) c := e.NewContext(nil, nil) r.Find(http.MethodGet, "/users/1", c) @@ -472,7 +474,7 @@ func TestContextPath(t *testing.T) { assert.Equal("/users/:id", c.Path()) - r.Add(http.MethodGet, "/users/:uid/files/:fid", nil) + r.Add(http.MethodGet, "/users/:uid/files/:fid", handler) c = e.NewContext(nil, nil) r.Find(http.MethodGet, "/users/1/files/1", c) assert.Equal("/users/:uid/files/:fid", c.Path()) diff --git a/router.go b/router.go index 2dd09fae2..5b2474b32 100644 --- a/router.go +++ b/router.go @@ -23,6 +23,10 @@ type ( methodHandler *methodHandler paramChild *node anyChild *node + // isLeaf indicates that node does not have child routes + isLeaf bool + // isHandler indicates that node has at least one handler registered to it + isHandler bool } kind uint8 children []*node @@ -50,6 +54,20 @@ const ( anyLabel = byte('*') ) +func (m *methodHandler) isHandler() bool { + return m.connect != nil || + m.delete != nil || + m.get != nil || + m.head != nil || + m.options != nil || + m.patch != nil || + m.post != nil || + m.propfind != nil || + m.put != nil || + m.trace != nil || + m.report != nil +} + // NewRouter returns a new Router instance. func NewRouter(e *Echo) *Router { return &Router{ @@ -73,6 +91,11 @@ func (r *Router) Add(method, path string, h HandlerFunc) { pnames := []string{} // Param names ppath := path // Pristine path + if h == nil && r.echo.Logger != nil { + // FIXME: in future we should return error + r.echo.Logger.Errorf("Adding route without handler function: %v:%v", method, path) + } + for i, lcpIndex := 0, len(path); i < lcpIndex; i++ { if path[i] == ':' { j := i + 1 @@ -86,6 +109,7 @@ func (r *Router) Add(method, path string, h HandlerFunc) { i, lcpIndex = j, len(path) if i == lcpIndex { + // path node is last fragment of route path. ie. `/users/:id` r.insert(method, path[:i], h, paramKind, ppath, pnames) } else { r.insert(method, path[:i], nil, paramKind, "", nil) @@ -136,6 +160,7 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string currentNode.ppath = ppath currentNode.pnames = pnames } + currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil } else if lcpLen < prefixLen { // Split node n := newNode( @@ -149,7 +174,6 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string currentNode.paramChild, currentNode.anyChild, ) - // Update parent path for all children to new node for _, child := range currentNode.staticChildren { child.parent = n @@ -171,6 +195,8 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string currentNode.pnames = nil currentNode.paramChild = nil currentNode.anyChild = nil + currentNode.isLeaf = false + currentNode.isHandler = false // Only Static children could reach here currentNode.addStaticChild(n) @@ -188,6 +214,7 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string // Only Static children could reach here currentNode.addStaticChild(n) } + currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil } else if lcpLen < searchLen { search = search[lcpLen:] c := currentNode.findChildWithLabel(search[0]) @@ -207,6 +234,7 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string case anyKind: currentNode.anyChild = n } + currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil } else { // Node already exists if h != nil { @@ -233,6 +261,8 @@ func newNode(t kind, pre string, p *node, sc children, mh *methodHandler, ppath methodHandler: mh, paramChild: paramChildren, anyChild: anyChildren, + isLeaf: sc == nil && paramChildren == nil && anyChildren == nil, + isHandler: mh.isHandler(), } } @@ -289,6 +319,12 @@ func (n *node) addHandler(method string, h HandlerFunc) { case REPORT: n.methodHandler.report = h } + + if h != nil { + n.isHandler = true + } else { + n.isHandler = n.methodHandler.isHandler() + } } func (n *node) findHandler(method string) HandlerFunc { @@ -343,6 +379,8 @@ func (r *Router) Find(method, path string, c Context) { currentNode := r.tree // Current node as root var ( + previousBestMatchNode *node + matchedHandler HandlerFunc // search stores the remaining path to check for match. By each iteration we move from start of path to end of the path // and search value gets shorter and shorter. search = path @@ -362,10 +400,11 @@ func (r *Router) Find(method, path string, c Context) { valid = currentNode != nil // Next node type by priority - // NOTE: With the current implementation we never backtrack from an `any` route, so `previous.kind` is - // always `static` or `any` - // If this is changed then for any route next kind would be `static` and this statement should be changed - nextNodeKind = previous.kind + 1 + if previous.kind == anyKind { + nextNodeKind = staticKind + } else { + nextNodeKind = previous.kind + 1 + } if fromKind == staticKind { // when backtracking is done from static kind block we did not change search so nothing to restore @@ -380,6 +419,7 @@ func (r *Router) Find(method, path string, c Context) { // for param/any node.prefix value is always `:` so we can not deduce searchIndex from that and must use pValue // for that index as it would also contain part of path we cut off before moving into node we are backtracking from searchIndex -= len(paramValues[paramIndex]) + paramValues[paramIndex] = "" } search = path[searchIndex:] return @@ -421,7 +461,7 @@ func (r *Router) Find(method, path string, c Context) { // goto Any } else { // Not found (this should never be possible for static node we are looking currently) - return + break } } @@ -429,9 +469,17 @@ func (r *Router) Find(method, path string, c Context) { search = search[lcpLen:] searchIndex = searchIndex + lcpLen - // Finish routing if no remaining search and we are on an leaf node - if search == "" && currentNode.ppath != "" { - break + // Finish routing if no remaining search and we are on a node with handler and matching method type + if search == "" && currentNode.isHandler { + // check if current node has handler registered for http method we are looking for. we store currentNode as + // best matching in case we do no find no more routes matching this path+method + if previousBestMatchNode == nil { + previousBestMatchNode = currentNode + } + if h := currentNode.findHandler(method); h != nil { + matchedHandler = h + break + } } // Static node @@ -446,10 +494,16 @@ func (r *Router) Find(method, path string, c Context) { // Param node if child := currentNode.paramChild; search != "" && child != nil { currentNode = child - // FIXME: when param node does not have any children then param node should act similarly to any node - consider all remaining search as match - i, l := 0, len(search) - for ; i < l && search[i] != '/'; i++ { + i := 0 + l := len(search) + if currentNode.isLeaf { + // when param node does not have any children then param node should act similarly to any node - consider all remaining search as match + i = l + } else { + for ; i < l && search[i] != '/'; i++ { + } } + paramValues[paramIndex] = search[:i] paramIndex++ search = search[i:] @@ -463,29 +517,50 @@ func (r *Router) Find(method, path string, c Context) { // If any node is found, use remaining path for paramValues currentNode = child paramValues[len(currentNode.pnames)-1] = search - break + // update indexes/search in case we need to backtrack when no handler match is found + paramIndex++ + searchIndex += +len(search) + search = "" + + // check if current node has handler registered for http method we are looking for. we store currentNode as + // best matching in case we do no find no more routes matching this path+method + if previousBestMatchNode == nil { + previousBestMatchNode = currentNode + } + if h := currentNode.findHandler(method); h != nil { + matchedHandler = h + break + } } // Let's backtrack to the first possible alternative node of the decision path nk, ok := backtrackToNextNodeKind(anyKind) if !ok { - return // No other possibilities on the decision path + break // No other possibilities on the decision path } else if nk == paramKind { goto Param } else if nk == anyKind { goto Any } else { // Not found - return + break } } - ctx.handler = currentNode.findHandler(method) - ctx.path = currentNode.ppath - ctx.pnames = currentNode.pnames + if currentNode == nil && previousBestMatchNode == nil { + return // nothing matched at all + } - if ctx.handler == nil { + if matchedHandler != nil { + ctx.handler = matchedHandler + } else { + // use previous match as basis. although we have no matching handler we have path match. + // so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404) + currentNode = previousBestMatchNode ctx.handler = currentNode.checkMethodNotAllowed() } + ctx.path = currentNode.ppath + ctx.pnames = currentNode.pnames + return } diff --git a/router_test.go b/router_test.go index 47e499402..71cedf8b6 100644 --- a/router_test.go +++ b/router_test.go @@ -692,11 +692,11 @@ func TestRouterParam(t *testing.T) { expectRoute: "/users/:id", expectParam: map[string]string{"id": "1"}, }, - { // FIXME: this documents current implementation (slash at end is problematic) + { name: "route /users/1/ to /users/:id", whenURL: "/users/1/", - expectRoute: nil, // FIXME: should be "/users/:id", - expectParam: nil, // FIXME: should be map[string]string{"id": "1/"}, + expectRoute: "/users/:id", + expectParam: map[string]string{"id": "1/"}, }, } @@ -716,6 +716,69 @@ func TestRouterParam(t *testing.T) { } } +func TestMethodNotAllowedAndNotFound(t *testing.T) { + e := New() + r := e.router + + // Routes + r.Add(http.MethodGet, "/*", handlerFunc) + r.Add(http.MethodPost, "/users/:id", handlerFunc) + + var testCases = []struct { + name string + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + }{ + { + name: "exact match for route+method", + whenMethod: http.MethodPost, + whenURL: "/users/1", + expectRoute: "/users/:id", + expectParam: map[string]string{"id": "1"}, + }, + { + name: "matches node but not method. sends 405 from best match node", + whenMethod: http.MethodPut, + whenURL: "/users/1", + expectRoute: nil, + expectError: ErrMethodNotAllowed, + }, + { + name: "best match is any route up in tree", + whenMethod: http.MethodGet, + whenURL: "/users/1", + expectRoute: "/*", + expectParam: map[string]string{"*": "users/1"}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) + + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + r.Find(method, tc.whenURL, c) + err := c.handler(c) + + if tc.expectError != nil { + assert.Equal(t, tc.expectError, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectRoute, c.Get("path")) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } +} + func TestRouterTwoParam(t *testing.T) { e := New() r := e.router @@ -740,8 +803,8 @@ func TestRouterParamWithSlash(t *testing.T) { r.Find(http.MethodGet, "/a/1/c/d/2/3", c) // `2/3` should mapped to path `/a/:b/c/d/:e` and into `:e` err := c.handler(c) - assert.Equal(t, nil, c.Get("path")) // FIXME: should be "/a/:b/c/d/:e" - assert.EqualError(t, err, "code=404, message=Not Found") // FIXME: should be .NoError() + assert.Equal(t, "/a/:b/c/d/:e", c.Get("path")) + assert.NoError(t, err) } // Issue #1754 - router needs to backtrack multiple levels upwards in tree to find the matching route @@ -2004,10 +2067,10 @@ func TestRouterParam1466(t *testing.T) { expectRoute: "/users/:username", expectParam: map[string]string{"username": "sharewithme"}, }, - { + { // route `/users/signup` is registered for POST. so param route `/users/:username` (lesser priority) is matched as it has GET handler whenURL: "/users/signup", - expectRoute: nil, // method not found as this route is for POST but request is for GET - expectParam: map[string]string{"username": ""}, + expectRoute: "/users/:username", + expectParam: map[string]string{"username": "signup"}, }, // Additional assertions for #1479 { From 76f186ad3bc749de348548e39e17d32bb834aa9a Mon Sep 17 00:00:00 2001 From: antonindrawan Date: Sat, 8 May 2021 21:19:24 +0200 Subject: [PATCH 131/446] feat(jwt): make KeyFunc public in JWT middleware (#1756) * feat(jwt): make KeyFunc public in JWT middleware It allows a user-defined function to supply the key for a token verification. --- middleware/jwt.go | 65 +++++++++++++++++++++++++++--------------- middleware/jwt_test.go | 30 +++++++++++++++++++ 2 files changed, 72 insertions(+), 23 deletions(-) diff --git a/middleware/jwt.go b/middleware/jwt.go index da00ea56b..fab4d6fdf 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -29,15 +29,19 @@ type ( // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. ErrorHandlerWithContext JWTErrorHandlerWithContext - // Signing key to validate token. Used as fallback if SigningKeys has length 0. - // Required. This or SigningKeys. + // Signing key to validate token. + // This is one of the three options to provide a token validation key. + // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. + // Required if neither user-defined KeyFunc nor SigningKeys is provided. SigningKey interface{} // Map of signing keys to validate token with kid field usage. - // Required. This or SigningKey. + // This is one of the three options to provide a token validation key. + // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. + // Required if neither user-defined KeyFunc nor SigningKey is provided. SigningKeys map[string]interface{} - // Signing method, used to check token signing method. + // Signing method used to check the token's signing algorithm. // Optional. Default value HS256. SigningMethod string @@ -64,7 +68,16 @@ type ( // Optional. Default value "Bearer". AuthScheme string - keyFunc jwt.Keyfunc + // KeyFunc defines a user-defined function that supplies the public key for a token validation. + // The function shall take care of verifying the signing algorithm and selecting the proper key. + // A user-defined KeyFunc can be useful if tokens are issued by an external party. + // + // When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored. + // This is one of the three options to provide a token validation key. + // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. + // Required if neither SigningKeys nor SigningKey is provided. + // Default to an internal implementation verifying the signing algorithm and selecting the proper key. + KeyFunc jwt.Keyfunc } // JWTSuccessHandler defines a function which is executed for a valid token. @@ -99,6 +112,7 @@ var ( TokenLookup: "header:" + echo.HeaderAuthorization, AuthScheme: "Bearer", Claims: jwt.MapClaims{}, + KeyFunc: nil, } ) @@ -123,7 +137,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.Skipper == nil { config.Skipper = DefaultJWTConfig.Skipper } - if config.SigningKey == nil && len(config.SigningKeys) == 0 { + if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil { panic("echo: jwt middleware requires signing key") } if config.SigningMethod == "" { @@ -141,21 +155,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.AuthScheme == "" { config.AuthScheme = DefaultJWTConfig.AuthScheme } - config.keyFunc = func(t *jwt.Token) (interface{}, error) { - // Check the signing method - if t.Method.Alg() != config.SigningMethod { - return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) - } - if len(config.SigningKeys) > 0 { - if kid, ok := t.Header["kid"].(string); ok { - if key, ok := config.SigningKeys[kid]; ok { - return key, nil - } - } - return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"]) - } - - return config.SigningKey, nil + if config.KeyFunc == nil { + config.KeyFunc = config.defaultKeyFunc } // Initialize @@ -196,11 +197,11 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { token := new(jwt.Token) // Issue #647, #656 if _, ok := config.Claims.(jwt.MapClaims); ok { - token, err = jwt.Parse(auth, config.keyFunc) + token, err = jwt.Parse(auth, config.KeyFunc) } else { t := reflect.ValueOf(config.Claims).Type().Elem() claims := reflect.New(t).Interface().(jwt.Claims) - token, err = jwt.ParseWithClaims(auth, claims, config.keyFunc) + token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc) } if err == nil && token.Valid { // Store user information from token into context. @@ -225,6 +226,24 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { } } +// defaultKeyFunc returns a signing key of the given token. +func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) { + // Check the signing method + if t.Method.Alg() != config.SigningMethod { + return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) + } + if len(config.SigningKeys) > 0 { + if kid, ok := t.Header["kid"].(string); ok { + if key, ok := config.SigningKeys[kid]; ok { + return key, nil + } + } + return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"]) + } + + return config.SigningKey, nil +} + // jwtFromHeader returns a `jwtExtractor` that extracts token from the request header. func jwtFromHeader(header string, authScheme string) jwtExtractor { return func(c echo.Context) (string, error) { diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 205721aec..37de31fcc 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "net/http" "net/http/httptest" "net/url" @@ -220,6 +221,35 @@ func TestJWT(t *testing.T) { expErrCode: http.StatusBadRequest, info: "Empty form field", }, + { + hdrAuth: validAuth, + config: JWTConfig{ + KeyFunc: func(*jwt.Token) (interface{}, error) { + return validKey, nil + }, + }, + info: "Valid JWT with a valid key using a user-defined KeyFunc", + }, + { + hdrAuth: validAuth, + config: JWTConfig{ + KeyFunc: func(*jwt.Token) (interface{}, error) { + return invalidKey, nil + }, + }, + expErrCode: http.StatusUnauthorized, + info: "Valid JWT with an invalid key using a user-defined KeyFunc", + }, + { + hdrAuth: validAuth, + config: JWTConfig{ + KeyFunc: func(*jwt.Token) (interface{}, error) { + return nil, errors.New("faulty KeyFunc") + }, + }, + expErrCode: http.StatusUnauthorized, + info: "Token verification does not pass using a user-defined KeyFunc", + }, } { if tc.reqURL == "" { tc.reqURL = "/" From 7256cb22749c462efd039c59dfee929374cb18c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=91=9E=E5=8D=8E?= Date: Sun, 9 May 2021 03:25:11 +0800 Subject: [PATCH 132/446] add a custom error handler to key-auth middleware (#1847) * add a custom error handler to key-auth middleware --- middleware/key_auth.go | 13 ++ middleware/key_auth_test.go | 260 +++++++++++++++++++++++++++++------- 2 files changed, 223 insertions(+), 50 deletions(-) diff --git a/middleware/key_auth.go b/middleware/key_auth.go index 94cfd1429..fd169aa2c 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -30,12 +30,19 @@ type ( // Validator is a function to validate key. // Required. Validator KeyAuthValidator + + // ErrorHandler defines a function which is executed for an invalid key. + // It may be used to define a custom error. + ErrorHandler KeyAuthErrorHandler } // KeyAuthValidator defines a function to validate KeyAuth credentials. KeyAuthValidator func(string, echo.Context) (bool, error) keyExtractor func(echo.Context) (string, error) + + // KeyAuthErrorHandler defines a function which is executed for an invalid key. + KeyAuthErrorHandler func(error, echo.Context) error ) var ( @@ -95,10 +102,16 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { // Extract and verify key key, err := extractor(c) if err != nil { + if config.ErrorHandler != nil { + return config.ErrorHandler(err, c) + } return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } valid, err := config.Validator(key, c) if err != nil { + if config.ErrorHandler != nil { + return config.ErrorHandler(err, c) + } return &echo.HTTPError{ Code: http.StatusUnauthorized, Message: "invalid key", diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go index b874898c8..476b402d9 100644 --- a/middleware/key_auth_test.go +++ b/middleware/key_auth_test.go @@ -1,9 +1,9 @@ package middleware import ( + "errors" "net/http" "net/http/httptest" - "net/url" "strings" "testing" @@ -11,65 +11,225 @@ import ( "github.com/stretchr/testify/assert" ) +func testKeyValidator(key string, c echo.Context) (bool, error) { + switch key { + case "valid-key": + return true, nil + case "error-key": + return false, errors.New("some user defined error") + default: + return false, nil + } +} + func TestKeyAuth(t *testing.T) { + handlerCalled := false + handler := func(c echo.Context) error { + handlerCalled = true + return c.String(http.StatusOK, "test") + } + middlewareChain := KeyAuth(testKeyValidator)(handler) + e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key") rec := httptest.NewRecorder() c := e.NewContext(req, rec) - config := KeyAuthConfig{ - Validator: func(key string, c echo.Context) (bool, error) { - return key == "valid-key", nil - }, - } - h := KeyAuthWithConfig(config)(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - assert := assert.New(t) + err := middlewareChain(c) - // Valid key - auth := DefaultKeyAuthConfig.AuthScheme + " " + "valid-key" - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + assert.NoError(t, err) + assert.True(t, handlerCalled) +} - // Invalid key - auth = DefaultKeyAuthConfig.AuthScheme + " " + "invalid-key" - req.Header.Set(echo.HeaderAuthorization, auth) - he := h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) +func TestKeyAuthWithConfig(t *testing.T) { + var testCases = []struct { + name string + givenRequestFunc func() *http.Request + givenRequest func(req *http.Request) + whenConfig func(conf *KeyAuthConfig) + expectHandlerCalled bool + expectError string + }{ + { + name: "ok, defaults, key from header", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key") + }, + expectHandlerCalled: true, + }, + { + name: "ok, custom skipper", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.Skipper = func(context echo.Context) bool { + return true + } + }, + expectHandlerCalled: true, + }, + { + name: "nok, defaults, invalid key from header, Authorization: Bearer", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key") + }, + expectHandlerCalled: false, + expectError: "code=401, message=Unauthorized", + }, + { + name: "nok, defaults, invalid scheme in header", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "Bear valid-key") + }, + expectHandlerCalled: false, + expectError: "code=400, message=invalid key in the request header", + }, + { + name: "nok, defaults, missing header", + givenRequest: func(req *http.Request) {}, + expectHandlerCalled: false, + expectError: "code=400, message=missing key in request header", + }, + { + name: "ok, custom key lookup, header", + givenRequest: func(req *http.Request) { + req.Header.Set("API-Key", "valid-key") + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "header:API-Key" + }, + expectHandlerCalled: true, + }, + { + name: "nok, custom key lookup, missing header", + givenRequest: func(req *http.Request) { + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "header:API-Key" + }, + expectHandlerCalled: false, + expectError: "code=400, message=missing key in request header", + }, + { + name: "ok, custom key lookup, query", + givenRequest: func(req *http.Request) { + q := req.URL.Query() + q.Add("key", "valid-key") + req.URL.RawQuery = q.Encode() + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "query:key" + }, + expectHandlerCalled: true, + }, + { + name: "nok, custom key lookup, missing query param", + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "query:key" + }, + expectHandlerCalled: false, + expectError: "code=400, message=missing key in the query string", + }, + { + name: "ok, custom key lookup, form", + givenRequestFunc: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("key=valid-key")) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) + return req + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "form:key" + }, + expectHandlerCalled: true, + }, + { + name: "nok, custom key lookup, missing key in form", + givenRequestFunc: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("xxx=valid-key")) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) + return req + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "form:key" + }, + expectHandlerCalled: false, + expectError: "code=400, message=missing key in the form", + }, + { + name: "nok, custom errorHandler, error from extractor", + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "header:token" + conf.ErrorHandler = func(err error, context echo.Context) error { + httpError := echo.NewHTTPError(http.StatusTeapot, "custom") + httpError.Internal = err + return httpError + } + }, + expectHandlerCalled: false, + expectError: "code=418, message=custom, internal=missing key in request header", + }, + { + name: "nok, custom errorHandler, error from validator", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.ErrorHandler = func(err error, context echo.Context) error { + httpError := echo.NewHTTPError(http.StatusTeapot, "custom") + httpError.Internal = err + return httpError + } + }, + expectHandlerCalled: false, + expectError: "code=418, message=custom, internal=some user defined error", + }, + { + name: "nok, defaults, error from validator", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") + }, + whenConfig: func(conf *KeyAuthConfig) {}, + expectHandlerCalled: false, + expectError: "code=401, message=invalid key, internal=some user defined error", + }, + } - // Missing Authorization header - req.Header.Del(echo.HeaderAuthorization) - he = h(c).(*echo.HTTPError) - assert.Equal(http.StatusBadRequest, he.Code) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + handlerCalled := false + handler := func(c echo.Context) error { + handlerCalled = true + return c.String(http.StatusOK, "test") + } + config := KeyAuthConfig{ + Validator: testKeyValidator, + } + if tc.whenConfig != nil { + tc.whenConfig(&config) + } + middlewareChain := KeyAuthWithConfig(config)(handler) - // Key from custom header - config.KeyLookup = "header:API-Key" - h = KeyAuthWithConfig(config)(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - req.Header.Set("API-Key", "valid-key") - assert.NoError(h(c)) + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenRequestFunc != nil { + req = tc.givenRequestFunc() + } + if tc.givenRequest != nil { + tc.givenRequest(req) + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - // Key from query string - config.KeyLookup = "query:key" - h = KeyAuthWithConfig(config)(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - q := req.URL.Query() - q.Add("key", "valid-key") - req.URL.RawQuery = q.Encode() - assert.NoError(h(c)) + err := middlewareChain(c) - // Key from form - config.KeyLookup = "form:key" - h = KeyAuthWithConfig(config)(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - f := make(url.Values) - f.Set("key", "valid-key") - req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) - req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) - c = e.NewContext(req, rec) - assert.NoError(h(c)) + assert.Equal(t, tc.expectHandlerCalled, handlerCalled) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } } From de3f87eb237abd70927e05e70fa860142e530183 Mon Sep 17 00:00:00 2001 From: Kaan Karakaya Date: Sat, 8 May 2021 22:30:06 +0300 Subject: [PATCH 133/446] Jwt lookup from multiple sources (#1845) * Jwt lookup from multiple sources --- middleware/jwt.go | 43 +++++++++++++++++++++++++++++------------- middleware/jwt_test.go | 8 ++++++++ 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/middleware/jwt.go b/middleware/jwt.go index fab4d6fdf..cd35b6215 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -160,17 +160,24 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { } // Initialize - parts := strings.Split(config.TokenLookup, ":") - extractor := jwtFromHeader(parts[1], config.AuthScheme) - switch parts[0] { - case "query": - extractor = jwtFromQuery(parts[1]) - case "param": - extractor = jwtFromParam(parts[1]) - case "cookie": - extractor = jwtFromCookie(parts[1]) - case "form": - extractor = jwtFromForm(parts[1]) + // Split sources + sources := strings.Split(config.TokenLookup, ",") + var extractors []jwtExtractor + for _, source := range sources { + parts := strings.Split(source, ":") + + switch parts[0] { + case "query": + extractors = append(extractors, jwtFromQuery(parts[1])) + case "param": + extractors = append(extractors, jwtFromParam(parts[1])) + case "cookie": + extractors = append(extractors, jwtFromCookie(parts[1])) + case "form": + extractors = append(extractors, jwtFromForm(parts[1])) + case "header": + extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme)) + } } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -182,8 +189,17 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.BeforeFunc != nil { config.BeforeFunc(c) } - - auth, err := extractor(c) + var auth string + var err error + for _, extractor := range extractors { + // Extract token from extractor, if it's not fail break the loop and + // set auth + auth, err = extractor(c) + if err == nil { + break + } + } + // If none of extractor has a token, handle error if err != nil { if config.ErrorHandler != nil { return config.ErrorHandler(err) @@ -194,6 +210,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { } return err } + token := new(jwt.Token) // Issue #647, #656 if _, ok := config.Claims.(jwt.MapClaims); ok { diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 37de31fcc..1a0265917 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -179,6 +179,14 @@ func TestJWT(t *testing.T) { hdrCookie: "jwt=" + token, info: "Valid cookie method", }, + { + config: JWTConfig{ + SigningKey: validKey, + TokenLookup: "query:jwt,cookie:jwt", + }, + hdrCookie: "jwt=" + token, + info: "Multiple jwt lookuop", + }, { config: JWTConfig{ SigningKey: validKey, From b643e6834ef15b9eb3609095a0dd5dcbe7fc5b8a Mon Sep 17 00:00:00 2001 From: Lukas Dietrich Date: Sat, 8 May 2021 21:33:17 +0200 Subject: [PATCH 134/446] Fix #1787: Add support for optional filesystem to the static middleware (#1797) * Add optional filesystem to static middleware. --- middleware/static.go | 80 +++++++++++++++++-------- middleware/static_1_16_test.go | 106 +++++++++++++++++++++++++++++++++ middleware/static_test.go | 29 +++++++++ 3 files changed, 191 insertions(+), 24 deletions(-) create mode 100644 middleware/static_1_16_test.go diff --git a/middleware/static.go b/middleware/static.go index ae79cb5fa..0106f7ce2 100644 --- a/middleware/static.go +++ b/middleware/static.go @@ -42,6 +42,10 @@ type ( // the filesystem path is not doubled // Optional. Default value false. IgnoreBase bool `yaml:"ignoreBase"` + + // Filesystem provides access to the static content. + // Optional. Defaults to http.Dir(config.Root) + Filesystem http.FileSystem `yaml:"-"` } ) @@ -146,6 +150,10 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { if config.Index == "" { config.Index = DefaultStaticConfig.Index } + if config.Filesystem == nil { + config.Filesystem = http.Dir(config.Root) + config.Root = "." + } // Index template t, err := template.New("index").Parse(html) @@ -178,49 +186,73 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { } } - fi, err := os.Stat(name) + file, err := openFile(config.Filesystem, name) if err != nil { - if os.IsNotExist(err) { - if err = next(c); err != nil { - if he, ok := err.(*echo.HTTPError); ok { - if config.HTML5 && he.Code == http.StatusNotFound { - return c.File(filepath.Join(config.Root, config.Index)) - } - } - return - } + if !os.IsNotExist(err) { + return err + } + + if err = next(c); err == nil { + return err + } + + he, ok := err.(*echo.HTTPError) + if !(ok && config.HTML5 && he.Code == http.StatusNotFound) { + return err + } + + file, err = openFile(config.Filesystem, filepath.Join(config.Root, config.Index)) + if err != nil { + return err } - return } - if fi.IsDir() { - index := filepath.Join(name, config.Index) - fi, err = os.Stat(index) + defer file.Close() + + info, err := file.Stat() + if err != nil { + return err + } + if info.IsDir() { + index, err := openFile(config.Filesystem, filepath.Join(name, config.Index)) if err != nil { if config.Browse { - return listDir(t, name, c.Response()) + return listDir(t, name, file, c.Response()) } + if os.IsNotExist(err) { return next(c) } - return } - return c.File(index) + defer index.Close() + + info, err = index.Stat() + if err != nil { + return err + } + + return serveFile(c, index, info) } - return c.File(name) + return serveFile(c, file, info) } } } -func listDir(t *template.Template, name string, res *echo.Response) (err error) { - file, err := os.Open(name) - if err != nil { - return - } - files, err := file.Readdir(-1) +func openFile(fs http.FileSystem, name string) (http.File, error) { + pathWithSlashes := filepath.ToSlash(name) + return fs.Open(pathWithSlashes) +} + +func serveFile(c echo.Context, file http.File, info os.FileInfo) error { + http.ServeContent(c.Response(), c.Request(), info.Name(), info.ModTime(), file) + return nil +} + +func listDir(t *template.Template, name string, dir http.File, res *echo.Response) (err error) { + files, err := dir.Readdir(-1) if err != nil { return } diff --git a/middleware/static_1_16_test.go b/middleware/static_1_16_test.go new file mode 100644 index 000000000..53e02f742 --- /dev/null +++ b/middleware/static_1_16_test.go @@ -0,0 +1,106 @@ +// +build go1.16 + +package middleware + +import ( + "io/fs" + "net/http" + "net/http/httptest" + "os" + "testing" + "testing/fstest" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestStatic_CustomFS(t *testing.T) { + var testCases = []struct { + name string + filesystem fs.FS + root string + whenURL string + expectContains string + expectCode int + }{ + { + name: "ok, serve index with Echo message", + whenURL: "/", + filesystem: os.DirFS("../_fixture"), + expectCode: http.StatusOK, + expectContains: "Echo", + }, + + { + name: "ok, serve index with Echo message", + whenURL: "/_fixture/", + filesystem: os.DirFS(".."), + expectCode: http.StatusOK, + expectContains: "Echo", + }, + { + name: "ok, serve file from map fs", + whenURL: "/file.txt", + filesystem: fstest.MapFS{ + "file.txt": &fstest.MapFile{Data: []byte("file.txt is ok")}, + }, + expectCode: http.StatusOK, + expectContains: "file.txt is ok", + }, + { + name: "nok, missing file in map fs", + whenURL: "/file.txt", + expectCode: http.StatusNotFound, + filesystem: fstest.MapFS{ + "file2.txt": &fstest.MapFile{Data: []byte("file2.txt is ok")}, + }, + }, + { + name: "nok, file is not a subpath of root", + whenURL: `/../../secret.txt`, + root: "/nested/folder", + filesystem: fstest.MapFS{ + "secret.txt": &fstest.MapFile{Data: []byte("this is a secret")}, + }, + expectCode: http.StatusNotFound, + }, + { + name: "nok, backslash is forbidden", + whenURL: `/..\..\secret.txt`, + expectCode: http.StatusNotFound, + root: "/nested/folder", + filesystem: fstest.MapFS{ + "secret.txt": &fstest.MapFile{Data: []byte("this is a secret")}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + config := StaticConfig{ + Root: ".", + Filesystem: http.FS(tc.filesystem), + } + + if tc.root != "" { + config.Root = tc.root + } + + middlewareFunc := StaticWithConfig(config) + e.Use(middlewareFunc) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + if tc.expectContains != "" { + responseBody := rec.Body.String() + assert.Contains(t, responseBody, tc.expectContains) + } + }) + } +} diff --git a/middleware/static_test.go b/middleware/static_test.go index 8c0c97ded..af6641f66 100644 --- a/middleware/static_test.go +++ b/middleware/static_test.go @@ -94,6 +94,32 @@ func TestStatic(t *testing.T) { expectCode: http.StatusNotFound, expectContains: "{\"message\":\"Not Found\"}\n", }, + { + name: "ok, do not serve file, when a handler took care of the request", + whenURL: "/regular-handler", + expectCode: http.StatusOK, + expectContains: "ok", + }, + { + name: "nok, when html5 fail if the index file does not exist", + givenConfig: &StaticConfig{ + Root: "../_fixture", + HTML5: true, + Index: "missing.html", + }, + whenURL: "/random", + expectCode: http.StatusInternalServerError, + }, + { + name: "ok, serve from http.FileSystem", + givenConfig: &StaticConfig{ + Root: "_fixture", + Filesystem: http.Dir(".."), + }, + whenURL: "/", + expectCode: http.StatusOK, + expectContains: "Echo", + }, } for _, tc := range testCases { @@ -115,6 +141,9 @@ func TestStatic(t *testing.T) { } else { // middleware is on root level e.Use(middlewareFunc) + e.GET("/regular-handler", func(c echo.Context) error { + return c.String(http.StatusOK, "ok") + }) } req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) From 2943a3200534a287aff51e37465c2c0ca1b27728 Mon Sep 17 00:00:00 2001 From: Voltboy Date: Wed, 28 Apr 2021 17:45:48 +0300 Subject: [PATCH 135/446] restore originalWriter in case of panic inside echoHandlerFuncWrapper.ServeHTTP method --- middleware/timeout.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/middleware/timeout.go b/middleware/timeout.go index d56e463b0..fb8ae4219 100644 --- a/middleware/timeout.go +++ b/middleware/timeout.go @@ -92,6 +92,15 @@ func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Reques originalWriter := t.ctx.Response().Writer t.ctx.Response().Writer = rw + // in case of panic we restore original writer and call panic again + // so it could be handled with global middleware Recover() + defer func() { + if err := recover(); err != nil { + t.ctx.Response().Writer = originalWriter + panic(err) + } + }() + err := t.handler(t.ctx) if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded { if err != nil && t.errHandler != nil { From 18d7fe11df962e14fdc496d8088870f3ee712e6c Mon Sep 17 00:00:00 2001 From: lipengwei Date: Sun, 25 Apr 2021 09:50:14 +0800 Subject: [PATCH 136/446] Fix #1858: Add query params binding support for anonymous struct pointer filed --- bind.go | 9 +++++++++ bind_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/bind.go b/bind.go index 08d398916..530697ee7 100644 --- a/bind.go +++ b/bind.go @@ -144,11 +144,20 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri for i := 0; i < typ.NumField(); i++ { typeField := typ.Field(i) structField := val.Field(i) + if typeField.Anonymous { + for structField.Kind() == reflect.Ptr { + structField = structField.Elem() + } + } if !structField.CanSet() { continue } structFieldKind := structField.Kind() inputFieldName := typeField.Tag.Get(tag) + if typeField.Anonymous && structField.Kind() == reflect.Struct && inputFieldName != "" { + // if anonymous struct, ignore custom tag + inputFieldName = "" + } if inputFieldName == "" { // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contains fields with tags). diff --git a/bind_test.go b/bind_test.go index 73398034e..163dd9450 100644 --- a/bind_test.go +++ b/bind_test.go @@ -100,6 +100,9 @@ type ( Struct struct { Foo string } + Bar struct { + Baz int `json:"baz" query:"baz"` + } ) func (t *Timestamp) UnmarshalParam(src string) error { @@ -330,6 +333,48 @@ func TestBindUnmarshalParamPtr(t *testing.T) { } } +func TestBindUnmarshalParamAnonymousFieldPtr(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?baz=1", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + *Bar + }{&Bar{}} + err := c.Bind(&result) + if assert.NoError(t, err) { + assert.Equal(t, 1, result.Baz) + } +} + +func TestBindUnmarshalParamAnonymousFieldPtrNil(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?baz=1", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + *Bar + }{} + err := c.Bind(&result) + if assert.NoError(t, err) { + assert.Nil(t, result.Bar) + } +} + +func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, `/?bar={"baz":100}&baz=1`, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + *Bar `json:"bar" query:"bar"` + }{&Bar{}} + err := c.Bind(&result) + if assert.NoError(t, err) { + assert.Equal(t, 1, result.Baz) + } +} + func TestBindUnmarshalTextPtr(t *testing.T) { e := New() req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil) From 1aef300cf47cee2b9a65deb74adbae7fe67b1bc6 Mon Sep 17 00:00:00 2001 From: lipengwei Date: Thu, 29 Apr 2021 09:22:01 +0800 Subject: [PATCH 137/446] explicitly return an error instead of hiding it --- bind.go | 6 +++--- bind_test.go | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/bind.go b/bind.go index 530697ee7..dfdf82d0c 100644 --- a/bind.go +++ b/bind.go @@ -145,7 +145,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri typeField := typ.Field(i) structField := val.Field(i) if typeField.Anonymous { - for structField.Kind() == reflect.Ptr { + if structField.Kind() == reflect.Ptr { structField = structField.Elem() } } @@ -155,8 +155,8 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri structFieldKind := structField.Kind() inputFieldName := typeField.Tag.Get(tag) if typeField.Anonymous && structField.Kind() == reflect.Struct && inputFieldName != "" { - // if anonymous struct, ignore custom tag - inputFieldName = "" + // if anonymous struct with query/param/form tags, report an error + return errors.New("query/param/form tags are not allowed with anonymous struct field") } if inputFieldName == "" { diff --git a/bind_test.go b/bind_test.go index 163dd9450..ff0337082 100644 --- a/bind_test.go +++ b/bind_test.go @@ -370,9 +370,7 @@ func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) { *Bar `json:"bar" query:"bar"` }{&Bar{}} err := c.Bind(&result) - if assert.NoError(t, err) { - assert.Equal(t, 1, result.Baz) - } + assert.Contains(t, err.Error(), "query/param/form tags are not allowed with anonymous struct field") } func TestBindUnmarshalTextPtr(t *testing.T) { From 2acb24adb0fd619dc3c672d7fdde25c2c21061d7 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Tue, 27 Apr 2021 11:51:36 +0300 Subject: [PATCH 138/446] Update version and changelog for 4.3.0 --- CHANGELOG.md | 24 ++++++++++++++++++++++++ echo.go | 2 +- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a3117b80c..f4a74760f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,29 @@ # Changelog +## v4.3.0 - 2021-05-08 + +**Important notes** + +* Route matching has improvements for following cases: + 1. Correctly match routes with parameter part as last part of route (with trailing backslash) + 2. Considering handlers when resolving routes and search for matching http method handler +* Echo minimal Go version is now 1.13. + +**Fixes** + +* When url ends with slash first param route is the match [#1804](https://github.com/labstack/echo/pull/1812) +* Router should check if node is suitable as matching route by path+method and if not then continue search in tree [#1808](https://github.com/labstack/echo/issues/1808) +* Fix timeout middleware not writing response correctly when handler panics [#1864](https://github.com/labstack/echo/pull/1864) +* Fix binder not working with embedded pointer structs [#1861](https://github.com/labstack/echo/pull/1861) +* Add Go 1.16 to CI and drop 1.12 specific code [#1850](https://github.com/labstack/echo/pull/1850) + +**Enhancements** + +* Make KeyFunc public in JWT middleware [#1756](https://github.com/labstack/echo/pull/1756) +* Add support for optional filesystem to the static middleware [#1797](https://github.com/labstack/echo/pull/1797) +* Add a custom error handler to key-auth middleware [#1847](https://github.com/labstack/echo/pull/1847) +* Allow JWT token to be looked up from multiple sources [#1845](https://github.com/labstack/echo/pull/1845) + ## v4.2.2 - 2021-04-07 **Fixes** diff --git a/echo.go b/echo.go index a24e3977f..dd0cbf355 100644 --- a/echo.go +++ b/echo.go @@ -234,7 +234,7 @@ const ( const ( // Version of Echo - Version = "4.2.2" + Version = "4.3.0" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 7846e3fa6b6cdbad700680eb5c3bd8c6e45a9ec7 Mon Sep 17 00:00:00 2001 From: Alexander Pochill Date: Tue, 25 May 2021 14:50:49 +0200 Subject: [PATCH 139/446] feat: Bind data using headers as source (#1866) Currently, echo supports binding data from query, path or body. Sometimes we need to read bind data from headers. It would be nice to automatically bind those using the `bindData` func, which is already well prepared to accept `http.Header`. I didn't add this to the `Bind` func, so this will not happen automatically. Main reason is backwards compatability. It might be confusing if variables are bound from headers when upgrading, and might even have become a security issue as pointed out in #1670. * Add docs for BindHeaders * Add test for BindHeader with invalid data type --- bind.go | 10 +++++++++- bind_test.go | 31 +++++++++++++++++++++++++++++++ echo_test.go | 4 ++-- 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/bind.go b/bind.go index dfdf82d0c..47947ce5b 100644 --- a/bind.go +++ b/bind.go @@ -97,6 +97,14 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { return nil } +// BindHeaders binds HTTP headers to a bindable object +func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error { + if err := b.bindData(i, c.Request().Header, "header"); err != nil { + return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + } + return nil +} + // Bind implements the `Binder#Bind` function. // Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous // step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams. @@ -134,7 +142,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri // !struct if typ.Kind() != reflect.Struct { - if tag == "param" || tag == "query" { + if tag == "param" || tag == "query" || tag == "header" { // incompatible type, data is probably to be found in the body return nil } diff --git a/bind_test.go b/bind_test.go index ff0337082..4ed8dbb50 100644 --- a/bind_test.go +++ b/bind_test.go @@ -269,6 +269,37 @@ func TestBindQueryParamsCaseSensitivePrioritized(t *testing.T) { } } +func TestBindHeaderParam(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Name", "Jon Doe") + req.Header.Set("Id", "2") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + u := new(user) + err := (&DefaultBinder{}).BindHeaders(c, u) + if assert.NoError(t, err) { + assert.Equal(t, 2, u.ID) + assert.Equal(t, "Jon Doe", u.Name) + } +} + +func TestBindHeaderParamBadType(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Id", "salamander") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + u := new(user) + err := (&DefaultBinder{}).BindHeaders(c, u) + assert.Error(t, err) + + httpErr, ok := err.(*HTTPError) + if assert.True(t, ok) { + assert.Equal(t, http.StatusBadRequest, httpErr.Code) + } +} + func TestBindUnmarshalParam(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) diff --git a/echo_test.go b/echo_test.go index ba498831b..e5bd371dd 100644 --- a/echo_test.go +++ b/echo_test.go @@ -24,8 +24,8 @@ import ( type ( user struct { - ID int `json:"id" xml:"id" form:"id" query:"id" param:"id"` - Name string `json:"name" xml:"name" form:"name" query:"name" param:"name"` + ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"` + Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"` } ) From 379bdeaa1ef5e29cbcab2da2d4728c5e6f984ce8 Mon Sep 17 00:00:00 2001 From: Kaan Karakaya Date: Wed, 26 May 2021 09:11:22 +0300 Subject: [PATCH 140/446] docs: Added comment about TokenLookup Signed-off-by: Kaan Karakaya --- middleware/jwt.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/middleware/jwt.go b/middleware/jwt.go index cd35b6215..6c8bcebb4 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -53,7 +53,7 @@ type ( // Optional. Default value jwt.MapClaims Claims jwt.Claims - // TokenLookup is a string in the form of ":" that is used + // TokenLookup is a string in the form of ":" or ":,:" that is used // to extract token from the request. // Optional. Default value "header:Authorization". // Possible values: @@ -62,6 +62,9 @@ type ( // - "param:" // - "cookie:" // - "form:" + // Multiply sources example: + // - "header: Authorization,cookie: myowncookie" + TokenLookup string // AuthScheme to be used in the Authorization header. From 1c24ab8c2b277949191487b859ce0b8b71c45e24 Mon Sep 17 00:00:00 2001 From: harukitosa <13haruki28@gmail.com> Date: Sat, 22 May 2021 21:18:04 +0900 Subject: [PATCH 141/446] fix rateLimiteDoc --- middleware/rate_limiter.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index 46a310d96..0291eb451 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -169,7 +169,8 @@ type ( /* NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with -the provided rate (as req/s). Burst and ExpiresIn will be set to default values. +the provided rate (as req/s). The provided rate less than 1 will be treated as zero. +Burst and ExpiresIn will be set to default values. Example (with 20 requests/sec): From fdacff0d93d54e8065ac7f5072041146f054a201 Mon Sep 17 00:00:00 2001 From: Oleksandr Savchenko Date: Thu, 20 May 2021 17:51:15 +0300 Subject: [PATCH 142/446] Split XFF header only by comma --- context.go | 4 ++-- context_test.go | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/context.go b/context.go index 0cee48ce0..fad8bf7be 100644 --- a/context.go +++ b/context.go @@ -276,9 +276,9 @@ func (c *context) RealIP() string { } // Fall back to legacy behavior if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" { - i := strings.IndexAny(ip, ", ") + i := strings.IndexAny(ip, ",") if i > 0 { - return ip[:i] + return strings.TrimSpace(ip[:i]) } return ip } diff --git a/context_test.go b/context_test.go index 2c61ffb3a..a8b9a9946 100644 --- a/context_test.go +++ b/context_test.go @@ -888,6 +888,14 @@ func TestContext_RealIP(t *testing.T) { }, "127.0.0.1", }, + { + &context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1,127.0.1.1"}}, + }, + }, + "127.0.0.1", + }, { &context{ request: &http.Request{ From 1ac4a8f3d0c6dc6ff8b9d666a3e860e70edcad87 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 6 Jun 2021 21:36:41 +0300 Subject: [PATCH 143/446] Adds JWTConfig.ParseTokenFunc to JWT middleware to allow different libraries implementing JWT parsing. --- middleware/jwt.go | 48 ++++++++--- middleware/jwt_test.go | 192 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 228 insertions(+), 12 deletions(-) diff --git a/middleware/jwt.go b/middleware/jwt.go index 6c8bcebb4..bce478743 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "fmt" "net/http" "reflect" @@ -49,7 +50,8 @@ type ( // Optional. Default value "user". ContextKey string - // Claims are extendable claims data defining token content. + // Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation. + // Not used if custom ParseTokenFunc is set. // Optional. Default value jwt.MapClaims Claims jwt.Claims @@ -74,13 +76,20 @@ type ( // KeyFunc defines a user-defined function that supplies the public key for a token validation. // The function shall take care of verifying the signing algorithm and selecting the proper key. // A user-defined KeyFunc can be useful if tokens are issued by an external party. + // Used by default ParseTokenFunc implementation. // // When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored. // This is one of the three options to provide a token validation key. // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. // Required if neither SigningKeys nor SigningKey is provided. + // Not used if custom ParseTokenFunc is set. // Default to an internal implementation verifying the signing algorithm and selecting the proper key. KeyFunc jwt.Keyfunc + + // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token + // parsing fails or parsed token is invalid. + // Defaults to implementation using `github.com/dgrijalva/jwt-go` as JWT implementation library + ParseTokenFunc func(auth string, c echo.Context) (interface{}, error) } // JWTSuccessHandler defines a function which is executed for a valid token. @@ -140,7 +149,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.Skipper == nil { config.Skipper = DefaultJWTConfig.Skipper } - if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil { + if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil { panic("echo: jwt middleware requires signing key") } if config.SigningMethod == "" { @@ -161,6 +170,9 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.KeyFunc == nil { config.KeyFunc = config.defaultKeyFunc } + if config.ParseTokenFunc == nil { + config.ParseTokenFunc = config.defaultParseToken + } // Initialize // Split sources @@ -214,16 +226,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { return err } - token := new(jwt.Token) - // Issue #647, #656 - if _, ok := config.Claims.(jwt.MapClaims); ok { - token, err = jwt.Parse(auth, config.KeyFunc) - } else { - t := reflect.ValueOf(config.Claims).Type().Elem() - claims := reflect.New(t).Interface().(jwt.Claims) - token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc) - } - if err == nil && token.Valid { + token, err := config.ParseTokenFunc(auth, c) + if err == nil { // Store user information from token into context. c.Set(config.ContextKey, token) if config.SuccessHandler != nil { @@ -246,6 +250,26 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { } } +func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) { + token := new(jwt.Token) + var err error + // Issue #647, #656 + if _, ok := config.Claims.(jwt.MapClaims); ok { + token, err = jwt.Parse(auth, config.KeyFunc) + } else { + t := reflect.ValueOf(config.Claims).Type().Elem() + claims := reflect.New(t).Interface().(jwt.Claims) + token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc) + } + if err != nil { + return nil, err + } + if !token.Valid { + return nil, errors.New("invalid token") + } + return token, nil +} + // defaultKeyFunc returns a signing key of the given token. func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) { // Check the signing method diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 1a0265917..9af4c83d8 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -2,6 +2,7 @@ package middleware import ( "errors" + "fmt" "net/http" "net/http/httptest" "net/url" @@ -404,3 +405,194 @@ func TestJWTwithKID(t *testing.T) { } } } + +func TestJWTConfig_skipper(t *testing.T) { + e := echo.New() + + e.Use(JWTWithConfig(JWTConfig{ + Skipper: func(context echo.Context) bool { + return true // skip everything + }, + SigningKey: []byte("secret"), + })) + + isCalled := false + e.GET("/", func(c echo.Context) error { + isCalled = true + return c.String(http.StatusTeapot, "test") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.True(t, isCalled) +} + +func TestJWTConfig_BeforeFunc(t *testing.T) { + e := echo.New() + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) + + isCalled := false + e.Use(JWTWithConfig(JWTConfig{ + BeforeFunc: func(context echo.Context) { + isCalled = true + }, + SigningKey: []byte("secret"), + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.True(t, isCalled) +} + +func TestJWTConfig_extractorErrorHandling(t *testing.T) { + var testCases = []struct { + name string + given JWTConfig + expectStatusCode int + }{ + { + name: "ok, ErrorHandler is executed", + given: JWTConfig{ + SigningKey: []byte("secret"), + ErrorHandler: func(err error) error { + return echo.NewHTTPError(http.StatusTeapot, "custom_error") + }, + }, + expectStatusCode: http.StatusTeapot, + }, + { + name: "ok, ErrorHandlerWithContext is executed", + given: JWTConfig{ + SigningKey: []byte("secret"), + ErrorHandlerWithContext: func(err error, context echo.Context) error { + return echo.NewHTTPError(http.StatusTeapot, "custom_error") + }, + }, + expectStatusCode: http.StatusTeapot, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusNotImplemented, "should not end up here") + }) + + e.Use(JWTWithConfig(tc.given)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, tc.expectStatusCode, res.Code) + }) + } +} + +func TestJWTConfig_parseTokenErrorHandling(t *testing.T) { + var testCases = []struct { + name string + given JWTConfig + expectErr string + }{ + { + name: "ok, ErrorHandler is executed", + given: JWTConfig{ + SigningKey: []byte("secret"), + ErrorHandler: func(err error) error { + return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error()) + }, + }, + expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n", + }, + { + name: "ok, ErrorHandlerWithContext is executed", + given: JWTConfig{ + SigningKey: []byte("secret"), + ErrorHandlerWithContext: func(err error, context echo.Context) error { + return echo.NewHTTPError(http.StatusTeapot, "ErrorHandlerWithContext: "+err.Error()) + }, + }, + expectErr: "{\"message\":\"ErrorHandlerWithContext: parsing failed\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + //e.Debug = true + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusNotImplemented, "should not end up here") + }) + + config := tc.given + parseTokenCalled := false + config.ParseTokenFunc = func(auth string, c echo.Context) (interface{}, error) { + parseTokenCalled = true + return nil, errors.New("parsing failed") + } + e.Use(JWTWithConfig(config)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + res := httptest.NewRecorder() + + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.Equal(t, tc.expectErr, res.Body.String()) + assert.True(t, parseTokenCalled) + }) + } +} + +func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { + e := echo.New() + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) + + // example of minimal custom ParseTokenFunc implementation. Allows you to use different versions of `github.com/dgrijalva/jwt-go` + // with current JWT middleware + signingKey := []byte("secret") + + config := JWTConfig{ + ParseTokenFunc: func(auth string, c echo.Context) (interface{}, error) { + keyFunc := func(t *jwt.Token) (interface{}, error) { + if t.Method.Alg() != "HS256" { + return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) + } + return signingKey, nil + } + + // claims are of type `jwt.MapClaims` when token is created with `jwt.Parse` + token, err := jwt.Parse(auth, keyFunc) + if err != nil { + return nil, err + } + if !token.Valid { + return nil, errors.New("invalid token") + } + return token, nil + }, + } + + e.Use(JWTWithConfig(config)) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) +} From f20820c0030a0d8c8aa20f63996092faa329fe03 Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Fri, 25 Jun 2021 17:56:07 -0300 Subject: [PATCH 144/446] Adding tests for Echo#Host (#1895) --- echo_test.go | 150 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) diff --git a/echo_test.go b/echo_test.go index e5bd371dd..dc553490b 100644 --- a/echo_test.go +++ b/echo_test.go @@ -472,6 +472,37 @@ func TestEchoRoutes(t *testing.T) { } } +func TestEchoRoutesHandleHostsProperly(t *testing.T) { + e := New() + h := e.Host("route.com") + routes := []*Route{ + {http.MethodGet, "/users/:user/events", ""}, + {http.MethodGet, "/users/:user/events/public", ""}, + {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, + {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, + } + for _, r := range routes { + h.Add(r.Method, r.Path, func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + } + + if assert.Equal(t, len(routes), len(e.Routes())) { + for _, r := range e.Routes() { + found := false + for _, rr := range routes { + if r.Method == rr.Method && r.Path == rr.Path { + found = true + break + } + } + if !found { + t.Errorf("Route %s %s not found", r.Method, r.Path) + } + } + } +} + func TestEchoServeHTTPPathEncoding(t *testing.T) { e := New() e.GET("/with/slash", func(c Context) error { @@ -514,6 +545,109 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) { } } +func TestEchoHost(t *testing.T) { + assert := assert.New(t) + + okHandler := func(c Context) error { return c.String(http.StatusOK, http.StatusText(http.StatusOK)) } + teapotHandler := func(c Context) error { return c.String(http.StatusTeapot, http.StatusText(http.StatusTeapot)) } + acceptHandler := func(c Context) error { return c.String(http.StatusAccepted, http.StatusText(http.StatusAccepted)) } + teapotMiddleware := MiddlewareFunc(func(next HandlerFunc) HandlerFunc { return teapotHandler }) + + e := New() + e.GET("/", acceptHandler) + e.GET("/foo", acceptHandler) + + ok := e.Host("ok.com") + ok.GET("/", okHandler) + ok.GET("/foo", okHandler) + + teapot := e.Host("teapot.com") + teapot.GET("/", teapotHandler) + teapot.GET("/foo", teapotHandler) + + middle := e.Host("middleware.com", teapotMiddleware) + middle.GET("/", okHandler) + middle.GET("/foo", okHandler) + + var testCases = []struct { + name string + whenHost string + whenPath string + expectBody string + expectStatus int + }{ + { + name: "No Host Root", + whenHost: "", + whenPath: "/", + expectBody: http.StatusText(http.StatusAccepted), + expectStatus: http.StatusAccepted, + }, + { + name: "No Host Foo", + whenHost: "", + whenPath: "/foo", + expectBody: http.StatusText(http.StatusAccepted), + expectStatus: http.StatusAccepted, + }, + { + name: "OK Host Root", + whenHost: "ok.com", + whenPath: "/", + expectBody: http.StatusText(http.StatusOK), + expectStatus: http.StatusOK, + }, + { + name: "OK Host Foo", + whenHost: "ok.com", + whenPath: "/foo", + expectBody: http.StatusText(http.StatusOK), + expectStatus: http.StatusOK, + }, + { + name: "Teapot Host Root", + whenHost: "teapot.com", + whenPath: "/", + expectBody: http.StatusText(http.StatusTeapot), + expectStatus: http.StatusTeapot, + }, + { + name: "Teapot Host Foo", + whenHost: "teapot.com", + whenPath: "/foo", + expectBody: http.StatusText(http.StatusTeapot), + expectStatus: http.StatusTeapot, + }, + { + name: "Middleware Host", + whenHost: "middleware.com", + whenPath: "/", + expectBody: http.StatusText(http.StatusTeapot), + expectStatus: http.StatusTeapot, + }, + { + name: "Middleware Host Foo", + whenHost: "middleware.com", + whenPath: "/foo", + expectBody: http.StatusText(http.StatusTeapot), + expectStatus: http.StatusTeapot, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.whenPath, nil) + req.Host = tc.whenHost + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(tc.expectStatus, rec.Code) + assert.Equal(tc.expectBody, rec.Body.String()) + }) + } +} + func TestEchoGroup(t *testing.T) { e := New() buf := new(bytes.Buffer) @@ -1166,6 +1300,22 @@ func TestEchoReverse(t *testing.T) { assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) } +func TestEchoReverseHandleHostProperly(t *testing.T) { + assert := assert.New(t) + + dummyHandler := func(Context) error { return nil } + + e := New() + h := e.Host("the_host") + h.GET("/static", dummyHandler).Name = "/static" + h.GET("/static/*", dummyHandler).Name = "/static/*" + + assert.Equal("/static", e.Reverse("/static")) + assert.Equal("/static", e.Reverse("/static", "missing param")) + assert.Equal("/static/*", e.Reverse("/static/*")) + assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt")) +} + func TestEcho_ListenerAddr(t *testing.T) { e := New() From fd7a8a97ac0e1aa76ab12c50e4c91324a5546b9b Mon Sep 17 00:00:00 2001 From: zacscoding Date: Wed, 23 Jun 2021 01:04:23 +0900 Subject: [PATCH 145/446] Adds RequestIDHandler function to RequestID middleware --- middleware/request_id.go | 6 ++++++ middleware/request_id_test.go | 11 +++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/middleware/request_id.go b/middleware/request_id.go index 21f801f3b..b0baeeb2d 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -14,6 +14,9 @@ type ( // Generator defines a function to generate an ID. // Optional. Default value random.String(32). Generator func() string + + // RequestIDHandler defines a function which is executed for a request id. + RequestIDHandler func(echo.Context, string) } ) @@ -53,6 +56,9 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { rid = config.Generator() } res.Header().Set(echo.HeaderXRequestID, rid) + if config.RequestIDHandler != nil { + config.RequestIDHandler(c, rid) + } return next(c) } diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go index 86eec8c3b..944b3b49e 100644 --- a/middleware/request_id_test.go +++ b/middleware/request_id_test.go @@ -23,13 +23,20 @@ func TestRequestID(t *testing.T) { h(c) assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) - // Custom generator + // Custom generator and handler + customID := "customGenerator" + calledHandler := false rid = RequestIDWithConfig(RequestIDConfig{ - Generator: func() string { return "customGenerator" }, + Generator: func() string { return customID }, + RequestIDHandler: func(_ echo.Context, id string) { + calledHandler = true + assert.Equal(t, customID, id) + }, }) h = rid(handler) h(c) assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") + assert.True(t, calledHandler) } func TestRequestID_IDNotAltered(t *testing.T) { From 5e791b07870ea1317a320987e9a4bf4c3c1013e7 Mon Sep 17 00:00:00 2001 From: Hosh Date: Mon, 5 Jul 2021 20:33:19 +0100 Subject: [PATCH 146/446] Allow for custom JSON encoding implementations (#1880) * Allow for custom JSON encoding implementations Co-authored-by: toimtoimtoim --- bind.go | 13 +++---- context.go | 16 +++----- echo.go | 8 ++++ json.go | 31 ++++++++++++++++ json_test.go | 101 +++++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 151 insertions(+), 18 deletions(-) create mode 100644 json.go create mode 100644 json_test.go diff --git a/bind.go b/bind.go index 47947ce5b..fdf0524c2 100644 --- a/bind.go +++ b/bind.go @@ -2,7 +2,6 @@ package echo import ( "encoding" - "encoding/json" "encoding/xml" "errors" "fmt" @@ -66,13 +65,13 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { ctype := req.Header.Get(HeaderContentType) switch { case strings.HasPrefix(ctype, MIMEApplicationJSON): - if err = json.NewDecoder(req.Body).Decode(i); err != nil { - if ute, ok := err.(*json.UnmarshalTypeError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err) - } else if se, ok := err.(*json.SyntaxError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err) + if err = c.Echo().JSONSerializer.Deserialize(c, i); err != nil { + switch err.(type) { + case *HTTPError: + return err + default: + return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML): if err = xml.NewDecoder(req.Body).Decode(i); err != nil { diff --git a/context.go b/context.go index fad8bf7be..91ab6e480 100644 --- a/context.go +++ b/context.go @@ -2,7 +2,6 @@ package echo import ( "bytes" - "encoding/json" "encoding/xml" "fmt" "io" @@ -457,17 +456,16 @@ func (c *context) String(code int, s string) (err error) { } func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error) { - enc := json.NewEncoder(c.response) - _, pretty := c.QueryParams()["pretty"] - if c.echo.Debug || pretty { - enc.SetIndent("", " ") + indent := "" + if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { + indent = defaultIndent } c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(callback + "(")); err != nil { return } - if err = enc.Encode(i); err != nil { + if err = c.echo.JSONSerializer.Serialize(c, i, indent); err != nil { return } if _, err = c.response.Write([]byte(");")); err != nil { @@ -477,13 +475,9 @@ func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error } func (c *context) json(code int, i interface{}, indent string) error { - enc := json.NewEncoder(c.response) - if indent != "" { - enc.SetIndent("", indent) - } c.writeContentType(MIMEApplicationJSONCharsetUTF8) c.response.Status = code - return enc.Encode(i) + return c.echo.JSONSerializer.Serialize(c, i, indent) } func (c *context) JSON(code int, i interface{}) (err error) { diff --git a/echo.go b/echo.go index dd0cbf355..afb1e27dc 100644 --- a/echo.go +++ b/echo.go @@ -90,6 +90,7 @@ type ( HidePort bool HTTPErrorHandler HTTPErrorHandler Binder Binder + JSONSerializer JSONSerializer Validator Validator Renderer Renderer Logger Logger @@ -125,6 +126,12 @@ type ( Validate(i interface{}) error } + // JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. + JSONSerializer interface { + Serialize(c Context, i interface{}, indent string) error + Deserialize(c Context, i interface{}) error + } + // Renderer is the interface that wraps the Render function. Renderer interface { Render(io.Writer, string, interface{}, Context) error @@ -315,6 +322,7 @@ func New() (e *Echo) { e.TLSServer.Handler = e e.HTTPErrorHandler = e.DefaultHTTPErrorHandler e.Binder = &DefaultBinder{} + e.JSONSerializer = &DefaultJSONSerializer{} e.Logger.SetLevel(log.ERROR) e.StdLogger = stdLog.New(e.Logger.Output(), e.Logger.Prefix()+": ", 0) e.pool.New = func() interface{} { diff --git a/json.go b/json.go new file mode 100644 index 000000000..16b2d0577 --- /dev/null +++ b/json.go @@ -0,0 +1,31 @@ +package echo + +import ( + "encoding/json" + "fmt" + "net/http" +) + +// DefaultJSONSerializer implements JSON encoding using encoding/json. +type DefaultJSONSerializer struct{} + +// Serialize converts an interface into a json and writes it to the response. +// You can optionally use the indent parameter to produce pretty JSONs. +func (d DefaultJSONSerializer) Serialize(c Context, i interface{}, indent string) error { + enc := json.NewEncoder(c.Response()) + if indent != "" { + enc.SetIndent("", indent) + } + return enc.Encode(i) +} + +// Deserialize reads a JSON from a request body and converts it into an interface. +func (d DefaultJSONSerializer) Deserialize(c Context, i interface{}) error { + err := json.NewDecoder(c.Request().Body).Decode(i) + if ute, ok := err.(*json.UnmarshalTypeError); ok { + return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err) + } else if se, ok := err.(*json.SyntaxError); ok { + return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err) + } + return err +} diff --git a/json_test.go b/json_test.go new file mode 100644 index 000000000..27ee43e73 --- /dev/null +++ b/json_test.go @@ -0,0 +1,101 @@ +package echo + +import ( + testify "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// Note this test is deliberately simple as there's not a lot to test. +// Just need to ensure it writes JSONs. The heavy work is done by the context methods. +func TestDefaultJSONCodec_Encode(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodPost, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec).(*context) + + assert := testify.New(t) + + // Echo + assert.Equal(e, c.Echo()) + + // Request + assert.NotNil(c.Request()) + + // Response + assert.NotNil(c.Response()) + + //-------- + // Default JSON encoder + //-------- + + enc := new(DefaultJSONSerializer) + + err := enc.Serialize(c, user{1, "Jon Snow"}, "") + if assert.NoError(err) { + assert.Equal(userJSON+"\n", rec.Body.String()) + } + + req = httptest.NewRequest(http.MethodPost, "/", nil) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec).(*context) + err = enc.Serialize(c, user{1, "Jon Snow"}, " ") + if assert.NoError(err) { + assert.Equal(userJSONPretty+"\n", rec.Body.String()) + } +} + +// Note this test is deliberately simple as there's not a lot to test. +// Just need to ensure it writes JSONs. The heavy work is done by the context methods. +func TestDefaultJSONCodec_Decode(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec).(*context) + + assert := testify.New(t) + + // Echo + assert.Equal(e, c.Echo()) + + // Request + assert.NotNil(c.Request()) + + // Response + assert.NotNil(c.Response()) + + //-------- + // Default JSON encoder + //-------- + + enc := new(DefaultJSONSerializer) + + var u = user{} + err := enc.Deserialize(c, &u) + if assert.NoError(err) { + assert.Equal(u, user{ID: 1, Name: "Jon Snow"}) + } + + var userUnmarshalSyntaxError = user{} + req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(invalidContent)) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec).(*context) + err = enc.Deserialize(c, &userUnmarshalSyntaxError) + assert.IsType(&HTTPError{}, err) + assert.EqualError(err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value") + + var userUnmarshalTypeError = struct { + ID string `json:"id"` + Name string `json:"name"` + }{} + + req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec).(*context) + err = enc.Deserialize(c, &userUnmarshalTypeError) + assert.IsType(&HTTPError{}, err) + assert.EqualError(err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string") + +} From 02de901d7ef52ce2005c88fc946e606decb72a79 Mon Sep 17 00:00:00 2001 From: Pablo Andres Fuente Date: Fri, 9 Jul 2021 23:36:03 -0300 Subject: [PATCH 147/446] Fixing Timeout middleware Context propagation (#1910) This will let middlewares/handler later on the chain to properly handle the Timeout middleware Context cancellation. Fixes #1909 --- middleware/timeout.go | 7 +++++- middleware/timeout_test.go | 44 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/middleware/timeout.go b/middleware/timeout.go index fb8ae4219..731136541 100644 --- a/middleware/timeout.go +++ b/middleware/timeout.go @@ -2,9 +2,10 @@ package middleware import ( "context" - "github.com/labstack/echo/v4" "net/http" "time" + + "github.com/labstack/echo/v4" ) type ( @@ -87,6 +88,10 @@ type echoHandlerFuncWrapper struct { } func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + // replace echo.Context Request with the one provided by TimeoutHandler to let later middlewares/handler on the chain + // handle properly it's cancellation + t.ctx.SetRequest(r) + // replace writer with TimeoutHandler custom one. This will guarantee that // `writes by h to its ResponseWriter will return ErrHandlerTimeout.` originalWriter := t.ctx.Response().Writer diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index f9f1826be..80891e829 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -2,8 +2,6 @@ package middleware import ( "errors" - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" "net/url" @@ -11,6 +9,9 @@ import ( "strings" "testing" "time" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" ) func TestTimeoutSkipper(t *testing.T) { @@ -273,3 +274,42 @@ func TestTimeoutWithDefaultErrorMessage(t *testing.T) { assert.Equal(t, http.StatusServiceUnavailable, rec.Code) assert.Equal(t, `Timeout

Timeout

`, rec.Body.String()) } + +func TestTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) { + t.Parallel() + + timeout := 1 * time.Millisecond + m := TimeoutWithConfig(TimeoutConfig{ + Timeout: timeout, + ErrorMessage: "Timeout! change me", + }) + + handlerFinishedExecution := make(chan bool) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + stopChan := make(chan struct{}) + err := m(func(c echo.Context) error { + // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) + // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output + // difference over 500microseconds (0.5millisecond) response seems to be reliable + <-stopChan + + // The Request Context should have a Deadline set by http.TimeoutHandler + if _, ok := c.Request().Context().Deadline(); !ok { + assert.Fail(t, "No timeout set on Request Context") + } + handlerFinishedExecution <- c.Request().Context().Err() == nil + return c.String(http.StatusOK, "Hello, World!") + })(c) + stopChan <- struct{}{} + + assert.NoError(t, err) + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) + assert.Equal(t, "Timeout! change me", rec.Body.String()) + assert.False(t, <-handlerFinishedExecution) +} From 58366f93e60c740da21d86584ddbdad024d98a9e Mon Sep 17 00:00:00 2001 From: Martti T Date: Mon, 12 Jul 2021 22:35:47 +0300 Subject: [PATCH 148/446] Update version and changelog for 4.4.0 (#1919) --- CHANGELOG.md | 15 +++++++++++++++ echo.go | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f4a74760f..892d70957 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,20 @@ # Changelog +## v4.4.0 - 2021-07-12 + +**Fixes** + +* Split HeaderXForwardedFor header only by comma [#1878](https://github.com/labstack/echo/pull/1878) +* Fix Timeout middleware Context propagation [#1910](https://github.com/labstack/echo/pull/1910) + +**Enhancements** + +* Bind data using headers as source [#1866](https://github.com/labstack/echo/pull/1866) +* Adds JWTConfig.ParseTokenFunc to JWT middleware to allow different libraries implementing JWT parsing. [#1887](https://github.com/labstack/echo/pull/1887) +* Adding tests for Echo#Host [#1895](https://github.com/labstack/echo/pull/1895) +* Adds RequestIDHandler function to RequestID middleware [#1898](https://github.com/labstack/echo/pull/1898) +* Allow for custom JSON encoding implementations [#1880](https://github.com/labstack/echo/pull/1880) + ## v4.3.0 - 2021-05-08 **Important notes** diff --git a/echo.go b/echo.go index afb1e27dc..406e806bc 100644 --- a/echo.go +++ b/echo.go @@ -241,7 +241,7 @@ const ( const ( // Version of Echo - Version = "4.3.0" + Version = "4.4.0" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 647af2acec9ba160cb247e5c26eb0c671a0c1f2a Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 1 Aug 2021 11:12:23 +0300 Subject: [PATCH 149/446] JWT middleware has been changed from `github.com/dgrijalva/jwt-go` to github.com/golang-jwt/jwt` due former library being unmaintained and having security issues. NOTE: `golang-jwt/jwt` now only supports last 2 Go releases. So 1.15+ For detailed information please read https://github.com/labstack/echo/discussions/1940 --- go.mod | 2 +- go.sum | 4 ++-- middleware/jwt.go | 6 ++++-- middleware/jwt_test.go | 6 ++++-- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index 2510d10c6..9cd3529bd 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/labstack/echo/v4 go 1.15 require ( - github.com/dgrijalva/jwt-go v3.2.0+incompatible + github.com/golang-jwt/jwt v3.2.2+incompatible github.com/labstack/gommon v0.3.0 github.com/mattn/go-colorable v0.1.8 // indirect github.com/stretchr/testify v1.4.0 diff --git a/go.sum b/go.sum index d18f10fb6..027e96600 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= -github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0= github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= diff --git a/middleware/jwt.go b/middleware/jwt.go index bce478743..c2e7c06d4 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -1,3 +1,5 @@ +// +build go1.15 + package middleware import ( @@ -7,7 +9,7 @@ import ( "reflect" "strings" - "github.com/dgrijalva/jwt-go" + "github.com/golang-jwt/jwt" "github.com/labstack/echo/v4" ) @@ -88,7 +90,7 @@ type ( // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token // parsing fails or parsed token is invalid. - // Defaults to implementation using `github.com/dgrijalva/jwt-go` as JWT implementation library + // Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library ParseTokenFunc func(auth string, c echo.Context) (interface{}, error) } diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 9af4c83d8..393fd93d3 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -1,3 +1,5 @@ +// +build go1.15 + package middleware import ( @@ -9,7 +11,7 @@ import ( "strings" "testing" - "github.com/dgrijalva/jwt-go" + "github.com/golang-jwt/jwt" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) @@ -562,7 +564,7 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { return c.String(http.StatusTeapot, "test") }) - // example of minimal custom ParseTokenFunc implementation. Allows you to use different versions of `github.com/dgrijalva/jwt-go` + // example of minimal custom ParseTokenFunc implementation. Allows you to use different versions of `github.com/golang-jwt/jwt` // with current JWT middleware signingKey := []byte("secret") From 5b8fa6979f5803ff091e24c9a0331aa1b38de302 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 1 Aug 2021 11:14:47 +0300 Subject: [PATCH 150/446] Update version and changelog for 4.5.0 --- CHANGELOG.md | 23 +++++++++++++++++++++++ echo.go | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 892d70957..02eb36fd4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,28 @@ # Changelog +## v4.5.0 - 2021-08-01 + +**Important notes** + +A **BREAKING CHANGE** is introduced for JWT middleware users. +The JWT library used for the JWT middleware had to be changed from [github.com/dgrijalva/jwt-go](https://github.com/dgrijalva/jwt-go) to +[github.com/golang-jwt/jwt](https://github.com/golang-jwt/jwt) due former library being unmaintained and affected by security +issues. +The [github.com/golang-jwt/jwt](https://github.com/golang-jwt/jwt) project is a drop-in replacement, but supports only the latest 2 Go versions. +So for JWT middleware users Go 1.15+ is required. For detailed information please read [#1940](https://github.com/labstack/echo/discussions/) + +To change the library imports in all .go files in your project replace all occurrences of `dgrijalva/jwt-go` with `golang-jwt/jwt`. + +For Linux CLI you can use: +```bash +find -type f -name "*.go" -exec sed -i "s/dgrijalva\/jwt-go/golang-jwt\/jwt/g" {} \; +go mod tidy +``` + +**Fixes** + +* Change JWT library to `github.com/golang-jwt/jwt` [#1946](https://github.com/labstack/echo/pull/1946) + ## v4.4.0 - 2021-07-12 **Fixes** diff --git a/echo.go b/echo.go index 406e806bc..246a62256 100644 --- a/echo.go +++ b/echo.go @@ -241,7 +241,7 @@ const ( const ( // Version of Echo - Version = "4.4.0" + Version = "4.5.0" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From fcda0e8840fcac871a60e3e2a2f314e52072e1f8 Mon Sep 17 00:00:00 2001 From: Kaushal Rohit Date: Tue, 20 Jul 2021 10:36:23 +0530 Subject: [PATCH 151/446] Add Cookie to KeyAuth middleware's KeyLookup --- middleware/key_auth.go | 15 +++++++++++++++ middleware/key_auth_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/middleware/key_auth.go b/middleware/key_auth.go index fd169aa2c..54f3b47f3 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -2,6 +2,7 @@ package middleware import ( "errors" + "fmt" "net/http" "strings" @@ -21,6 +22,7 @@ type ( // - "header:" // - "query:" // - "form:" + // - "cookie:" KeyLookup string `yaml:"key_lookup"` // AuthScheme to be used in the Authorization header. @@ -91,6 +93,8 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { extractor = keyFromQuery(parts[1]) case "form": extractor = keyFromForm(parts[1]) + case "cookie": + extractor = keyFromCookie(parts[1]) } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -164,3 +168,14 @@ func keyFromForm(param string) keyExtractor { return key, nil } } + +// keyFromCookie returns a `keyExtractor` that extracts key from the form. +func keyFromCookie(cookieName string) keyExtractor { + return func(c echo.Context) (string, error) { + key, err := c.Cookie(cookieName) + if err != nil { + return "", fmt.Errorf("missing key in cookies: %w", err) + } + return key.Value, nil + } +} diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go index 476b402d9..0cc513ab0 100644 --- a/middleware/key_auth_test.go +++ b/middleware/key_auth_test.go @@ -157,6 +157,30 @@ func TestKeyAuthWithConfig(t *testing.T) { expectHandlerCalled: false, expectError: "code=400, message=missing key in the form", }, + { + name: "ok, custom key lookup, cookie", + givenRequest: func(req *http.Request) { + req.AddCookie(&http.Cookie{ + Name: "key", + Value: "valid-key", + }) + q := req.URL.Query() + q.Add("key", "valid-key") + req.URL.RawQuery = q.Encode() + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "cookie:key" + }, + expectHandlerCalled: true, + }, + { + name: "nok, custom key lookup, missing cookie param", + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "cookie:key" + }, + expectHandlerCalled: false, + expectError: "code=400, message=missing key in cookies: http: named cookie not present", + }, { name: "nok, custom errorHandler, error from extractor", whenConfig: func(conf *KeyAuthConfig) { From 499097e061fe1757b9734d9892ec25de78f56e53 Mon Sep 17 00:00:00 2001 From: Philipp Thun Date: Mon, 9 Aug 2021 17:21:13 +0200 Subject: [PATCH 152/446] Ignore case of auth scheme in request header Some clients send an authorization header containing the "bearer" keyword in lower case. This led to echo responding with "missing or malformed jwt". Request.BasicAuth (net/http) ignores the basic auth scheme's case since a while: https://go-review.googlesource.com/c/go/+/111516/ --- middleware/jwt.go | 2 +- middleware/jwt_test.go | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/middleware/jwt.go b/middleware/jwt.go index c2e7c06d4..21e33ab82 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -295,7 +295,7 @@ func jwtFromHeader(header string, authScheme string) jwtExtractor { return func(c echo.Context) (string, error) { auth := c.Request().Header.Get(header) l := len(authScheme) - if len(auth) > l+1 && auth[:l] == authScheme { + if len(auth) > l+1 && strings.EqualFold(auth[:l], authScheme) { return auth[l+1:], nil } return "", ErrJWTMissing diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 393fd93d3..5f36ce0a5 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -261,6 +261,11 @@ func TestJWT(t *testing.T) { expErrCode: http.StatusUnauthorized, info: "Token verification does not pass using a user-defined KeyFunc", }, + { + hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token, + config: JWTConfig{SigningKey: validKey}, + info: "Valid JWT with lower case AuthScheme", + }, } { if tc.reqURL == "" { tc.reqURL = "/" From 7d41537e70ce8e3e25c2e8199c798c7cfe3d4299 Mon Sep 17 00:00:00 2001 From: Mohammad Alian Date: Thu, 12 Aug 2021 19:56:04 +0900 Subject: [PATCH 153/446] return first if response is already committed in DefaultHTTPErrorHandler --- echo.go | 21 ++++++++++++--------- echo_test.go | 17 +++++++++++++++++ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/echo.go b/echo.go index 246a62256..a28fa0c1a 100644 --- a/echo.go +++ b/echo.go @@ -358,6 +358,11 @@ func (e *Echo) Routers() map[string]*Router { // DefaultHTTPErrorHandler is the default HTTP error handler. It sends a JSON response // with status code. func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { + + if c.Response().Committed { + return + } + he, ok := err.(*HTTPError) if ok { if he.Internal != nil { @@ -384,15 +389,13 @@ func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { } // Send response - if !c.Response().Committed { - if c.Request().Method == http.MethodHead { // Issue #608 - err = c.NoContent(he.Code) - } else { - err = c.JSON(code, message) - } - if err != nil { - e.Logger.Error(err) - } + if c.Request().Method == http.MethodHead { // Issue #608 + err = c.NoContent(he.Code) + } else { + err = c.JSON(code, message) + } + if err != nil { + e.Logger.Error(err) } } diff --git a/echo_test.go b/echo_test.go index dc553490b..f28915864 100644 --- a/echo_test.go +++ b/echo_test.go @@ -1124,6 +1124,15 @@ func TestDefaultHTTPErrorHandler(t *testing.T) { "error": "stackinfo", }) }) + e.Any("/early-return", func(c Context) error { + c.String(http.StatusOK, "OK") + return errors.New("ERROR") + }) + e.GET("/internal-error", func(c Context) error { + err := errors.New("internal error message body") + return NewHTTPError(http.StatusBadRequest).SetInternal(err) + }) + // With Debug=true plain response contains error message c, b := request(http.MethodGet, "/plain", e) assert.Equal(t, http.StatusInternalServerError, c) @@ -1136,6 +1145,14 @@ func TestDefaultHTTPErrorHandler(t *testing.T) { c, b = request(http.MethodGet, "/servererror", e) assert.Equal(t, http.StatusInternalServerError, c) assert.Equal(t, "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", b) + // if the body is already set HTTPErrorHandler should not add anything to response body + c, b = request(http.MethodGet, "/early-return", e) + assert.Equal(t, http.StatusOK, c) + assert.Equal(t, "OK", b) + // internal error should be reflected in the message + c, b = request(http.MethodGet, "/internal-error", e) + assert.Equal(t, http.StatusBadRequest, c) + assert.Equal(t, "{\n \"error\": \"code=400, message=Bad Request, internal=internal error message body\",\n \"message\": \"Bad Request\"\n}\n", b) e.Debug = false // With Debug=false the error response is shortened From 3dfe1a7b61b6caa1b331068cbf876c44fca54eeb Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Sun, 15 Aug 2021 09:42:15 -0700 Subject: [PATCH 154/446] Update README.md --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4dec531a2..364f98ac3 100644 --- a/README.md +++ b/README.md @@ -114,8 +114,11 @@ func hello(c echo.Context) error { ## Credits -- [Vishal Rana](https://github.com/vishr) - Author -- [Nitin Rana](https://github.com/nr17) - Consultant +- [Vishal Rana](https://github.com/vishr) (Author) +- [Nitin Rana](https://github.com/nr17) (Consultant) +- [Roland Lammel](https://github.com/lammel) (Maintainer) +- [Martti T.](https://github.com/aldas) (Maintainer) +- [Pablo Andres Fuente](https://github.com/pafuent) (Maintainer) - [Contributors](https://github.com/labstack/echo/graphs/contributors) ## License From 128cb7fd40f0b8d430f2d8b07a966a40237fde93 Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Sun, 15 Aug 2021 10:17:01 -0700 Subject: [PATCH 155/446] docs: add vishr as a contributor for design (#1958) * docs: update README.md [skip ci] * docs: create .all-contributorsrc [skip ci] Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> --- .all-contributorsrc | 24 ++++++++++++++++++++++++ README.md | 23 +++++++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 .all-contributorsrc diff --git a/.all-contributorsrc b/.all-contributorsrc new file mode 100644 index 000000000..4780e40f5 --- /dev/null +++ b/.all-contributorsrc @@ -0,0 +1,24 @@ +{ + "files": [ + "README.md" + ], + "imageSize": 100, + "commit": false, + "contributors": [ + { + "login": "vishr", + "name": "Vishal Rana", + "avatar_url": "https://avatars.githubusercontent.com/u/314036?v=4", + "profile": "http://vishr.com", + "contributions": [ + "design" + ] + } + ], + "contributorsPerLine": 7, + "projectName": "echo", + "projectOwner": "labstack", + "repoType": "github", + "repoHost": "https://github.com", + "skipCi": true +} diff --git a/README.md b/README.md index 364f98ac3..ed03b506e 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,7 @@ + +[![All Contributors](https://img.shields.io/badge/all_contributors-1-orange.svg?style=flat-square)](#contributors-) + [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge) [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4) @@ -124,3 +127,23 @@ func hello(c echo.Context) error { ## License [MIT](https://github.com/labstack/echo/blob/master/LICENSE) + +## Contributors ✨ + +Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)): + + + + + + + + +

Vishal Rana

🎨
+ + + + + + +This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome! \ No newline at end of file From 8b162675bc01f791203b37a1b8b5ba7b47044c51 Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Sun, 15 Aug 2021 10:17:44 -0700 Subject: [PATCH 156/446] Update README.md --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index ed03b506e..1174d84e9 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,6 @@ func hello(c echo.Context) error { - [Roland Lammel](https://github.com/lammel) (Maintainer) - [Martti T.](https://github.com/aldas) (Maintainer) - [Pablo Andres Fuente](https://github.com/pafuent) (Maintainer) -- [Contributors](https://github.com/labstack/echo/graphs/contributors) ## License @@ -146,4 +145,4 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d -This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome! \ No newline at end of file +This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome! From cbf9c6baaa04ec91cc26f76e1411783e928ace64 Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Sun, 15 Aug 2021 10:20:24 -0700 Subject: [PATCH 157/446] docs: add vishr as a contributor for maintenance (#1959) * docs: update README.md [skip ci] * docs: update .all-contributorsrc [skip ci] Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> --- .all-contributorsrc | 3 ++- README.md | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.all-contributorsrc b/.all-contributorsrc index 4780e40f5..604c95456 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -11,7 +11,8 @@ "avatar_url": "https://avatars.githubusercontent.com/u/314036?v=4", "profile": "http://vishr.com", "contributions": [ - "design" + "design", + "maintenance" ] } ], diff --git a/README.md b/README.md index 1174d84e9..d5a860bcd 100644 --- a/README.md +++ b/README.md @@ -136,7 +136,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d - +

Vishal Rana

🎨

Vishal Rana

🎨 🚧
From 59e5078e66cc7353e730c53c0a9648f2237d20ad Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Sun, 15 Aug 2021 10:21:52 -0700 Subject: [PATCH 158/446] docs: add aldas as a contributor for maintenance (#1960) * docs: update README.md [skip ci] * docs: update .all-contributorsrc [skip ci] Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> --- .all-contributorsrc | 9 +++++++++ README.md | 3 ++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/.all-contributorsrc b/.all-contributorsrc index 604c95456..b195823e2 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -14,6 +14,15 @@ "design", "maintenance" ] + }, + { + "login": "aldas", + "name": "Martti T.", + "avatar_url": "https://avatars.githubusercontent.com/u/2320301?v=4", + "profile": "https://github.com/aldas", + "contributions": [ + "maintenance" + ] } ], "contributorsPerLine": 7, diff --git a/README.md b/README.md index d5a860bcd..9f1ede77b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -[![All Contributors](https://img.shields.io/badge/all_contributors-1-orange.svg?style=flat-square)](#contributors-) +[![All Contributors](https://img.shields.io/badge/all_contributors-2-orange.svg?style=flat-square)](#contributors-) [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge) @@ -137,6 +137,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d +

Vishal Rana

🎨 🚧

Martti T.

🚧
From 6b89450ce3a0bdfc4196e1d77ee79b58f29d5551 Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Sun, 15 Aug 2021 10:23:38 -0700 Subject: [PATCH 159/446] docs: add pafuent as a contributor for maintenance (#1961) * docs: update README.md [skip ci] * docs: update .all-contributorsrc [skip ci] Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> --- .all-contributorsrc | 9 +++++++++ README.md | 3 ++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/.all-contributorsrc b/.all-contributorsrc index b195823e2..771416424 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -23,6 +23,15 @@ "contributions": [ "maintenance" ] + }, + { + "login": "pafuent", + "name": "Pablo Andres Fuente", + "avatar_url": "https://avatars.githubusercontent.com/u/6979945?v=4", + "profile": "https://github.com/pafuent", + "contributions": [ + "maintenance" + ] } ], "contributorsPerLine": 7, diff --git a/README.md b/README.md index 9f1ede77b..2d731788d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -[![All Contributors](https://img.shields.io/badge/all_contributors-2-orange.svg?style=flat-square)](#contributors-) +[![All Contributors](https://img.shields.io/badge/all_contributors-3-orange.svg?style=flat-square)](#contributors-) [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge) @@ -138,6 +138,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
Vishal Rana

🎨 🚧
Martti T.

🚧 +
Pablo Andres Fuente

🚧 From 560fca0d499a6aec85c8b41fa04df749958315b7 Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Sun, 15 Aug 2021 10:26:37 -0700 Subject: [PATCH 160/446] Update .all-contributorsrc --- .all-contributorsrc | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/.all-contributorsrc b/.all-contributorsrc index 771416424..a1eaa0976 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -11,7 +11,15 @@ "avatar_url": "https://avatars.githubusercontent.com/u/314036?v=4", "profile": "http://vishr.com", "contributions": [ - "design", + "design" + ] + }, + { + "login": "lammel", + "name": "Roland Lammel", + "avatar_url": "https://avatars.githubusercontent.com/u/43678?v=4", + "profile": "https://github.com/lammel", + "contributions": [ "maintenance" ] }, From ac7c1346e87d686148a405c7635a13c380f0891c Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Sun, 15 Aug 2021 10:37:48 -0700 Subject: [PATCH 161/446] docs: add aldas as a contributor for review (#1962) * docs: update README.md [skip ci] * docs: update .all-contributorsrc [skip ci] * docs: update README.md [skip ci] * docs: update .all-contributorsrc [skip ci] Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> Co-authored-by: Vishal Rana --- .all-contributorsrc | 10 ++++++++++ README.md | 24 +----------------------- 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/.all-contributorsrc b/.all-contributorsrc index a1eaa0976..14fe31b84 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -40,6 +40,16 @@ "contributions": [ "maintenance" ] + }, + { + "login": "aldas", + "name": "Martti T.", + "avatar_url": "https://avatars.githubusercontent.com/u/2320301?v=4", + "profile": "https://github.com/aldas", + "contributions": [ + "maintenance", + "review" + ] } ], "contributorsPerLine": 7, diff --git a/README.md b/README.md index 2d731788d..611869d01 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -[![All Contributors](https://img.shields.io/badge/all_contributors-3-orange.svg?style=flat-square)](#contributors-) +[![All Contributors](https://img.shields.io/badge/all_contributors-2-orange.svg?style=flat-square)](#contributors-) [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge) @@ -126,25 +126,3 @@ func hello(c echo.Context) error { ## License [MIT](https://github.com/labstack/echo/blob/master/LICENSE) - -## Contributors ✨ - -Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)): - - - - - - - - - - -

Vishal Rana

🎨 🚧

Martti T.

🚧

Pablo Andres Fuente

🚧
- - - - - - -This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome! From eaba4c3d398019ee70a370f2217698c8da19ccaf Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Sun, 15 Aug 2021 10:38:48 -0700 Subject: [PATCH 162/446] Delete .all-contributorsrc --- .all-contributorsrc | 61 --------------------------------------------- 1 file changed, 61 deletions(-) delete mode 100644 .all-contributorsrc diff --git a/.all-contributorsrc b/.all-contributorsrc deleted file mode 100644 index 14fe31b84..000000000 --- a/.all-contributorsrc +++ /dev/null @@ -1,61 +0,0 @@ -{ - "files": [ - "README.md" - ], - "imageSize": 100, - "commit": false, - "contributors": [ - { - "login": "vishr", - "name": "Vishal Rana", - "avatar_url": "https://avatars.githubusercontent.com/u/314036?v=4", - "profile": "http://vishr.com", - "contributions": [ - "design" - ] - }, - { - "login": "lammel", - "name": "Roland Lammel", - "avatar_url": "https://avatars.githubusercontent.com/u/43678?v=4", - "profile": "https://github.com/lammel", - "contributions": [ - "maintenance" - ] - }, - { - "login": "aldas", - "name": "Martti T.", - "avatar_url": "https://avatars.githubusercontent.com/u/2320301?v=4", - "profile": "https://github.com/aldas", - "contributions": [ - "maintenance" - ] - }, - { - "login": "pafuent", - "name": "Pablo Andres Fuente", - "avatar_url": "https://avatars.githubusercontent.com/u/6979945?v=4", - "profile": "https://github.com/pafuent", - "contributions": [ - "maintenance" - ] - }, - { - "login": "aldas", - "name": "Martti T.", - "avatar_url": "https://avatars.githubusercontent.com/u/2320301?v=4", - "profile": "https://github.com/aldas", - "contributions": [ - "maintenance", - "review" - ] - } - ], - "contributorsPerLine": 7, - "projectName": "echo", - "projectOwner": "labstack", - "repoType": "github", - "repoHost": "https://github.com", - "skipCi": true -} From d793521d1c5041105d0516642667731f445de33e Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Sun, 15 Aug 2021 10:44:14 -0700 Subject: [PATCH 163/446] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 611869d01..d9ea709e4 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,7 @@ func hello(c echo.Context) error { - [Roland Lammel](https://github.com/lammel) (Maintainer) - [Martti T.](https://github.com/aldas) (Maintainer) - [Pablo Andres Fuente](https://github.com/pafuent) (Maintainer) +- [Contributors](https://github.com/labstack/echo/graphs/contributors) ## License From 7f502b1ff10913aeab28abab64f2bf45952c768d Mon Sep 17 00:00:00 2001 From: pwli Date: Mon, 23 Aug 2021 01:25:09 +0800 Subject: [PATCH 164/446] try to fix #1905 and add some notes (#1947) * fix 1905 and add some notes (cherry picked from commit 9d96199e2dbb6d4374b5a8b6e16fdc0b0d7cb3a7) * fix typo (cherry picked from commit e8ea1bcabb6cdb50b06e1ec0e7c3cce44287d8b7) * Add tests for timeout middleware with full http.Server stack running. Add warning about middleware * Fix example Co-authored-by: lipengwei Co-authored-by: toimtoimtoim --- middleware/timeout.go | 60 ++++++++++++++++- middleware/timeout_test.go | 128 +++++++++++++++++++++++++++++++++++++ 2 files changed, 185 insertions(+), 3 deletions(-) diff --git a/middleware/timeout.go b/middleware/timeout.go index 731136541..768ef8d70 100644 --- a/middleware/timeout.go +++ b/middleware/timeout.go @@ -8,6 +8,53 @@ import ( "github.com/labstack/echo/v4" ) +// --------------------------------------------------------------------------------------------------------------- +// WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING +// WARNING: Timeout middleware causes more problems than it solves. +// WARNING: This middleware should be first middleware as it messes with request Writer and could cause data race if +// it is in other position +// +// Depending on out requirements you could be better of setting timeout to context and +// check its deadline from handler. +// +// For example: create middleware to set timeout to context +// func RequestTimeout(timeout time.Duration) echo.MiddlewareFunc { +// return func(next echo.HandlerFunc) echo.HandlerFunc { +// return func(c echo.Context) error { +// timeoutCtx, cancel := context.WithTimeout(c.Request().Context(), timeout) +// c.SetRequest(c.Request().WithContext(timeoutCtx)) +// defer cancel() +// return next(c) +// } +// } +//} +// +// Create handler that checks for context deadline and runs actual task in separate coroutine +// Note: separate coroutine may not be even if you do not want to process continue executing and +// just want to stop long-running handler to stop and you are using "context aware" methods (ala db queries with ctx) +// e.GET("/", func(c echo.Context) error { +// +// doneCh := make(chan error) +// go func(ctx context.Context) { +// doneCh <- myPossiblyLongRunningBackgroundTaskWithCtx(ctx) +// }(c.Request().Context()) +// +// select { // wait for task to finish or context to timeout/cancelled +// case err := <-doneCh: +// if err != nil { +// return err +// } +// return c.String(http.StatusOK, "OK") +// case <-c.Request().Context().Done(): +// if c.Request().Context().Err() == context.DeadlineExceeded { +// return c.String(http.StatusServiceUnavailable, "timeout") +// } +// return c.Request().Context().Err() +// } +// +// }) +// + type ( // TimeoutConfig defines the config for Timeout middleware. TimeoutConfig struct { @@ -116,13 +163,20 @@ func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Reques // we restore original writer only for cases we did not timeout. On timeout we have already sent response to client // and should not anymore send additional headers/data // so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body - t.ctx.Response().Writer = originalWriter if err != nil { - // call global error handler to write error to the client. This is needed or `http.TimeoutHandler` will send status code by itself - // and after that our tries to write status code will not work anymore + // Error must be written into Writer created in `http.TimeoutHandler` so to get Response into `commited` state. + // So call global error handler to write error to the client. This is needed or `http.TimeoutHandler` will send + // status code by itself and after that our tries to write status code will not work anymore and/or create errors in + // log about `superfluous response.WriteHeader call from` t.ctx.Error(err) // we pass error from handler to middlewares up in handler chain to act on it if needed. But this means that // global error handler is probably be called twice as `t.ctx.Error` already does that. + + // NB: later call of the global error handler or middlewares will not take any effect, as echo.Response will be + // already marked as `committed` because we called global error handler above. + t.ctx.Response().Writer = originalWriter // make sure we restore before we signal original coroutine about the error t.errChan <- err + return } + t.ctx.Response().Writer = originalWriter } diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index 80891e829..aa6402b8d 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -1,7 +1,12 @@ package middleware import ( + "bytes" "errors" + "fmt" + "io/ioutil" + "log" + "net" "net/http" "net/http/httptest" "net/url" @@ -313,3 +318,126 @@ func TestTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) { assert.Equal(t, "Timeout! change me", rec.Body.String()) assert.False(t, <-handlerFinishedExecution) } + +func TestTimeoutWithFullEchoStack(t *testing.T) { + // test timeout with full http server stack running, do see what http.Server.ErrorLog contains + var testCases = []struct { + name string + whenPath string + expectStatusCode int + expectResponse string + expectLogContains []string + expectLogNotContains []string + }{ + { + name: "404 - write response in global error handler", + whenPath: "/404", + expectResponse: "{\"message\":\"Not Found\"}\n", + expectStatusCode: http.StatusNotFound, + expectLogNotContains: []string{"echo:http: superfluous response.WriteHeader call from"}, + expectLogContains: []string{`"status":404,"error":"code=404, message=Not Found"`}, + }, + { + name: "418 - write response in handler", + whenPath: "/", + expectResponse: "{\"message\":\"OK\"}\n", + expectStatusCode: http.StatusTeapot, + expectLogNotContains: []string{"echo:http: superfluous response.WriteHeader call from"}, + expectLogContains: []string{`"status":418,"error":"",`}, + }, + { + name: "503 - handler timeouts, write response in timeout middleware", + whenPath: "/?delay=50ms", + expectResponse: "Timeout

Timeout

", + expectStatusCode: http.StatusServiceUnavailable, + expectLogNotContains: []string{ + "echo:http: superfluous response.WriteHeader call from", + "{", // means that logger was not called. + }, + }, + } + + e := echo.New() + + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) + + // NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first + // FIXME: I have no idea how to fix this without adding mutexes. + e.Use(TimeoutWithConfig(TimeoutConfig{ + Timeout: 15 * time.Millisecond, + })) + e.Use(Logger()) + e.Use(Recover()) + + e.GET("/", func(c echo.Context) error { + var delay time.Duration + if err := echo.QueryParamsBinder(c).Duration("delay", &delay).BindError(); err != nil { + return err + } + if delay > 0 { + time.Sleep(delay) + } + return c.JSON(http.StatusTeapot, map[string]string{"message": "OK"}) + }) + + server, addr, err := startServer(e) + if err != nil { + assert.NoError(t, err) + return + } + defer server.Close() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf.Reset() // this is design this can not be run in parallel + + res, err := http.Get(fmt.Sprintf("http://%v%v", addr, tc.whenPath)) + if err != nil { + assert.NoError(t, err) + return + } + + assert.Equal(t, tc.expectStatusCode, res.StatusCode) + if body, err := ioutil.ReadAll(res.Body); err == nil { + assert.Equal(t, tc.expectResponse, string(body)) + } else { + assert.Fail(t, err.Error()) + } + + logged := buf.String() + for _, subStr := range tc.expectLogContains { + assert.True(t, strings.Contains(logged, subStr)) + } + for _, subStr := range tc.expectLogNotContains { + assert.False(t, strings.Contains(logged, subStr)) + } + }) + } +} + +func startServer(e *echo.Echo) (*http.Server, string, error) { + l, err := net.Listen("tcp", ":0") + if err != nil { + return nil, "", err + } + + s := http.Server{ + Handler: e, + ErrorLog: log.New(e.Logger.Output(), "echo:", 0), + } + + errCh := make(chan error) + go func() { + if err := s.Serve(l); err != http.ErrServerClosed { + errCh <- err + } + }() + + select { + case <-time.After(10 * time.Millisecond): + return &s, l.Addr().String(), nil + case err := <-errCh: + return nil, "", err + } +} From 1e7e67cddb1e5f34bbcb89906b0ef4f8f7e755cc Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 14 Sep 2021 20:57:47 +0300 Subject: [PATCH 165/446] Added request logger middleware which helps to use custom logger library for logging requests (#1980) Added request logger middleware which helps to use custom logger library for logging requests. --- echo.go | 5 + go.mod | 9 +- go.sum | 22 +- middleware/logger_test.go | 73 ++++++ middleware/request_logger.go | 310 ++++++++++++++++++++++ middleware/request_logger_test.go | 417 ++++++++++++++++++++++++++++++ 6 files changed, 823 insertions(+), 13 deletions(-) create mode 100644 middleware/request_logger.go create mode 100644 middleware/request_logger_test.go diff --git a/echo.go b/echo.go index a28fa0c1a..3292aa1a9 100644 --- a/echo.go +++ b/echo.go @@ -357,6 +357,11 @@ func (e *Echo) Routers() map[string]*Router { // DefaultHTTPErrorHandler is the default HTTP error handler. It sends a JSON response // with status code. +// +// NOTE: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). +// When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from +// handler. Then the error that global error handler received will be ignored because we have already "commited" the +// response and status code header has been sent to the client. func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { if c.Response().Committed { diff --git a/go.mod b/go.mod index 9cd3529bd..d2c884abd 100644 --- a/go.mod +++ b/go.mod @@ -6,11 +6,12 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/labstack/gommon v0.3.0 github.com/mattn/go-colorable v0.1.8 // indirect + github.com/mattn/go-isatty v0.0.14 // indirect github.com/stretchr/testify v1.4.0 github.com/valyala/fasttemplate v1.2.1 - golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 - golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 - golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57 // indirect - golang.org/x/text v0.3.6 // indirect + golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 + golang.org/x/net v0.0.0-20210907225631-ff17edfbf26d + golang.org/x/sys v0.0.0-20210908160347-a851e7ddeee0 // indirect + golang.org/x/text v0.3.7 // indirect golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 ) diff --git a/go.sum b/go.sum index 027e96600..92dbd77db 100644 --- a/go.sum +++ b/go.sum @@ -9,8 +9,9 @@ github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= -github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -21,23 +22,26 @@ github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyC github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w= -golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= +golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0= -golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20210907225631-ff17edfbf26d h1:kuk8nKPQ25KCDODLCDXt99tnTVeOyOM8HGvtJ0NzAvw= +golang.org/x/net v0.0.0-20210907225631-ff17edfbf26d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57 h1:F5Gozwx4I1xtr/sr/8CFbb57iKi3297KFs0QDbGN60A= -golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210908160347-a851e7ddeee0 h1:6xxeVXiyYpF8WCTnKKCbjnEdsrwjZYY8TOuk7xP0chg= +golang.org/x/sys v0.0.0-20210908160347-a851e7ddeee0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/middleware/logger_test.go b/middleware/logger_test.go index 4d4515b19..394f62712 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -171,3 +171,76 @@ func TestLoggerCustomTimestamp(t *testing.T) { _, err := time.Parse(customTimeFormat, loggedTime) assert.Error(t, err) } + +func BenchmarkLoggerWithConfig_withoutMapFields(b *testing.B) { + e := echo.New() + + buf := new(bytes.Buffer) + mw := LoggerWithConfig(LoggerConfig{ + Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` + + `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + + `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` + + `"bytes_out":${bytes_out}, "protocol":"${protocol}"}` + "\n", + Output: buf, + })(func(c echo.Context) error { + c.Request().Header.Set(echo.HeaderXRequestID, "123") + c.FormValue("to force parse form") + return c.String(http.StatusTeapot, "OK") + }) + + f := make(url.Values) + f.Set("csrf", "token") + f.Add("multiple", "1") + f.Add("multiple", "2") + req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", strings.NewReader(f.Encode())) + req.Header.Set("Referer", "https://echo.labstack.com/") + req.Header.Set("User-Agent", "curl/7.68.0") + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + mw(c) + buf.Reset() + } +} + +func BenchmarkLoggerWithConfig_withMapFields(b *testing.B) { + e := echo.New() + + buf := new(bytes.Buffer) + mw := LoggerWithConfig(LoggerConfig{ + Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` + + `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + + `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` + + `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}", "protocol":"${protocol}"` + + `"us":"${query:username}", "cf":"${form:csrf}", "Referer2":"${header:Referer}"}` + "\n", + Output: buf, + })(func(c echo.Context) error { + c.Request().Header.Set(echo.HeaderXRequestID, "123") + c.FormValue("to force parse form") + return c.String(http.StatusTeapot, "OK") + }) + + f := make(url.Values) + f.Set("csrf", "token") + f.Add("multiple", "1") + f.Add("multiple", "2") + req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", strings.NewReader(f.Encode())) + req.Header.Set("Referer", "https://echo.labstack.com/") + req.Header.Set("User-Agent", "curl/7.68.0") + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + mw(c) + buf.Reset() + } +} diff --git a/middleware/request_logger.go b/middleware/request_logger.go new file mode 100644 index 000000000..7829a1fd1 --- /dev/null +++ b/middleware/request_logger.go @@ -0,0 +1,310 @@ +package middleware + +import ( + "errors" + "github.com/labstack/echo/v4" + "net/http" + "time" +) + +// Example for `fmt.Printf` +// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ +// LogStatus: true, +// LogURI: true, +// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { +// fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status) +// return nil +// }, +// })) +// +// Example for Zerolog (https://github.com/rs/zerolog) +// logger := zerolog.New(os.Stdout) +// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ +// LogURI: true, +// LogStatus: true, +// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { +// logger.Info(). +// Str("URI", v.URI). +// Int("status", v.Status). +// Msg("request") +// +// return nil +// }, +// })) +// +// Example for Zap (https://github.com/uber-go/zap) +// logger, _ := zap.NewProduction() +// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ +// LogURI: true, +// LogStatus: true, +// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { +// logger.Info("request", +// zap.String("URI", v.URI), +// zap.Int("status", v.Status), +// ) +// +// return nil +// }, +// })) +// +// Example for Logrus (https://github.com/sirupsen/logrus) +// log := logrus.New() +// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ +// LogURI: true, +// LogStatus: true, +// LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error { +// log.WithFields(logrus.Fields{ +// "URI": values.URI, +// "status": values.Status, +// }).Info("request") +// +// return nil +// }, +// })) + +// RequestLoggerConfig is configuration for Request Logger middleware. +type RequestLoggerConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // BeforeNextFunc defines a function that is called before next middleware or handler is called in chain. + BeforeNextFunc func(c echo.Context) + // LogValuesFunc defines a function that is called with values extracted by logger from request/response. + // Mandatory. + LogValuesFunc func(c echo.Context, v RequestLoggerValues) error + + // LogLatency instructs logger to record duration it took to execute rest of the handler chain (next(c) call). + LogLatency bool + // LogProtocol instructs logger to extract request protocol (i.e. `HTTP/1.1` or `HTTP/2`) + LogProtocol bool + // LogRemoteIP instructs logger to extract request remote IP. See `echo.Context.RealIP()` for implementation details. + LogRemoteIP bool + // LogHost instructs logger to extract request host value (i.e. `example.com`) + LogHost bool + // LogMethod instructs logger to extract request method value (i.e. `GET` etc) + LogMethod bool + // LogURI instructs logger to extract request URI (i.e. `/list?lang=en&page=1`) + LogURI bool + // LogURIPath instructs logger to extract request URI path part (i.e. `/list`) + LogURIPath bool + // LogRoutePath instructs logger to extract route path part to which request was matched to (i.e. `/user/:id`) + LogRoutePath bool + // LogRequestID instructs logger to extract request ID from request `X-Request-ID` header or response if request did not have value. + LogRequestID bool + // LogReferer instructs logger to extract request referer values. + LogReferer bool + // LogUserAgent instructs logger to extract request user agent values. + LogUserAgent bool + // LogStatus instructs logger to extract response status code. If handler chain returns an echo.HTTPError, + // the status code is extracted from the echo.HTTPError returned + LogStatus bool + // LogError instructs logger to extract error returned from executed handler chain. + LogError bool + // LogContentLength instructs logger to extract content length header value. Note: this value could be different from + // actual request body size as it could be spoofed etc. + LogContentLength bool + // LogResponseSize instructs logger to extract response content length value. Note: when used with Gzip middleware + // this value may not be always correct. + LogResponseSize bool + // LogHeaders instructs logger to extract given list of headers from request. Note: request can contain more than + // one header with same value so slice of values is been logger for each given header. + // + // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header + // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding". + LogHeaders []string + // LogQueryParams instructs logger to extract given list of query parameters from request URI. Note: request can + // contain more than one query parameter with same name so slice of values is been logger for each given query param name. + LogQueryParams []string + // LogFormValues instructs logger to extract given list of form values from request body+URI. Note: request can + // contain more than one form value with same name so slice of values is been logger for each given form value name. + LogFormValues []string + + timeNow func() time.Time +} + +// RequestLoggerValues contains extracted values from logger. +type RequestLoggerValues struct { + // Latency is duration it took to execute rest of the handler chain (next(c) call). + Latency time.Duration + // Protocol is request protocol (i.e. `HTTP/1.1` or `HTTP/2`) + Protocol string + // RemoteIP is request remote IP. See `echo.Context.RealIP()` for implementation details. + RemoteIP string + // Host is request host value (i.e. `example.com`) + Host string + // Method is request method value (i.e. `GET` etc) + Method string + // URI is request URI (i.e. `/list?lang=en&page=1`) + URI string + // URIPath is request URI path part (i.e. `/list`) + URIPath string + // RoutePath is route path part to which request was matched to (i.e. `/user/:id`) + RoutePath string + // RequestID is request ID from request `X-Request-ID` header or response if request did not have value. + RequestID string + // Referer is request referer values. + Referer string + // UserAgent is request user agent values. + UserAgent string + // Status is response status code. Then handler returns an echo.HTTPError then code from there. + Status int + // Error is error returned from executed handler chain. + Error error + // ContentLength is content length header value. Note: this value could be different from actual request body size + // as it could be spoofed etc. + ContentLength string + // ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct. + ResponseSize int64 + // Headers are list of headers from request. Note: request can contain more than one header with same value so slice + // of values is been logger for each given header. + // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header + // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding". + Headers map[string][]string + // QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter + // with same name so slice of values is been logger for each given query param name. + QueryParams map[string][]string + // FormValues are list of form values from request body+URI. Note: request can contain more than one form value with + // same name so slice of values is been logger for each given form value name. + FormValues map[string][]string +} + +// RequestLoggerWithConfig returns a RequestLogger middleware with config. +func RequestLoggerWithConfig(config RequestLoggerConfig) echo.MiddlewareFunc { + mw, err := config.ToMiddleware() + if err != nil { + panic(err) + } + return mw +} + +// ToMiddleware converts RequestLoggerConfig into middleware or returns an error for invalid configuration. +func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { + if config.Skipper == nil { + config.Skipper = DefaultSkipper + } + now = time.Now + if config.timeNow != nil { + now = config.timeNow + } + + if config.LogValuesFunc == nil { + return nil, errors.New("missing LogValuesFunc callback function for request logger middleware") + } + + logHeaders := len(config.LogHeaders) > 0 + headers := append([]string(nil), config.LogHeaders...) + for i, v := range headers { + headers[i] = http.CanonicalHeaderKey(v) + } + + logQueryParams := len(config.LogQueryParams) > 0 + logFormValues := len(config.LogFormValues) > 0 + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + + req := c.Request() + res := c.Response() + start := now() + + if config.BeforeNextFunc != nil { + config.BeforeNextFunc(c) + } + err := next(c) + + v := RequestLoggerValues{} + if config.LogLatency { + v.Latency = now().Sub(start) + } + if config.LogProtocol { + v.Protocol = req.Proto + } + if config.LogRemoteIP { + v.RemoteIP = c.RealIP() + } + if config.LogHost { + v.Host = req.Host + } + if config.LogMethod { + v.Method = req.Method + } + if config.LogURI { + v.URI = req.RequestURI + } + if config.LogURIPath { + p := req.URL.Path + if p == "" { + p = "/" + } + v.URIPath = p + } + if config.LogRoutePath { + v.RoutePath = c.Path() + } + if config.LogRequestID { + id := req.Header.Get(echo.HeaderXRequestID) + if id == "" { + id = res.Header().Get(echo.HeaderXRequestID) + } + v.RequestID = id + } + if config.LogReferer { + v.Referer = req.Referer() + } + if config.LogUserAgent { + v.UserAgent = req.UserAgent() + } + if config.LogStatus { + v.Status = res.Status + if err != nil { + if httpErr, ok := err.(*echo.HTTPError); ok { + v.Status = httpErr.Code + } + } + } + if config.LogError && err != nil { + v.Error = err + } + if config.LogContentLength { + v.ContentLength = req.Header.Get(echo.HeaderContentLength) + } + if config.LogResponseSize { + v.ResponseSize = res.Size + } + if logHeaders { + v.Headers = map[string][]string{} + for _, header := range headers { + if values, ok := req.Header[header]; ok { + v.Headers[header] = values + } + } + } + if logQueryParams { + queryParams := c.QueryParams() + v.QueryParams = map[string][]string{} + for _, param := range config.LogQueryParams { + if values, ok := queryParams[param]; ok { + v.QueryParams[param] = values + } + } + } + if logFormValues { + v.FormValues = map[string][]string{} + for _, formValue := range config.LogFormValues { + if values, ok := req.Form[formValue]; ok { + v.FormValues[formValue] = values + } + } + } + + if errOnLog := config.LogValuesFunc(c, v); errOnLog != nil { + return errOnLog + } + + return err + } + }, nil +} diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go new file mode 100644 index 000000000..d5d9be08b --- /dev/null +++ b/middleware/request_logger_test.go @@ -0,0 +1,417 @@ +package middleware + +import ( + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + "testing" + "time" +) + +func TestRequestLoggerWithConfig(t *testing.T) { + e := echo.New() + + var expect RequestLoggerValues + e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + LogRoutePath: true, + LogURI: true, + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + expect = values + return nil + }, + })) + + e.GET("/test", func(c echo.Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.Equal(t, "/test", expect.RoutePath) +} + +func TestRequestLoggerWithConfig_missingOnLogValuesPanics(t *testing.T) { + assert.Panics(t, func() { + RequestLoggerWithConfig(RequestLoggerConfig{ + LogValuesFunc: nil, + }) + }) +} + +func TestRequestLogger_skipper(t *testing.T) { + e := echo.New() + + loggerCalled := false + e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + loggerCalled = true + return nil + }, + })) + + e.GET("/test", func(c echo.Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.False(t, loggerCalled) +} + +func TestRequestLogger_beforeNextFunc(t *testing.T) { + e := echo.New() + + var myLoggerInstance int + e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + BeforeNextFunc: func(c echo.Context) { + c.Set("myLoggerInstance", 42) + }, + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + myLoggerInstance = c.Get("myLoggerInstance").(int) + return nil + }, + })) + + e.GET("/test", func(c echo.Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.Equal(t, 42, myLoggerInstance) +} + +func TestRequestLogger_logError(t *testing.T) { + e := echo.New() + + var expect RequestLoggerValues + e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + LogError: true, + LogStatus: true, + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + expect = values + return nil + }, + })) + + e.GET("/test", func(c echo.Context) error { + return echo.NewHTTPError(http.StatusNotAcceptable, "nope") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusNotAcceptable, rec.Code) + assert.Equal(t, http.StatusNotAcceptable, expect.Status) + assert.EqualError(t, expect.Error, "code=406, message=nope") +} + +func TestRequestLogger_LogValuesFuncError(t *testing.T) { + e := echo.New() + + var expect RequestLoggerValues + e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + LogError: true, + LogStatus: true, + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + expect = values + return echo.NewHTTPError(http.StatusNotAcceptable, "LogValuesFuncError") + }, + })) + + e.GET("/test", func(c echo.Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + // NOTE: when global error handler received error returned from middleware the status has already + // been written to the client and response has been "commited" therefore global error handler does not do anything + // and error that bubbled up in middleware chain will not be reflected in response code. + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.Equal(t, http.StatusTeapot, expect.Status) +} + +func TestRequestLogger_ID(t *testing.T) { + var testCases = []struct { + name string + whenFromRequest bool + expect string + }{ + { + name: "ok, ID is provided from request headers", + whenFromRequest: true, + expect: "123", + }, + { + name: "ok, ID is from response headers", + whenFromRequest: false, + expect: "321", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + var expect RequestLoggerValues + e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + LogRequestID: true, + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + expect = values + return nil + }, + })) + + e.GET("/test", func(c echo.Context) error { + c.Response().Header().Set(echo.HeaderXRequestID, "321") + return c.String(http.StatusTeapot, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + if tc.whenFromRequest { + req.Header.Set(echo.HeaderXRequestID, "123") + } + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.Equal(t, tc.expect, expect.RequestID) + }) + } +} + +func TestRequestLogger_headerIsCaseInsensitive(t *testing.T) { + e := echo.New() + + var expect RequestLoggerValues + mw := RequestLoggerWithConfig(RequestLoggerConfig{ + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + expect = values + return nil + }, + LogHeaders: []string{"referer", "User-Agent"}, + })(func(c echo.Context) error { + c.Request().Header.Set(echo.HeaderXRequestID, "123") + c.FormValue("to force parse form") + return c.String(http.StatusTeapot, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/test?lang=en&checked=1&checked=2", nil) + req.Header.Set("referer", "https://echo.labstack.com/") + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := mw(c) + + assert.NoError(t, err) + assert.Len(t, expect.Headers, 1) + assert.Equal(t, []string{"https://echo.labstack.com/"}, expect.Headers["Referer"]) +} + +func TestRequestLogger_allFields(t *testing.T) { + e := echo.New() + + isFirstNowCall := true + var expect RequestLoggerValues + mw := RequestLoggerWithConfig(RequestLoggerConfig{ + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + expect = values + return nil + }, + LogLatency: true, + LogProtocol: true, + LogRemoteIP: true, + LogHost: true, + LogMethod: true, + LogURI: true, + LogURIPath: true, + LogRoutePath: true, + LogRequestID: true, + LogReferer: true, + LogUserAgent: true, + LogStatus: true, + LogError: true, + LogContentLength: true, + LogResponseSize: true, + LogHeaders: []string{"accept-encoding", "User-Agent"}, + LogQueryParams: []string{"lang", "checked"}, + LogFormValues: []string{"csrf", "multiple"}, + timeNow: func() time.Time { + if isFirstNowCall { + isFirstNowCall = false + return time.Unix(1631045377, 0) + } + return time.Unix(1631045377+10, 0) + }, + })(func(c echo.Context) error { + c.Request().Header.Set(echo.HeaderXRequestID, "123") + c.FormValue("to force parse form") + return c.String(http.StatusTeapot, "OK") + }) + + f := make(url.Values) + f.Set("csrf", "token") + f.Set("multiple", "1") + f.Add("multiple", "2") + reader := strings.NewReader(f.Encode()) + req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", reader) + req.Header.Set("Referer", "https://echo.labstack.com/") + req.Header.Set("User-Agent", "curl/7.68.0") + req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size()))) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) + req.Header.Set(echo.HeaderXRealIP, "8.8.8.8") + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + c.SetPath("/test*") + + err := mw(c) + + assert.NoError(t, err) + assert.Equal(t, 10*time.Second, expect.Latency) + assert.Equal(t, "HTTP/1.1", expect.Protocol) + assert.Equal(t, "8.8.8.8", expect.RemoteIP) + assert.Equal(t, "example.com", expect.Host) + assert.Equal(t, http.MethodPost, expect.Method) + assert.Equal(t, "/test?lang=en&checked=1&checked=2", expect.URI) + assert.Equal(t, "/test", expect.URIPath) + assert.Equal(t, "/test*", expect.RoutePath) + assert.Equal(t, "123", expect.RequestID) + assert.Equal(t, "https://echo.labstack.com/", expect.Referer) + assert.Equal(t, "curl/7.68.0", expect.UserAgent) + assert.Equal(t, 418, expect.Status) + assert.Equal(t, nil, expect.Error) + assert.Equal(t, "32", expect.ContentLength) + assert.Equal(t, int64(2), expect.ResponseSize) + + assert.Len(t, expect.Headers, 1) + assert.Equal(t, []string{"curl/7.68.0"}, expect.Headers["User-Agent"]) + + assert.Len(t, expect.QueryParams, 2) + assert.Equal(t, []string{"en"}, expect.QueryParams["lang"]) + assert.Equal(t, []string{"1", "2"}, expect.QueryParams["checked"]) + + assert.Len(t, expect.FormValues, 2) + assert.Equal(t, []string{"token"}, expect.FormValues["csrf"]) + assert.Equal(t, []string{"1", "2"}, expect.FormValues["multiple"]) +} + +func BenchmarkRequestLogger_withoutMapFields(b *testing.B) { + e := echo.New() + + mw := RequestLoggerWithConfig(RequestLoggerConfig{ + Skipper: nil, + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + return nil + }, + LogLatency: true, + LogProtocol: true, + LogRemoteIP: true, + LogHost: true, + LogMethod: true, + LogURI: true, + LogURIPath: true, + LogRoutePath: true, + LogRequestID: true, + LogReferer: true, + LogUserAgent: true, + LogStatus: true, + LogError: true, + LogContentLength: true, + LogResponseSize: true, + })(func(c echo.Context) error { + c.Request().Header.Set(echo.HeaderXRequestID, "123") + return c.String(http.StatusTeapot, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/test?lang=en", nil) + req.Header.Set("Referer", "https://echo.labstack.com/") + req.Header.Set("User-Agent", "curl/7.68.0") + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + mw(c) + } +} + +func BenchmarkRequestLogger_withMapFields(b *testing.B) { + e := echo.New() + + mw := RequestLoggerWithConfig(RequestLoggerConfig{ + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + return nil + }, + LogLatency: true, + LogProtocol: true, + LogRemoteIP: true, + LogHost: true, + LogMethod: true, + LogURI: true, + LogURIPath: true, + LogRoutePath: true, + LogRequestID: true, + LogReferer: true, + LogUserAgent: true, + LogStatus: true, + LogError: true, + LogContentLength: true, + LogResponseSize: true, + LogHeaders: []string{"accept-encoding", "User-Agent"}, + LogQueryParams: []string{"lang", "checked"}, + LogFormValues: []string{"csrf", "multiple"}, + })(func(c echo.Context) error { + c.Request().Header.Set(echo.HeaderXRequestID, "123") + c.FormValue("to force parse form") + return c.String(http.StatusTeapot, "OK") + }) + + f := make(url.Values) + f.Set("csrf", "token") + f.Add("multiple", "1") + f.Add("multiple", "2") + req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", strings.NewReader(f.Encode())) + req.Header.Set("Referer", "https://echo.labstack.com/") + req.Header.Set("User-Agent", "curl/7.68.0") + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + mw(c) + } +} From f6b45f23769729efc9321678607dc74f8b2261ba Mon Sep 17 00:00:00 2001 From: Martti T Date: Wed, 15 Sep 2021 22:29:35 +0300 Subject: [PATCH 166/446] CI: test against Go 1.17 (#1984) --- .github/workflows/echo.yml | 6 +++--- go.mod | 4 ++-- go.sum | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index ec5517561..266406664 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -27,7 +27,7 @@ jobs: os: [ubuntu-latest, macos-latest, windows-latest] # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy # Echo tests with last four major releases - go: [1.13, 1.14, 1.15, 1.16] + go: [1.14, 1.15, 1.16, 1.17] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: @@ -61,7 +61,7 @@ jobs: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... - name: Upload coverage to Codecov - if: success() && matrix.go == 1.16 && matrix.os == 'ubuntu-latest' + if: success() && matrix.go == 1.17 && matrix.os == 'ubuntu-latest' uses: codecov/codecov-action@v1 with: token: @@ -71,7 +71,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - go: [1.16] + go: [1.17] name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: diff --git a/go.mod b/go.mod index d2c884abd..60a643177 100644 --- a/go.mod +++ b/go.mod @@ -10,8 +10,8 @@ require ( github.com/stretchr/testify v1.4.0 github.com/valyala/fasttemplate v1.2.1 golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 - golang.org/x/net v0.0.0-20210907225631-ff17edfbf26d - golang.org/x/sys v0.0.0-20210908160347-a851e7ddeee0 // indirect + golang.org/x/net v0.0.0-20210913180222-943fd674d43e + golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect golang.org/x/text v0.3.7 // indirect golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 ) diff --git a/go.sum b/go.sum index 92dbd77db..9dcac7c5e 100644 --- a/go.sum +++ b/go.sum @@ -25,8 +25,8 @@ github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+ golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210907225631-ff17edfbf26d h1:kuk8nKPQ25KCDODLCDXt99tnTVeOyOM8HGvtJ0NzAvw= -golang.org/x/net v0.0.0-20210907225631-ff17edfbf26d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20210913180222-943fd674d43e h1:+b/22bPvDYt4NPDcy4xAGCmON713ONAWFeY3Z7I3tR8= +golang.org/x/net v0.0.0-20210913180222-943fd674d43e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -35,8 +35,8 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210908160347-a851e7ddeee0 h1:6xxeVXiyYpF8WCTnKKCbjnEdsrwjZYY8TOuk7xP0chg= -golang.org/x/sys v0.0.0-20210908160347-a851e7ddeee0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 h1:xrCZDmdtoloIiooiA9q0OQb9r8HejIHYoHGhGCe1pGg= +golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= From 9fc4672195953a79193c10508c168ab563e842c8 Mon Sep 17 00:00:00 2001 From: Martti T Date: Sun, 19 Sep 2021 11:39:12 +0300 Subject: [PATCH 167/446] Allow escaping of colon in route path so Google Cloud API "custom methods" https://cloud.google.com/apis/design/custom_methods can be implemented (resolves #1987) (#1988) Allow escaping of colon in route path so Google Cloud API "custom methods" https://cloud.google.com/apis/design/custom_methods could be implemented (resolves #1987) --- router.go | 3 +++ router_test.go | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/router.go b/router.go index 5b2474b32..a8277c8b8 100644 --- a/router.go +++ b/router.go @@ -98,6 +98,9 @@ func (r *Router) Add(method, path string, h HandlerFunc) { for i, lcpIndex := 0, len(path); i < lcpIndex; i++ { if path[i] == ':' { + if i > 0 && path[i-1] == '\\' { + continue + } j := i + 1 r.insert(method, path[:i], nil, staticKind, "", nil) diff --git a/router_test.go b/router_test.go index 71cedf8b6..1cb36b447 100644 --- a/router_test.go +++ b/router_test.go @@ -1118,6 +1118,58 @@ func TestRouterParamStaticConflict(t *testing.T) { } } +func TestRouterParam_escapeColon(t *testing.T) { + // to allow Google cloud API like route paths with colon in them + // i.e. https://service.name/v1/some/resource/name:customVerb <- that `:customVerb` is not path param. It is just a string + e := New() + + e.POST("/files/a/long/file\\:undelete", handlerFunc) + e.POST("/v1/some/resource/name:customVerb", handlerFunc) + + var testCases = []struct { + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError string + }{ + { + whenURL: "/files/a/long/file\\:undelete", + expectRoute: "/files/a/long/file\\:undelete", + expectParam: map[string]string{}, + }, + { + whenURL: "/files/a/long/file\\:notMatching", + expectRoute: nil, + expectError: "code=404, message=Not Found", + expectParam: nil, + }, + { + whenURL: "/v1/some/resource/name:PATCH", + expectRoute: "/v1/some/resource/name:customVerb", + expectParam: map[string]string{"customVerb": ":PATCH"}, + }, + } + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + c := e.NewContext(nil, nil).(*context) + + e.router.Find(http.MethodPost, tc.whenURL, c) + err := c.handler(c) + + assert.Equal(t, tc.expectRoute, c.Get("path")) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } +} + func TestRouterMatchAny(t *testing.T) { e := New() r := e.router From 6a85f48960f2e0060bd3e5002802e2ac89413856 Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Mon, 20 Sep 2021 12:08:18 -0700 Subject: [PATCH 168/446] Update README.md --- README.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/README.md b/README.md index d9ea709e4..364f98ac3 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,4 @@ - -[![All Contributors](https://img.shields.io/badge/all_contributors-2-orange.svg?style=flat-square)](#contributors-) - [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge) [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4) From 4651c7aafe8e37cfcd78d93e580cd501dc40107b Mon Sep 17 00:00:00 2001 From: Martti T Date: Mon, 20 Sep 2021 22:23:52 +0300 Subject: [PATCH 169/446] Update version and changelog for 4.6.0 (#1990) --- CHANGELOG.md | 17 +++++++++++++++++ echo.go | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 02eb36fd4..e9b623647 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,22 @@ # Changelog +## v4.6.0 - 2021-09-20 + +Introduced a new [request logger](https://github.com/labstack/echo/blob/master/middleware/request_logger.go) middleware +to help with cases when you want to use some other logging library in your application. + +**Fixes** + +* fix timeout middleware warning: superfluous response.WriteHeader [#1905](https://github.com/labstack/echo/issues/1905) + +**Enhancements** + +* Add Cookie to KeyAuth middleware's KeyLookup [#1929](https://github.com/labstack/echo/pull/1929) +* JWT middleware should ignore case of auth scheme in request header [#1951](https://github.com/labstack/echo/pull/1951) +* Refactor default error handler to return first if response is already committed [#1956](https://github.com/labstack/echo/pull/1956) +* Added request logger middleware which helps to use custom logger library for logging requests. [#1980](https://github.com/labstack/echo/pull/1980) +* Allow escaping of colon in route path so Google Cloud API "custom methods" could be implemented [#1988](https://github.com/labstack/echo/pull/1988) + ## v4.5.0 - 2021-08-01 **Important notes** diff --git a/echo.go b/echo.go index 3292aa1a9..f4eb8eca6 100644 --- a/echo.go +++ b/echo.go @@ -241,7 +241,7 @@ const ( const ( // Version of Echo - Version = "4.5.0" + Version = "4.6.0" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From a2e6ca7ed66c47ee16343ef77d921e454ce49298 Mon Sep 17 00:00:00 2001 From: Martti T Date: Thu, 23 Sep 2021 23:17:09 +0300 Subject: [PATCH 170/446] Add start time to request logger middleware values (#1991) --- middleware/request_logger.go | 6 +++++- middleware/request_logger_test.go | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/middleware/request_logger.go b/middleware/request_logger.go index 7829a1fd1..1b3e3eaad 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -124,6 +124,8 @@ type RequestLoggerConfig struct { // RequestLoggerValues contains extracted values from logger. type RequestLoggerValues struct { + // StartTime is time recorded before next middleware/handler is executed. + StartTime time.Time // Latency is duration it took to execute rest of the handler chain (next(c) call). Latency time.Duration // Protocol is request protocol (i.e. `HTTP/1.1` or `HTTP/2`) @@ -215,7 +217,9 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { } err := next(c) - v := RequestLoggerValues{} + v := RequestLoggerValues{ + StartTime: start, + } if config.LogLatency { v.Latency = now().Sub(start) } diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go index d5d9be08b..5118b1216 100644 --- a/middleware/request_logger_test.go +++ b/middleware/request_logger_test.go @@ -296,6 +296,7 @@ func TestRequestLogger_allFields(t *testing.T) { err := mw(c) assert.NoError(t, err) + assert.Equal(t, time.Unix(1631045377, 0), expect.StartTime) assert.Equal(t, 10*time.Second, expect.Latency) assert.Equal(t, "HTTP/1.1", expect.Protocol) assert.Equal(t, "8.8.8.8", expect.RemoteIP) From c6f0c667f145b5e5347ba812c9de5a5a4280bac5 Mon Sep 17 00:00:00 2001 From: Martti T Date: Sun, 26 Sep 2021 18:56:43 +0300 Subject: [PATCH 171/446] Update version and changelog for 4.6.1 (#1995) --- CHANGELOG.md | 6 ++++++ echo.go | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9b623647..f52f264f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## v4.6.1 - 2021-09-26 + +**Enhancements** + +* Add start time to request logger middleware values [#1991](https://github.com/labstack/echo/pull/1991) + ## v4.6.0 - 2021-09-20 Introduced a new [request logger](https://github.com/labstack/echo/blob/master/middleware/request_logger.go) middleware diff --git a/echo.go b/echo.go index f4eb8eca6..df5d35843 100644 --- a/echo.go +++ b/echo.go @@ -241,7 +241,7 @@ const ( const ( // Version of Echo - Version = "4.6.0" + Version = "4.6.1" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 4b88e25e49537dacca73903ccd243f734fdbbe9c Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Wed, 6 Oct 2021 21:47:37 -0700 Subject: [PATCH 172/446] Create FUNDING.yml --- .github/FUNDING.yml | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 .github/FUNDING.yml diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..af410716d --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,12 @@ +# These are supported funding model platforms + +github: [labstack] +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] From d604704563de63e42a352ffc51b6d633d9d595e3 Mon Sep 17 00:00:00 2001 From: Kaan Karakaya Date: Tue, 12 Oct 2021 22:52:46 +0300 Subject: [PATCH 173/446] Fix rate limiter example time.Minutes is doesn't exist --- middleware/rate_limiter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index 0291eb451..edcc56a58 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -199,7 +199,7 @@ Characteristics: Example: limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig( - middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minutes}, + middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minute}, ) */ func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) { From 7ef3e0002d928d6c82f29b9e3c5f362526759af2 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Fri, 5 Nov 2021 09:31:16 +0200 Subject: [PATCH 174/446] update dependencies --- go.mod | 13 +++++-------- go.sum | 46 ++++++++++++++++++++-------------------------- 2 files changed, 25 insertions(+), 34 deletions(-) diff --git a/go.mod b/go.mod index 60a643177..2a80d2443 100644 --- a/go.mod +++ b/go.mod @@ -4,14 +4,11 @@ go 1.15 require ( github.com/golang-jwt/jwt v3.2.2+incompatible - github.com/labstack/gommon v0.3.0 - github.com/mattn/go-colorable v0.1.8 // indirect - github.com/mattn/go-isatty v0.0.14 // indirect - github.com/stretchr/testify v1.4.0 + github.com/labstack/gommon v0.3.1 + github.com/stretchr/testify v1.7.0 github.com/valyala/fasttemplate v1.2.1 - golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 - golang.org/x/net v0.0.0-20210913180222-943fd674d43e - golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect + golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 + golang.org/x/net v0.0.0-20211104170005-ce137452f963 golang.org/x/text v0.3.7 // indirect - golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 + golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac ) diff --git a/go.sum b/go.sum index 9dcac7c5e..47f4a9761 100644 --- a/go.sum +++ b/go.sum @@ -1,51 +1,45 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0= -github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= -github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8= -github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= -github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/labstack/gommon v0.3.1 h1:OomWaJXm7xR6L1HmEtGyQf26TEn7V6X88mktX9kee9o= +github.com/labstack/gommon v0.3.1/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= +github.com/mattn/go-colorable v0.1.11 h1:nQ+aFkoE2TMGc0b68U2OKSexC+eq46+XwZzWXHRmPYs= +github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= -golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210913180222-943fd674d43e h1:+b/22bPvDYt4NPDcy4xAGCmON713ONAWFeY3Z7I3tR8= -golang.org/x/net v0.0.0-20210913180222-943fd674d43e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/net v0.0.0-20211104170005-ce137452f963 h1:8gJUadZl+kWvZBqG/LautX0X6qe5qTC2VI/3V3NBRAY= +golang.org/x/net v0.0.0-20211104170005-ce137452f963/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 h1:xrCZDmdtoloIiooiA9q0OQb9r8HejIHYoHGhGCe1pGg= -golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4= +golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac h1:7zkz7BUtwNFFqcowJ+RIgu2MaV/MapERkDIy+mwPyjs= +golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 0c4ad8646ad949d13a8d4ebe073185c1741b4d71 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Fri, 5 Nov 2021 09:43:59 +0200 Subject: [PATCH 175/446] update dependencies use 1.14 for choosing updated deps. Using current tip (1.17) will cause tests fail as some packages are not supporting 1.14. `docker run --rm -it -v $(pwd):/project golang:1.14 /bin/sh -c "cd /project && go get ./... && go mod tidy"` --- go.mod | 6 +++--- go.sum | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 2a80d2443..e5fa0d55f 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,8 @@ require ( github.com/labstack/gommon v0.3.1 github.com/stretchr/testify v1.7.0 github.com/valyala/fasttemplate v1.2.1 - golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 - golang.org/x/net v0.0.0-20211104170005-ce137452f963 + golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 + golang.org/x/net v0.0.0-20210913180222-943fd674d43e golang.org/x/text v0.3.7 // indirect - golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac + golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 ) diff --git a/go.sum b/go.sum index 47f4a9761..8a1ec2f9f 100644 --- a/go.sum +++ b/go.sum @@ -18,15 +18,20 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= +golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210913180222-943fd674d43e h1:+b/22bPvDYt4NPDcy4xAGCmON713ONAWFeY3Z7I3tR8= +golang.org/x/net v0.0.0-20210913180222-943fd674d43e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211104170005-ce137452f963 h1:8gJUadZl+kWvZBqG/LautX0X6qe5qTC2VI/3V3NBRAY= golang.org/x/net v0.0.0-20211104170005-ce137452f963/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -35,6 +40,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= +golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac h1:7zkz7BUtwNFFqcowJ+RIgu2MaV/MapERkDIy+mwPyjs= golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= From 8b4cce5021eebda91db43f80f4a8f2ae2f949282 Mon Sep 17 00:00:00 2001 From: nephtyws Date: Tue, 12 Oct 2021 13:49:20 +0900 Subject: [PATCH 176/446] Sort import order on example in README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 364f98ac3..930cb0344 100644 --- a/README.md +++ b/README.md @@ -66,9 +66,9 @@ go get github.com/labstack/echo/v4 package main import ( - "net/http" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" + "net/http" ) func main() { From bd29ef9e46c6a897973c528e150bcf83d8664799 Mon Sep 17 00:00:00 2001 From: Luka Jajanidze Date: Sun, 21 Nov 2021 19:28:49 +0400 Subject: [PATCH 177/446] added references to Limiter docs for 0-1 behaviour --- middleware/rate_limiter.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index edcc56a58..c947b7dae 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -170,6 +170,9 @@ type ( /* NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with the provided rate (as req/s). The provided rate less than 1 will be treated as zero. +Also rate between 0 and 1 will be treated as zero. +for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. + Burst and ExpiresIn will be set to default values. Example (with 20 requests/sec): From 3f099663f12f363a7d5ad64c851b608ba3f7052f Mon Sep 17 00:00:00 2001 From: Luka Jajanidze Date: Sun, 21 Nov 2021 19:46:08 +0400 Subject: [PATCH 178/446] removed unnecessary comments --- middleware/rate_limiter.go | 1 - 1 file changed, 1 deletion(-) diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index c947b7dae..940d973be 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -170,7 +170,6 @@ type ( /* NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with the provided rate (as req/s). The provided rate less than 1 will be treated as zero. -Also rate between 0 and 1 will be treated as zero. for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. Burst and ExpiresIn will be set to default values. From 902c55355238a226c356d5e6c706a62c76f64197 Mon Sep 17 00:00:00 2001 From: Luka Jajanidze Date: Sun, 21 Nov 2021 19:55:18 +0400 Subject: [PATCH 179/446] Added comments for RateLimiterMemoryStoreConfig and RateLimiterMemoryStore --- middleware/rate_limiter.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index 940d973be..be2b348db 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -153,9 +153,10 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { type ( // RateLimiterMemoryStore is the built-in store implementation for RateLimiter RateLimiterMemoryStore struct { - visitors map[string]*Visitor - mutex sync.Mutex - rate rate.Limit + visitors map[string]*Visitor + mutex sync.Mutex + rate rate.Limit //for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. + burst int expiresIn time.Duration lastCleanup time.Time @@ -223,7 +224,7 @@ func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (s // RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore type RateLimiterMemoryStoreConfig struct { - Rate rate.Limit // Rate of requests allowed to pass as req/s + Rate rate.Limit // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. Burst int // Burst additionally allows a number of requests to pass when rate limit is reached ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up } From b437ee3879f21e973be85ff827754343a34fbe35 Mon Sep 17 00:00:00 2001 From: David Desmarais-Michaud Date: Fri, 3 Dec 2021 05:03:42 -0500 Subject: [PATCH 180/446] stream decompression instead of buffering (#2018) * stream decompression instead of buffering * simple body replace with gzip reader with deferred close * defer resource closes * simply gzip.Reader pool --- middleware/decompress.go | 69 ++++++++++++++-------------------------- 1 file changed, 24 insertions(+), 45 deletions(-) diff --git a/middleware/decompress.go b/middleware/decompress.go index c046359a2..88ec70982 100644 --- a/middleware/decompress.go +++ b/middleware/decompress.go @@ -1,10 +1,8 @@ package middleware import ( - "bytes" "compress/gzip" "io" - "io/ioutil" "net/http" "sync" @@ -43,26 +41,7 @@ type DefaultGzipDecompressPool struct { } func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { - return sync.Pool{ - New: func() interface{} { - // create with an empty reader (but with GZIP header) - w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed) - if err != nil { - return err - } - - b := new(bytes.Buffer) - w.Reset(b) - w.Flush() - w.Close() - - r, err := gzip.NewReader(bytes.NewReader(b.Bytes())) - if err != nil { - return err - } - return r - }, - } + return sync.Pool{New: func() interface{} { return new(gzip.Reader) }} } //Decompress decompresses request body based if content encoding type is set to "gzip" with default config @@ -82,38 +61,38 @@ func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { pool := config.GzipDecompressPool.gzipDecompressPool() + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } - switch c.Request().Header.Get(echo.HeaderContentEncoding) { - case GZIPEncoding: - b := c.Request().Body - - i := pool.Get() - gr, ok := i.(*gzip.Reader) - if !ok { - return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) - } - if err := gr.Reset(b); err != nil { - pool.Put(gr) - if err == io.EOF { //ignore if body is empty - return next(c) - } - return err - } - var buf bytes.Buffer - io.Copy(&buf, gr) + if c.Request().Header.Get(echo.HeaderContentEncoding) != GZIPEncoding { + return next(c) + } - gr.Close() - pool.Put(gr) + i := pool.Get() + gr, ok := i.(*gzip.Reader) + if !ok || gr == nil { + return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) + } + defer pool.Put(gr) - b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here + b := c.Request().Body + defer b.Close() - r := ioutil.NopCloser(&buf) - c.Request().Body = r + if err := gr.Reset(b); err != nil { + if err == io.EOF { //ignore if body is empty + return next(c) + } + return err } + + // only Close gzip reader if it was set to a proper gzip source otherwise it will panic on close. + defer gr.Close() + + c.Request().Body = gr + return next(c) } } From c32fafad68daa7214f0ca005b4614ca38e90b2b8 Mon Sep 17 00:00:00 2001 From: Guilherme Cardoso Date: Tue, 7 Dec 2021 10:56:32 +0000 Subject: [PATCH 181/446] Add support for configurable target header for the request_id middleware --- echo.go | 1 + middleware/request_id.go | 15 +++++++++++---- middleware/request_id_test.go | 31 +++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/echo.go b/echo.go index df5d35843..ad03dd519 100644 --- a/echo.go +++ b/echo.go @@ -214,6 +214,7 @@ const ( HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" HeaderXRealIP = "X-Real-IP" HeaderXRequestID = "X-Request-ID" + HeaderXCorrelationID = "X-Correlation-ID" HeaderXRequestedWith = "X-Requested-With" HeaderServer = "Server" HeaderOrigin = "Origin" diff --git a/middleware/request_id.go b/middleware/request_id.go index b0baeeb2d..8c5ff6605 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -17,14 +17,18 @@ type ( // RequestIDHandler defines a function which is executed for a request id. RequestIDHandler func(echo.Context, string) + + // TargetHeader defines what header to look for to populate the id + TargetHeader string } ) var ( // DefaultRequestIDConfig is the default RequestID middleware config. DefaultRequestIDConfig = RequestIDConfig{ - Skipper: DefaultSkipper, - Generator: generator, + Skipper: DefaultSkipper, + Generator: generator, + TargetHeader: echo.HeaderXRequestID, } ) @@ -42,6 +46,9 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { if config.Generator == nil { config.Generator = generator } + if config.TargetHeader == "" { + config.TargetHeader = echo.HeaderXRequestID + } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -51,11 +58,11 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { req := c.Request() res := c.Response() - rid := req.Header.Get(echo.HeaderXRequestID) + rid := req.Header.Get(config.TargetHeader) if rid == "" { rid = config.Generator() } - res.Header().Set(echo.HeaderXRequestID, rid) + res.Header().Set(config.TargetHeader, rid) if config.RequestIDHandler != nil { config.RequestIDHandler(c, rid) } diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go index 944b3b49e..21b777826 100644 --- a/middleware/request_id_test.go +++ b/middleware/request_id_test.go @@ -55,3 +55,34 @@ func TestRequestID_IDNotAltered(t *testing.T) { _ = h(c) assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "") } + +func TestRequestIDConfigDifferentHeader(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid := RequestIDWithConfig(RequestIDConfig{TargetHeader: echo.HeaderXCorrelationID}) + h := rid(handler) + h(c) + assert.Len(t, rec.Header().Get(echo.HeaderXCorrelationID), 32) + + // Custom generator and handler + customID := "customGenerator" + calledHandler := false + rid = RequestIDWithConfig(RequestIDConfig{ + Generator: func() string { return customID }, + TargetHeader: echo.HeaderXCorrelationID, + RequestIDHandler: func(_ echo.Context, id string) { + calledHandler = true + assert.Equal(t, customID, id) + }, + }) + h = rid(handler) + h(c) + assert.Equal(t, rec.Header().Get(echo.HeaderXCorrelationID), "customGenerator") + assert.True(t, calledHandler) +} From 7bde9aea068072e08c41148fc230393872d9c49c Mon Sep 17 00:00:00 2001 From: Nao Yonashiro Date: Wed, 15 Dec 2021 17:15:13 +0900 Subject: [PATCH 182/446] Fixed a problem that returned wrong content-encoding when the gzip compressed content was empty (#1921) Fixed a problem that returned wrong content-encoding when the gzip compressed content was empty --- middleware/compress.go | 9 ++++----- middleware/compress_test.go | 21 +++++++++++++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/middleware/compress.go b/middleware/compress.go index 6ae197453..ac6672e9d 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -27,6 +27,7 @@ type ( gzipResponseWriter struct { io.Writer http.ResponseWriter + wroteBody bool } ) @@ -78,8 +79,9 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { } rw := res.Writer w.Reset(rw) + grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw} defer func() { - if res.Size == 0 { + if !grw.wroteBody { if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { res.Header().Del(echo.HeaderContentEncoding) } @@ -92,7 +94,6 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { w.Close() pool.Put(w) }() - grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw} res.Writer = grw } return next(c) @@ -101,9 +102,6 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { } func (w *gzipResponseWriter) WriteHeader(code int) { - if code == http.StatusNoContent { // Issue #489 - w.ResponseWriter.Header().Del(echo.HeaderContentEncoding) - } w.Header().Del(echo.HeaderContentLength) // Issue #444 w.ResponseWriter.WriteHeader(code) } @@ -112,6 +110,7 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) { if w.Header().Get(echo.HeaderContentType) == "" { w.Header().Set(echo.HeaderContentType, http.DetectContentType(b)) } + w.wroteBody = true return w.Writer.Write(b) } diff --git a/middleware/compress_test.go b/middleware/compress_test.go index d16ffca43..b62bffef5 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -106,6 +106,27 @@ func TestGzipNoContent(t *testing.T) { } } +func TestGzipEmpty(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := Gzip()(func(c echo.Context) error { + return c.String(http.StatusOK, "") + }) + if assert.NoError(t, h(c)) { + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get(echo.HeaderContentType)) + r, err := gzip.NewReader(rec.Body) + if assert.NoError(t, err) { + var buf bytes.Buffer + buf.ReadFrom(r) + assert.Equal(t, "", buf.String()) + } + } +} + func TestGzipErrorReturned(t *testing.T) { e := echo.New() e.Use(Gzip()) From 6b5e62b27ea0bc459843e67014360dd35ae8147b Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Thu, 16 Dec 2021 22:58:40 +0200 Subject: [PATCH 183/446] fix: route containing escaped colon should be matchable but is not matched to request path (fixes #2046) --- router.go | 3 +++ router_test.go | 16 ++++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/router.go b/router.go index a8277c8b8..dc93e29c8 100644 --- a/router.go +++ b/router.go @@ -99,6 +99,9 @@ func (r *Router) Add(method, path string, h HandlerFunc) { for i, lcpIndex := 0, len(path); i < lcpIndex; i++ { if path[i] == ':' { if i > 0 && path[i-1] == '\\' { + path = path[:i-1] + path[i:] + i-- + lcpIndex-- continue } j := i + 1 diff --git a/router_test.go b/router_test.go index 1cb36b447..57be74deb 100644 --- a/router_test.go +++ b/router_test.go @@ -1124,6 +1124,8 @@ func TestRouterParam_escapeColon(t *testing.T) { e := New() e.POST("/files/a/long/file\\:undelete", handlerFunc) + e.POST("/multilevel\\:undelete/second\\:something", handlerFunc) + e.POST("/mixed/:id/second\\:something", handlerFunc) e.POST("/v1/some/resource/name:customVerb", handlerFunc) var testCases = []struct { @@ -1133,12 +1135,22 @@ func TestRouterParam_escapeColon(t *testing.T) { expectError string }{ { - whenURL: "/files/a/long/file\\:undelete", + whenURL: "/files/a/long/file:undelete", expectRoute: "/files/a/long/file\\:undelete", expectParam: map[string]string{}, }, { - whenURL: "/files/a/long/file\\:notMatching", + whenURL: "/multilevel:undelete/second:something", + expectRoute: "/multilevel\\:undelete/second\\:something", + expectParam: map[string]string{}, + }, + { + whenURL: "/mixed/123/second:something", + expectRoute: "/mixed/:id/second\\:something", + expectParam: map[string]string{"id": "123"}, + }, + { + whenURL: "/files/a/long/file:notMatching", expectRoute: nil, expectError: "code=404, message=Not Found", expectParam: nil, From 4fffee2ec8a4efe5cca66cdfe17e6eeec59df60a Mon Sep 17 00:00:00 2001 From: Rashad Ansari Date: Thu, 19 Aug 2021 15:00:07 +0200 Subject: [PATCH 184/446] Add custom jwt extractor to jwt config --- middleware/jwt.go | 49 ++++++++++++++++++++++++------------------ middleware/jwt_test.go | 24 +++++++++++++++++++++ 2 files changed, 52 insertions(+), 21 deletions(-) diff --git a/middleware/jwt.go b/middleware/jwt.go index 21e33ab82..43605e377 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -68,9 +68,14 @@ type ( // - "form:" // Multiply sources example: // - "header: Authorization,cookie: myowncookie" - TokenLookup string + // TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context. + // This is one of the two options to provide a token extractor. + // The order of precedence is user-defined TokenLookupFuncs, and TokenLookup. + // You can also provide both if you want. + TokenLookupFuncs []TokenLookupFunc + // AuthScheme to be used in the Authorization header. // Optional. Default value "Bearer". AuthScheme string @@ -103,7 +108,8 @@ type ( // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. JWTErrorHandlerWithContext func(error, echo.Context) error - jwtExtractor func(echo.Context) (string, error) + // TokenLookupFunc defines a function for extracting JWT token from the given context. + TokenLookupFunc func(echo.Context) (string, error) ) // Algorithms @@ -120,13 +126,14 @@ var ( var ( // DefaultJWTConfig is the default JWT auth middleware config. DefaultJWTConfig = JWTConfig{ - Skipper: DefaultSkipper, - SigningMethod: AlgorithmHS256, - ContextKey: "user", - TokenLookup: "header:" + echo.HeaderAuthorization, - AuthScheme: "Bearer", - Claims: jwt.MapClaims{}, - KeyFunc: nil, + Skipper: DefaultSkipper, + SigningMethod: AlgorithmHS256, + ContextKey: "user", + TokenLookup: "header:" + echo.HeaderAuthorization, + TokenLookupFuncs: nil, + AuthScheme: "Bearer", + Claims: jwt.MapClaims{}, + KeyFunc: nil, } ) @@ -163,7 +170,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.Claims == nil { config.Claims = DefaultJWTConfig.Claims } - if config.TokenLookup == "" { + if config.TokenLookup == "" && len(config.TokenLookupFuncs) == 0 { config.TokenLookup = DefaultJWTConfig.TokenLookup } if config.AuthScheme == "" { @@ -179,7 +186,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { // Initialize // Split sources sources := strings.Split(config.TokenLookup, ",") - var extractors []jwtExtractor + var extractors = config.TokenLookupFuncs for _, source := range sources { parts := strings.Split(source, ":") @@ -290,8 +297,8 @@ func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) { return config.SigningKey, nil } -// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header. -func jwtFromHeader(header string, authScheme string) jwtExtractor { +// jwtFromHeader returns a `TokenLookupFunc` that extracts token from the request header. +func jwtFromHeader(header string, authScheme string) TokenLookupFunc { return func(c echo.Context) (string, error) { auth := c.Request().Header.Get(header) l := len(authScheme) @@ -302,8 +309,8 @@ func jwtFromHeader(header string, authScheme string) jwtExtractor { } } -// jwtFromQuery returns a `jwtExtractor` that extracts token from the query string. -func jwtFromQuery(param string) jwtExtractor { +// jwtFromQuery returns a `TokenLookupFunc` that extracts token from the query string. +func jwtFromQuery(param string) TokenLookupFunc { return func(c echo.Context) (string, error) { token := c.QueryParam(param) if token == "" { @@ -313,8 +320,8 @@ func jwtFromQuery(param string) jwtExtractor { } } -// jwtFromParam returns a `jwtExtractor` that extracts token from the url param string. -func jwtFromParam(param string) jwtExtractor { +// jwtFromParam returns a `TokenLookupFunc` that extracts token from the url param string. +func jwtFromParam(param string) TokenLookupFunc { return func(c echo.Context) (string, error) { token := c.Param(param) if token == "" { @@ -324,8 +331,8 @@ func jwtFromParam(param string) jwtExtractor { } } -// jwtFromCookie returns a `jwtExtractor` that extracts token from the named cookie. -func jwtFromCookie(name string) jwtExtractor { +// jwtFromCookie returns a `TokenLookupFunc` that extracts token from the named cookie. +func jwtFromCookie(name string) TokenLookupFunc { return func(c echo.Context) (string, error) { cookie, err := c.Cookie(name) if err != nil { @@ -335,8 +342,8 @@ func jwtFromCookie(name string) jwtExtractor { } } -// jwtFromForm returns a `jwtExtractor` that extracts token from the form field. -func jwtFromForm(name string) jwtExtractor { +// jwtFromForm returns a `TokenLookupFunc` that extracts token from the form field. +func jwtFromForm(name string) TokenLookupFunc { return func(c echo.Context) (string, error) { field := c.FormValue(name) if field == "" { diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 5f36ce0a5..18454d0a7 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -603,3 +603,27 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) { assert.Equal(t, http.StatusTeapot, res.Code) } + +func TestJWTConfig_TokenLookupFuncs(t *testing.T) { + e := echo.New() + + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + e.Use(JWTWithConfig(JWTConfig{ + TokenLookupFuncs: []TokenLookupFunc{ + func(c echo.Context) (string, error) { + return c.Request().Header.Get("X-API-Key"), nil + }, + }, + SigningKey: []byte("secret"), + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusOK, res.Code) +} From 5b26a5257b066ec32acbe918641b76ff05b4a87c Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 4 Dec 2021 20:02:11 +0200 Subject: [PATCH 185/446] `Allow` header support in Router, MethodNotFoundHandler (405) and CORS middleware --- context.go | 7 ++ echo.go | 14 ++- echo_test.go | 3 + middleware/cors.go | 69 +++++++---- middleware/cors_test.go | 252 ++++++++++++++++++++++++++++++---------- router.go | 93 ++++++++++++--- router_test.go | 122 +++++++++++++++++-- 7 files changed, 441 insertions(+), 119 deletions(-) diff --git a/context.go b/context.go index 91ab6e480..ea542cb86 100644 --- a/context.go +++ b/context.go @@ -210,6 +210,13 @@ type ( } ) +const ( + // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain. + // Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests. + // It is added to context only when Router does not find matching method handler for request. + ContextKeyHeaderAllow = "____echo____header_allow" +) + const ( defaultMemory = 32 << 20 // 32 MB indexPage = "index.html" diff --git a/echo.go b/echo.go index ad03dd519..8747039e4 100644 --- a/echo.go +++ b/echo.go @@ -190,8 +190,11 @@ const ( // Headers const ( - HeaderAccept = "Accept" - HeaderAcceptEncoding = "Accept-Encoding" + HeaderAccept = "Accept" + HeaderAcceptEncoding = "Accept-Encoding" + // HeaderAllow is header field that lists the set of methods advertised as supported by the target resource. + // Allow header is mandatory for status 405 (method not found) and useful OPTIONS method responses. + // See: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1 HeaderAllow = "Allow" HeaderAuthorization = "Authorization" HeaderContentDisposition = "Content-Disposition" @@ -302,6 +305,13 @@ var ( } MethodNotAllowedHandler = func(c Context) error { + // 'Allow' header RFC: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1 + // >> An origin server MUST generate an Allow field in a 405 (Method Not Allowed) response + // and MAY do so in any other response. + routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string) + if ok && routerAllowMethods != "" { + c.Response().Header().Set(HeaderAllow, routerAllowMethods) + } return ErrMethodNotAllowed } ) diff --git a/echo_test.go b/echo_test.go index f28915864..13a51b6cc 100644 --- a/echo_test.go +++ b/echo_test.go @@ -716,13 +716,16 @@ func TestEchoNotFound(t *testing.T) { func TestEchoMethodNotAllowed(t *testing.T) { e := New() + e.GET("/", func(c Context) error { return c.String(http.StatusOK, "Echo!") }) req := httptest.NewRequest(http.MethodPost, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusMethodNotAllowed, rec.Code) + assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow)) } func TestEchoContext(t *testing.T) { diff --git a/middleware/cors.go b/middleware/cors.go index d6ef89644..a5122f26e 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -29,6 +29,8 @@ type ( // AllowMethods defines a list methods allowed when accessing the resource. // This is used in response to a preflight request. // Optional. Default value DefaultCORSConfig.AllowMethods. + // If `allowMethods` is left empty will fill for preflight request `Access-Control-Allow-Methods` header value + // from `Allow` header that echo.Router set into context. AllowMethods []string `yaml:"allow_methods"` // AllowHeaders defines a list of request headers that can be used when @@ -41,6 +43,8 @@ type ( // a response to a preflight request, this indicates whether or not the // actual request can be made using credentials. // Optional. Default value false. + // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. + // See http://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html AllowCredentials bool `yaml:"allow_credentials"` // ExposeHeaders defines a whitelist headers that clients are allowed to @@ -80,7 +84,9 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { if len(config.AllowOrigins) == 0 { config.AllowOrigins = DefaultCORSConfig.AllowOrigins } + hasCustomAllowMethods := true if len(config.AllowMethods) == 0 { + hasCustomAllowMethods = false config.AllowMethods = DefaultCORSConfig.AllowMethods } @@ -109,10 +115,28 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { origin := req.Header.Get(echo.HeaderOrigin) allowOrigin := "" - preflight := req.Method == http.MethodOptions res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) - // No Origin provided + // Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method, + // Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request + // For simplicity we just consider method type and later `Origin` header. + preflight := req.Method == http.MethodOptions + + // Although router adds special handler in case of OPTIONS method we avoid calling next for OPTIONS in this middleware + // as CORS requests do not have cookies / authentication headers by default, so we could get stuck in auth + // middlewares by calling next(c). + // But we still want to send `Allow` header as response in case of Non-CORS OPTIONS request as router default + // handler does. + routerAllowMethods := "" + if preflight { + tmpAllowMethods, ok := c.Get(echo.ContextKeyHeaderAllow).(string) + if ok && tmpAllowMethods != "" { + routerAllowMethods = tmpAllowMethods + c.Response().Header().Set(echo.HeaderAllow, routerAllowMethods) + } + } + + // No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain if origin == "" { if !preflight { return next(c) @@ -145,19 +169,15 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } } - // Check allowed origin patterns - for _, re := range allowOriginPatterns { - if allowOrigin == "" { - didx := strings.Index(origin, "://") - if didx == -1 { - continue - } - domAuth := origin[didx+3:] - // to avoid regex cost by invalid long domain - if len(domAuth) > 253 { - break - } - + checkPatterns := false + if allowOrigin == "" { + // to avoid regex cost by invalid (long) domains (253 is domain name max limit) + if len(origin) <= (253+3+4) && strings.Contains(origin, "://") { + checkPatterns = true + } + } + if checkPatterns { + for _, re := range allowOriginPatterns { if match, _ := regexp.MatchString(re, origin); match { allowOrigin = origin break @@ -174,12 +194,13 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { return c.NoContent(http.StatusNoContent) } + res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) + if config.AllowCredentials { + res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") + } + // Simple request if !preflight { - res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) - if config.AllowCredentials { - res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") - } if exposeHeaders != "" { res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders) } @@ -189,11 +210,13 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { // Preflight request res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) - res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) - res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods) - if config.AllowCredentials { - res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") + + if !hasCustomAllowMethods && routerAllowMethods != "" { + res.Header().Set(echo.HeaderAccessControlAllowMethods, routerAllowMethods) + } else { + res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods) } + if allowHeaders != "" { res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders) } else { diff --git a/middleware/cors_test.go b/middleware/cors_test.go index 717abe498..daadbab6e 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -251,114 +251,238 @@ func Test_allowOriginSubdomain(t *testing.T) { } } +func TestCORSWithConfig_AllowMethods(t *testing.T) { + var testCases = []struct { + name string + allowOrigins []string + allowContextKey string + + whenOrigin string + whenAllowMethods []string + + expectAllow string + expectAccessControlAllowMethods string + }{ + { + name: "custom AllowMethods, preflight, no origin, sets only allow header from context key", + allowContextKey: "OPTIONS, GET", + whenAllowMethods: []string{http.MethodGet, http.MethodHead}, + whenOrigin: "", + expectAllow: "OPTIONS, GET", + }, + { + name: "default AllowMethods, preflight, no origin, no allow header in context key and in response", + allowContextKey: "", + whenAllowMethods: nil, + whenOrigin: "", + expectAllow: "", + }, + { + name: "custom AllowMethods, preflight, existing origin, sets both headers different values", + allowContextKey: "OPTIONS, GET", + whenAllowMethods: []string{http.MethodGet, http.MethodHead}, + whenOrigin: "http://google.com", + expectAllow: "OPTIONS, GET", + expectAccessControlAllowMethods: "GET,HEAD", + }, + { + name: "default AllowMethods, preflight, existing origin, sets both headers", + allowContextKey: "OPTIONS, GET", + whenAllowMethods: nil, + whenOrigin: "http://google.com", + expectAllow: "OPTIONS, GET", + expectAccessControlAllowMethods: "OPTIONS, GET", + }, + { + name: "default AllowMethods, preflight, existing origin, no allows, sets only CORS allow methods", + allowContextKey: "", + whenAllowMethods: nil, + whenOrigin: "http://google.com", + expectAllow: "", + expectAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + e.GET("/test", func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + }) + + cors := CORSWithConfig(CORSConfig{ + AllowOrigins: tc.allowOrigins, + AllowMethods: tc.whenAllowMethods, + }) + + req := httptest.NewRequest(http.MethodOptions, "/test", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + req.Header.Set(echo.HeaderOrigin, tc.whenOrigin) + if tc.allowContextKey != "" { + c.Set(echo.ContextKeyHeaderAllow, tc.allowContextKey) + } + + h := cors(echo.NotFoundHandler) + h(c) + + assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow)) + assert.Equal(t, tc.expectAccessControlAllowMethods, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) + }) + } +} + func TestCorsHeaders(t *testing.T) { tests := []struct { - domain, allowedOrigin, method string - expected bool + name string + originDomain string + method string + allowedOrigin string + expected bool + expectStatus int + expectAllowHeader string }{ { - domain: "", // Request does not have Origin header + name: "non-preflight request, allow any origin, missing origin header = no CORS logic done", + originDomain: "", allowedOrigin: "*", method: http.MethodGet, expected: false, + expectStatus: http.StatusOK, }, { - domain: "http://example.com", + name: "non-preflight request, allow any origin, specific origin domain", + originDomain: "http://example.com", allowedOrigin: "*", method: http.MethodGet, expected: true, + expectStatus: http.StatusOK, }, { - domain: "", // Request does not have Origin header + name: "non-preflight request, allow specific origin, missing origin header = no CORS logic done", + originDomain: "", // Request does not have Origin header allowedOrigin: "http://example.com", method: http.MethodGet, expected: false, + expectStatus: http.StatusOK, }, { - domain: "http://bar.com", + name: "non-preflight request, allow specific origin, different origin header = CORS logic failure", + originDomain: "http://bar.com", allowedOrigin: "http://example.com", method: http.MethodGet, expected: false, + expectStatus: http.StatusOK, }, { - domain: "http://example.com", + name: "non-preflight request, allow specific origin, matching origin header = CORS logic done", + originDomain: "http://example.com", allowedOrigin: "http://example.com", method: http.MethodGet, expected: true, + expectStatus: http.StatusOK, }, { - domain: "", // Request does not have Origin header - allowedOrigin: "*", - method: http.MethodOptions, - expected: false, + name: "preflight, allow any origin, missing origin header = no CORS logic done", + originDomain: "", // Request does not have Origin header + allowedOrigin: "*", + method: http.MethodOptions, + expected: false, + expectStatus: http.StatusNoContent, + expectAllowHeader: "OPTIONS, GET, POST", }, { - domain: "http://example.com", - allowedOrigin: "*", - method: http.MethodOptions, - expected: true, + name: "preflight, allow any origin, existing origin header = CORS logic done", + originDomain: "http://example.com", + allowedOrigin: "*", + method: http.MethodOptions, + expected: true, + expectStatus: http.StatusNoContent, + expectAllowHeader: "OPTIONS, GET, POST", }, { - domain: "", // Request does not have Origin header - allowedOrigin: "http://example.com", - method: http.MethodOptions, - expected: false, + name: "preflight, allow any origin, missing origin header = no CORS logic done", + originDomain: "", // Request does not have Origin header + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: false, + expectStatus: http.StatusNoContent, + expectAllowHeader: "OPTIONS, GET, POST", }, { - domain: "http://bar.com", - allowedOrigin: "http://example.com", - method: http.MethodGet, - expected: false, + name: "preflight, allow specific origin, different origin header = no CORS logic done", + originDomain: "http://bar.com", + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: false, + expectStatus: http.StatusNoContent, + expectAllowHeader: "OPTIONS, GET, POST", }, { - domain: "http://example.com", - allowedOrigin: "http://example.com", - method: http.MethodOptions, - expected: true, + name: "preflight, allow specific origin, matching origin header = CORS logic done", + originDomain: "http://example.com", + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: true, + expectStatus: http.StatusNoContent, + expectAllowHeader: "OPTIONS, GET, POST", }, } - e := echo.New() - for _, tt := range tests { - req := httptest.NewRequest(tt.method, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - if tt.domain != "" { - req.Header.Set(echo.HeaderOrigin, tt.domain) - } - cors := CORSWithConfig(CORSConfig{ - AllowOrigins: []string{tt.allowedOrigin}, - //AllowCredentials: true, - //MaxAge: 3600, - }) - h := cors(echo.NotFoundHandler) - h(c) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() - assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary)) + e.Use(CORSWithConfig(CORSConfig{ + AllowOrigins: []string{tc.allowedOrigin}, + //AllowCredentials: true, + //MaxAge: 3600, + })) - expectedAllowOrigin := "" - if tt.allowedOrigin == "*" { - expectedAllowOrigin = "*" - } else { - expectedAllowOrigin = tt.domain - } + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + }) + e.POST("/", func(c echo.Context) error { + return c.String(http.StatusCreated, "OK") + }) - switch { - case tt.expected && tt.method == http.MethodOptions: - assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods) - assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary])) - case tt.expected && tt.method == http.MethodGet: - assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin - default: - assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) - assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin - } + req := httptest.NewRequest(tc.method, "/", nil) + rec := httptest.NewRecorder() + + if tc.originDomain != "" { + req.Header.Set(echo.HeaderOrigin, tc.originDomain) + } + + // we run through whole Echo handler chain to see how CORS works with Router OPTIONS handler + e.ServeHTTP(rec, req) + + assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary)) + assert.Equal(t, tc.expectAllowHeader, rec.Header().Get(echo.HeaderAllow)) + assert.Equal(t, tc.expectStatus, rec.Code) + + expectedAllowOrigin := "" + if tc.allowedOrigin == "*" { + expectedAllowOrigin = "*" + } else { + expectedAllowOrigin = tc.originDomain + } + switch { + case tc.expected && tc.method == http.MethodOptions: + assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods) + assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + + assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary])) + + case tc.expected && tc.method == http.MethodGet: + assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin + default: + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin + } + }) - if tt.method == http.MethodOptions { - assert.Equal(t, http.StatusNoContent, rec.Code) - } } } diff --git a/router.go b/router.go index dc93e29c8..1a2ce561f 100644 --- a/router.go +++ b/router.go @@ -1,6 +1,7 @@ package echo import ( + "bytes" "net/http" ) @@ -31,17 +32,18 @@ type ( kind uint8 children []*node methodHandler struct { - connect HandlerFunc - delete HandlerFunc - get HandlerFunc - head HandlerFunc - options HandlerFunc - patch HandlerFunc - post HandlerFunc - propfind HandlerFunc - put HandlerFunc - trace HandlerFunc - report HandlerFunc + connect HandlerFunc + delete HandlerFunc + get HandlerFunc + head HandlerFunc + options HandlerFunc + patch HandlerFunc + post HandlerFunc + propfind HandlerFunc + put HandlerFunc + trace HandlerFunc + report HandlerFunc + allowHeader string } ) @@ -68,6 +70,51 @@ func (m *methodHandler) isHandler() bool { m.report != nil } +func (m *methodHandler) updateAllowHeader() { + buf := new(bytes.Buffer) + buf.WriteString(http.MethodOptions) + + if m.connect != nil { + buf.WriteString(", ") + buf.WriteString(http.MethodConnect) + } + if m.delete != nil { + buf.WriteString(", ") + buf.WriteString(http.MethodDelete) + } + if m.get != nil { + buf.WriteString(", ") + buf.WriteString(http.MethodGet) + } + if m.head != nil { + buf.WriteString(", ") + buf.WriteString(http.MethodHead) + } + if m.patch != nil { + buf.WriteString(", ") + buf.WriteString(http.MethodPatch) + } + if m.post != nil { + buf.WriteString(", ") + buf.WriteString(http.MethodPost) + } + if m.propfind != nil { + buf.WriteString(", PROPFIND") + } + if m.put != nil { + buf.WriteString(", ") + buf.WriteString(http.MethodPut) + } + if m.trace != nil { + buf.WriteString(", ") + buf.WriteString(http.MethodTrace) + } + if m.report != nil { + buf.WriteString(", REPORT") + } + m.allowHeader = buf.String() +} + // NewRouter returns a new Router instance. func NewRouter(e *Echo) *Router { return &Router{ @@ -326,6 +373,7 @@ func (n *node) addHandler(method string, h HandlerFunc) { n.methodHandler.report = h } + n.methodHandler.updateAllowHeader() if h != nil { n.isHandler = true } else { @@ -362,13 +410,14 @@ func (n *node) findHandler(method string) HandlerFunc { } } -func (n *node) checkMethodNotAllowed() HandlerFunc { - for _, m := range methods { - if h := n.findHandler(m); h != nil { - return MethodNotAllowedHandler - } +func optionsMethodHandler(allowMethods string) func(c Context) error { + return func(c Context) error { + // Note: we are not handling most of the CORS headers here. CORS is handled by CORS middleware + // 'OPTIONS' method RFC: https://httpwg.org/specs/rfc7231.html#OPTIONS + // 'Allow' header RFC: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1 + c.Response().Header().Add(HeaderAllow, allowMethods) + return c.NoContent(http.StatusNoContent) } - return NotFoundHandler } // Find lookup a handler registered for method and path. It also parses URL for path @@ -563,7 +612,15 @@ func (r *Router) Find(method, path string, c Context) { // use previous match as basis. although we have no matching handler we have path match. // so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404) currentNode = previousBestMatchNode - ctx.handler = currentNode.checkMethodNotAllowed() + + ctx.handler = NotFoundHandler + if currentNode.isHandler { + ctx.Set(ContextKeyHeaderAllow, currentNode.methodHandler.allowHeader) + ctx.handler = MethodNotAllowedHandler + if method == http.MethodOptions { + ctx.handler = optionsMethodHandler(currentNode.methodHandler.allowHeader) + } + } } ctx.path = currentNode.ppath ctx.pnames = currentNode.pnames diff --git a/router_test.go b/router_test.go index 57be74deb..5cbb8d9b8 100644 --- a/router_test.go +++ b/router_test.go @@ -3,6 +3,7 @@ package echo import ( "fmt" "net/http" + "net/http/httptest" "strings" "testing" @@ -725,12 +726,13 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) { r.Add(http.MethodPost, "/users/:id", handlerFunc) var testCases = []struct { - name string - whenMethod string - whenURL string - expectRoute interface{} - expectParam map[string]string - expectError error + name string + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + expectAllowHeader string }{ { name: "exact match for route+method", @@ -740,11 +742,12 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) { expectParam: map[string]string{"id": "1"}, }, { - name: "matches node but not method. sends 405 from best match node", - whenMethod: http.MethodPut, - whenURL: "/users/1", - expectRoute: nil, - expectError: ErrMethodNotAllowed, + name: "matches node but not method. sends 405 from best match node", + whenMethod: http.MethodPut, + whenURL: "/users/1", + expectRoute: nil, + expectError: ErrMethodNotAllowed, + expectAllowHeader: "OPTIONS, POST", }, { name: "best match is any route up in tree", @@ -756,7 +759,9 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - c := e.NewContext(nil, nil).(*context) + req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec).(*context) method := http.MethodGet if tc.whenMethod != "" { @@ -775,10 +780,36 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) { assert.Equal(t, expectedValue, c.Param(param)) } checkUnusedParamValues(t, c, tc.expectParam) + + assert.Equal(t, tc.expectAllowHeader, c.Response().Header().Get(HeaderAllow)) }) } } +func TestRouterOptionsMethodHandler(t *testing.T) { + e := New() + + var keyInContext interface{} + e.Use(func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + err := next(c) + keyInContext = c.Get(ContextKeyHeaderAllow) + return err + } + }) + e.GET("/test", func(c Context) error { + return c.String(http.StatusOK, "Echo!") + }) + + req := httptest.NewRequest(http.MethodOptions, "/test", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusNoContent, rec.Code) + assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow)) + assert.Equal(t, "OPTIONS, GET", keyInContext) +} + func TestRouterTwoParam(t *testing.T) { e := New() r := e.router @@ -2288,6 +2319,73 @@ func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) { } } +func TestRouterHandleMethodOptions(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/users", handlerFunc) + r.Add(http.MethodPost, "/users", handlerFunc) + r.Add(http.MethodPut, "/users/:id", handlerFunc) + r.Add(http.MethodGet, "/users/:id", handlerFunc) + + var testCases = []struct { + name string + whenMethod string + whenURL string + expectAllowHeader string + expectStatus int + }{ + { + name: "allows GET and POST handlers", + whenMethod: http.MethodOptions, + whenURL: "/users", + expectAllowHeader: "OPTIONS, GET, POST", + expectStatus: http.StatusNoContent, + }, + { + name: "allows GET and PUT handlers", + whenMethod: http.MethodOptions, + whenURL: "/users/1", + expectAllowHeader: "OPTIONS, GET, PUT", + expectStatus: http.StatusNoContent, + }, + { + name: "GET does not have allows header", + whenMethod: http.MethodGet, + whenURL: "/users", + expectAllowHeader: "", + expectStatus: http.StatusOK, + }, + { + name: "path with no handlers does not set Allows header", + whenMethod: http.MethodOptions, + whenURL: "/notFound", + expectAllowHeader: "", + expectStatus: http.StatusNotFound, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec).(*context) + + r.Find(tc.whenMethod, tc.whenURL, c) + err := c.handler(c) + + if tc.expectStatus >= 400 { + assert.Error(t, err) + he := err.(*HTTPError) + assert.Equal(t, tc.expectStatus, he.Code) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectStatus, rec.Code) + } + assert.Equal(t, tc.expectAllowHeader, c.Response().Header().Get("Allow")) + }) + } +} + func benchmarkRouterRoutes(b *testing.B, routes []*Route, routesToFind []*Route) { e := New() r := e.router From 6f6befe555e9f076b8f4a4c060ad100f1c7e46b4 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Thu, 9 Dec 2021 21:57:20 +0200 Subject: [PATCH 186/446] improve docs --- context.go | 2 +- echo.go | 12 ++++++------ middleware/cors.go | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/context.go b/context.go index ea542cb86..f2421d77b 100644 --- a/context.go +++ b/context.go @@ -214,7 +214,7 @@ const ( // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain. // Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests. // It is added to context only when Router does not find matching method handler for request. - ContextKeyHeaderAllow = "____echo____header_allow" + ContextKeyHeaderAllow = "echo_header_allow" ) const ( diff --git a/echo.go b/echo.go index 8747039e4..427898217 100644 --- a/echo.go +++ b/echo.go @@ -192,9 +192,10 @@ const ( const ( HeaderAccept = "Accept" HeaderAcceptEncoding = "Accept-Encoding" - // HeaderAllow is header field that lists the set of methods advertised as supported by the target resource. - // Allow header is mandatory for status 405 (method not found) and useful OPTIONS method responses. - // See: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1 + // HeaderAllow is the name of the "Allow" header field used to list the set of methods + // advertised as supported by the target resource. Returning an Allow header is mandatory + // for status 405 (method not found) and useful for the OPTIONS method in responses. + // See RFC 7231: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1 HeaderAllow = "Allow" HeaderAuthorization = "Authorization" HeaderContentDisposition = "Content-Disposition" @@ -305,9 +306,8 @@ var ( } MethodNotAllowedHandler = func(c Context) error { - // 'Allow' header RFC: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1 - // >> An origin server MUST generate an Allow field in a 405 (Method Not Allowed) response - // and MAY do so in any other response. + // See RFC 7231 section 7.4.1: An origin server MUST generate an Allow field in a 405 (Method Not Allowed) + // response and MAY do so in any other response. For disabled resources an empty Allow header may be returned routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string) if ok && routerAllowMethods != "" { c.Response().Header().Set(HeaderAllow, routerAllowMethods) diff --git a/middleware/cors.go b/middleware/cors.go index a5122f26e..16259512a 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -172,7 +172,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { checkPatterns := false if allowOrigin == "" { // to avoid regex cost by invalid (long) domains (253 is domain name max limit) - if len(origin) <= (253+3+4) && strings.Contains(origin, "://") { + if len(origin) <= (253+3+5) && strings.Contains(origin, "://") { checkPatterns = true } } From 296c31358a08f0b83d785942778b5e1bc0ea9d5c Mon Sep 17 00:00:00 2001 From: Martti T Date: Sun, 9 Jan 2022 02:41:40 +0200 Subject: [PATCH 187/446] Add list of middlewares to readme including 3rd party projects (#2065) --- README.md | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 930cb0344..885c7bd6d 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,6 @@ [![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo) [![Build Status](http://img.shields.io/travis/labstack/echo.svg?style=flat-square)](https://travis-ci.org/labstack/echo) [![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo) -[![Join the chat at https://gitter.im/labstack/echo](https://img.shields.io/badge/gitter-join%20chat-brightgreen.svg?style=flat-square)](https://gitter.im/labstack/echo) [![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://github.com/labstack/echo/discussions) [![Twitter](https://img.shields.io/badge/twitter-@labstack-55acee.svg?style=flat-square)](https://twitter.com/labstack) [![License](http://img.shields.io/badge/license-mit-blue.svg?style=flat-square)](https://raw.githubusercontent.com/labstack/echo/master/LICENSE) @@ -92,10 +91,21 @@ func hello(c echo.Context) error { } ``` +# Third-party middlewares + +| Repository | Description | +|------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | (by Echo team) [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | +| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | +| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | +| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | +| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | + +Please send a PR to add your own library here. + ## Help - [Forum](https://github.com/labstack/echo/discussions) -- [Chat](https://gitter.im/labstack/echo) ## Contribute From aada6f95d715491107857d3bdf2f7c5c6ba50339 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Mon, 10 Jan 2022 21:16:35 +0200 Subject: [PATCH 188/446] Fix Echo version number which was not incremented with Release 4.6.2. Now bumped to 4.6.3 --- CHANGELOG.md | 23 +++++++++++++++++++++++ echo.go | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f52f264f3..372ed13c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,28 @@ # Changelog +## v4.6.3 - 2022-01-10 + +**Fixes** + +* Fixed Echo version number in greeting message which was not incremented to `4.6.2` [#2066](https://github.com/labstack/echo/issues/2066) + + +## v4.6.2 - 2022-01-08 + +**Fixes** + +* Fixed route containing escaped colon should be matchable but is not matched to request path [#2047](https://github.com/labstack/echo/pull/2047) +* Fixed a problem that returned wrong content-encoding when the gzip compressed content was empty. [#1921](https://github.com/labstack/echo/pull/1921) +* Update (test) dependencies [#2021](https://github.com/labstack/echo/pull/2021) + + +**Enhancements** + +* Add support for configurable target header for the request_id middleware [#2040](https://github.com/labstack/echo/pull/2040) +* Change decompress middleware to use stream decompression instead of buffering [#2018](https://github.com/labstack/echo/pull/2018) +* Documentation updates + + ## v4.6.1 - 2021-09-26 **Enhancements** diff --git a/echo.go b/echo.go index 427898217..9ae2ed4e5 100644 --- a/echo.go +++ b/echo.go @@ -246,7 +246,7 @@ const ( const ( // Version of Echo - Version = "4.6.1" + Version = "4.6.3" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 8d2c45eeff4ed39e9a9794b4869b5b164018b368 Mon Sep 17 00:00:00 2001 From: darkweak Date: Mon, 10 Jan 2022 22:23:35 +0100 Subject: [PATCH 189/446] Add Souin middleware into third-party-middlewares --- README.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 885c7bd6d..20d662c05 100644 --- a/README.md +++ b/README.md @@ -93,13 +93,14 @@ func hello(c echo.Context) error { # Third-party middlewares -| Repository | Description | -|------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Repository | Description | +|------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | (by Echo team) [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | -| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | -| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | -| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | -| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | +| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | +| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | +| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | +| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | +| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | Please send a PR to add your own library here. From 94638be9f8c6eeabb6cf7fb7f00a19e3fd2c430a Mon Sep 17 00:00:00 2001 From: Patrick Willner <50421879+heat1q@users.noreply.github.com> Date: Fri, 21 Jan 2022 17:32:53 +0100 Subject: [PATCH 190/446] Add Retry-After header constant --- echo.go | 1 + 1 file changed, 1 insertion(+) diff --git a/echo.go b/echo.go index 9ae2ed4e5..fc7e116f0 100644 --- a/echo.go +++ b/echo.go @@ -207,6 +207,7 @@ const ( HeaderIfModifiedSince = "If-Modified-Since" HeaderLastModified = "Last-Modified" HeaderLocation = "Location" + HeaderRetryAfter = "Retry-After" HeaderUpgrade = "Upgrade" HeaderVary = "Vary" HeaderWWWAuthenticate = "WWW-Authenticate" From db9c708124972d4869ae2a5b3c705f5650924092 Mon Sep 17 00:00:00 2001 From: mikestefanello Date: Thu, 20 Jan 2022 08:23:40 -0500 Subject: [PATCH 191/446] Add pagoda to the README. --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 20d662c05..8b2321f05 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ func hello(c echo.Context) error { | [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | | [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | | [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | +| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. Please send a PR to add your own library here. From eb371a9e64674b9399d18a1252fe91125c39022c Mon Sep 17 00:00:00 2001 From: Clement JACOB Date: Mon, 24 Jan 2022 10:28:48 +0100 Subject: [PATCH 192/446] Adding support for HEAD method query params binding (#2027) * Adding support for HEAD method query params binding. * Update comment for added HEAD method for bind Co-authored-by: Roland Lammel --- bind.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bind.go b/bind.go index fdf0524c2..c841ca010 100644 --- a/bind.go +++ b/bind.go @@ -111,11 +111,11 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { if err := b.BindPathParams(c, i); err != nil { return err } - // Issue #1670 - Query params are binded only for GET/DELETE and NOT for usual request with body (POST/PUT/PATCH) - // Reasoning here is that parameters in query and bind destination struct could have UNEXPECTED matches and results due that. - // i.e. is `&id=1&lang=en` from URL same as `{"id":100,"lang":"de"}` request body and which one should have priority when binding. - // This HTTP method check restores pre v4.1.11 behavior and avoids different problems when query is mixed with body - if c.Request().Method == http.MethodGet || c.Request().Method == http.MethodDelete { + // Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body. + // For example a request URL `&id=1&lang=en` with body `{"id":100,"lang":"de"}` would lead to precedence issues. + // The HTTP method check restores pre-v4.1.11 behavior to avoid these problems (see issue #1670) + method := c.Request().Method + if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead { if err = b.BindQueryParams(c, i); err != nil { return err } From f3865f9aa539c1eddd80fdb371d4712a1296bd7b Mon Sep 17 00:00:00 2001 From: sivchari Date: Mon, 24 Jan 2022 18:33:13 +0900 Subject: [PATCH 193/446] Minor syntax fixes (#1994) --- echo.go | 2 +- router.go | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/echo.go b/echo.go index fc7e116f0..d067b8966 100644 --- a/echo.go +++ b/echo.go @@ -654,7 +654,7 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Acquire context c := e.pool.Get().(*context) c.Reset(r, w) - h := NotFoundHandler + var h func(Context) error if e.premiddleware == nil { e.findRouter(r.Host).Find(r.Method, GetPath(r), c) diff --git a/router.go b/router.go index 1a2ce561f..a1de2d6e3 100644 --- a/router.go +++ b/router.go @@ -624,6 +624,4 @@ func (r *Router) Find(method, path string, c Context) { } ctx.path = currentNode.ppath ctx.pnames = currentNode.pnames - - return } From 7c41b93f0c8164c9354be374f670635c1598398b Mon Sep 17 00:00:00 2001 From: ant1k9 <56701963+ant1k9@users.noreply.github.com> Date: Mon, 24 Jan 2022 13:23:41 +0300 Subject: [PATCH 194/446] Add LogErrorFunc to recover middleware (#2072) LogErrorFunc provides more general interface to handle errors in the recover middleware. --- middleware/recover.go | 21 +++++++++++++-- middleware/recover_test.go | 53 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/middleware/recover.go b/middleware/recover.go index 0dbe740da..a621a9efe 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -9,6 +9,9 @@ import ( ) type ( + // LogErrorFunc defines a function for custom logging in the middleware. + LogErrorFunc func(c echo.Context, err error, stack []byte) error + // RecoverConfig defines the config for Recover middleware. RecoverConfig struct { // Skipper defines a function to skip middleware. @@ -30,6 +33,10 @@ type ( // LogLevel is log level to printing stack trace. // Optional. Default value 0 (Print). LogLevel log.Lvl + + // LogErrorFunc defines a function for custom logging in the middleware. + // If it's set you don't need to provide LogLevel for config. + LogErrorFunc LogErrorFunc } ) @@ -41,6 +48,7 @@ var ( DisableStackAll: false, DisablePrintStack: false, LogLevel: 0, + LogErrorFunc: nil, } ) @@ -73,9 +81,18 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { if !ok { err = fmt.Errorf("%v", r) } - stack := make([]byte, config.StackSize) - length := runtime.Stack(stack, !config.DisableStackAll) + var stack []byte + var length int + if !config.DisablePrintStack { + stack = make([]byte, config.StackSize) + length = runtime.Stack(stack, !config.DisableStackAll) + stack = stack[:length] + } + + if config.LogErrorFunc != nil { + err = config.LogErrorFunc(c, err, stack) + } else if !config.DisablePrintStack { msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length]) switch config.LogLevel { case log.DEBUG: diff --git a/middleware/recover_test.go b/middleware/recover_test.go index 644332972..9ac4feedc 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -2,6 +2,7 @@ package middleware import ( "bytes" + "errors" "fmt" "net/http" "net/http/httptest" @@ -81,3 +82,55 @@ func TestRecoverWithConfig_LogLevel(t *testing.T) { }) } } + +func TestRecoverWithConfig_LogErrorFunc(t *testing.T) { + e := echo.New() + e.Logger.SetLevel(log.DEBUG) + + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + testError := errors.New("test") + config := DefaultRecoverConfig + config.LogErrorFunc = func(c echo.Context, err error, stack []byte) error { + msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack) + if errors.Is(err, testError) { + c.Logger().Debug(msg) + } else { + c.Logger().Error(msg) + } + return err + } + + t.Run("first branch case for LogErrorFunc", func(t *testing.T) { + buf.Reset() + h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { + panic(testError) + })) + + h(c) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + + output := buf.String() + assert.Contains(t, output, "PANIC RECOVER") + assert.Contains(t, output, `"level":"DEBUG"`) + }) + + t.Run("else branch case for LogErrorFunc", func(t *testing.T) { + buf.Reset() + h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { + panic("other") + })) + + h(c) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + + output := buf.String() + assert.Contains(t, output, "PANIC RECOVER") + assert.Contains(t, output, `"level":"ERROR"`) + }) +} From 1b1a68fd4f9315fd73e0260ffd650fb1ace6b9b8 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 8 Jan 2022 22:41:34 +0200 Subject: [PATCH 195/446] Improve filesystem support (Go 1.16+). Add field echo.Filesystem, methods: echo.FileFS, echo.StaticFS, group.FileFS, group.StaticFS. Following methods will use echo.Filesystem to server files: echo.File, echo.Static, group.File, group.Static, Context.File --- context.go | 25 ---- context_fs.go | 33 +++++ context_fs_go1.16.go | 47 +++++++ context_fs_go1.16_test.go | 135 ++++++++++++++++++++ echo.go | 53 +------- echo_fs.go | 62 ++++++++++ echo_fs_go1.16.go | 126 +++++++++++++++++++ echo_fs_go1.16_test.go | 251 ++++++++++++++++++++++++++++++++++++++ echo_test.go | 17 +-- go.mod | 2 +- go.sum | 11 +- group.go | 5 - group_fs.go | 9 ++ group_fs_go1.16.go | 34 ++++++ group_fs_go1.16_test.go | 106 ++++++++++++++++ 15 files changed, 819 insertions(+), 97 deletions(-) create mode 100644 context_fs.go create mode 100644 context_fs_go1.16.go create mode 100644 context_fs_go1.16_test.go create mode 100644 echo_fs.go create mode 100644 echo_fs_go1.16.go create mode 100644 echo_fs_go1.16_test.go create mode 100644 group_fs.go create mode 100644 group_fs_go1.16.go create mode 100644 group_fs_go1.16_test.go diff --git a/context.go b/context.go index f2421d77b..a4ecfadfc 100644 --- a/context.go +++ b/context.go @@ -9,8 +9,6 @@ import ( "net" "net/http" "net/url" - "os" - "path/filepath" "strings" "sync" ) @@ -569,29 +567,6 @@ func (c *context) Stream(code int, contentType string, r io.Reader) (err error) return } -func (c *context) File(file string) (err error) { - f, err := os.Open(file) - if err != nil { - return NotFoundHandler(c) - } - defer f.Close() - - fi, _ := f.Stat() - if fi.IsDir() { - file = filepath.Join(file, indexPage) - f, err = os.Open(file) - if err != nil { - return NotFoundHandler(c) - } - defer f.Close() - if fi, err = f.Stat(); err != nil { - return - } - } - http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f) - return -} - func (c *context) Attachment(file, name string) error { return c.contentDisposition(file, name, "attachment") } diff --git a/context_fs.go b/context_fs.go new file mode 100644 index 000000000..11ee84bcd --- /dev/null +++ b/context_fs.go @@ -0,0 +1,33 @@ +//go:build !go1.16 +// +build !go1.16 + +package echo + +import ( + "net/http" + "os" + "path/filepath" +) + +func (c *context) File(file string) (err error) { + f, err := os.Open(file) + if err != nil { + return NotFoundHandler(c) + } + defer f.Close() + + fi, _ := f.Stat() + if fi.IsDir() { + file = filepath.Join(file, indexPage) + f, err = os.Open(file) + if err != nil { + return NotFoundHandler(c) + } + defer f.Close() + if fi, err = f.Stat(); err != nil { + return + } + } + http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f) + return +} diff --git a/context_fs_go1.16.go b/context_fs_go1.16.go new file mode 100644 index 000000000..eeffef507 --- /dev/null +++ b/context_fs_go1.16.go @@ -0,0 +1,47 @@ +//go:build go1.16 +// +build go1.16 + +package echo + +import ( + "errors" + "io" + "io/fs" + "net/http" + "path/filepath" +) + +func (c *context) File(file string) error { + return fsFile(c, file, c.echo.Filesystem) +} + +func (c *context) FileFS(file string, filesystem fs.FS) error { + return fsFile(c, file, filesystem) +} + +func fsFile(c Context, file string, filesystem fs.FS) error { + f, err := filesystem.Open(file) + if err != nil { + return ErrNotFound + } + defer f.Close() + + fi, _ := f.Stat() + if fi.IsDir() { + file = filepath.Join(file, indexPage) + f, err = filesystem.Open(file) + if err != nil { + return ErrNotFound + } + defer f.Close() + if fi, err = f.Stat(); err != nil { + return err + } + } + ff, ok := f.(io.ReadSeeker) + if !ok { + return errors.New("file does not implement io.ReadSeeker") + } + http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) + return nil +} diff --git a/context_fs_go1.16_test.go b/context_fs_go1.16_test.go new file mode 100644 index 000000000..f209e8a06 --- /dev/null +++ b/context_fs_go1.16_test.go @@ -0,0 +1,135 @@ +//go:build go1.16 +// +build go1.16 + +package echo + +import ( + testify "github.com/stretchr/testify/assert" + "io/fs" + "net/http" + "net/http/httptest" + "os" + "testing" +) + +func TestContext_File(t *testing.T) { + var testCases = []struct { + name string + whenFile string + whenFS fs.FS + expectStatus int + expectStartsWith []byte + expectError string + }{ + { + name: "ok, from default file system", + whenFile: "_fixture/images/walle.png", + whenFS: nil, + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "ok, from custom file system", + whenFile: "walle.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, not existent file", + whenFile: "not.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: nil, + expectError: "code=404, message=Not Found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + if tc.whenFS != nil { + e.Filesystem = tc.whenFS + } + + handler := func(ec Context) error { + return ec.(*context).File(tc.whenFile) + } + + req := httptest.NewRequest(http.MethodGet, "/match.png", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + + testify.Equal(t, tc.expectStatus, rec.Code) + if tc.expectError != "" { + testify.EqualError(t, err, tc.expectError) + } else { + testify.NoError(t, err) + } + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + testify.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestContext_FileFS(t *testing.T) { + var testCases = []struct { + name string + whenFile string + whenFS fs.FS + expectStatus int + expectStartsWith []byte + expectError string + }{ + { + name: "ok", + whenFile: "walle.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, not existent file", + whenFile: "not.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: nil, + expectError: "code=404, message=Not Found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + handler := func(ec Context) error { + return ec.(*context).FileFS(tc.whenFile, tc.whenFS) + } + + req := httptest.NewRequest(http.MethodGet, "/match.png", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + + testify.Equal(t, tc.expectStatus, rec.Code) + if tc.expectError != "" { + testify.EqualError(t, err, tc.expectError) + } else { + testify.NoError(t, err) + } + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + testify.Equal(t, tc.expectStartsWith, body) + }) + } +} diff --git a/echo.go b/echo.go index d067b8966..56255c6cb 100644 --- a/echo.go +++ b/echo.go @@ -47,9 +47,6 @@ import ( stdLog "log" "net" "net/http" - "net/url" - "os" - "path/filepath" "reflect" "runtime" "sync" @@ -66,6 +63,7 @@ import ( type ( // Echo is the top-level framework instance. Echo struct { + filesystem common // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get // listener address info (on which interface/port was listener binded) without having data races. @@ -320,8 +318,9 @@ var ( // New creates an instance of Echo. func New() (e *Echo) { e = &Echo{ - Server: new(http.Server), - TLSServer: new(http.Server), + filesystem: createFilesystem(), + Server: new(http.Server), + TLSServer: new(http.Server), AutoTLSManager: autocert.Manager{ Prompt: autocert.AcceptTOS, }, @@ -500,50 +499,6 @@ func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middlew return routes } -// Static registers a new route with path prefix to serve static files from the -// provided root directory. -func (e *Echo) Static(prefix, root string) *Route { - if root == "" { - root = "." // For security we want to restrict to CWD. - } - return e.static(prefix, root, e.GET) -} - -func (common) static(prefix, root string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route) *Route { - h := func(c Context) error { - p, err := url.PathUnescape(c.Param("*")) - if err != nil { - return err - } - - name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security - fi, err := os.Stat(name) - if err != nil { - // The access path does not exist - return NotFoundHandler(c) - } - - // If the request is for a directory and does not end with "/" - p = c.Request().URL.Path // path must not be empty. - if fi.IsDir() && p[len(p)-1] != '/' { - // Redirect to ends with "/" - return c.Redirect(http.StatusMovedPermanently, p+"/") - } - return c.File(name) - } - // Handle added routes based on trailing slash: - // /prefix => exact route "/prefix" + any route "/prefix/*" - // /prefix/ => only any route "/prefix/*" - if prefix != "" { - if prefix[len(prefix)-1] == '/' { - // Only add any route for intentional trailing slash - return get(prefix+"*", h) - } - get(prefix, h) - } - return get(prefix+"/*", h) -} - func (common) file(path, file string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route, m ...MiddlewareFunc) *Route { return get(path, func(c Context) error { diff --git a/echo_fs.go b/echo_fs.go new file mode 100644 index 000000000..c3790545a --- /dev/null +++ b/echo_fs.go @@ -0,0 +1,62 @@ +//go:build !go1.16 +// +build !go1.16 + +package echo + +import ( + "net/http" + "net/url" + "os" + "path/filepath" +) + +type filesystem struct { +} + +func createFilesystem() filesystem { + return filesystem{} +} + +// Static registers a new route with path prefix to serve static files from the +// provided root directory. +func (e *Echo) Static(prefix, root string) *Route { + if root == "" { + root = "." // For security we want to restrict to CWD. + } + return e.static(prefix, root, e.GET) +} + +func (common) static(prefix, root string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route) *Route { + h := func(c Context) error { + p, err := url.PathUnescape(c.Param("*")) + if err != nil { + return err + } + + name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security + fi, err := os.Stat(name) + if err != nil { + // The access path does not exist + return NotFoundHandler(c) + } + + // If the request is for a directory and does not end with "/" + p = c.Request().URL.Path // path must not be empty. + if fi.IsDir() && p[len(p)-1] != '/' { + // Redirect to ends with "/" + return c.Redirect(http.StatusMovedPermanently, p+"/") + } + return c.File(name) + } + // Handle added routes based on trailing slash: + // /prefix => exact route "/prefix" + any route "/prefix/*" + // /prefix/ => only any route "/prefix/*" + if prefix != "" { + if prefix[len(prefix)-1] == '/' { + // Only add any route for intentional trailing slash + return get(prefix+"*", h) + } + get(prefix, h) + } + return get(prefix+"/*", h) +} diff --git a/echo_fs_go1.16.go b/echo_fs_go1.16.go new file mode 100644 index 000000000..b4258e367 --- /dev/null +++ b/echo_fs_go1.16.go @@ -0,0 +1,126 @@ +//go:build go1.16 +// +build go1.16 + +package echo + +import ( + "fmt" + "io/fs" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" +) + +type filesystem struct { + // Filesystem is file system used by Static and File handlers to access files. + // Defaults to os.DirFS(".") + Filesystem fs.FS +} + +func createFilesystem() filesystem { + return filesystem{ + Filesystem: newDefaultFS(), + } +} + +// Static registers a new route with path prefix to serve static files from the provided root directory. +func (e *Echo) Static(pathPrefix, root string) *Route { + subFs, err := subFS(e.Filesystem, root) + if err != nil { + // happens when `root` contains invalid path according to `fs.ValidPath` rules and we are unable to create FS + panic(fmt.Errorf("invalid root given to echo.Static, err %w", err)) + } + return e.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(subFs, false), + ) +} + +// StaticFS registers a new route with path prefix to serve static files from the provided file system. +func (e *Echo) StaticFS(pathPrefix string, fileSystem fs.FS) *Route { + return e.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(fileSystem, false), + ) +} + +// StaticDirectoryHandler creates handler function to serve files from provided file system +// When disablePathUnescaping is set then file name from path is not unescaped and is served as is. +func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc { + return func(c Context) error { + p := c.Param("*") + if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice + tmpPath, err := url.PathUnescape(p) + if err != nil { + return fmt.Errorf("failed to unescape path variable: %w", err) + } + p = tmpPath + } + + // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid + name := filepath.Clean(strings.TrimPrefix(p, "/")) + fi, err := fs.Stat(fileSystem, name) + if err != nil { + return ErrNotFound + } + + // If the request is for a directory and does not end with "/" + p = c.Request().URL.Path // path must not be empty. + if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { + // Redirect to ends with "/" + return c.Redirect(http.StatusMovedPermanently, p+"/") + } + return fsFile(c, name, fileSystem) + } +} + +// FileFS registers a new route with path to serve file from the provided file system. +func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { + return e.GET(path, StaticFileHandler(file, filesystem), m...) +} + +// StaticFileHandler creates handler function to serve file from provided file system +func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { + return func(c Context) error { + return fsFile(c, file, filesystem) + } +} + +// defaultFS emulates os.Open behaviour with filesystem opened by `os.DirFs`. Difference between `os.Open` and `fs.Open` +// is that FS does not allow to open path that start with `..` or `/` etc. For example previously you could have `../images` +// in your application but `fs := os.DirFS("./")` would not allow you to use `fs.Open("../images")` and this would break +// all old applications that rely on being able to traverse up from current executable run path. +// NB: private because you really should use fs.FS implementation instances +type defaultFS struct { + prefix string + fs fs.FS +} + +func newDefaultFS() *defaultFS { + dir, _ := os.Getwd() + return &defaultFS{ + prefix: dir, + fs: os.DirFS(dir), + } +} + +func (fs defaultFS) Open(name string) (fs.File, error) { + return fs.fs.Open(name) +} + +func subFS(currentFs fs.FS, root string) (fs.FS, error) { + if dFS, ok := currentFs.(*defaultFS); ok { + // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS to + // allow cases when root is given as `../somepath` which is not valid for fs.FS + root = filepath.Join(dFS.prefix, root) + return &defaultFS{ + prefix: root, + fs: os.DirFS(root), + }, nil + } + return fs.Sub(currentFs, filepath.Clean(root)) +} diff --git a/echo_fs_go1.16_test.go b/echo_fs_go1.16_test.go new file mode 100644 index 000000000..4a95b105c --- /dev/null +++ b/echo_fs_go1.16_test.go @@ -0,0 +1,251 @@ +//go:build go1.16 +// +build go1.16 + +package echo + +import ( + "github.com/stretchr/testify/assert" + "io/fs" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" +) + +func TestEcho_StaticFS(t *testing.T) { + var testCases = []struct { + name string + givenPrefix string + givenFs fs.FS + whenURL string + expectStatus int + expectHeaderLocation string + expectBodyStartsWith string + }{ + { + name: "ok", + givenPrefix: "/images", + givenFs: os.DirFS("./_fixture/images"), + whenURL: "/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "No file", + givenPrefix: "/images", + givenFs: os.DirFS("_fixture/scripts"), + whenURL: "/images/bolt.png", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory", + givenPrefix: "/images", + givenFs: os.DirFS("_fixture/images"), + whenURL: "/images/", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory Redirect", + givenPrefix: "/", + givenFs: os.DirFS("_fixture/"), + whenURL: "/folder", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory Redirect with non-root path", + givenPrefix: "/static", + givenFs: os.DirFS("_fixture"), + whenURL: "/static", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/static/", + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory 404 (request URL without slash)", + givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" + givenFs: os.DirFS("_fixture"), + whenURL: "/folder", // no trailing slash + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Prefixed directory redirect (without slash redirect to slash)", + givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* + givenFs: os.DirFS("_fixture"), + whenURL: "/folder", // no trailing slash + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory with index.html", + givenPrefix: "/", + givenFs: os.DirFS("_fixture"), + whenURL: "/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending with slash)", + givenPrefix: "/assets/", + givenFs: os.DirFS("_fixture"), + whenURL: "/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending without slash)", + givenPrefix: "/assets", + givenFs: os.DirFS("_fixture"), + whenURL: "/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Sub-directory with index.html", + givenPrefix: "/", + givenFs: os.DirFS("_fixture"), + whenURL: "/folder/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "do not allow directory traversal (backslash - windows separator)", + givenPrefix: "/", + givenFs: os.DirFS("_fixture/"), + whenURL: `/..\\middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "do not allow directory traversal (slash - unix separator)", + givenPrefix: "/", + givenFs: os.DirFS("_fixture/"), + whenURL: `/../middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.StaticFS(tc.givenPrefix, tc.givenFs) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectStatus, rec.Code) + body := rec.Body.String() + if tc.expectBodyStartsWith != "" { + assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) + } else { + assert.Equal(t, "", body) + } + + if tc.expectHeaderLocation != "" { + assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) + } else { + _, ok := rec.Result().Header["Location"] + assert.False(t, ok) + } + }) + } +} + +func TestEcho_FileFS(t *testing.T) { + var testCases = []struct { + name string + whenPath string + whenFile string + whenFS fs.FS + givenURL string + expectCode int + expectStartsWith []byte + }{ + { + name: "ok", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/walle", + expectCode: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, requesting invalid path", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/walle.png", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + { + name: "nok, serving not existent file from filesystem", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "not-existent.png", + givenURL: "/walle", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) + + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestEcho_StaticPanic(t *testing.T) { + var testCases = []struct { + name string + givenRoot string + expectError string + }{ + { + name: "panics for ../", + givenRoot: "../assets", + expectError: "invalid root given to echo.Static, err sub ../assets: invalid name", + }, + { + name: "panics for /", + givenRoot: "/assets", + expectError: "invalid root given to echo.Static, err sub /assets: invalid name", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Filesystem = os.DirFS("./") + + assert.PanicsWithError(t, tc.expectError, func() { + e.Static("/assets", tc.givenRoot) + }) + }) + } +} diff --git a/echo_test.go b/echo_test.go index 13a51b6cc..f175d765b 100644 --- a/echo_test.go +++ b/echo_test.go @@ -211,7 +211,6 @@ func TestEchoStatic(t *testing.T) { } func TestEchoStaticRedirectIndex(t *testing.T) { - assert := assert.New(t) e := New() // HandlerFunc @@ -220,23 +219,25 @@ func TestEchoStaticRedirectIndex(t *testing.T) { errCh := make(chan error) go func() { - errCh <- e.Start("127.0.0.1:1323") + errCh <- e.Start(":0") }() - time.Sleep(200 * time.Millisecond) + err := waitForServerStart(e, errCh, false) + assert.NoError(t, err) - if resp, err := http.Get("http://127.0.0.1:1323/static"); err == nil { + addr := e.ListenerAddr().String() + if resp, err := http.Get("http://" + addr + "/static"); err == nil { // http.Get follows redirects by default defer resp.Body.Close() - assert.Equal(http.StatusOK, resp.StatusCode) + assert.Equal(t, http.StatusOK, resp.StatusCode) if body, err := ioutil.ReadAll(resp.Body); err == nil { - assert.Equal(true, strings.HasPrefix(string(body), "")) + assert.Equal(t, true, strings.HasPrefix(string(body), "")) } else { - assert.Fail(err.Error()) + assert.Fail(t, err.Error()) } } else { - assert.Fail(err.Error()) + assert.NoError(t, err) } if err := e.Close(); err != nil { diff --git a/go.mod b/go.mod index e5fa0d55f..80087d6f9 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/stretchr/testify v1.7.0 github.com/valyala/fasttemplate v1.2.1 golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 - golang.org/x/net v0.0.0-20210913180222-943fd674d43e + golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f golang.org/x/text v0.3.7 // indirect golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 ) diff --git a/go.sum b/go.sum index 8a1ec2f9f..f66734243 100644 --- a/go.sum +++ b/go.sum @@ -20,18 +20,13 @@ github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52 github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210913180222-943fd674d43e h1:+b/22bPvDYt4NPDcy4xAGCmON713ONAWFeY3Z7I3tR8= -golang.org/x/net v0.0.0-20210913180222-943fd674d43e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211104170005-ce137452f963 h1:8gJUadZl+kWvZBqG/LautX0X6qe5qTC2VI/3V3NBRAY= -golang.org/x/net v0.0.0-20211104170005-ce137452f963/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f h1:OfiFi4JbukWwe3lzw+xunroH1mnC1e2Gy5cxNJApiSY= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -42,8 +37,6 @@ golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac h1:7zkz7BUtwNFFqcowJ+RIgu2MaV/MapERkDIy+mwPyjs= -golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/group.go b/group.go index 426bef9eb..bba470ce8 100644 --- a/group.go +++ b/group.go @@ -102,11 +102,6 @@ func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) { return } -// Static implements `Echo#Static()` for sub-routes within the Group. -func (g *Group) Static(prefix, root string) { - g.static(prefix, root, g.GET) -} - // File implements `Echo#File()` for sub-routes within the Group. func (g *Group) File(path, file string) { g.file(path, file, g.GET) diff --git a/group_fs.go b/group_fs.go new file mode 100644 index 000000000..0a1ce4a94 --- /dev/null +++ b/group_fs.go @@ -0,0 +1,9 @@ +//go:build !go1.16 +// +build !go1.16 + +package echo + +// Static implements `Echo#Static()` for sub-routes within the Group. +func (g *Group) Static(prefix, root string) { + g.static(prefix, root, g.GET) +} diff --git a/group_fs_go1.16.go b/group_fs_go1.16.go new file mode 100644 index 000000000..e276c80ca --- /dev/null +++ b/group_fs_go1.16.go @@ -0,0 +1,34 @@ +//go:build go1.16 +// +build go1.16 + +package echo + +import ( + "fmt" + "io/fs" + "net/http" +) + +// Static implements `Echo#Static()` for sub-routes within the Group. +func (g *Group) Static(pathPrefix, root string) { + subFs, err := subFS(g.echo.Filesystem, root) + if err != nil { + // happens when `root` contains invalid path according to `fs.ValidPath` rules and we are unable to create FS + panic(fmt.Errorf("invalid root given to group.Static, err %w", err)) + } + g.StaticFS(pathPrefix, subFs) +} + +// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. +func (g *Group) StaticFS(pathPrefix string, fileSystem fs.FS) { + g.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(fileSystem, false), + ) +} + +// FileFS implements `Echo#FileFS()` for sub-routes within the Group. +func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { + return g.GET(path, StaticFileHandler(file, filesystem), m...) +} diff --git a/group_fs_go1.16_test.go b/group_fs_go1.16_test.go new file mode 100644 index 000000000..8fabfa1ec --- /dev/null +++ b/group_fs_go1.16_test.go @@ -0,0 +1,106 @@ +//go:build go1.16 +// +build go1.16 + +package echo + +import ( + "github.com/stretchr/testify/assert" + "io/fs" + "net/http" + "net/http/httptest" + "os" + "testing" +) + +func TestGroup_FileFS(t *testing.T) { + var testCases = []struct { + name string + whenPath string + whenFile string + whenFS fs.FS + givenURL string + expectCode int + expectStartsWith []byte + }{ + { + name: "ok", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/assets/walle", + expectCode: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, requesting invalid path", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/assets/walle.png", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + { + name: "nok, serving not existent file from filesystem", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "not-existent.png", + givenURL: "/assets/walle", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + g := e.Group("/assets") + g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) + + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestGroup_StaticPanic(t *testing.T) { + var testCases = []struct { + name string + givenRoot string + expectError string + }{ + { + name: "panics for ../", + givenRoot: "../images", + expectError: "invalid root given to group.Static, err sub ../images: invalid name", + }, + { + name: "panics for /", + givenRoot: "/images", + expectError: "invalid root given to group.Static, err sub /images: invalid name", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Filesystem = os.DirFS("./") + + g := e.Group("/assets") + + assert.PanicsWithError(t, tc.expectError, func() { + g.Static("/images", tc.givenRoot) + }) + }) + } +} From af2a49dbbcf35aea0a5d4a6bb23bde5e98773d48 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Tue, 11 Jan 2022 20:43:02 +0200 Subject: [PATCH 196/446] Fix fs.Sub problems on Windows --- context_fs_go1.16.go | 2 +- echo_fs_go1.16.go | 5 +++-- echo_fs_go1.16_test.go | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/context_fs_go1.16.go b/context_fs_go1.16.go index eeffef507..006848f78 100644 --- a/context_fs_go1.16.go +++ b/context_fs_go1.16.go @@ -28,7 +28,7 @@ func fsFile(c Context, file string, filesystem fs.FS) error { fi, _ := f.Stat() if fi.IsDir() { - file = filepath.Join(file, indexPage) + file = filepath.ToSlash(filepath.Join(file, indexPage)) f, err = filesystem.Open(file) if err != nil { return ErrNotFound diff --git a/echo_fs_go1.16.go b/echo_fs_go1.16.go index b4258e367..1b412bc41 100644 --- a/echo_fs_go1.16.go +++ b/echo_fs_go1.16.go @@ -62,7 +62,7 @@ func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) Handle } // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid - name := filepath.Clean(strings.TrimPrefix(p, "/")) + name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/"))) fi, err := fs.Stat(fileSystem, name) if err != nil { return ErrNotFound @@ -113,6 +113,7 @@ func (fs defaultFS) Open(name string) (fs.File, error) { } func subFS(currentFs fs.FS, root string) (fs.FS, error) { + root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows if dFS, ok := currentFs.(*defaultFS); ok { // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS to // allow cases when root is given as `../somepath` which is not valid for fs.FS @@ -122,5 +123,5 @@ func subFS(currentFs fs.FS, root string) (fs.FS, error) { fs: os.DirFS(root), }, nil } - return fs.Sub(currentFs, filepath.Clean(root)) + return fs.Sub(currentFs, root) } diff --git a/echo_fs_go1.16_test.go b/echo_fs_go1.16_test.go index 4a95b105c..715c69ef3 100644 --- a/echo_fs_go1.16_test.go +++ b/echo_fs_go1.16_test.go @@ -244,7 +244,7 @@ func TestEcho_StaticPanic(t *testing.T) { e.Filesystem = os.DirFS("./") assert.PanicsWithError(t, tc.expectError, func() { - e.Static("/assets", tc.givenRoot) + e.Static("../assets", tc.givenRoot) }) }) } From b830c4ef959b1338227ab28531fedf34cec348ff Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 15 Jan 2022 11:04:07 +0200 Subject: [PATCH 197/446] Improve filesystem support. --- context_fs_go1.16.go | 7 ++++++- context_fs_go1.16_test.go | 18 +++++++++--------- echo_fs_go1.16.go | 34 ++++++++++++++++++++++++++-------- echo_fs_go1.16_test.go | 20 +++++++++++++++++--- group_fs_go1.16.go | 17 ++++++++--------- group_fs_go1.16_test.go | 4 ++-- 6 files changed, 68 insertions(+), 32 deletions(-) diff --git a/context_fs_go1.16.go b/context_fs_go1.16.go index 006848f78..77878c6aa 100644 --- a/context_fs_go1.16.go +++ b/context_fs_go1.16.go @@ -15,6 +15,11 @@ func (c *context) File(file string) error { return fsFile(c, file, c.echo.Filesystem) } +// FileFS serves file from given file system. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. func (c *context) FileFS(file string, filesystem fs.FS) error { return fsFile(c, file, filesystem) } @@ -28,7 +33,7 @@ func fsFile(c Context, file string, filesystem fs.FS) error { fi, _ := f.Stat() if fi.IsDir() { - file = filepath.ToSlash(filepath.Join(file, indexPage)) + file = filepath.Join(file, indexPage) f, err = filesystem.Open(file) if err != nil { return ErrNotFound diff --git a/context_fs_go1.16_test.go b/context_fs_go1.16_test.go index f209e8a06..027d1c483 100644 --- a/context_fs_go1.16_test.go +++ b/context_fs_go1.16_test.go @@ -4,7 +4,7 @@ package echo import ( - testify "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" "io/fs" "net/http" "net/http/httptest" @@ -62,18 +62,18 @@ func TestContext_File(t *testing.T) { err := handler(c) - testify.Equal(t, tc.expectStatus, rec.Code) + assert.Equal(t, tc.expectStatus, rec.Code) if tc.expectError != "" { - testify.EqualError(t, err, tc.expectError) + assert.EqualError(t, err, tc.expectError) } else { - testify.NoError(t, err) + assert.NoError(t, err) } body := rec.Body.Bytes() if len(body) > len(tc.expectStartsWith) { body = body[:len(tc.expectStartsWith)] } - testify.Equal(t, tc.expectStartsWith, body) + assert.Equal(t, tc.expectStartsWith, body) }) } } @@ -118,18 +118,18 @@ func TestContext_FileFS(t *testing.T) { err := handler(c) - testify.Equal(t, tc.expectStatus, rec.Code) + assert.Equal(t, tc.expectStatus, rec.Code) if tc.expectError != "" { - testify.EqualError(t, err, tc.expectError) + assert.EqualError(t, err, tc.expectError) } else { - testify.NoError(t, err) + assert.NoError(t, err) } body := rec.Body.Bytes() if len(body) > len(tc.expectStartsWith) { body = body[:len(tc.expectStartsWith)] } - testify.Equal(t, tc.expectStartsWith, body) + assert.Equal(t, tc.expectStartsWith, body) }) } } diff --git a/echo_fs_go1.16.go b/echo_fs_go1.16.go index 1b412bc41..435459de2 100644 --- a/echo_fs_go1.16.go +++ b/echo_fs_go1.16.go @@ -16,6 +16,10 @@ import ( type filesystem struct { // Filesystem is file system used by Static and File handlers to access files. // Defaults to os.DirFS(".") + // + // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary + // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths + // including `assets/images` as their prefix. Filesystem fs.FS } @@ -26,12 +30,8 @@ func createFilesystem() filesystem { } // Static registers a new route with path prefix to serve static files from the provided root directory. -func (e *Echo) Static(pathPrefix, root string) *Route { - subFs, err := subFS(e.Filesystem, root) - if err != nil { - // happens when `root` contains invalid path according to `fs.ValidPath` rules and we are unable to create FS - panic(fmt.Errorf("invalid root given to echo.Static, err %w", err)) - } +func (e *Echo) Static(pathPrefix, fsRoot string) *Route { + subFs := MustSubFS(e.Filesystem, fsRoot) return e.Add( http.MethodGet, pathPrefix+"*", @@ -40,11 +40,15 @@ func (e *Echo) Static(pathPrefix, root string) *Route { } // StaticFS registers a new route with path prefix to serve static files from the provided file system. -func (e *Echo) StaticFS(pathPrefix string, fileSystem fs.FS) *Route { +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route { return e.Add( http.MethodGet, pathPrefix+"*", - StaticDirectoryHandler(fileSystem, false), + StaticDirectoryHandler(filesystem, false), ) } @@ -125,3 +129,17 @@ func subFS(currentFs fs.FS, root string) (fs.FS, error) { } return fs.Sub(currentFs, root) } + +// MustSubFS creates sub FS from current filesystem or panic on failure. +// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. +// +// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with +// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to +// create sub fs which uses necessary prefix for directory path. +func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { + subFs, err := subFS(currentFs, fsRoot) + if err != nil { + panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err)) + } + return subFs +} diff --git a/echo_fs_go1.16_test.go b/echo_fs_go1.16_test.go index 715c69ef3..07e516555 100644 --- a/echo_fs_go1.16_test.go +++ b/echo_fs_go1.16_test.go @@ -18,6 +18,7 @@ func TestEcho_StaticFS(t *testing.T) { name string givenPrefix string givenFs fs.FS + givenFsRoot string whenURL string expectStatus int expectHeaderLocation string @@ -31,6 +32,14 @@ func TestEcho_StaticFS(t *testing.T) { expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), }, + { + name: "ok, from sub fs", + givenPrefix: "/images", + givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"), + whenURL: "/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, { name: "No file", givenPrefix: "/images", @@ -135,7 +144,12 @@ func TestEcho_StaticFS(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() - e.StaticFS(tc.givenPrefix, tc.givenFs) + + tmpFs := tc.givenFs + if tc.givenFsRoot != "" { + tmpFs = MustSubFS(tmpFs, tc.givenFsRoot) + } + e.StaticFS(tc.givenPrefix, tmpFs) req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) rec := httptest.NewRecorder() @@ -229,12 +243,12 @@ func TestEcho_StaticPanic(t *testing.T) { { name: "panics for ../", givenRoot: "../assets", - expectError: "invalid root given to echo.Static, err sub ../assets: invalid name", + expectError: "can not create sub FS, invalid root given, err: sub ../assets: invalid name", }, { name: "panics for /", givenRoot: "/assets", - expectError: "invalid root given to echo.Static, err sub /assets: invalid name", + expectError: "can not create sub FS, invalid root given, err: sub /assets: invalid name", }, } diff --git a/group_fs_go1.16.go b/group_fs_go1.16.go index e276c80ca..2ba52b5e2 100644 --- a/group_fs_go1.16.go +++ b/group_fs_go1.16.go @@ -4,27 +4,26 @@ package echo import ( - "fmt" "io/fs" "net/http" ) // Static implements `Echo#Static()` for sub-routes within the Group. -func (g *Group) Static(pathPrefix, root string) { - subFs, err := subFS(g.echo.Filesystem, root) - if err != nil { - // happens when `root` contains invalid path according to `fs.ValidPath` rules and we are unable to create FS - panic(fmt.Errorf("invalid root given to group.Static, err %w", err)) - } +func (g *Group) Static(pathPrefix, fsRoot string) { + subFs := MustSubFS(g.echo.Filesystem, fsRoot) g.StaticFS(pathPrefix, subFs) } // StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. -func (g *Group) StaticFS(pathPrefix string, fileSystem fs.FS) { +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) { g.Add( http.MethodGet, pathPrefix+"*", - StaticDirectoryHandler(fileSystem, false), + StaticDirectoryHandler(filesystem, false), ) } diff --git a/group_fs_go1.16_test.go b/group_fs_go1.16_test.go index 8fabfa1ec..d0caa33db 100644 --- a/group_fs_go1.16_test.go +++ b/group_fs_go1.16_test.go @@ -82,12 +82,12 @@ func TestGroup_StaticPanic(t *testing.T) { { name: "panics for ../", givenRoot: "../images", - expectError: "invalid root given to group.Static, err sub ../images: invalid name", + expectError: "can not create sub FS, invalid root given, err: sub ../images: invalid name", }, { name: "panics for /", givenRoot: "/images", - expectError: "invalid root given to group.Static, err sub /images: invalid name", + expectError: "can not create sub FS, invalid root given, err: sub /images: invalid name", }, } From db5bace1c4ec720e63f43615dc471a0e9a4609cf Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 15 Jan 2022 11:11:04 +0200 Subject: [PATCH 198/446] fix Windows --- context_fs_go1.16.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/context_fs_go1.16.go b/context_fs_go1.16.go index 77878c6aa..889ba9d35 100644 --- a/context_fs_go1.16.go +++ b/context_fs_go1.16.go @@ -33,7 +33,7 @@ func fsFile(c Context, file string, filesystem fs.FS) error { fi, _ := f.Stat() if fi.IsDir() { - file = filepath.Join(file, indexPage) + file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows f, err = filesystem.Open(file) if err != nil { return ErrNotFound From feaa6ede6a4bbe4a0f25e8fed022868134ccbc6b Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 15 Jan 2022 11:26:04 +0200 Subject: [PATCH 199/446] improve comments --- context_fs_go1.16.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/context_fs_go1.16.go b/context_fs_go1.16.go index 889ba9d35..c1c724afd 100644 --- a/context_fs_go1.16.go +++ b/context_fs_go1.16.go @@ -33,7 +33,7 @@ func fsFile(c Context, file string, filesystem fs.FS) error { fi, _ := f.Stat() if fi.IsDir() { - file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows + file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect. f, err = filesystem.Open(file) if err != nil { return ErrNotFound From 9e9924d763c1af96f8da799da3d0781866e3190b Mon Sep 17 00:00:00 2001 From: Eng Zer Jun Date: Tue, 25 Jan 2022 00:09:49 +0800 Subject: [PATCH 200/446] build: upgrade `go` directive in `go.mod` to 1.17 (#2049) This commit enables support for module graph pruning and lazy module loading for projects that are at Go 1.17 or higher. Reference: https://go.dev/ref/mod#go-mod-file-go Reference: https://go.dev/ref/mod#graph-pruning Reference: https://go.dev/ref/mod#lazy-loading Signed-off-by: Eng Zer Jun --- go.mod | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 80087d6f9..4de2bdde1 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/labstack/echo/v4 -go 1.15 +go 1.17 require ( github.com/golang-jwt/jwt v3.2.2+incompatible @@ -9,6 +9,16 @@ require ( github.com/valyala/fasttemplate v1.2.1 golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f - golang.org/x/text v0.3.7 // indirect golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 ) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/mattn/go-colorable v0.1.11 // indirect + github.com/mattn/go-isatty v0.0.14 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + golang.org/x/sys v0.0.0-20211103235746-7861aae1554b // indirect + golang.org/x/text v0.3.7 // indirect + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect +) From 4a1ccdfdc520eb90573a97a7d04fd9fc300c1629 Mon Sep 17 00:00:00 2001 From: Martti T Date: Mon, 24 Jan 2022 22:03:45 +0200 Subject: [PATCH 201/446] JWT, KeyAuth, CSRF multivalue extractors (#2060) * CSRF, JWT, KeyAuth middleware support for multivalue value extractors * Add flag to JWT and KeyAuth middleware to allow continuing execution `next(c)` when error handler decides to swallow the error (returns nil). --- echo.go | 4 +- middleware/csrf.go | 110 +++---- middleware/csrf_test.go | 276 +++++++++++++--- middleware/extractor.go | 184 +++++++++++ middleware/extractor_test.go | 587 +++++++++++++++++++++++++++++++++++ middleware/jwt.go | 188 ++++------- middleware/jwt_test.go | 308 +++++++++++++----- middleware/key_auth.go | 173 +++++------ middleware/key_auth_test.go | 123 +++++++- middleware/middleware.go | 4 +- 10 files changed, 1562 insertions(+), 395 deletions(-) create mode 100644 middleware/extractor.go create mode 100644 middleware/extractor_test.go diff --git a/echo.go b/echo.go index 56255c6cb..2e63cc6b1 100644 --- a/echo.go +++ b/echo.go @@ -111,10 +111,10 @@ type ( } // MiddlewareFunc defines a function to process middleware. - MiddlewareFunc func(HandlerFunc) HandlerFunc + MiddlewareFunc func(next HandlerFunc) HandlerFunc // HandlerFunc defines a function to serve HTTP requests. - HandlerFunc func(Context) error + HandlerFunc func(c Context) error // HTTPErrorHandler is a centralized HTTP error handler. HTTPErrorHandler func(error, Context) diff --git a/middleware/csrf.go b/middleware/csrf.go index 7804997d4..61299f5ca 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -2,9 +2,7 @@ package middleware import ( "crypto/subtle" - "errors" "net/http" - "strings" "time" "github.com/labstack/echo/v4" @@ -21,13 +19,15 @@ type ( TokenLength uint8 `yaml:"token_length"` // Optional. Default value 32. - // TokenLookup is a string in the form of ":" that is used + // TokenLookup is a string in the form of ":" or ":,:" that is used // to extract token from the request. // Optional. Default value "header:X-CSRF-Token". // Possible values: - // - "header:" - // - "form:" + // - "header:" or "header::" // - "query:" + // - "form:" + // Multiple sources example: + // - "header:X-CSRF-Token,query:csrf" TokenLookup string `yaml:"token_lookup"` // Context key to store generated CSRF token into context. @@ -62,12 +62,11 @@ type ( // Optional. Default value SameSiteDefaultMode. CookieSameSite http.SameSite `yaml:"cookie_same_site"` } - - // csrfTokenExtractor defines a function that takes `echo.Context` and returns - // either a token or an error. - csrfTokenExtractor func(echo.Context) (string, error) ) +// ErrCSRFInvalid is returned when CSRF check fails +var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") + var ( // DefaultCSRFConfig is the default CSRF middleware config. DefaultCSRFConfig = CSRFConfig{ @@ -114,14 +113,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { config.CookieSecure = true } - // Initialize - parts := strings.Split(config.TokenLookup, ":") - extractor := csrfTokenFromHeader(parts[1]) - switch parts[0] { - case "form": - extractor = csrfTokenFromForm(parts[1]) - case "query": - extractor = csrfTokenFromQuery(parts[1]) + extractors, err := createExtractors(config.TokenLookup, "") + if err != nil { + panic(err) } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -130,28 +124,50 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return next(c) } - req := c.Request() - k, err := c.Cookie(config.CookieName) token := "" - - // Generate token - if err != nil { - token = random.String(config.TokenLength) + if k, err := c.Cookie(config.CookieName); err != nil { + token = random.String(config.TokenLength) // Generate token } else { - // Reuse token - token = k.Value + token = k.Value // Reuse token } - switch req.Method { + switch c.Request().Method { case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: default: // Validate token only for requests which are not defined as 'safe' by RFC7231 - clientToken, err := extractor(c) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + var lastExtractorErr error + var lastTokenErr error + outer: + for _, extractor := range extractors { + clientTokens, err := extractor(c) + if err != nil { + lastExtractorErr = err + continue + } + + for _, clientToken := range clientTokens { + if validateCSRFToken(token, clientToken) { + lastTokenErr = nil + lastExtractorErr = nil + break outer + } + lastTokenErr = ErrCSRFInvalid + } } - if !validateCSRFToken(token, clientToken) { - return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") + if lastTokenErr != nil { + return lastTokenErr + } else if lastExtractorErr != nil { + // ugly part to preserve backwards compatible errors. someone could rely on them + if lastExtractorErr == errQueryExtractorValueMissing { + lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string") + } else if lastExtractorErr == errFormExtractorValueMissing { + lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter") + } else if lastExtractorErr == errHeaderExtractorValueMissing { + lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header") + } else { + lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error()) + } + return lastExtractorErr } } @@ -184,38 +200,6 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { } } -// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the -// provided request header. -func csrfTokenFromHeader(header string) csrfTokenExtractor { - return func(c echo.Context) (string, error) { - return c.Request().Header.Get(header), nil - } -} - -// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the -// provided form parameter. -func csrfTokenFromForm(param string) csrfTokenExtractor { - return func(c echo.Context) (string, error) { - token := c.FormValue(param) - if token == "" { - return "", errors.New("missing csrf token in the form parameter") - } - return token, nil - } -} - -// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the -// provided query parameter. -func csrfTokenFromQuery(param string) csrfTokenExtractor { - return func(c echo.Context) (string, error) { - token := c.QueryParam(param) - if token == "" { - return "", errors.New("missing csrf token in the query string") - } - return token, nil - } -} - func validateCSRFToken(token, clientToken string) bool { return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1 } diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index af1d26394..9aff82a98 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -1,7 +1,6 @@ package middleware import ( - "fmt" "net/http" "net/http/httptest" "net/url" @@ -13,14 +12,205 @@ import ( "github.com/stretchr/testify/assert" ) +func TestCSRF_tokenExtractors(t *testing.T) { + var testCases = []struct { + name string + whenTokenLookup string + whenCookieName string + givenCSRFCookie string + givenMethod string + givenQueryTokens map[string][]string + givenFormTokens map[string][]string + givenHeaderTokens map[string][]string + expectError string + }{ + { + name: "ok, multiple token lookups sources, succeeds on last one", + whenTokenLookup: "header:X-CSRF-Token,form:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenHeaderTokens: map[string][]string{ + echo.HeaderXCSRFToken: {"invalid_token"}, + }, + givenFormTokens: map[string][]string{ + "csrf": {"token"}, + }, + }, + { + name: "ok, token from POST form", + whenTokenLookup: "form:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenFormTokens: map[string][]string{ + "csrf": {"token"}, + }, + }, + { + name: "ok, token from POST form, second token passes", + whenTokenLookup: "form:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenFormTokens: map[string][]string{ + "csrf": {"invalid", "token"}, + }, + }, + { + name: "nok, invalid token from POST form", + whenTokenLookup: "form:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenFormTokens: map[string][]string{ + "csrf": {"invalid_token"}, + }, + expectError: "code=403, message=invalid csrf token", + }, + { + name: "nok, missing token from POST form", + whenTokenLookup: "form:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenFormTokens: map[string][]string{}, + expectError: "code=400, message=missing csrf token in the form parameter", + }, + { + name: "ok, token from POST header", + whenTokenLookup: "", // will use defaults + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenHeaderTokens: map[string][]string{ + echo.HeaderXCSRFToken: {"token"}, + }, + }, + { + name: "ok, token from POST header, second token passes", + whenTokenLookup: "header:" + echo.HeaderXCSRFToken, + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenHeaderTokens: map[string][]string{ + echo.HeaderXCSRFToken: {"invalid", "token"}, + }, + }, + { + name: "nok, invalid token from POST header", + whenTokenLookup: "header:" + echo.HeaderXCSRFToken, + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenHeaderTokens: map[string][]string{ + echo.HeaderXCSRFToken: {"invalid_token"}, + }, + expectError: "code=403, message=invalid csrf token", + }, + { + name: "nok, missing token from POST header", + whenTokenLookup: "header:" + echo.HeaderXCSRFToken, + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenHeaderTokens: map[string][]string{}, + expectError: "code=400, message=missing csrf token in request header", + }, + { + name: "ok, token from PUT query param", + whenTokenLookup: "query:csrf-param", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{ + "csrf-param": {"token"}, + }, + }, + { + name: "ok, token from PUT query form, second token passes", + whenTokenLookup: "query:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{ + "csrf": {"invalid", "token"}, + }, + }, + { + name: "nok, invalid token from PUT query form", + whenTokenLookup: "query:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{ + "csrf": {"invalid_token"}, + }, + expectError: "code=403, message=invalid csrf token", + }, + { + name: "nok, missing token from PUT query form", + whenTokenLookup: "query:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{}, + expectError: "code=400, message=missing csrf token in the query string", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + q := make(url.Values) + for queryParam, values := range tc.givenQueryTokens { + for _, v := range values { + q.Add(queryParam, v) + } + } + + f := make(url.Values) + for formKey, values := range tc.givenFormTokens { + for _, v := range values { + f.Add(formKey, v) + } + } + + var req *http.Request + switch tc.givenMethod { + case http.MethodGet: + req = httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil) + case http.MethodPost, http.MethodPut: + req = httptest.NewRequest(http.MethodPost, "/?"+q.Encode(), strings.NewReader(f.Encode())) + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + } + + for header, values := range tc.givenHeaderTokens { + for _, v := range values { + req.Header.Add(header, v) + } + } + + if tc.givenCSRFCookie != "" { + req.Header.Set(echo.HeaderCookie, "_csrf="+tc.givenCSRFCookie) + } + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + csrf := CSRFWithConfig(CSRFConfig{ + TokenLookup: tc.whenTokenLookup, + CookieName: tc.whenCookieName, + }) + + h := csrf(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + err := h(c) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + func TestCSRF(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - csrf := CSRFWithConfig(CSRFConfig{ - TokenLength: 16, - }) + csrf := CSRF() h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -43,7 +233,7 @@ func TestCSRF(t *testing.T) { assert.Error(t, h(c)) // Valid CSRF token - token := random.String(16) + token := random.String(32) req.Header.Set(echo.HeaderCookie, "_csrf="+token) req.Header.Set(echo.HeaderXCSRFToken, token) if assert.NoError(t, h(c)) { @@ -51,38 +241,6 @@ func TestCSRF(t *testing.T) { } } -func TestCSRFTokenFromForm(t *testing.T) { - f := make(url.Values) - f.Set("csrf", "token") - e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) - req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) - c := e.NewContext(req, nil) - token, err := csrfTokenFromForm("csrf")(c) - if assert.NoError(t, err) { - assert.Equal(t, "token", token) - } - _, err = csrfTokenFromForm("invalid")(c) - assert.Error(t, err) -} - -func TestCSRFTokenFromQuery(t *testing.T) { - q := make(url.Values) - q.Set("csrf", "token") - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) - req.URL.RawQuery = q.Encode() - c := e.NewContext(req, nil) - token, err := csrfTokenFromQuery("csrf")(c) - if assert.NoError(t, err) { - assert.Equal(t, "token", token) - } - _, err = csrfTokenFromQuery("invalid")(c) - assert.Error(t, err) - csrfTokenFromQuery("csrf") -} - func TestCSRFSetSameSiteMode(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -135,7 +293,6 @@ func TestCSRFWithSameSiteDefaultMode(t *testing.T) { r := h(c) assert.NoError(t, r) - fmt.Println(rec.Header()["Set-Cookie"]) assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"]) } @@ -158,3 +315,46 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) { assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"]) assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"]) } + +func TestCSRFConfig_skipper(t *testing.T) { + var testCases = []struct { + name string + whenSkip bool + expectCookies int + }{ + { + name: "do skip", + whenSkip: true, + expectCookies: 0, + }, + { + name: "do not skip", + whenSkip: false, + expectCookies: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + csrf := CSRFWithConfig(CSRFConfig{ + Skipper: func(c echo.Context) bool { + return tc.whenSkip + }, + }) + + h := csrf(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + r := h(c) + assert.NoError(t, r) + cookie := rec.Header()["Set-Cookie"] + assert.Len(t, cookie, tc.expectCookies) + }) + } +} diff --git a/middleware/extractor.go b/middleware/extractor.go new file mode 100644 index 000000000..a57ed4e13 --- /dev/null +++ b/middleware/extractor.go @@ -0,0 +1,184 @@ +package middleware + +import ( + "errors" + "fmt" + "github.com/labstack/echo/v4" + "net/textproto" + "strings" +) + +const ( + // extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion + // attack vector + extractorLimit = 20 +) + +var errHeaderExtractorValueMissing = errors.New("missing value in request header") +var errHeaderExtractorValueInvalid = errors.New("invalid value in request header") +var errQueryExtractorValueMissing = errors.New("missing value in the query string") +var errParamExtractorValueMissing = errors.New("missing value in path params") +var errCookieExtractorValueMissing = errors.New("missing value in cookies") +var errFormExtractorValueMissing = errors.New("missing value in the form") + +// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context. +type ValuesExtractor func(c echo.Context) ([]string, error) + +func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, error) { + if lookups == "" { + return nil, nil + } + sources := strings.Split(lookups, ",") + var extractors = make([]ValuesExtractor, 0) + for _, source := range sources { + parts := strings.Split(source, ":") + if len(parts) < 2 { + return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source) + } + + switch parts[0] { + case "query": + extractors = append(extractors, valuesFromQuery(parts[1])) + case "param": + extractors = append(extractors, valuesFromParam(parts[1])) + case "cookie": + extractors = append(extractors, valuesFromCookie(parts[1])) + case "form": + extractors = append(extractors, valuesFromForm(parts[1])) + case "header": + prefix := "" + if len(parts) > 2 { + prefix = parts[2] + } else if authScheme != "" && parts[1] == echo.HeaderAuthorization { + // backwards compatibility for JWT and KeyAuth: + // * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer " etc + // * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that + // behaviour for default values and Authorization header. + prefix = authScheme + if !strings.HasSuffix(prefix, " ") { + prefix += " " + } + } + extractors = append(extractors, valuesFromHeader(parts[1], prefix)) + } + } + return extractors, nil +} + +// valuesFromHeader returns a functions that extracts values from the request header. +// valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static +// prefix like `Authorization: ` where part that we want to remove is ` ` +// note the space at the end. In case of basic authentication `Authorization: Basic ` prefix we want to remove +// is `Basic `. In case of JWT tokens `Authorization: Bearer ` prefix is `Bearer `. +// If prefix is left empty the whole value is returned. +func valuesFromHeader(header string, valuePrefix string) ValuesExtractor { + prefixLen := len(valuePrefix) + // standard library parses http.Request header keys in canonical form but we may provide something else so fix this + header = textproto.CanonicalMIMEHeaderKey(header) + return func(c echo.Context) ([]string, error) { + values := c.Request().Header.Values(header) + if len(values) == 0 { + return nil, errHeaderExtractorValueMissing + } + + result := make([]string, 0) + for i, value := range values { + if prefixLen == 0 { + result = append(result, value) + if i >= extractorLimit-1 { + break + } + continue + } + if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) { + result = append(result, value[prefixLen:]) + if i >= extractorLimit-1 { + break + } + } + } + + if len(result) == 0 { + if prefixLen > 0 { + return nil, errHeaderExtractorValueInvalid + } + return nil, errHeaderExtractorValueMissing + } + return result, nil + } +} + +// valuesFromQuery returns a function that extracts values from the query string. +func valuesFromQuery(param string) ValuesExtractor { + return func(c echo.Context) ([]string, error) { + result := c.QueryParams()[param] + if len(result) == 0 { + return nil, errQueryExtractorValueMissing + } else if len(result) > extractorLimit-1 { + result = result[:extractorLimit] + } + return result, nil + } +} + +// valuesFromParam returns a function that extracts values from the url param string. +func valuesFromParam(param string) ValuesExtractor { + return func(c echo.Context) ([]string, error) { + result := make([]string, 0) + paramVales := c.ParamValues() + for i, p := range c.ParamNames() { + if param == p { + result = append(result, paramVales[i]) + if i >= extractorLimit-1 { + break + } + } + } + if len(result) == 0 { + return nil, errParamExtractorValueMissing + } + return result, nil + } +} + +// valuesFromCookie returns a function that extracts values from the named cookie. +func valuesFromCookie(name string) ValuesExtractor { + return func(c echo.Context) ([]string, error) { + cookies := c.Cookies() + if len(cookies) == 0 { + return nil, errCookieExtractorValueMissing + } + + result := make([]string, 0) + for i, cookie := range cookies { + if name == cookie.Name { + result = append(result, cookie.Value) + if i >= extractorLimit-1 { + break + } + } + } + if len(result) == 0 { + return nil, errCookieExtractorValueMissing + } + return result, nil + } +} + +// valuesFromForm returns a function that extracts values from the form field. +func valuesFromForm(name string) ValuesExtractor { + return func(c echo.Context) ([]string, error) { + if parseErr := c.Request().ParseForm(); parseErr != nil { + return nil, fmt.Errorf("valuesFromForm parse form failed: %w", parseErr) + } + values := c.Request().Form[name] + if len(values) == 0 { + return nil, errFormExtractorValueMissing + } + if len(values) > extractorLimit-1 { + values = values[:extractorLimit] + } + result := append([]string{}, values...) + return result, nil + } +} diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go new file mode 100644 index 000000000..ae4b30a8a --- /dev/null +++ b/middleware/extractor_test.go @@ -0,0 +1,587 @@ +package middleware + +import ( + "fmt" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +type pathParam struct { + name string + value string +} + +func setPathParams(c echo.Context, params []pathParam) { + names := make([]string, 0, len(params)) + values := make([]string, 0, len(params)) + for _, pp := range params { + names = append(names, pp.name) + values = append(values, pp.value) + } + c.SetParamNames(names...) + c.SetParamValues(values...) +} + +func TestCreateExtractors(t *testing.T) { + var testCases = []struct { + name string + givenRequest func() *http.Request + givenPathParams []pathParam + whenLoopups string + expectValues []string + expectCreateError string + expectError string + }{ + { + name: "ok, header", + givenRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, "Bearer token") + return req + }, + whenLoopups: "header:Authorization:Bearer ", + expectValues: []string{"token"}, + }, + { + name: "ok, form", + givenRequest: func() *http.Request { + f := make(url.Values) + f.Set("name", "Jon Snow") + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + return req + }, + whenLoopups: "form:name", + expectValues: []string{"Jon Snow"}, + }, + { + name: "ok, cookie", + givenRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderCookie, "_csrf=token") + return req + }, + whenLoopups: "cookie:_csrf", + expectValues: []string{"token"}, + }, + { + name: "ok, param", + givenPathParams: []pathParam{ + {name: "id", value: "123"}, + }, + whenLoopups: "param:id", + expectValues: []string{"123"}, + }, + { + name: "ok, query", + givenRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/?id=999", nil) + return req + }, + whenLoopups: "query:id", + expectValues: []string{"999"}, + }, + { + name: "nok, invalid lookup", + whenLoopups: "query", + expectCreateError: "extractor source for lookup could not be split into needed parts: query", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenRequest != nil { + req = tc.givenRequest() + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if tc.givenPathParams != nil { + setPathParams(c, tc.givenPathParams) + } + + extractors, err := createExtractors(tc.whenLoopups, "") + if tc.expectCreateError != "" { + assert.EqualError(t, err, tc.expectCreateError) + return + } + assert.NoError(t, err) + + for _, e := range extractors { + values, eErr := e(c) + assert.Equal(t, tc.expectValues, values) + if tc.expectError != "" { + assert.EqualError(t, eErr, tc.expectError) + return + } + assert.NoError(t, eErr) + } + }) + } +} + +func TestValuesFromHeader(t *testing.T) { + exampleRequest := func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") + } + + var testCases = []struct { + name string + givenRequest func(req *http.Request) + whenName string + whenValuePrefix string + expectValues []string + expectError string + }{ + { + name: "ok, single value", + givenRequest: exampleRequest, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "basic ", + expectValues: []string{"dXNlcjpwYXNzd29yZA=="}, + }, + { + name: "ok, single value, case insensitive", + givenRequest: exampleRequest, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "Basic ", + expectValues: []string{"dXNlcjpwYXNzd29yZA=="}, + }, + { + name: "ok, multiple value", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") + req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") + }, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "basic ", + expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"}, + }, + { + name: "ok, empty prefix", + givenRequest: exampleRequest, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "", + expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="}, + }, + { + name: "nok, no matching due different prefix", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") + req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") + }, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "Bearer ", + expectError: errHeaderExtractorValueInvalid.Error(), + }, + { + name: "nok, no matching due different prefix", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") + req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") + }, + whenName: echo.HeaderWWWAuthenticate, + whenValuePrefix: "", + expectError: errHeaderExtractorValueMissing.Error(), + }, + { + name: "nok, no headers", + givenRequest: nil, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "basic ", + expectError: errHeaderExtractorValueMissing.Error(), + }, + { + name: "ok, prefix, cut values over extractorLimit", + givenRequest: func(req *http.Request) { + for i := 1; i <= 25; i++ { + req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("basic %v", i)) + } + }, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "basic ", + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + { + name: "ok, cut values over extractorLimit", + givenRequest: func(req *http.Request) { + for i := 1; i <= 25; i++ { + req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("%v", i)) + } + }, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "", + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenRequest != nil { + tc.givenRequest(req) + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix) + + values, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValuesFromQuery(t *testing.T) { + var testCases = []struct { + name string + givenQueryPart string + whenName string + expectValues []string + expectError string + }{ + { + name: "ok, single value", + givenQueryPart: "?id=123&name=test", + whenName: "id", + expectValues: []string{"123"}, + }, + { + name: "ok, multiple value", + givenQueryPart: "?id=123&id=456&name=test", + whenName: "id", + expectValues: []string{"123", "456"}, + }, + { + name: "nok, missing value", + givenQueryPart: "?id=123&name=test", + whenName: "nope", + expectError: errQueryExtractorValueMissing.Error(), + }, + { + name: "ok, cut values over extractorLimit", + givenQueryPart: "?name=test" + + "&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" + + "&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" + + "&id=21&id=22&id=23&id=24&id=25", + whenName: "id", + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + extractor := valuesFromQuery(tc.whenName) + + values, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValuesFromParam(t *testing.T) { + examplePathParams := []pathParam{ + {name: "id", value: "123"}, + {name: "gid", value: "456"}, + {name: "gid", value: "789"}, + } + examplePathParams20 := make([]pathParam, 0) + for i := 1; i < 25; i++ { + examplePathParams20 = append(examplePathParams20, pathParam{name: "id", value: fmt.Sprintf("%v", i)}) + } + + var testCases = []struct { + name string + givenPathParams []pathParam + whenName string + expectValues []string + expectError string + }{ + { + name: "ok, single value", + givenPathParams: examplePathParams, + whenName: "id", + expectValues: []string{"123"}, + }, + { + name: "ok, multiple value", + givenPathParams: examplePathParams, + whenName: "gid", + expectValues: []string{"456", "789"}, + }, + { + name: "nok, no values", + givenPathParams: nil, + whenName: "nope", + expectValues: nil, + expectError: errParamExtractorValueMissing.Error(), + }, + { + name: "nok, no matching value", + givenPathParams: examplePathParams, + whenName: "nope", + expectValues: nil, + expectError: errParamExtractorValueMissing.Error(), + }, + { + name: "ok, cut values over extractorLimit", + givenPathParams: examplePathParams20, + whenName: "id", + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if tc.givenPathParams != nil { + setPathParams(c, tc.givenPathParams) + } + + extractor := valuesFromParam(tc.whenName) + + values, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValuesFromCookie(t *testing.T) { + exampleRequest := func(req *http.Request) { + req.Header.Set(echo.HeaderCookie, "_csrf=token") + } + + var testCases = []struct { + name string + givenRequest func(req *http.Request) + whenName string + expectValues []string + expectError string + }{ + { + name: "ok, single value", + givenRequest: exampleRequest, + whenName: "_csrf", + expectValues: []string{"token"}, + }, + { + name: "ok, multiple value", + givenRequest: func(req *http.Request) { + req.Header.Add(echo.HeaderCookie, "_csrf=token") + req.Header.Add(echo.HeaderCookie, "_csrf=token2") + }, + whenName: "_csrf", + expectValues: []string{"token", "token2"}, + }, + { + name: "nok, no matching cookie", + givenRequest: exampleRequest, + whenName: "xxx", + expectValues: nil, + expectError: errCookieExtractorValueMissing.Error(), + }, + { + name: "nok, no cookies at all", + givenRequest: nil, + whenName: "xxx", + expectValues: nil, + expectError: errCookieExtractorValueMissing.Error(), + }, + { + name: "ok, cut values over extractorLimit", + givenRequest: func(req *http.Request) { + for i := 1; i < 25; i++ { + req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i)) + } + }, + whenName: "_csrf", + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenRequest != nil { + tc.givenRequest(req) + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + extractor := valuesFromCookie(tc.whenName) + + values, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValuesFromForm(t *testing.T) { + examplePostFormRequest := func(mod func(v *url.Values)) *http.Request { + f := make(url.Values) + f.Set("name", "Jon Snow") + f.Set("emails[]", "jon@labstack.com") + if mod != nil { + mod(&f) + } + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + + return req + } + exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request { + f := make(url.Values) + f.Set("name", "Jon Snow") + f.Set("emails[]", "jon@labstack.com") + if mod != nil { + mod(&f) + } + + req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil) + return req + } + + var testCases = []struct { + name string + givenRequest *http.Request + whenName string + expectValues []string + expectError string + }{ + { + name: "ok, POST form, single value", + givenRequest: examplePostFormRequest(nil), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com"}, + }, + { + name: "ok, POST form, multiple value", + givenRequest: examplePostFormRequest(func(v *url.Values) { + v.Add("emails[]", "snow@labstack.com") + }), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, + }, + { + name: "ok, GET form, single value", + givenRequest: exampleGetFormRequest(nil), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com"}, + }, + { + name: "ok, GET form, multiple value", + givenRequest: examplePostFormRequest(func(v *url.Values) { + v.Add("emails[]", "snow@labstack.com") + }), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, + }, + { + name: "nok, POST form, value missing", + givenRequest: examplePostFormRequest(nil), + whenName: "nope", + expectError: errFormExtractorValueMissing.Error(), + }, + { + name: "nok, POST form, form parsing error", + givenRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Body = nil + return req + }(), + whenName: "name", + expectError: "valuesFromForm parse form failed: missing form body", + }, + { + name: "ok, cut values over extractorLimit", + givenRequest: examplePostFormRequest(func(v *url.Values) { + for i := 1; i < 25; i++ { + v.Add("id[]", fmt.Sprintf("%v", i)) + } + }), + whenName: "id[]", + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := tc.givenRequest + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + extractor := valuesFromForm(tc.whenName) + + values, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/middleware/jwt.go b/middleware/jwt.go index 43605e377..bec5167e2 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -1,3 +1,4 @@ +//go:build go1.15 // +build go1.15 package middleware @@ -5,12 +6,10 @@ package middleware import ( "errors" "fmt" - "net/http" - "reflect" - "strings" - "github.com/golang-jwt/jwt" "github.com/labstack/echo/v4" + "net/http" + "reflect" ) type ( @@ -22,7 +21,8 @@ type ( // BeforeFunc defines a function which is executed just before the middleware. BeforeFunc BeforeFunc - // SuccessHandler defines a function which is executed for a valid token. + // SuccessHandler defines a function which is executed for a valid token before middleware chain continues with next + // middleware or handler. SuccessHandler JWTSuccessHandler // ErrorHandler defines a function which is executed for an invalid token. @@ -32,6 +32,13 @@ type ( // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. ErrorHandlerWithContext JWTErrorHandlerWithContext + // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to + // ignore the error (by returning `nil`). + // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. + // In that case you can use ErrorHandlerWithContext to set a default public JWT token value in the request context + // and continue. Some logic down the remaining execution chain needs to check that (public) token value then. + ContinueOnIgnoredError bool + // Signing key to validate token. // This is one of the three options to provide a token validation key. // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. @@ -61,20 +68,25 @@ type ( // to extract token from the request. // Optional. Default value "header:Authorization". // Possible values: - // - "header:" + // - "header:" or "header::" + // `` is argument value to cut/trim prefix of the extracted value. This is useful if header + // value has static prefix like `Authorization: ` where part that we + // want to cut is ` ` note the space at the end. + // In case of JWT tokens `Authorization: Bearer ` prefix we cut is `Bearer `. + // If prefix is left empty the whole value is returned. // - "query:" // - "param:" // - "cookie:" // - "form:" - // Multiply sources example: - // - "header: Authorization,cookie: myowncookie" + // Multiple sources example: + // - "header:Authorization,cookie:myowncookie" TokenLookup string // TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context. // This is one of the two options to provide a token extractor. // The order of precedence is user-defined TokenLookupFuncs, and TokenLookup. // You can also provide both if you want. - TokenLookupFuncs []TokenLookupFunc + TokenLookupFuncs []ValuesExtractor // AuthScheme to be used in the Authorization header. // Optional. Default value "Bearer". @@ -100,16 +112,13 @@ type ( } // JWTSuccessHandler defines a function which is executed for a valid token. - JWTSuccessHandler func(echo.Context) + JWTSuccessHandler func(c echo.Context) // JWTErrorHandler defines a function which is executed for an invalid token. - JWTErrorHandler func(error) error + JWTErrorHandler func(err error) error // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. - JWTErrorHandlerWithContext func(error, echo.Context) error - - // TokenLookupFunc defines a function for extracting JWT token from the given context. - TokenLookupFunc func(echo.Context) (string, error) + JWTErrorHandlerWithContext func(err error, c echo.Context) error ) // Algorithms @@ -183,25 +192,12 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { config.ParseTokenFunc = config.defaultParseToken } - // Initialize - // Split sources - sources := strings.Split(config.TokenLookup, ",") - var extractors = config.TokenLookupFuncs - for _, source := range sources { - parts := strings.Split(source, ":") - - switch parts[0] { - case "query": - extractors = append(extractors, jwtFromQuery(parts[1])) - case "param": - extractors = append(extractors, jwtFromParam(parts[1])) - case "cookie": - extractors = append(extractors, jwtFromCookie(parts[1])) - case "form": - extractors = append(extractors, jwtFromForm(parts[1])) - case "header": - extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme)) - } + extractors, err := createExtractors(config.TokenLookup, config.AuthScheme) + if err != nil { + panic(err) + } + if len(config.TokenLookupFuncs) > 0 { + extractors = append(config.TokenLookupFuncs, extractors...) } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -213,48 +209,54 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.BeforeFunc != nil { config.BeforeFunc(c) } - var auth string - var err error + + var lastExtractorErr error + var lastTokenErr error for _, extractor := range extractors { - // Extract token from extractor, if it's not fail break the loop and - // set auth - auth, err = extractor(c) - if err == nil { - break + auths, err := extractor(c) + if err != nil { + lastExtractorErr = ErrJWTMissing // backwards compatibility: all extraction errors are same (unlike KeyAuth) + continue } - } - // If none of extractor has a token, handle error - if err != nil { - if config.ErrorHandler != nil { - return config.ErrorHandler(err) + for _, auth := range auths { + token, err := config.ParseTokenFunc(auth, c) + if err != nil { + lastTokenErr = err + continue + } + // Store user information from token into context. + c.Set(config.ContextKey, token) + if config.SuccessHandler != nil { + config.SuccessHandler(c) + } + return next(c) } - - if config.ErrorHandlerWithContext != nil { - return config.ErrorHandlerWithContext(err, c) - } - return err } - - token, err := config.ParseTokenFunc(auth, c) - if err == nil { - // Store user information from token into context. - c.Set(config.ContextKey, token) - if config.SuccessHandler != nil { - config.SuccessHandler(c) - } - return next(c) + // we are here only when we did not successfully extract or parse any of the tokens + err := lastTokenErr + if err == nil { // prioritize token errors over extracting errors + err = lastExtractorErr } if config.ErrorHandler != nil { return config.ErrorHandler(err) } if config.ErrorHandlerWithContext != nil { - return config.ErrorHandlerWithContext(err, c) + tmpErr := config.ErrorHandlerWithContext(err, c) + if config.ContinueOnIgnoredError && tmpErr == nil { + return next(c) + } + return tmpErr } - return &echo.HTTPError{ - Code: ErrJWTInvalid.Code, - Message: ErrJWTInvalid.Message, - Internal: err, + + // backwards compatible errors codes + if lastTokenErr != nil { + return &echo.HTTPError{ + Code: ErrJWTInvalid.Code, + Message: ErrJWTInvalid.Message, + Internal: err, + } } + return err // this is lastExtractorErr value } } } @@ -296,59 +298,3 @@ func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) { return config.SigningKey, nil } - -// jwtFromHeader returns a `TokenLookupFunc` that extracts token from the request header. -func jwtFromHeader(header string, authScheme string) TokenLookupFunc { - return func(c echo.Context) (string, error) { - auth := c.Request().Header.Get(header) - l := len(authScheme) - if len(auth) > l+1 && strings.EqualFold(auth[:l], authScheme) { - return auth[l+1:], nil - } - return "", ErrJWTMissing - } -} - -// jwtFromQuery returns a `TokenLookupFunc` that extracts token from the query string. -func jwtFromQuery(param string) TokenLookupFunc { - return func(c echo.Context) (string, error) { - token := c.QueryParam(param) - if token == "" { - return "", ErrJWTMissing - } - return token, nil - } -} - -// jwtFromParam returns a `TokenLookupFunc` that extracts token from the url param string. -func jwtFromParam(param string) TokenLookupFunc { - return func(c echo.Context) (string, error) { - token := c.Param(param) - if token == "" { - return "", ErrJWTMissing - } - return token, nil - } -} - -// jwtFromCookie returns a `TokenLookupFunc` that extracts token from the named cookie. -func jwtFromCookie(name string) TokenLookupFunc { - return func(c echo.Context) (string, error) { - cookie, err := c.Cookie(name) - if err != nil { - return "", ErrJWTMissing - } - return cookie.Value, nil - } -} - -// jwtFromForm returns a `TokenLookupFunc` that extracts token from the form field. -func jwtFromForm(name string) TokenLookupFunc { - return func(c echo.Context) (string, error) { - field := c.FormValue(name) - if field == "" { - return "", ErrJWTMissing - } - return field, nil - } -} diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 18454d0a7..eee9df966 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -1,3 +1,4 @@ +//go:build go1.15 // +build go1.15 package middleware @@ -28,6 +29,26 @@ type jwtCustomClaims struct { jwtCustomInfo } +func TestJWT(t *testing.T) { + e := echo.New() + + e.GET("/", func(c echo.Context) error { + token := c.Get("user").(*jwt.Token) + return c.JSON(http.StatusOK, token.Claims) + }) + + e.Use(JWT([]byte("secret"))) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, "bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + res := httptest.NewRecorder() + + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String()) +} + func TestJWTRace(t *testing.T) { e := echo.New() handler := func(c echo.Context) error { @@ -64,8 +85,7 @@ func TestJWTRace(t *testing.T) { assert.Equal(t, claims.Admin, true) } -func TestJWT(t *testing.T) { - e := echo.New() +func TestJWTConfig(t *testing.T) { handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } @@ -74,7 +94,8 @@ func TestJWT(t *testing.T) { invalidKey := []byte("invalid-key") validAuth := DefaultJWTConfig.AuthScheme + " " + token - for _, tc := range []struct { + testCases := []struct { + name string expPanic bool expErrCode int // 0 for Success config JWTConfig @@ -82,166 +103,166 @@ func TestJWT(t *testing.T) { hdrAuth string hdrCookie string // test.Request doesn't provide SetCookie(); use name=val formValues map[string]string - info string }{ { + name: "No signing key provided", expPanic: true, - info: "No signing key provided", }, { + name: "Unexpected signing method", expErrCode: http.StatusBadRequest, config: JWTConfig{ SigningKey: validKey, SigningMethod: "RS256", }, - info: "Unexpected signing method", }, { + name: "Invalid key", expErrCode: http.StatusUnauthorized, hdrAuth: validAuth, config: JWTConfig{SigningKey: invalidKey}, - info: "Invalid key", }, { + name: "Valid JWT", hdrAuth: validAuth, config: JWTConfig{SigningKey: validKey}, - info: "Valid JWT", }, { + name: "Valid JWT with custom AuthScheme", hdrAuth: "Token" + " " + token, config: JWTConfig{AuthScheme: "Token", SigningKey: validKey}, - info: "Valid JWT with custom AuthScheme", }, { + name: "Valid JWT with custom claims", hdrAuth: validAuth, config: JWTConfig{ Claims: &jwtCustomClaims{}, SigningKey: []byte("secret"), }, - info: "Valid JWT with custom claims", }, { + name: "Invalid Authorization header", hdrAuth: "invalid-auth", expErrCode: http.StatusBadRequest, config: JWTConfig{SigningKey: validKey}, - info: "Invalid Authorization header", }, { + name: "Empty header auth field", config: JWTConfig{SigningKey: validKey}, expErrCode: http.StatusBadRequest, - info: "Empty header auth field", }, { + name: "Valid query method", config: JWTConfig{ SigningKey: validKey, TokenLookup: "query:jwt", }, reqURL: "/?a=b&jwt=" + token, - info: "Valid query method", }, { + name: "Invalid query param name", config: JWTConfig{ SigningKey: validKey, TokenLookup: "query:jwt", }, reqURL: "/?a=b&jwtxyz=" + token, expErrCode: http.StatusBadRequest, - info: "Invalid query param name", }, { + name: "Invalid query param value", config: JWTConfig{ SigningKey: validKey, TokenLookup: "query:jwt", }, reqURL: "/?a=b&jwt=invalid-token", expErrCode: http.StatusUnauthorized, - info: "Invalid query param value", }, { + name: "Empty query", config: JWTConfig{ SigningKey: validKey, TokenLookup: "query:jwt", }, reqURL: "/?a=b", expErrCode: http.StatusBadRequest, - info: "Empty query", }, { + name: "Valid param method", config: JWTConfig{ SigningKey: validKey, TokenLookup: "param:jwt", }, reqURL: "/" + token, - info: "Valid param method", }, { + name: "Valid cookie method", config: JWTConfig{ SigningKey: validKey, TokenLookup: "cookie:jwt", }, hdrCookie: "jwt=" + token, - info: "Valid cookie method", }, { + name: "Multiple jwt lookuop", config: JWTConfig{ SigningKey: validKey, TokenLookup: "query:jwt,cookie:jwt", }, hdrCookie: "jwt=" + token, - info: "Multiple jwt lookuop", }, { + name: "Invalid token with cookie method", config: JWTConfig{ SigningKey: validKey, TokenLookup: "cookie:jwt", }, expErrCode: http.StatusUnauthorized, hdrCookie: "jwt=invalid", - info: "Invalid token with cookie method", }, { + name: "Empty cookie", config: JWTConfig{ SigningKey: validKey, TokenLookup: "cookie:jwt", }, expErrCode: http.StatusBadRequest, - info: "Empty cookie", }, { + name: "Valid form method", config: JWTConfig{ SigningKey: validKey, TokenLookup: "form:jwt", }, formValues: map[string]string{"jwt": token}, - info: "Valid form method", }, { + name: "Invalid token with form method", config: JWTConfig{ SigningKey: validKey, TokenLookup: "form:jwt", }, expErrCode: http.StatusUnauthorized, formValues: map[string]string{"jwt": "invalid"}, - info: "Invalid token with form method", }, { + name: "Empty form field", config: JWTConfig{ SigningKey: validKey, TokenLookup: "form:jwt", }, expErrCode: http.StatusBadRequest, - info: "Empty form field", }, { + name: "Valid JWT with a valid key using a user-defined KeyFunc", hdrAuth: validAuth, config: JWTConfig{ KeyFunc: func(*jwt.Token) (interface{}, error) { return validKey, nil }, }, - info: "Valid JWT with a valid key using a user-defined KeyFunc", }, { + name: "Valid JWT with an invalid key using a user-defined KeyFunc", hdrAuth: validAuth, config: JWTConfig{ KeyFunc: func(*jwt.Token) (interface{}, error) { @@ -249,9 +270,9 @@ func TestJWT(t *testing.T) { }, }, expErrCode: http.StatusUnauthorized, - info: "Valid JWT with an invalid key using a user-defined KeyFunc", }, { + name: "Token verification does not pass using a user-defined KeyFunc", hdrAuth: validAuth, config: JWTConfig{ KeyFunc: func(*jwt.Token) (interface{}, error) { @@ -259,67 +280,70 @@ func TestJWT(t *testing.T) { }, }, expErrCode: http.StatusUnauthorized, - info: "Token verification does not pass using a user-defined KeyFunc", }, { + name: "Valid JWT with lower case AuthScheme", hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token, config: JWTConfig{SigningKey: validKey}, - info: "Valid JWT with lower case AuthScheme", }, - } { - if tc.reqURL == "" { - tc.reqURL = "/" - } + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + if tc.reqURL == "" { + tc.reqURL = "/" + } - var req *http.Request - if len(tc.formValues) > 0 { - form := url.Values{} - for k, v := range tc.formValues { - form.Set(k, v) + var req *http.Request + if len(tc.formValues) > 0 { + form := url.Values{} + for k, v := range tc.formValues { + form.Set(k, v) + } + req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode())) + req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded") + req.ParseForm() + } else { + req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil) } - req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode())) - req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded") - req.ParseForm() - } else { - req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil) - } - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) - req.Header.Set(echo.HeaderCookie, tc.hdrCookie) - c := e.NewContext(req, res) + res := httptest.NewRecorder() + req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) + req.Header.Set(echo.HeaderCookie, tc.hdrCookie) + c := e.NewContext(req, res) - if tc.reqURL == "/"+token { - c.SetParamNames("jwt") - c.SetParamValues(token) - } + if tc.reqURL == "/"+token { + c.SetParamNames("jwt") + c.SetParamValues(token) + } - if tc.expPanic { - assert.Panics(t, func() { - JWTWithConfig(tc.config) - }, tc.info) - continue - } + if tc.expPanic { + assert.Panics(t, func() { + JWTWithConfig(tc.config) + }, tc.name) + return + } - if tc.expErrCode != 0 { - h := JWTWithConfig(tc.config)(handler) - he := h(c).(*echo.HTTPError) - assert.Equal(t, tc.expErrCode, he.Code, tc.info) - continue - } + if tc.expErrCode != 0 { + h := JWTWithConfig(tc.config)(handler) + he := h(c).(*echo.HTTPError) + assert.Equal(t, tc.expErrCode, he.Code, tc.name) + return + } - h := JWTWithConfig(tc.config)(handler) - if assert.NoError(t, h(c), tc.info) { - user := c.Get("user").(*jwt.Token) - switch claims := user.Claims.(type) { - case jwt.MapClaims: - assert.Equal(t, claims["name"], "John Doe", tc.info) - case *jwtCustomClaims: - assert.Equal(t, claims.Name, "John Doe", tc.info) - assert.Equal(t, claims.Admin, true, tc.info) - default: - panic("unexpected type of claims") + h := JWTWithConfig(tc.config)(handler) + if assert.NoError(t, h(c), tc.name) { + user := c.Get("user").(*jwt.Token) + switch claims := user.Claims.(type) { + case jwt.MapClaims: + assert.Equal(t, claims["name"], "John Doe", tc.name) + case *jwtCustomClaims: + assert.Equal(t, claims.Name, "John Doe", tc.name) + assert.Equal(t, claims.Admin, true, tc.name) + default: + panic("unexpected type of claims") + } } - } + }) } } @@ -608,13 +632,14 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) { e := echo.New() e.GET("/", func(c echo.Context) error { - return c.String(http.StatusOK, "test") + token := c.Get("user").(*jwt.Token) + return c.JSON(http.StatusOK, token.Claims) }) e.Use(JWTWithConfig(JWTConfig{ - TokenLookupFuncs: []TokenLookupFunc{ - func(c echo.Context) (string, error) { - return c.Request().Header.Get("X-API-Key"), nil + TokenLookupFuncs: []ValuesExtractor{ + func(c echo.Context) ([]string, error) { + return []string{c.Request().Header.Get("X-API-Key")}, nil }, }, SigningKey: []byte("secret"), @@ -626,4 +651,129 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) { e.ServeHTTP(res, req) assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String()) +} + +func TestJWTConfig_SuccessHandler(t *testing.T) { + var testCases = []struct { + name string + givenToken string + expectCalled bool + expectStatus int + }{ + { + name: "ok, success handler is called", + givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ", + expectCalled: true, + expectStatus: http.StatusOK, + }, + { + name: "nok, success handler is not called", + givenToken: "x.x.x", + expectCalled: false, + expectStatus: http.StatusUnauthorized, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + e.GET("/", func(c echo.Context) error { + token := c.Get("user").(*jwt.Token) + return c.JSON(http.StatusOK, token.Claims) + }) + + wasCalled := false + e.Use(JWTWithConfig(JWTConfig{ + SuccessHandler: func(c echo.Context) { + wasCalled = true + }, + SigningKey: []byte("secret"), + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken) + res := httptest.NewRecorder() + + e.ServeHTTP(res, req) + + assert.Equal(t, tc.expectCalled, wasCalled) + assert.Equal(t, tc.expectStatus, res.Code) + }) + } +} + +func TestJWTConfig_ContinueOnIgnoredError(t *testing.T) { + var testCases = []struct { + name string + whenContinueOnIgnoredError bool + givenToken string + expectStatus int + expectBody string + }{ + { + name: "no error handler is called", + whenContinueOnIgnoredError: true, + givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ", + expectStatus: http.StatusTeapot, + expectBody: "", + }, + { + name: "ContinueOnIgnoredError is false and error handler is called for missing token", + whenContinueOnIgnoredError: false, + givenToken: "", + // empty response with 200. This emulates previous behaviour when error handler swallowed the error + expectStatus: http.StatusOK, + expectBody: "", + }, + { + name: "error handler is called for missing token", + whenContinueOnIgnoredError: true, + givenToken: "", + expectStatus: http.StatusTeapot, + expectBody: "public-token", + }, + { + name: "error handler is called for invalid token", + whenContinueOnIgnoredError: true, + givenToken: "x.x.x", + expectStatus: http.StatusUnauthorized, + expectBody: "{\"message\":\"Unauthorized\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + e.GET("/", func(c echo.Context) error { + testValue, _ := c.Get("test").(string) + return c.String(http.StatusTeapot, testValue) + }) + + e.Use(JWTWithConfig(JWTConfig{ + ContinueOnIgnoredError: tc.whenContinueOnIgnoredError, + SigningKey: []byte("secret"), + ErrorHandlerWithContext: func(err error, c echo.Context) error { + if err == ErrJWTMissing { + c.Set("test", "public-token") + return nil + } + return echo.ErrUnauthorized + }, + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenToken != "" { + req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken) + } + res := httptest.NewRecorder() + + e.ServeHTTP(res, req) + + assert.Equal(t, tc.expectStatus, res.Code) + assert.Equal(t, tc.expectBody, res.Body.String()) + }) + } } diff --git a/middleware/key_auth.go b/middleware/key_auth.go index 54f3b47f3..e8a6b0853 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -2,11 +2,8 @@ package middleware import ( "errors" - "fmt" - "net/http" - "strings" - "github.com/labstack/echo/v4" + "net/http" ) type ( @@ -15,15 +12,21 @@ type ( // Skipper defines a function to skip middleware. Skipper Skipper - // KeyLookup is a string in the form of ":" that is used + // KeyLookup is a string in the form of ":" or ":,:" that is used // to extract key from the request. // Optional. Default value "header:Authorization". // Possible values: - // - "header:" + // - "header:" or "header::" + // `` is argument value to cut/trim prefix of the extracted value. This is useful if header + // value has static prefix like `Authorization: ` where part that we + // want to cut is ` ` note the space at the end. + // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. // - "query:" // - "form:" // - "cookie:" - KeyLookup string `yaml:"key_lookup"` + // Multiple sources example: + // - "header:Authorization,header:X-Api-Key" + KeyLookup string // AuthScheme to be used in the Authorization header. // Optional. Default value "Bearer". @@ -36,15 +39,20 @@ type ( // ErrorHandler defines a function which is executed for an invalid key. // It may be used to define a custom error. ErrorHandler KeyAuthErrorHandler + + // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to + // ignore the error (by returning `nil`). + // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. + // In that case you can use ErrorHandler to set a default public key auth value in the request context + // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. + ContinueOnIgnoredError bool } // KeyAuthValidator defines a function to validate KeyAuth credentials. - KeyAuthValidator func(string, echo.Context) (bool, error) - - keyExtractor func(echo.Context) (string, error) + KeyAuthValidator func(auth string, c echo.Context) (bool, error) // KeyAuthErrorHandler defines a function which is executed for an invalid key. - KeyAuthErrorHandler func(error, echo.Context) error + KeyAuthErrorHandler func(err error, c echo.Context) error ) var ( @@ -56,6 +64,21 @@ var ( } ) +// ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups +type ErrKeyAuthMissing struct { + Err error +} + +// Error returns errors text +func (e *ErrKeyAuthMissing) Error() string { + return e.Err.Error() +} + +// Unwrap unwraps error +func (e *ErrKeyAuthMissing) Unwrap() error { + return e.Err +} + // KeyAuth returns an KeyAuth middleware. // // For valid key it calls the next handler. @@ -85,16 +108,9 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { panic("echo: key-auth middleware requires a validator function") } - // Initialize - parts := strings.Split(config.KeyLookup, ":") - extractor := keyFromHeader(parts[1], config.AuthScheme) - switch parts[0] { - case "query": - extractor = keyFromQuery(parts[1]) - case "form": - extractor = keyFromForm(parts[1]) - case "cookie": - extractor = keyFromCookie(parts[1]) + extractors, err := createExtractors(config.KeyLookup, config.AuthScheme) + if err != nil { + panic(err) } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -103,79 +119,62 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { return next(c) } - // Extract and verify key - key, err := extractor(c) - if err != nil { - if config.ErrorHandler != nil { - return config.ErrorHandler(err, c) + var lastExtractorErr error + var lastValidatorErr error + for _, extractor := range extractors { + keys, err := extractor(c) + if err != nil { + lastExtractorErr = err + continue + } + for _, key := range keys { + valid, err := config.Validator(key, c) + if err != nil { + lastValidatorErr = err + continue + } + if valid { + return next(c) + } + lastValidatorErr = errors.New("invalid key") } - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - valid, err := config.Validator(key, c) - if err != nil { - if config.ErrorHandler != nil { - return config.ErrorHandler(err, c) + + // we are here only when we did not successfully extract and validate any of keys + err := lastValidatorErr + if err == nil { // prioritize validator errors over extracting errors + // ugly part to preserve backwards compatible errors. someone could rely on them + if lastExtractorErr == errQueryExtractorValueMissing { + err = errors.New("missing key in the query string") + } else if lastExtractorErr == errCookieExtractorValueMissing { + err = errors.New("missing key in cookies") + } else if lastExtractorErr == errFormExtractorValueMissing { + err = errors.New("missing key in the form") + } else if lastExtractorErr == errHeaderExtractorValueMissing { + err = errors.New("missing key in request header") + } else if lastExtractorErr == errHeaderExtractorValueInvalid { + err = errors.New("invalid key in the request header") + } else { + err = lastExtractorErr } + err = &ErrKeyAuthMissing{Err: err} + } + + if config.ErrorHandler != nil { + tmpErr := config.ErrorHandler(err, c) + if config.ContinueOnIgnoredError && tmpErr == nil { + return next(c) + } + return tmpErr + } + if lastValidatorErr != nil { // prioritize validator errors over extracting errors return &echo.HTTPError{ Code: http.StatusUnauthorized, - Message: "invalid key", - Internal: err, + Message: "Unauthorized", + Internal: lastValidatorErr, } - } else if valid { - return next(c) } - return echo.ErrUnauthorized - } - } -} - -// keyFromHeader returns a `keyExtractor` that extracts key from the request header. -func keyFromHeader(header string, authScheme string) keyExtractor { - return func(c echo.Context) (string, error) { - auth := c.Request().Header.Get(header) - if auth == "" { - return "", errors.New("missing key in request header") - } - if header == echo.HeaderAuthorization { - l := len(authScheme) - if len(auth) > l+1 && auth[:l] == authScheme { - return auth[l+1:], nil - } - return "", errors.New("invalid key in the request header") - } - return auth, nil - } -} - -// keyFromQuery returns a `keyExtractor` that extracts key from the query string. -func keyFromQuery(param string) keyExtractor { - return func(c echo.Context) (string, error) { - key := c.QueryParam(param) - if key == "" { - return "", errors.New("missing key in the query string") - } - return key, nil - } -} - -// keyFromForm returns a `keyExtractor` that extracts key from the form. -func keyFromForm(param string) keyExtractor { - return func(c echo.Context) (string, error) { - key := c.FormValue(param) - if key == "" { - return "", errors.New("missing key in the form") - } - return key, nil - } -} - -// keyFromCookie returns a `keyExtractor` that extracts key from the form. -func keyFromCookie(cookieName string) keyExtractor { - return func(c echo.Context) (string, error) { - key, err := c.Cookie(cookieName) - if err != nil { - return "", fmt.Errorf("missing key in cookies: %w", err) + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - return key.Value, nil } } diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go index 0cc513ab0..ff8968c38 100644 --- a/middleware/key_auth_test.go +++ b/middleware/key_auth_test.go @@ -76,7 +76,7 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key") }, expectHandlerCalled: false, - expectError: "code=401, message=Unauthorized", + expectError: "code=401, message=Unauthorized, internal=invalid key", }, { name: "nok, defaults, invalid scheme in header", @@ -92,6 +92,17 @@ func TestKeyAuthWithConfig(t *testing.T) { expectHandlerCalled: false, expectError: "code=400, message=missing key in request header", }, + { + name: "ok, custom key lookup from multiple places, query and header", + givenRequest: func(req *http.Request) { + req.URL.RawQuery = "key=invalid-key" + req.Header.Set("API-Key", "valid-key") + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "query:key,header:API-Key" + }, + expectHandlerCalled: true, + }, { name: "ok, custom key lookup, header", givenRequest: func(req *http.Request) { @@ -179,7 +190,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "cookie:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in cookies: http: named cookie not present", + expectError: "code=400, message=missing key in cookies", }, { name: "nok, custom errorHandler, error from extractor", @@ -216,7 +227,7 @@ func TestKeyAuthWithConfig(t *testing.T) { }, whenConfig: func(conf *KeyAuthConfig) {}, expectHandlerCalled: false, - expectError: "code=401, message=invalid key, internal=some user defined error", + expectError: "code=401, message=Unauthorized, internal=some user defined error", }, } @@ -257,3 +268,109 @@ func TestKeyAuthWithConfig(t *testing.T) { }) } } + +func TestKeyAuthWithConfig_panicsOnInvalidLookup(t *testing.T) { + assert.PanicsWithError( + t, + "extractor source for lookup could not be split into needed parts: a", + func() { + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + KeyAuthWithConfig(KeyAuthConfig{ + Validator: testKeyValidator, + KeyLookup: "a", + })(handler) + }, + ) +} + +func TestKeyAuthWithConfig_panicsOnEmptyValidator(t *testing.T) { + assert.PanicsWithValue( + t, + "echo: key-auth middleware requires a validator function", + func() { + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + KeyAuthWithConfig(KeyAuthConfig{ + Validator: nil, + })(handler) + }, + ) +} + +func TestKeyAuthWithConfig_ContinueOnIgnoredError(t *testing.T) { + var testCases = []struct { + name string + whenContinueOnIgnoredError bool + givenKey string + expectStatus int + expectBody string + }{ + { + name: "no error handler is called", + whenContinueOnIgnoredError: true, + givenKey: "valid-key", + expectStatus: http.StatusTeapot, + expectBody: "", + }, + { + name: "ContinueOnIgnoredError is false and error handler is called for missing token", + whenContinueOnIgnoredError: false, + givenKey: "", + // empty response with 200. This emulates previous behaviour when error handler swallowed the error + expectStatus: http.StatusOK, + expectBody: "", + }, + { + name: "error handler is called for missing token", + whenContinueOnIgnoredError: true, + givenKey: "", + expectStatus: http.StatusTeapot, + expectBody: "public-auth", + }, + { + name: "error handler is called for invalid token", + whenContinueOnIgnoredError: true, + givenKey: "x.x.x", + expectStatus: http.StatusUnauthorized, + expectBody: "{\"message\":\"Unauthorized\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + e.GET("/", func(c echo.Context) error { + testValue, _ := c.Get("test").(string) + return c.String(http.StatusTeapot, testValue) + }) + + e.Use(KeyAuthWithConfig(KeyAuthConfig{ + Validator: testKeyValidator, + ErrorHandler: func(err error, c echo.Context) error { + if _, ok := err.(*ErrKeyAuthMissing); ok { + c.Set("test", "public-auth") + return nil + } + return echo.ErrUnauthorized + }, + KeyLookup: "header:X-API-Key", + ContinueOnIgnoredError: tc.whenContinueOnIgnoredError, + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenKey != "" { + req.Header.Set("X-API-Key", tc.givenKey) + } + res := httptest.NewRecorder() + + e.ServeHTTP(res, req) + + assert.Equal(t, tc.expectStatus, res.Code) + assert.Equal(t, tc.expectBody, res.Body.String()) + }) + } +} diff --git a/middleware/middleware.go b/middleware/middleware.go index a7ad73a5c..f250ca49a 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -12,10 +12,10 @@ import ( type ( // Skipper defines a function to skip middleware. Returning true skips processing // the middleware. - Skipper func(echo.Context) bool + Skipper func(c echo.Context) bool // BeforeFunc defines a function which is executed just before the middleware. - BeforeFunc func(echo.Context) + BeforeFunc func(c echo.Context) ) func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { From 6cb3b7c046b9fe73f0a1341cd2eea8071a60cc04 Mon Sep 17 00:00:00 2001 From: eric <65116642+nonbutAworker@users.noreply.github.com> Date: Wed, 23 Feb 2022 15:22:20 +0800 Subject: [PATCH 202/446] remove redundant 0 in make chan (#2101) * remove 0 in make(chan,0) to fix go-staticcheck problem --- echo_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/echo_test.go b/echo_test.go index f175d765b..d31e7b604 100644 --- a/echo_test.go +++ b/echo_test.go @@ -961,7 +961,7 @@ func TestEchoStartTLSByteString(t *testing.T) { e := New() e.HideBanner = true - errChan := make(chan error, 0) + errChan := make(chan error) go func() { errChan <- e.StartTLS(":0", test.cert, test.key) @@ -999,7 +999,7 @@ func TestEcho_StartAutoTLS(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() - errChan := make(chan error, 0) + errChan := make(chan error) go func() { errChan <- e.StartAutoTLS(tc.addr) From 27b404bbc5290de56044a906c9f1692a08b64e29 Mon Sep 17 00:00:00 2001 From: eric <65116642+nonbutAworker@users.noreply.github.com> Date: Wed, 23 Feb 2022 19:28:20 +0800 Subject: [PATCH 203/446] remove unused notFoundHandler in echo struct (#2102) * remove unused notFoundHandler in echo --- echo.go | 1 - 1 file changed, 1 deletion(-) diff --git a/echo.go b/echo.go index 2e63cc6b1..88a732ee7 100644 --- a/echo.go +++ b/echo.go @@ -75,7 +75,6 @@ type ( maxParam *int router *Router routers map[string]*Router - notFoundHandler HandlerFunc pool sync.Pool Server *http.Server TLSServer *http.Server From 124825ee629f32aade886f1aeb76e0c6f70c7faa Mon Sep 17 00:00:00 2001 From: Yusuf Eyisan Date: Tue, 1 Mar 2022 10:56:46 +0300 Subject: [PATCH 204/446] Bugfix/1834 Fix X-Real-IP bug (#2007) * Fix incorrect return ip value for RealIpHeader * Improve test file to compare correct real IPs to each other and have better comments * Refactor ip extractor tests to be more readable (longer but readable) Co-authored-by: toimtoimtoim --- echo.go | 6 +- ip.go | 142 +++++++++- ip_test.go | 777 +++++++++++++++++++++++++++++++++++++++-------------- 3 files changed, 711 insertions(+), 214 deletions(-) diff --git a/echo.go b/echo.go index 88a732ee7..b658de4d7 100644 --- a/echo.go +++ b/echo.go @@ -214,9 +214,9 @@ const ( HeaderXForwardedSsl = "X-Forwarded-Ssl" HeaderXUrlScheme = "X-Url-Scheme" HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" - HeaderXRealIP = "X-Real-IP" - HeaderXRequestID = "X-Request-ID" - HeaderXCorrelationID = "X-Correlation-ID" + HeaderXRealIP = "X-Real-Ip" + HeaderXRequestID = "X-Request-Id" + HeaderXCorrelationID = "X-Correlation-Id" HeaderXRequestedWith = "X-Requested-With" HeaderServer = "Server" HeaderOrigin = "Origin" diff --git a/ip.go b/ip.go index 39cb421fd..46d464cf9 100644 --- a/ip.go +++ b/ip.go @@ -6,6 +6,130 @@ import ( "strings" ) +/** +By: https://github.com/tmshn (See: https://github.com/labstack/echo/pull/1478 , https://github.com/labstack/echox/pull/134 ) +Source: https://echo.labstack.com/guide/ip-address/ + +IP address plays fundamental role in HTTP; it's used for access control, auditing, geo-based access analysis and more. +Echo provides handy method [`Context#RealIP()`](https://godoc.org/github.com/labstack/echo#Context) for that. + +However, it is not trivial to retrieve the _real_ IP address from requests especially when you put L7 proxies before the application. +In such situation, _real_ IP needs to be relayed on HTTP layer from proxies to your app, but you must not trust HTTP headers unconditionally. +Otherwise, you might give someone a chance of deceiving you. **A security risk!** + +To retrieve IP address reliably/securely, you must let your application be aware of the entire architecture of your infrastructure. +In Echo, this can be done by configuring `Echo#IPExtractor` appropriately. +This guides show you why and how. + +> Note: if you dont' set `Echo#IPExtractor` explicitly, Echo fallback to legacy behavior, which is not a good choice. + +Let's start from two questions to know the right direction: + +1. Do you put any HTTP (L7) proxy in front of the application? + - It includes both cloud solutions (such as AWS ALB or GCP HTTP LB) and OSS ones (such as Nginx, Envoy or Istio ingress gateway). +2. If yes, what HTTP header do your proxies use to pass client IP to the application? + +## Case 1. With no proxy + +If you put no proxy (e.g.: directory facing to the internet), all you need to (and have to) see is IP address from network layer. +Any HTTP header is untrustable because the clients have full control what headers to be set. + +In this case, use `echo.ExtractIPDirect()`. + +```go +e.IPExtractor = echo.ExtractIPDirect() +``` + +## Case 2. With proxies using `X-Forwarded-For` header + +[`X-Forwared-For` (XFF)](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For) is the popular header +to relay clients' IP addresses. +At each hop on the proxies, they append the request IP address at the end of the header. + +Following example diagram illustrates this behavior. + +```text +┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ +│ "Origin" │───────────>│ Proxy 1 │───────────>│ Proxy 2 │───────────>│ Your app │ +│ (IP: a) │ │ (IP: b) │ │ (IP: c) │ │ │ +└──────────┘ └──────────┘ └──────────┘ └──────────┘ + +Case 1. +XFF: "" "a" "a, b" + ~~~~~~ +Case 2. +XFF: "x" "x, a" "x, a, b" + ~~~~~~~~~ + ↑ What your app will see +``` + +In this case, use **first _untrustable_ IP reading from right**. Never use first one reading from left, as it is +configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructre". +In above example, if `b` and `c` are trustable, the IP address of the client is `a` for both cases, never be `x`. + +In Echo, use `ExtractIPFromXFFHeader(...TrustOption)`. + +```go +e.IPExtractor = echo.ExtractIPFromXFFHeader() +``` + +By default, it trusts internal IP addresses (loopback, link-local unicast, private-use and unique local address +from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and +[RFC4193](https://tools.ietf.org/html/rfc4193)). +To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s. + +E.g.: + +```go +e.IPExtractor = echo.ExtractIPFromXFFHeader( + TrustLinkLocal(false), + TrustIPRanges(lbIPRange), +) +``` + +- Ref: https://godoc.org/github.com/labstack/echo#TrustOption + +## Case 3. With proxies using `X-Real-IP` header + +`X-Real-IP` is another HTTP header to relay clients' IP addresses, but it carries only one address unlike XFF. + +If your proxies set this header, use `ExtractIPFromRealIPHeader(...TrustOption)`. + +```go +e.IPExtractor = echo.ExtractIPFromRealIPHeader() +``` + +Again, it trusts internal IP addresses by default (loopback, link-local unicast, private-use and unique local address +from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and +[RFC4193](https://tools.ietf.org/html/rfc4193)). +To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s. + +- Ref: https://godoc.org/github.com/labstack/echo#TrustOption + +> **Never forget** to configure the outermost proxy (i.e.; at the edge of your infrastructure) **not to pass through incoming headers**. +> Otherwise there is a chance of fraud, as it is what clients can control. + +## About default behavior + +In default behavior, Echo sees all of first XFF header, X-Real-IP header and IP from network layer. + +As you might already notice, after reading this article, this is not good. +Sole reason this is default is just backward compatibility. + +## Private IP ranges + +See: https://en.wikipedia.org/wiki/Private_network + +Private IPv4 address ranges (RFC 1918): +* 10.0.0.0 – 10.255.255.255 (24-bit block) +* 172.16.0.0 – 172.31.255.255 (20-bit block) +* 192.168.0.0 – 192.168.255.255 (16-bit block) + +Private IPv6 address ranges: +* fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA) + +*/ + type ipChecker struct { trustLoopback bool trustLinkLocal bool @@ -52,6 +176,7 @@ func newIPChecker(configs []TrustOption) *ipChecker { return checker } +// Go1.16+ added `ip.IsPrivate()` but until that use this implementation func isPrivateIPRange(ip net.IP) bool { if ip4 := ip.To4(); ip4 != nil { return ip4[0] == 10 || @@ -87,10 +212,12 @@ type IPExtractor func(*http.Request) string // ExtractIPDirect extracts IP address using actual IP address. // Use this if your server faces to internet directory (i.e.: uses no proxy). func ExtractIPDirect() IPExtractor { - return func(req *http.Request) string { - ra, _, _ := net.SplitHostPort(req.RemoteAddr) - return ra - } + return extractIP +} + +func extractIP(req *http.Request) string { + ra, _, _ := net.SplitHostPort(req.RemoteAddr) + return ra } // ExtractIPFromRealIPHeader extracts IP address using x-real-ip header. @@ -98,14 +225,13 @@ func ExtractIPDirect() IPExtractor { func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { checker := newIPChecker(options) return func(req *http.Request) string { - directIP := ExtractIPDirect()(req) realIP := req.Header.Get(HeaderXRealIP) if realIP != "" { - if ip := net.ParseIP(directIP); ip != nil && checker.trust(ip) { + if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) { return realIP } } - return directIP + return extractIP(req) } } @@ -115,7 +241,7 @@ func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor { checker := newIPChecker(options) return func(req *http.Request) string { - directIP := ExtractIPDirect()(req) + directIP := extractIP(req) xffs := req.Header[HeaderXForwardedFor] if len(xffs) == 0 { return directIP diff --git a/ip_test.go b/ip_test.go index 5acc11798..755900d3d 100644 --- a/ip_test.go +++ b/ip_test.go @@ -1,235 +1,606 @@ package echo import ( + "github.com/stretchr/testify/assert" "net" "net/http" - "strings" "testing" - - testify "github.com/stretchr/testify/assert" ) -const ( - // For RemoteAddr - ipForRemoteAddrLoopback = "127.0.0.1" // From 127.0.0.0/8 - sampleRemoteAddrLoopback = ipForRemoteAddrLoopback + ":8080" - ipForRemoteAddrExternal = "203.0.113.1" - sampleRemoteAddrExternal = ipForRemoteAddrExternal + ":8080" - // For x-real-ip - ipForRealIP = "203.0.113.10" - // For XFF - ipForXFF1LinkLocal = "169.254.0.101" // From 169.254.0.0/16 - ipForXFF2Private = "192.168.0.102" // From 192.168.0.0/16 - ipForXFF3External = "2001:db8::103" - ipForXFF4Private = "fc00::104" // From fc00::/7 - ipForXFF5External = "198.51.100.105" - ipForXFF6External = "192.0.2.106" - ipForXFFBroken = "this.is.broken.lol" - // keys for test cases - ipTestReqKeyNoHeader = "no header" - ipTestReqKeyRealIPExternal = "x-real-ip; remote addr external" - ipTestReqKeyRealIPInternal = "x-real-ip; remote addr internal" - ipTestReqKeyRealIPAndXFFExternal = "x-real-ip and xff; remote addr external" - ipTestReqKeyRealIPAndXFFInternal = "x-real-ip and xff; remote addr internal" - ipTestReqKeyXFFExternal = "xff; remote addr external" - ipTestReqKeyXFFInternal = "xff; remote addr internal" - ipTestReqKeyBrokenXFF = "broken xff" -) +func mustParseCIDR(s string) *net.IPNet { + _, IPNet, err := net.ParseCIDR(s) + if err != nil { + panic(err) + } + return IPNet +} + +func TestIPChecker_TrustOption(t *testing.T) { + var testCases = []struct { + name string + givenOptions []TrustOption + whenIP string + expect bool + }{ + { + name: "ip is within trust range, trusts additional private IPV6 network", + givenOptions: []TrustOption{ + TrustLoopback(false), + TrustLinkLocal(false), + TrustPrivateNet(false), + // this is private IPv6 ip + // CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48 + // Address: 2001:0db8:0000:0000:0000:0000:0000:0103 + // Range start: 2001:0db8:0000:0000:0000:0000:0000:0000 + // Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff + TrustIPRange(mustParseCIDR("2001:db8::103/48")), + }, + whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", + expect: true, + }, + { + name: "ip is within trust range, trusts additional private IPV6 network", + givenOptions: []TrustOption{ + TrustIPRange(mustParseCIDR("2001:db8::103/48")), + }, + whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", + expect: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker(tc.givenOptions) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustIPRange(t *testing.T) { + var testCases = []struct { + name string + givenRange string + whenIP string + expect bool + }{ + { + name: "ip is within trust range, IPV6 network range", + // CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48 + // Address: 2001:0db8:0000:0000:0000:0000:0000:0103 + // Range start: 2001:0db8:0000:0000:0000:0000:0000:0000 + // Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff + givenRange: "2001:db8::103/48", + whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", + expect: true, + }, + { + name: "ip is outside (upper bounds) of trust range, IPV6 network range", + givenRange: "2001:db8::103/48", + whenIP: "2001:0db8:0001:0000:0000:0000:0000:0000", + expect: false, + }, + { + name: "ip is outside (lower bounds) of trust range, IPV6 network range", + givenRange: "2001:db8::103/48", + whenIP: "2001:0db7:ffff:ffff:ffff:ffff:ffff:ffff", + expect: false, + }, + { + name: "ip is within trust range, IPV4 network range", + // CIDR Notation: 8.8.8.8/24 + // Address: 8.8.8.8 + // Range start: 8.8.8.0 + // Range end: 8.8.8.255 + givenRange: "8.8.8.0/24", + whenIP: "8.8.8.8", + expect: true, + }, + { + name: "ip is within trust range, IPV4 network range", + // CIDR Notation: 8.8.8.8/24 + // Address: 8.8.8.8 + // Range start: 8.8.8.0 + // Range end: 8.8.8.255 + givenRange: "8.8.8.0/24", + whenIP: "8.8.8.8", + expect: true, + }, + { + name: "ip is outside (upper bounds) of trust range, IPV4 network range", + givenRange: "8.8.8.0/24", + whenIP: "8.8.9.0", + expect: false, + }, + { + name: "ip is outside (lower bounds) of trust range, IPV4 network range", + givenRange: "8.8.8.0/24", + whenIP: "8.8.7.255", + expect: false, + }, + { + name: "public ip, trust everything in IPV4 network range", + givenRange: "0.0.0.0/0", + whenIP: "8.8.8.8", + expect: true, + }, + { + name: "internal ip, trust everything in IPV4 network range", + givenRange: "0.0.0.0/0", + whenIP: "127.0.10.1", + expect: true, + }, + { + name: "public ip, trust everything in IPV6 network range", + givenRange: "::/0", + whenIP: "2a00:1450:4026:805::200e", + expect: true, + }, + { + name: "internal ip, trust everything in IPV6 network range", + givenRange: "::/0", + whenIP: "0:0:0:0:0:0:0:1", + expect: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cidr := mustParseCIDR(tc.givenRange) + checker := newIPChecker([]TrustOption{ + TrustLoopback(false), // disable to avoid interference + TrustLinkLocal(false), // disable to avoid interference + TrustPrivateNet(false), // disable to avoid interference + + TrustIPRange(cidr), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} -var ( - sampleXFF = strings.Join([]string{ - ipForXFF6External, ipForXFF5External, ipForXFF4Private, ipForXFF3External, ipForXFF2Private, ipForXFF1LinkLocal, - }, ", ") +func TestTrustPrivateNet(t *testing.T) { + var testCases = []struct { + name string + whenIP string + expect bool + }{ + { + name: "do not trust public IPv4 address", + whenIP: "8.8.8.8", + expect: false, + }, + { + name: "do not trust public IPv6 address", + whenIP: "2a00:1450:4026:805::200e", + expect: false, + }, - requests = map[string]*http.Request{ - ipTestReqKeyNoHeader: &http.Request{ - RemoteAddr: sampleRemoteAddrExternal, + { // Class A: 10.0.0.0 — 10.255.255.255 + name: "do not trust IPv4 just outside of class A (lower bounds)", + whenIP: "9.255.255.255", + expect: false, + }, + { + name: "do not trust IPv4 just outside of class A (upper bounds)", + whenIP: "11.0.0.0", + expect: false, + }, + { + name: "trust IPv4 of class A (lower bounds)", + whenIP: "10.0.0.0", + expect: true, + }, + { + name: "trust IPv4 of class A (upper bounds)", + whenIP: "10.255.255.255", + expect: true, + }, + + { // Class B: 172.16.0.0 — 172.31.255.255 + name: "do not trust IPv4 just outside of class B (lower bounds)", + whenIP: "172.15.255.255", + expect: false, + }, + { + name: "do not trust IPv4 just outside of class B (upper bounds)", + whenIP: "172.32.0.0", + expect: false, + }, + { + name: "trust IPv4 of class B (lower bounds)", + whenIP: "172.16.0.0", + expect: true, + }, + { + name: "trust IPv4 of class B (upper bounds)", + whenIP: "172.31.255.255", + expect: true, + }, + + { // Class C: 192.168.0.0 — 192.168.255.255 + name: "do not trust IPv4 just outside of class C (lower bounds)", + whenIP: "192.167.255.255", + expect: false, + }, + { + name: "do not trust IPv4 just outside of class C (upper bounds)", + whenIP: "192.169.0.0", + expect: false, + }, + { + name: "trust IPv4 of class C (lower bounds)", + whenIP: "192.168.0.0", + expect: true, + }, + { + name: "trust IPv4 of class C (upper bounds)", + whenIP: "192.168.255.255", + expect: true, + }, + + { // fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA) + // splits the address block in two equally sized halves, fc00::/8 and fd00::/8. + // https://en.wikipedia.org/wiki/Unique_local_address + name: "trust IPv6 private address", + whenIP: "fdfc:3514:2cb3:4bd5::", + expect: true, + }, + { + name: "do not trust IPv6 just out of /fd (upper bounds)", + whenIP: "/fe00:0000:0000:0000:0000", + expect: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker([]TrustOption{ + TrustLoopback(false), // disable to avoid interference + TrustLinkLocal(false), // disable to avoid interference + + TrustPrivateNet(true), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustLinkLocal(t *testing.T) { + var testCases = []struct { + name string + whenIP string + expect bool + }{ + { + name: "trust link local IPv4 address (lower bounds)", + whenIP: "169.254.0.0", + expect: true, + }, + { + name: "trust link local IPv4 address (upper bounds)", + whenIP: "169.254.255.255", + expect: true, }, - ipTestReqKeyRealIPExternal: &http.Request{ - Header: http.Header{ - "X-Real-Ip": []string{ipForRealIP}, + { + name: "do not trust link local IPv4 address (outside of lower bounds)", + whenIP: "169.253.255.255", + expect: false, + }, + { + name: "do not trust link local IPv4 address (outside of upper bounds)", + whenIP: "169.255.0.0", + expect: false, + }, + { + name: "trust link local IPv6 address ", + whenIP: "fe80::1", + expect: true, + }, + { + name: "do not trust link local IPv6 address ", + whenIP: "fec0::1", + expect: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker([]TrustOption{ + TrustLoopback(false), // disable to avoid interference + TrustPrivateNet(false), // disable to avoid interference + + TrustLinkLocal(true), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustLoopback(t *testing.T) { + var testCases = []struct { + name string + whenIP string + expect bool + }{ + { + name: "trust IPv4 as localhost", + whenIP: "127.0.0.1", + expect: true, + }, + { + name: "trust IPv6 as localhost", + whenIP: "::1", + expect: true, + }, + { + name: "do not trust public ip as localhost", + whenIP: "8.8.8.8", + expect: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker([]TrustOption{ + TrustLinkLocal(false), // disable to avoid interference + TrustPrivateNet(false), // disable to avoid interference + + TrustLoopback(true), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestExtractIPDirect(t *testing.T) { + var testCases = []struct { + name string + whenRequest http.Request + expectIP string + }{ + { + name: "request has no headers, extracts IP from request remote addr", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1:8080", }, - RemoteAddr: sampleRemoteAddrExternal, + expectIP: "203.0.113.1", }, - ipTestReqKeyRealIPInternal: &http.Request{ - Header: http.Header{ - "X-Real-Ip": []string{ipForRealIP}, + { + name: "request is from external IP has X-Real-Ip header, extractor still extracts IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.10"}, + }, + RemoteAddr: "203.0.113.1:8080", }, - RemoteAddr: sampleRemoteAddrLoopback, + expectIP: "203.0.113.1", }, - ipTestReqKeyRealIPAndXFFExternal: &http.Request{ - Header: http.Header{ - "X-Real-Ip": []string{ipForRealIP}, - HeaderXForwardedFor: []string{sampleXFF}, + { + name: "request is from internal IP and has Real-IP header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.10"}, + }, + RemoteAddr: "127.0.0.1:8080", }, - RemoteAddr: sampleRemoteAddrExternal, + expectIP: "127.0.0.1", }, - ipTestReqKeyRealIPAndXFFInternal: &http.Request{ - Header: http.Header{ - "X-Real-Ip": []string{ipForRealIP}, - HeaderXForwardedFor: []string{sampleXFF}, + { + name: "request is from external IP and has XFF + Real-IP header, extractor still extracts external IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.10"}, + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "203.0.113.1:8080", }, - RemoteAddr: sampleRemoteAddrLoopback, + expectIP: "203.0.113.1", }, - ipTestReqKeyXFFExternal: &http.Request{ - Header: http.Header{ - HeaderXForwardedFor: []string{sampleXFF}, + { + name: "request is from internal IP and has XFF + Real-IP header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"127.0.0.1"}, + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "127.0.0.1:8080", }, - RemoteAddr: sampleRemoteAddrExternal, + expectIP: "127.0.0.1", }, - ipTestReqKeyXFFInternal: &http.Request{ - Header: http.Header{ - HeaderXForwardedFor: []string{sampleXFF}, + { + name: "request is from external IP and has XFF header, extractor still extracts external IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "203.0.113.1:8080", }, - RemoteAddr: sampleRemoteAddrLoopback, + expectIP: "203.0.113.1", }, - ipTestReqKeyBrokenXFF: &http.Request{ - Header: http.Header{ - HeaderXForwardedFor: []string{ipForXFFBroken + ", " + ipForXFF1LinkLocal}, + { + name: "request is from internal IP and has XFF header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "127.0.0.1:8080", }, - RemoteAddr: sampleRemoteAddrLoopback, + expectIP: "127.0.0.1", + }, + { + name: "request is from internal IP and has INVALID XFF header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"this.is.broken.lol, 169.254.0.101"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", }, } -) -func TestExtractIP(t *testing.T) { - _, ipv4AllRange, _ := net.ParseCIDR("0.0.0.0/0") - _, ipv6AllRange, _ := net.ParseCIDR("::/0") - _, ipForXFF3ExternalRange, _ := net.ParseCIDR(ipForXFF3External + "/48") - _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR(ipForRemoteAddrExternal + "/24") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + extractedIP := ExtractIPDirect()(&tc.whenRequest) + assert.Equal(t, tc.expectIP, extractedIP) + }) + } +} + +func TestExtractIPFromRealIPHeader(t *testing.T) { + _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") - tests := map[string]*struct { - extractor IPExtractor - expectedIPs map[string]string + var testCases = []struct { + name string + givenTrustOptions []TrustOption + whenRequest http.Request + expectIP string }{ - "ExtractIPDirect": { - ExtractIPDirect(), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromRealIPHeader(default)": { - ExtractIPFromRealIPHeader(), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRealIP, - ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPAndXFFInternal: ipForRealIP, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromRealIPHeader(trust only direct-facing proxy)": { - ExtractIPFromRealIPHeader(TrustLoopback(false), TrustLinkLocal(false), TrustPrivateNet(false), TrustIPRange(ipForRemoteAddrExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRealIP, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForRealIP, - ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromRealIPHeader(trust direct-facing proxy)": { - ExtractIPFromRealIPHeader(TrustIPRange(ipForRemoteAddrExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRealIP, - ipTestReqKeyRealIPInternal: ipForRealIP, - ipTestReqKeyRealIPAndXFFExternal: ipForRealIP, - ipTestReqKeyRealIPAndXFFInternal: ipForRealIP, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(default)": { - ExtractIPFromXFFHeader(), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPAndXFFInternal: ipForXFF3External, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForXFF3External, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(trust only direct-facing proxy)": { - ExtractIPFromXFFHeader(TrustLoopback(false), TrustLinkLocal(false), TrustPrivateNet(false), TrustIPRange(ipForRemoteAddrExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForXFF1LinkLocal, - ipTestReqKeyRealIPAndXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyXFFExternal: ipForXFF1LinkLocal, - ipTestReqKeyXFFInternal: ipForRemoteAddrLoopback, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(trust direct-facing proxy)": { - ExtractIPFromXFFHeader(TrustIPRange(ipForRemoteAddrExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForXFF3External, - ipTestReqKeyRealIPAndXFFInternal: ipForXFF3External, - ipTestReqKeyXFFExternal: ipForXFF3External, - ipTestReqKeyXFFInternal: ipForXFF3External, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(trust everything)": { - // This is similar to legacy behavior, but ignores x-real-ip header. - ExtractIPFromXFFHeader(TrustIPRange(ipv4AllRange), TrustIPRange(ipv6AllRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForXFF6External, - ipTestReqKeyRealIPAndXFFInternal: ipForXFF6External, - ipTestReqKeyXFFExternal: ipForXFF6External, - ipTestReqKeyXFFInternal: ipForXFF6External, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - "ExtractIPFromXFFHeader(trust ipForXFF3External)": { - // This trusts private network also after "additional" trust ranges unlike `TrustNProxies(1)` doesn't - ExtractIPFromXFFHeader(TrustIPRange(ipForXFF3ExternalRange)), - map[string]string{ - ipTestReqKeyNoHeader: ipForRemoteAddrExternal, - ipTestReqKeyRealIPExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPInternal: ipForRemoteAddrLoopback, - ipTestReqKeyRealIPAndXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyRealIPAndXFFInternal: ipForXFF5External, - ipTestReqKeyXFFExternal: ipForRemoteAddrExternal, - ipTestReqKeyXFFInternal: ipForXFF5External, - ipTestReqKeyBrokenXFF: ipForRemoteAddrLoopback, - }, - }, - } - for name, test := range tests { - t.Run(name, func(t *testing.T) { - assert := testify.New(t) - for key, req := range requests { - actual := test.extractor(req) - expected := test.expectedIPs[key] - assert.Equal(expected, actual, "Request: %s", key) - } + { + name: "request has no headers, extracts IP from request remote addr", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has INVALID external X-Real-Ip header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"xxx.yyy.zzz.ccc"}, // <-- this is invalid + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.199"}, // <-- this is untrusted + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.199"}, + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.199", + }, + { + name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.199"}, + HeaderXForwardedFor: []string{"203.0.113.198, 203.0.113.197"}, // <-- should not affect anything + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.199", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + extractedIP := ExtractIPFromRealIPHeader(tc.givenTrustOptions...)(&tc.whenRequest) + assert.Equal(t, tc.expectIP, extractedIP) + }) + } +} + +func TestExtractIPFromXFFHeader(t *testing.T) { + _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") + + var testCases = []struct { + name string + givenTrustOptions []TrustOption + whenRequest http.Request + expectIP string + }{ + { + name: "request has no headers, extracts IP from request remote addr", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request has INVALID external XFF header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"xxx.yyy.zzz.ccc, 127.0.0.2"}, // <-- this is invalid + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + { + name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"127.0.0.3, 127.0.0.2, 127.0.0.1"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.3", + }, + { + name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"203.0.113.199"}, // <-- this is untrusted + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header", + givenTrustOptions: []TrustOption{ + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, + // from request its seems that request has been proxied through 6 servers. + // 1) 203.0.1.100 (this is external IP set by 203.0.100.100 which we do not trust - could be spoofed) + // 2) 203.0.100.100 (this is outside of our network but set by 203.0.113.199 which we trust to set correct IPs) + // 3) 203.0.113.199 (we trust, for example maybe our proxy from some other office) + // 4) 192.168.1.100 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products) + // 5) 127.0.0.1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing) + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"203.0.1.100, 203.0.100.100, 203.0.113.199, 192.168.1.100"}, + }, + RemoteAddr: "127.0.0.1:8080", // IP of proxy upstream of our APP + }, + expectIP: "203.0.100.100", // this is first trusted IP in XFF chain + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + extractedIP := ExtractIPFromXFFHeader(tc.givenTrustOptions...)(&tc.whenRequest) + assert.Equal(t, tc.expectIP, extractedIP) }) } } From 7e719b46e290993a7f396819808998e3ae0becf4 Mon Sep 17 00:00:00 2001 From: Wagner Souza Date: Tue, 1 Mar 2022 20:11:28 -0300 Subject: [PATCH 205/446] Add cache-control and connection headers (#2103) Co-authored-by: Wagner Souza --- echo.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/echo.go b/echo.go index b658de4d7..6143403f3 100644 --- a/echo.go +++ b/echo.go @@ -220,6 +220,8 @@ const ( HeaderXRequestedWith = "X-Requested-With" HeaderServer = "Server" HeaderOrigin = "Origin" + HeaderCacheControl = "Cache-Control" + HeaderConnection = "Connection" // Access control HeaderAccessControlRequestMethod = "Access-Control-Request-Method" From d66712b252b09751742243aaae56fdd5628ce4d2 Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Wed, 2 Mar 2022 22:59:19 +0100 Subject: [PATCH 206/446] Update direct golang deps --- go.mod | 10 +++++----- go.sum | 23 +++++++++++++---------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index 4de2bdde1..f09e32cf9 100644 --- a/go.mod +++ b/go.mod @@ -7,18 +7,18 @@ require ( github.com/labstack/gommon v0.3.1 github.com/stretchr/testify v1.7.0 github.com/valyala/fasttemplate v1.2.1 - golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 - golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f - golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 + golang.org/x/crypto v0.0.0-20220214200702-86341886e292 + golang.org/x/net v0.0.0-20220225172249-27dd8689420f + golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-colorable v0.1.11 // indirect + github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.0.0-20211103235746-7861aae1554b // indirect + golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect golang.org/x/text v0.3.7 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index f66734243..7b86ace06 100644 --- a/go.sum +++ b/go.sum @@ -5,8 +5,9 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/labstack/gommon v0.3.1 h1:OomWaJXm7xR6L1HmEtGyQf26TEn7V6X88mktX9kee9o= github.com/labstack/gommon v0.3.1/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= -github.com/mattn/go-colorable v0.1.11 h1:nQ+aFkoE2TMGc0b68U2OKSexC+eq46+XwZzWXHRmPYs= github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -18,25 +19,27 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= -golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f h1:OfiFi4JbukWwe3lzw+xunroH1mnC1e2Gy5cxNJApiSY= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc= +golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 h1:M73Iuj3xbbb9Uk1DYhzydthsj6oOd6l9bpuFcNoUvTs= +golang.org/x/time v0.0.0-20220224211638-0e9765cccd65/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From da85d23d685ce31105f1a88682edaeb284223c53 Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Wed, 2 Mar 2022 23:11:46 +0100 Subject: [PATCH 207/446] Revert "Update direct golang deps" This reverts commit d66712b252b09751742243aaae56fdd5628ce4d2. --- go.mod | 10 +++++----- go.sum | 23 ++++++++++------------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/go.mod b/go.mod index f09e32cf9..4de2bdde1 100644 --- a/go.mod +++ b/go.mod @@ -7,18 +7,18 @@ require ( github.com/labstack/gommon v0.3.1 github.com/stretchr/testify v1.7.0 github.com/valyala/fasttemplate v1.2.1 - golang.org/x/crypto v0.0.0-20220214200702-86341886e292 - golang.org/x/net v0.0.0-20220225172249-27dd8689420f - golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 + golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 + golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f + golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-colorable v0.1.12 // indirect + github.com/mattn/go-colorable v0.1.11 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect + golang.org/x/sys v0.0.0-20211103235746-7861aae1554b // indirect golang.org/x/text v0.3.7 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index 7b86ace06..f66734243 100644 --- a/go.sum +++ b/go.sum @@ -5,9 +5,8 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/labstack/gommon v0.3.1 h1:OomWaJXm7xR6L1HmEtGyQf26TEn7V6X88mktX9kee9o= github.com/labstack/gommon v0.3.1/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= +github.com/mattn/go-colorable v0.1.11 h1:nQ+aFkoE2TMGc0b68U2OKSexC+eq46+XwZzWXHRmPYs= github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= -github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -19,27 +18,25 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE= -golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= +golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f h1:OfiFi4JbukWwe3lzw+xunroH1mnC1e2Gy5cxNJApiSY= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs= -golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 h1:M73Iuj3xbbb9Uk1DYhzydthsj6oOd6l9bpuFcNoUvTs= -golang.org/x/time v0.0.0-20220224211638-0e9765cccd65/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= +golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From 5ebed440aeec1abf7f08cca41cb02f6aaf0d7f6a Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Wed, 2 Mar 2022 23:16:19 +0100 Subject: [PATCH 208/446] Update version to v4.7.0 --- CHANGELOG.md | 21 +++++++++++++++++++++ echo.go | 2 +- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 372ed13c5..461ac89c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,26 @@ # Changelog +## v4.7.0 - 2022-03-01 + +**Enhancements** + +* Add JWT, KeyAuth, CSRF multivalue extractors [#2060](https://github.com/labstack/echo/pull/2060) +* Add LogErrorFunc to recover middleware [#2072](https://github.com/labstack/echo/pull/2072) +* Add support for HEAD method query params binding [#2027](https://github.com/labstack/echo/pull/2027) +* Improve filesystem support with echo.FileFS, echo.StaticFS, group.FileFS, group.StaticFS [#2064](https://github.com/labstack/echo/pull/2064) + +**Fixes** + +* Fix X-Real-IP bug, improve tests [#2007](https://github.com/labstack/echo/pull/2007) +* Minor syntax fixes [#1994](https://github.com/labstack/echo/pull/1994), [#2102](https://github.com/labstack/echo/pull/2102), [#2102](https://github.com/labstack/echo/pull/2102) + +**General** + +* Add cache-control and connection headers [#2103](https://github.com/labstack/echo/pull/2103) +* Add Retry-After header constant [#2078](https://github.com/labstack/echo/pull/2078) +* Upgrade `go` directive in `go.mod` to 1.17 [#2049](https://github.com/labstack/echo/pull/2049) +* Add Pagoda [#2077](https://github.com/labstack/echo/pull/2077) and Souin [#2069](https://github.com/labstack/echo/pull/2069) to 3rd-party middlewares in README + ## v4.6.3 - 2022-01-10 **Fixes** diff --git a/echo.go b/echo.go index 6143403f3..143f9ffe3 100644 --- a/echo.go +++ b/echo.go @@ -246,7 +246,7 @@ const ( const ( // Version of Echo - Version = "4.6.3" + Version = "4.7.0" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 3f5b733425617138573e3768381278f619561f7e Mon Sep 17 00:00:00 2001 From: Martti T Date: Sun, 13 Mar 2022 15:05:12 +0200 Subject: [PATCH 209/446] Fix `e.Static`, `.File()`, `c.Attachment()` being picky with paths starting with `./`, `../` and `/` after 4.7.0 introduced echo.Filesystem support (Go1.16+) (#2123) * Fix `e.Static`, `.File()`, `c.Attachment()` being picky with paths starting with `./`, `../` and `/` after 4.7.0 introduced echo.Filesystem support (Go1.16+) --- echo_fs_go1.16.go | 40 ++++++++++++++++++++++++------- echo_test.go | 61 +++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 88 insertions(+), 13 deletions(-) diff --git a/echo_fs_go1.16.go b/echo_fs_go1.16.go index 435459de2..eb17768ab 100644 --- a/echo_fs_go1.16.go +++ b/echo_fs_go1.16.go @@ -10,6 +10,7 @@ import ( "net/url" "os" "path/filepath" + "runtime" "strings" ) @@ -94,10 +95,12 @@ func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { } } -// defaultFS emulates os.Open behaviour with filesystem opened by `os.DirFs`. Difference between `os.Open` and `fs.Open` -// is that FS does not allow to open path that start with `..` or `/` etc. For example previously you could have `../images` -// in your application but `fs := os.DirFS("./")` would not allow you to use `fs.Open("../images")` and this would break -// all old applications that rely on being able to traverse up from current executable run path. +// defaultFS exists to preserve pre v4.7.0 behaviour where files were open by `os.Open`. +// v4.7 introduced `echo.Filesystem` field which is Go1.16+ `fs.Fs` interface. +// Difference between `os.Open` and `fs.Open` is that FS does not allow opening path that start with `.`, `..` or `/` +// etc. For example previously you could have `../images` in your application but `fs := os.DirFS("./")` would not +// allow you to use `fs.Open("../images")` and this would break all old applications that rely on being able to +// traverse up from current executable run path. // NB: private because you really should use fs.FS implementation instances type defaultFS struct { prefix string @@ -108,20 +111,26 @@ func newDefaultFS() *defaultFS { dir, _ := os.Getwd() return &defaultFS{ prefix: dir, - fs: os.DirFS(dir), + fs: nil, } } func (fs defaultFS) Open(name string) (fs.File, error) { + if fs.fs == nil { + return os.Open(name) + } return fs.fs.Open(name) } func subFS(currentFs fs.FS, root string) (fs.FS, error) { root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows if dFS, ok := currentFs.(*defaultFS); ok { - // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS to - // allow cases when root is given as `../somepath` which is not valid for fs.FS - root = filepath.Join(dFS.prefix, root) + // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. + // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we + // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs + if isRelativePath(root) { + root = filepath.Join(dFS.prefix, root) + } return &defaultFS{ prefix: root, fs: os.DirFS(root), @@ -130,6 +139,21 @@ func subFS(currentFs fs.FS, root string) (fs.FS, error) { return fs.Sub(currentFs, root) } +func isRelativePath(path string) bool { + if path == "" { + return true + } + if path[0] == '/' { + return false + } + if runtime.GOOS == "windows" && strings.IndexByte(path, ':') != -1 { + // https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file?redirectedfrom=MSDN#file_and_directory_names + // https://docs.microsoft.com/en-us/dotnet/standard/io/file-path-formats + return false + } + return true +} + // MustSubFS creates sub FS from current filesystem or panic on failure. // Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. // diff --git a/echo_test.go b/echo_test.go index d31e7b604..0e1e42be0 100644 --- a/echo_test.go +++ b/echo_test.go @@ -84,6 +84,14 @@ func TestEchoStatic(t *testing.T) { expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), }, + { + name: "ok with relative path for root points to directory", + givenPrefix: "/images", + givenRoot: "./_fixture/images", + whenURL: "/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, { name: "No file", givenPrefix: "/images", @@ -246,11 +254,54 @@ func TestEchoStaticRedirectIndex(t *testing.T) { } func TestEchoFile(t *testing.T) { - e := New() - e.File("/walle", "_fixture/images/walle.png") - c, b := request(http.MethodGet, "/walle", e) - assert.Equal(t, http.StatusOK, c) - assert.NotEmpty(t, b) + var testCases = []struct { + name string + givenPath string + givenFile string + whenPath string + expectCode int + expectStartsWith string + }{ + { + name: "ok", + givenPath: "/walle", + givenFile: "_fixture/images/walle.png", + whenPath: "/walle", + expectCode: http.StatusOK, + expectStartsWith: string([]byte{0x89, 0x50, 0x4e}), + }, + { + name: "ok with relative path", + givenPath: "/", + givenFile: "./go.mod", + whenPath: "/", + expectCode: http.StatusOK, + expectStartsWith: "module github.com/labstack/echo/v", + }, + { + name: "nok file does not exist", + givenPath: "/", + givenFile: "./this-file-does-not-exist", + whenPath: "/", + expectCode: http.StatusNotFound, + expectStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() // we are using echo.defaultFS instance + e.File(tc.givenPath, tc.givenFile) + + c, b := request(http.MethodGet, tc.whenPath, e) + assert.Equal(t, tc.expectCode, c) + + if len(b) > len(tc.expectStartsWith) { + b = b[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, b) + }) + } } func TestEchoMiddleware(t *testing.T) { From 54efc3850dd205bbffe650763533310cae170f4d Mon Sep 17 00:00:00 2001 From: eric <65116642+nonbutAworker@users.noreply.github.com> Date: Sun, 13 Mar 2022 21:31:39 +0800 Subject: [PATCH 210/446] remove some unused code (#2116) * remove unused code --- binder_test.go | 2 +- router_test.go | 31 ------------------------------- 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/binder_test.go b/binder_test.go index 946906a96..034967793 100644 --- a/binder_test.go +++ b/binder_test.go @@ -54,7 +54,7 @@ func TestBindingError_Error(t *testing.T) { func TestBindingError_ErrorJSON(t *testing.T) { err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error")) - resp, err := json.Marshal(err) + resp, _ := json.Marshal(err) assert.Equal(t, `{"field":"id","message":"bind failed"}`, string(resp)) } diff --git a/router_test.go b/router_test.go index 5cbb8d9b8..457566b90 100644 --- a/router_test.go +++ b/router_test.go @@ -1,7 +1,6 @@ package echo import ( - "fmt" "net/http" "net/http/httptest" "strings" @@ -2446,33 +2445,3 @@ func BenchmarkRouterGooglePlusAPIMisses(b *testing.B) { func BenchmarkRouterParamsAndAnyAPI(b *testing.B) { benchmarkRouterRoutes(b, paramAndAnyAPI, paramAndAnyAPIToFind) } - -func (n *node) printTree(pfx string, tail bool) { - p := prefix(tail, pfx, "└── ", "├── ") - fmt.Printf("%s%s, %p: type=%d, parent=%p, handler=%v, pnames=%v\n", p, n.prefix, n, n.kind, n.parent, n.methodHandler, n.pnames) - - p = prefix(tail, pfx, " ", "│ ") - - children := n.staticChildren - l := len(children) - - if n.paramChild != nil { - n.paramChild.printTree(p, n.anyChild == nil && l == 0) - } - if n.anyChild != nil { - n.anyChild.printTree(p, l == 0) - } - for i := 0; i < l-1; i++ { - children[i].printTree(p, false) - } - if l > 0 { - children[l-1].printTree(p, true) - } -} - -func prefix(tail bool, p, on, off string) string { - if tail { - return fmt.Sprintf("%s%s", p, on) - } - return fmt.Sprintf("%s%s", p, off) -} From b445958c3ce4cf34997a67ef73a30cd870170945 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 13 Mar 2022 18:20:30 +0200 Subject: [PATCH 211/446] Update version and changelog for 4.7.1 --- CHANGELOG.md | 11 +++++++++++ echo.go | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 461ac89c9..7d1d9086a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## v4.7.1 - 2022-03-13 + +**Fixes** + +* Fix `e.Static`, `.File()`, `c.Attachment()` being picky with paths starting with `./`, `../` and `/` after 4.7.0 introduced echo.Filesystem support (Go1.16+) [#2123](https://github.com/labstack/echo/pull/2123) + +**Enhancements** + +* Remove some unused code [#2116](https://github.com/labstack/echo/pull/2116) + + ## v4.7.0 - 2022-03-01 **Enhancements** diff --git a/echo.go b/echo.go index 143f9ffe3..5b3087269 100644 --- a/echo.go +++ b/echo.go @@ -246,7 +246,7 @@ const ( const ( // Version of Echo - Version = "4.7.0" + Version = "4.7.1" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 05df10c62f8a753e342623afb7dec8dbf4ef3f59 Mon Sep 17 00:00:00 2001 From: Gabriel Nelle Date: Mon, 14 Mar 2022 10:44:07 +0100 Subject: [PATCH 212/446] fix nil pointer exception when calling Start again after address binding error --- echo.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/echo.go b/echo.go index 5b3087269..c2f23c194 100644 --- a/echo.go +++ b/echo.go @@ -732,7 +732,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) { return s.Serve(e.Listener) } -func (e *Echo) configureServer(s *http.Server) (err error) { +func (e *Echo) configureServer(s *http.Server) error { // Setup e.colorer.SetOutput(e.Logger.Output()) s.ErrorLog = e.StdLogger @@ -747,10 +747,11 @@ func (e *Echo) configureServer(s *http.Server) (err error) { if s.TLSConfig == nil { if e.Listener == nil { - e.Listener, err = newListener(s.Addr, e.ListenerNetwork) + l, err := newListener(s.Addr, e.ListenerNetwork) if err != nil { return err } + e.Listener = l } if !e.HidePort { e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) @@ -791,7 +792,7 @@ func (e *Echo) TLSListenerAddr() net.Addr { } // StartH2CServer starts a custom http/2 server with h2c (HTTP/2 Cleartext). -func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { +func (e *Echo) StartH2CServer(address string, h2s *http2.Server) error { e.startupMutex.Lock() // Setup s := e.Server @@ -808,11 +809,12 @@ func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) { } if e.Listener == nil { - e.Listener, err = newListener(s.Addr, e.ListenerNetwork) + l, err := newListener(s.Addr, e.ListenerNetwork) if err != nil { e.startupMutex.Unlock() return err } + e.Listener = l } if !e.HidePort { e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) From 5c38c3b770c2e477f17266e09fe77ee07ab70dfe Mon Sep 17 00:00:00 2001 From: Becir Basic Date: Wed, 16 Mar 2022 00:29:42 +0100 Subject: [PATCH 213/446] Recover middleware should not log panic for aborted handler (#2134, fixes #2133) Co-authored-by: Becir Basic --- middleware/recover.go | 4 ++++ middleware/recover_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/middleware/recover.go b/middleware/recover.go index a621a9efe..7b6128533 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -2,6 +2,7 @@ package middleware import ( "fmt" + "net/http" "runtime" "github.com/labstack/echo/v4" @@ -77,6 +78,9 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { defer func() { if r := recover(); r != nil { + if r == http.ErrAbortHandler { + panic(r) + } err, ok := r.(error) if !ok { err = fmt.Errorf("%v", r) diff --git a/middleware/recover_test.go b/middleware/recover_test.go index 9ac4feedc..b27f3b41c 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -28,6 +28,35 @@ func TestRecover(t *testing.T) { assert.Contains(t, buf.String(), "PANIC RECOVER") } +func TestRecoverErrAbortHandler(t *testing.T) { + e := echo.New() + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := Recover()(echo.HandlerFunc(func(c echo.Context) error { + panic(http.ErrAbortHandler) + })) + defer func() { + r := recover() + if r == nil { + assert.Fail(t, "expecting `http.ErrAbortHandler`, got `nil`") + } else { + if err, ok := r.(error); ok { + assert.ErrorIs(t, err, http.ErrAbortHandler) + } else { + assert.Fail(t, "not of error type") + } + } + }() + + h(c) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.NotContains(t, buf.String(), "PANIC RECOVER") +} + func TestRecoverWithConfig_LogLevel(t *testing.T) { tests := []struct { logLevel log.Lvl From 01d7d01bbc1948cd308b2ae93a131654e6dba195 Mon Sep 17 00:00:00 2001 From: Martti T Date: Wed, 16 Mar 2022 01:43:20 +0200 Subject: [PATCH 214/446] Fix CSRF middleware not being able to extract token from `multipart/form-data` form (#2136, fixes #2135) --- middleware/extractor.go | 4 ++-- middleware/extractor_test.go | 39 +++++++++++++++++++++++++++--------- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/middleware/extractor.go b/middleware/extractor.go index a57ed4e13..afdfd8195 100644 --- a/middleware/extractor.go +++ b/middleware/extractor.go @@ -168,8 +168,8 @@ func valuesFromCookie(name string) ValuesExtractor { // valuesFromForm returns a function that extracts values from the form field. func valuesFromForm(name string) ValuesExtractor { return func(c echo.Context) ([]string, error) { - if parseErr := c.Request().ParseForm(); parseErr != nil { - return nil, fmt.Errorf("valuesFromForm parse form failed: %w", parseErr) + if c.Request().Form == nil { + _ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does } values := c.Request().Form[name] if len(values) == 0 { diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go index ae4b30a8a..2e898f541 100644 --- a/middleware/extractor_test.go +++ b/middleware/extractor_test.go @@ -1,9 +1,11 @@ package middleware import ( + "bytes" "fmt" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" + "mime/multipart" "net/http" "net/http/httptest" "net/url" @@ -499,6 +501,25 @@ func TestValuesFromForm(t *testing.T) { return req } + exampleMultiPartFormRequest := func(mod func(w *multipart.Writer)) *http.Request { + var b bytes.Buffer + w := multipart.NewWriter(&b) + w.WriteField("name", "Jon Snow") + w.WriteField("emails[]", "jon@labstack.com") + if mod != nil { + mod(w) + } + + fw, _ := w.CreateFormFile("upload", "my.file") + fw.Write([]byte(`
hi
`)) + w.Close() + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(b.String())) + req.Header.Add(echo.HeaderContentType, w.FormDataContentType()) + + return req + } + var testCases = []struct { name string givenRequest *http.Request @@ -520,6 +541,14 @@ func TestValuesFromForm(t *testing.T) { whenName: "emails[]", expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, }, + { + name: "ok, POST multipart/form, multiple value", + givenRequest: exampleMultiPartFormRequest(func(w *multipart.Writer) { + w.WriteField("emails[]", "snow@labstack.com") + }), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, + }, { name: "ok, GET form, single value", givenRequest: exampleGetFormRequest(nil), @@ -540,16 +569,6 @@ func TestValuesFromForm(t *testing.T) { whenName: "nope", expectError: errFormExtractorValueMissing.Error(), }, - { - name: "nok, POST form, form parsing error", - givenRequest: func() *http.Request { - req := httptest.NewRequest(http.MethodPost, "/", nil) - req.Body = nil - return req - }(), - whenName: "name", - expectError: "valuesFromForm parse form failed: missing form body", - }, { name: "ok, cut values over extractorLimit", givenRequest: examplePostFormRequest(func(v *url.Values) { From 1919cf4491fa46624a34eb1fb2dd13d414343b64 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 13 Mar 2022 16:00:02 +0200 Subject: [PATCH 215/446] Timeout middleware write race --- middleware/timeout.go | 130 ++++++++++++++++++++++++------------- middleware/timeout_test.go | 9 ++- 2 files changed, 91 insertions(+), 48 deletions(-) diff --git a/middleware/timeout.go b/middleware/timeout.go index 768ef8d70..4e8836c85 100644 --- a/middleware/timeout.go +++ b/middleware/timeout.go @@ -2,10 +2,10 @@ package middleware import ( "context" + "github.com/labstack/echo/v4" "net/http" + "sync" "time" - - "github.com/labstack/echo/v4" ) // --------------------------------------------------------------------------------------------------------------- @@ -55,29 +55,27 @@ import ( // }) // -type ( - // TimeoutConfig defines the config for Timeout middleware. - TimeoutConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code - // It can be used to define a custom timeout error message - ErrorMessage string - - // OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after - // request timeouted and we already had sent the error code (503) and message response to the client. - // NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer - // will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()` - OnTimeoutRouteErrorHandler func(err error, c echo.Context) - - // Timeout configures a timeout for the middleware, defaults to 0 for no timeout - // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) - // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output - // difference over 500microseconds (0.5millisecond) response seems to be reliable - Timeout time.Duration - } -) +// TimeoutConfig defines the config for Timeout middleware. +type TimeoutConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code + // It can be used to define a custom timeout error message + ErrorMessage string + + // OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after + // request timeouted and we already had sent the error code (503) and message response to the client. + // NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer + // will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()` + OnTimeoutRouteErrorHandler func(err error, c echo.Context) + + // Timeout configures a timeout for the middleware, defaults to 0 for no timeout + // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) + // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output + // difference over 500microseconds (0.5millisecond) response seems to be reliable + Timeout time.Duration +} var ( // DefaultTimeoutConfig is the default Timeout middleware config. @@ -94,10 +92,17 @@ func Timeout() echo.MiddlewareFunc { return TimeoutWithConfig(DefaultTimeoutConfig) } -// TimeoutWithConfig returns a Timeout middleware with config. -// See: `Timeout()`. +// TimeoutWithConfig returns a Timeout middleware with config or panics on invalid configuration. func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc { - // Defaults + mw, err := config.ToMiddleware() + if err != nil { + panic(err) + } + return mw +} + +// ToMiddleware converts Config to middleware or returns an error for invalid configuration +func (config TimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultTimeoutConfig.Skipper } @@ -108,26 +113,29 @@ func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc { return next(c) } + errChan := make(chan error, 1) handlerWrapper := echoHandlerFuncWrapper{ + writer: &ignorableWriter{ResponseWriter: c.Response().Writer}, ctx: c, handler: next, - errChan: make(chan error, 1), + errChan: errChan, errHandler: config.OnTimeoutRouteErrorHandler, } handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage) - handler.ServeHTTP(c.Response().Writer, c.Request()) + handler.ServeHTTP(handlerWrapper.writer, c.Request()) select { - case err := <-handlerWrapper.errChan: + case err := <-errChan: return err default: return nil } } - } + }, nil } type echoHandlerFuncWrapper struct { + writer *ignorableWriter ctx echo.Context handler echo.HandlerFunc errHandler func(err error, c echo.Context) @@ -160,23 +168,53 @@ func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Reques } return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers } - // we restore original writer only for cases we did not timeout. On timeout we have already sent response to client - // and should not anymore send additional headers/data - // so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body if err != nil { - // Error must be written into Writer created in `http.TimeoutHandler` so to get Response into `commited` state. - // So call global error handler to write error to the client. This is needed or `http.TimeoutHandler` will send - // status code by itself and after that our tries to write status code will not work anymore and/or create errors in - // log about `superfluous response.WriteHeader call from` - t.ctx.Error(err) - // we pass error from handler to middlewares up in handler chain to act on it if needed. But this means that - // global error handler is probably be called twice as `t.ctx.Error` already does that. - - // NB: later call of the global error handler or middlewares will not take any effect, as echo.Response will be - // already marked as `committed` because we called global error handler above. - t.ctx.Response().Writer = originalWriter // make sure we restore before we signal original coroutine about the error + // This is needed as `http.TimeoutHandler` will write status code by itself on error and after that our tries to write + // status code will not work anymore as Echo.Response thinks it has been already "committed" and further writes + // create errors in log about `superfluous response.WriteHeader call from` + t.writer.Ignore(true) + t.ctx.Response().Writer = originalWriter // make sure we restore writer before we signal original coroutine about the error + // we pass error from handler to middlewares up in handler chain to act on it if needed. t.errChan <- err return } + // we restore original writer only for cases we did not timeout. On timeout we have already sent response to client + // and should not anymore send additional headers/data + // so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body t.ctx.Response().Writer = originalWriter } + +// ignorableWriter is ResponseWriter implementations that allows us to mark writer to ignore further write calls. This +// is handy in cases when you do not have direct control of code being executed (3rd party middleware) but want to make +// sure that external code will not be able to write response to the client. +// Writer is coroutine safe for writes. +type ignorableWriter struct { + http.ResponseWriter + + lock sync.Mutex + ignoreWrites bool +} + +func (w *ignorableWriter) Ignore(ignore bool) { + w.lock.Lock() + w.ignoreWrites = ignore + w.lock.Unlock() +} + +func (w *ignorableWriter) WriteHeader(code int) { + w.lock.Lock() + defer w.lock.Unlock() + if w.ignoreWrites { + return + } + w.ResponseWriter.WriteHeader(code) +} + +func (w *ignorableWriter) Write(b []byte) (int, error) { + w.lock.Lock() + defer w.lock.Unlock() + if w.ignoreWrites { + return len(b), nil + } + return w.ResponseWriter.Write(b) +} diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index aa6402b8d..7fb802a9a 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -74,13 +74,18 @@ func TestTimeoutErrorOutInHandler(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) + rec.Code = 1 // we want to be sure that even 200 will not be sent err := m(func(c echo.Context) error { + // this error must not be written to the client response. Middlewares upstream of timeout middleware must be able + // to handle returned error and this can be done only then handler has not yet committed (written status code) + // the response. return echo.NewHTTPError(http.StatusTeapot, "err") })(c) assert.Error(t, err) - assert.Equal(t, http.StatusTeapot, rec.Code) - assert.Equal(t, "{\"message\":\"err\"}\n", rec.Body.String()) + assert.EqualError(t, err, "code=418, message=err") + assert.Equal(t, 1, rec.Code) + assert.Equal(t, "", rec.Body.String()) } func TestTimeoutSuccessfulRequest(t *testing.T) { From ec92fedf21e817d2d52004a4178292404beb9eaa Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Wed, 16 Mar 2022 08:43:59 +0200 Subject: [PATCH 216/446] Update version and changelog for 4.7.2 --- CHANGELOG.md | 13 +++++++++++++ echo.go | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d1d9086a..ba75d71f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## v4.7.2 - 2022-03-16 + +**Fixes** + +* Fix nil pointer exception when calling Start again after address binding error [#2131](https://github.com/labstack/echo/pull/2131) +* Fix CSRF middleware not being able to extract token from multipart/form-data form [#2136](https://github.com/labstack/echo/pull/2136) +* Fix Timeout middleware write race [#2126](https://github.com/labstack/echo/pull/2126) + +**Enhancements** + +* Recover middleware should not log panic for aborted handler [#2134](https://github.com/labstack/echo/pull/2134) + + ## v4.7.1 - 2022-03-13 **Fixes** diff --git a/echo.go b/echo.go index c2f23c194..8829619c7 100644 --- a/echo.go +++ b/echo.go @@ -246,7 +246,7 @@ const ( const ( // Version of Echo - Version = "4.7.1" + Version = "4.7.2" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 59d2eaa4ac35c4dca41b6545bd410b95f60fe354 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 13 Mar 2022 17:30:02 +0200 Subject: [PATCH 217/446] Add new value binding methods (UnixTimeMilli,TextUnmarshaler,JSONUnmarshaler) to ValueBinder --- binder.go | 121 ++++++++++++++++++++-- binder_test.go | 274 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 385 insertions(+), 10 deletions(-) diff --git a/binder.go b/binder.go index 0900ce8dc..4409c174a 100644 --- a/binder.go +++ b/binder.go @@ -1,6 +1,8 @@ package echo import ( + "encoding" + "encoding/json" "fmt" "net/http" "strconv" @@ -52,8 +54,11 @@ import ( * time * duration * BindUnmarshaler() interface + * TextUnmarshaler() interface + * JSONUnmarshaler() interface * UnixTime() - converts unix time (integer) to time.Time - * UnixTimeNano() - converts unix time with nano second precision (integer) to time.Time + * UnixTimeMilli() - converts unix time with millisecond precision (integer) to time.Time + * UnixTimeNano() - converts unix time with nanosecond precision (integer) to time.Time * CustomFunc() - callback function for your custom conversion logic. Signature `func(values []string) []error` */ @@ -321,6 +326,78 @@ func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshal return b } +// JSONUnmarshaler binds parameter to destination implementing json.Unmarshaler interface +func (b *ValueBinder) JSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + return b + } + + if err := dest.UnmarshalJSON([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err)) + } + return b +} + +// MustJSONUnmarshaler requires parameter value to exist to be bind to destination implementing json.Unmarshaler interface. +// Returns error when value does not exist +func (b *ValueBinder) MustJSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil)) + return b + } + + if err := dest.UnmarshalJSON([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err)) + } + return b +} + +// TextUnmarshaler binds parameter to destination implementing encoding.TextUnmarshaler interface +func (b *ValueBinder) TextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + return b + } + + if err := dest.UnmarshalText([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err)) + } + return b +} + +// MustTextUnmarshaler requires parameter value to exist to be bind to destination implementing encoding.TextUnmarshaler interface. +// Returns error when value does not exist +func (b *ValueBinder) MustTextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil)) + return b + } + + if err := dest.UnmarshalText([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err)) + } + return b +} + // BindWithDelimiter binds parameter to destination by suitable conversion function. // Delimiter is used before conversion to split parameter value to separate values func (b *ValueBinder) BindWithDelimiter(sourceParam string, dest interface{}, delimiter string) *ValueBinder { @@ -1161,7 +1238,7 @@ func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]tim // Note: // * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder { - return b.unixTime(sourceParam, dest, false, false) + return b.unixTime(sourceParam, dest, false, time.Second) } // MustUnixTime requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding @@ -1172,10 +1249,31 @@ func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder // Note: // * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder { - return b.unixTime(sourceParam, dest, true, false) + return b.unixTime(sourceParam, dest, true, time.Second) +} + +// UnixTimeMilli binds parameter to time.Time variable (in local Time corresponding to the given Unix time in millisecond precision). +// +// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 +// +// Note: +// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, false, time.Millisecond) } -// UnixTimeNano binds parameter to time.Time variable (in local Time corresponding to the given Unix time in nano second precision). +// MustUnixTimeMilli requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding +// to the given Unix time in millisecond precision). Returns error when value does not exist. +// +// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 +// +// Note: +// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, true, time.Millisecond) +} + +// UnixTimeNano binds parameter to time.Time variable (in local Time corresponding to the given Unix time in nanosecond precision). // // Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 // Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00 @@ -1185,7 +1283,7 @@ func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBi // * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal // * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { - return b.unixTime(sourceParam, dest, false, true) + return b.unixTime(sourceParam, dest, false, time.Nanosecond) } // MustUnixTimeNano requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding @@ -1199,10 +1297,10 @@ func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBi // * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal // * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { - return b.unixTime(sourceParam, dest, true, true) + return b.unixTime(sourceParam, dest, true, time.Nanosecond) } -func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, isNano bool) *ValueBinder { +func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, precision time.Duration) *ValueBinder { if b.failFast && b.errors != nil { return b } @@ -1221,10 +1319,13 @@ func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExi return b } - if isNano { - *dest = time.Unix(0, n) - } else { + switch precision { + case time.Second: *dest = time.Unix(n, 0) + case time.Millisecond: + *dest = time.Unix(n/1e3, (n%1e3)*1e6) // TODO: time.UnixMilli(n) exists since Go1.17 switch to that when min version allows + case time.Nanosecond: + *dest = time.Unix(0, n) } return b } diff --git a/binder_test.go b/binder_test.go index 034967793..910bbfc50 100644 --- a/binder_test.go +++ b/binder_test.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/stretchr/testify/assert" "io" + "math/big" "net/http" "net/http/httptest" "strconv" @@ -2187,6 +2188,188 @@ func TestValueBinder_BindUnmarshaler(t *testing.T) { } } +func TestValueBinder_JSONUnmarshaler(t *testing.T) { + example := big.NewInt(999) + + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue big.Int + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: big.Int{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: big.Int{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=xxx", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest big.Int + var err error + if tc.whenMust { + err = b.MustJSONUnmarshaler("param", &dest).BindError() + } else { + err = b.JSONUnmarshaler("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_TextUnmarshaler(t *testing.T) { + example := big.NewInt(999) + + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue big.Int + expectError string + }{ + { + name: "ok, binds value", + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: big.Int{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: big.Int{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=xxx", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest big.Int + var err error + if tc.whenMust { + err = b.MustTextUnmarshaler("param", &dest).BindError() + } else { + err = b.TextUnmarshaler("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + func TestValueBinder_BindWithDelimiter_types(t *testing.T) { var testCases = []struct { name string @@ -2529,6 +2712,97 @@ func TestValueBinder_UnixTime(t *testing.T) { } } +func TestValueBinder_UnixTimeMilli(t *testing.T) { + exampleTime, _ := time.Parse(time.RFC3339Nano, "2022-03-13T15:13:30.140000000+00:00") // => 1647184410140 + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue time.Time + expectError string + }{ + { + name: "ok, binds value, unix time in milliseconds", + whenURL: "/search?param=1647184410140¶m=1647184410199", + expectValue: exampleTime, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: time.Time{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1647184410140¶m=1647184410199", + expectValue: exampleTime, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: time.Time{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustUnixTimeMilli("param", &dest).BindError() + } else { + err = b.UnixTimeMilli("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano()) + assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC)) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + func TestValueBinder_UnixTimeNano(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43.000000000+00:00") // => 1609180603 exampleTimeNano, _ := time.Parse(time.RFC3339Nano, "2020-12-28T18:36:43.123456789+00:00") // => 1609180603123456789 From 63c62bcbe521dd060e38392f60a1437764d0794c Mon Sep 17 00:00:00 2001 From: Roland Lammel Date: Wed, 16 Mar 2022 00:56:50 +0100 Subject: [PATCH 218/446] Tidy up comments for value binders --- binder.go | 86 +++++++++++++++++++++++++++---------------------------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/binder.go b/binder.go index 4409c174a..5a6cf9d9b 100644 --- a/binder.go +++ b/binder.go @@ -209,7 +209,7 @@ func (b *ValueBinder) CustomFunc(sourceParam string, customFunc func(values []st return b.customFunc(sourceParam, customFunc, false) } -// MustCustomFunc requires parameter values to exist to be bind with Func. Returns error when value does not exist. +// MustCustomFunc requires parameter values to exist to bind with Func. Returns error when value does not exist. func (b *ValueBinder) MustCustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder { return b.customFunc(sourceParam, customFunc, true) } @@ -246,7 +246,7 @@ func (b *ValueBinder) String(sourceParam string, dest *string) *ValueBinder { return b } -// MustString requires parameter value to exist to be bind to string variable. Returns error when value does not exist +// MustString requires parameter value to exist to bind to string variable. Returns error when value does not exist func (b *ValueBinder) MustString(sourceParam string, dest *string) *ValueBinder { if b.failFast && b.errors != nil { return b @@ -275,7 +275,7 @@ func (b *ValueBinder) Strings(sourceParam string, dest *[]string) *ValueBinder { return b } -// MustStrings requires parameter values to exist to be bind to slice of string variables. Returns error when value does not exist +// MustStrings requires parameter values to exist to bind to slice of string variables. Returns error when value does not exist func (b *ValueBinder) MustStrings(sourceParam string, dest *[]string) *ValueBinder { if b.failFast && b.errors != nil { return b @@ -307,7 +307,7 @@ func (b *ValueBinder) BindUnmarshaler(sourceParam string, dest BindUnmarshaler) return b } -// MustBindUnmarshaler requires parameter value to exist to be bind to destination implementing BindUnmarshaler interface. +// MustBindUnmarshaler requires parameter value to exist to bind to destination implementing BindUnmarshaler interface. // Returns error when value does not exist func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder { if b.failFast && b.errors != nil { @@ -343,7 +343,7 @@ func (b *ValueBinder) JSONUnmarshaler(sourceParam string, dest json.Unmarshaler) return b } -// MustJSONUnmarshaler requires parameter value to exist to be bind to destination implementing json.Unmarshaler interface. +// MustJSONUnmarshaler requires parameter value to exist to bind to destination implementing json.Unmarshaler interface. // Returns error when value does not exist func (b *ValueBinder) MustJSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder { if b.failFast && b.errors != nil { @@ -379,7 +379,7 @@ func (b *ValueBinder) TextUnmarshaler(sourceParam string, dest encoding.TextUnma return b } -// MustTextUnmarshaler requires parameter value to exist to be bind to destination implementing encoding.TextUnmarshaler interface. +// MustTextUnmarshaler requires parameter value to exist to bind to destination implementing encoding.TextUnmarshaler interface. // Returns error when value does not exist func (b *ValueBinder) MustTextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder { if b.failFast && b.errors != nil { @@ -404,7 +404,7 @@ func (b *ValueBinder) BindWithDelimiter(sourceParam string, dest interface{}, de return b.bindWithDelimiter(sourceParam, dest, delimiter, false) } -// MustBindWithDelimiter requires parameter value to exist to be bind destination by suitable conversion function. +// MustBindWithDelimiter requires parameter value to exist to bind destination by suitable conversion function. // Delimiter is used before conversion to split parameter value to separate values func (b *ValueBinder) MustBindWithDelimiter(sourceParam string, dest interface{}, delimiter string) *ValueBinder { return b.bindWithDelimiter(sourceParam, dest, delimiter, true) @@ -453,7 +453,7 @@ func (b *ValueBinder) Int64(sourceParam string, dest *int64) *ValueBinder { return b.intValue(sourceParam, dest, 64, false) } -// MustInt64 requires parameter value to exist to be bind to int64 variable. Returns error when value does not exist +// MustInt64 requires parameter value to exist to bind to int64 variable. Returns error when value does not exist func (b *ValueBinder) MustInt64(sourceParam string, dest *int64) *ValueBinder { return b.intValue(sourceParam, dest, 64, true) } @@ -463,7 +463,7 @@ func (b *ValueBinder) Int32(sourceParam string, dest *int32) *ValueBinder { return b.intValue(sourceParam, dest, 32, false) } -// MustInt32 requires parameter value to exist to be bind to int32 variable. Returns error when value does not exist +// MustInt32 requires parameter value to exist to bind to int32 variable. Returns error when value does not exist func (b *ValueBinder) MustInt32(sourceParam string, dest *int32) *ValueBinder { return b.intValue(sourceParam, dest, 32, true) } @@ -473,7 +473,7 @@ func (b *ValueBinder) Int16(sourceParam string, dest *int16) *ValueBinder { return b.intValue(sourceParam, dest, 16, false) } -// MustInt16 requires parameter value to exist to be bind to int16 variable. Returns error when value does not exist +// MustInt16 requires parameter value to exist to bind to int16 variable. Returns error when value does not exist func (b *ValueBinder) MustInt16(sourceParam string, dest *int16) *ValueBinder { return b.intValue(sourceParam, dest, 16, true) } @@ -483,7 +483,7 @@ func (b *ValueBinder) Int8(sourceParam string, dest *int8) *ValueBinder { return b.intValue(sourceParam, dest, 8, false) } -// MustInt8 requires parameter value to exist to be bind to int8 variable. Returns error when value does not exist +// MustInt8 requires parameter value to exist to bind to int8 variable. Returns error when value does not exist func (b *ValueBinder) MustInt8(sourceParam string, dest *int8) *ValueBinder { return b.intValue(sourceParam, dest, 8, true) } @@ -493,7 +493,7 @@ func (b *ValueBinder) Int(sourceParam string, dest *int) *ValueBinder { return b.intValue(sourceParam, dest, 0, false) } -// MustInt requires parameter value to exist to be bind to int variable. Returns error when value does not exist +// MustInt requires parameter value to exist to bind to int variable. Returns error when value does not exist func (b *ValueBinder) MustInt(sourceParam string, dest *int) *ValueBinder { return b.intValue(sourceParam, dest, 0, true) } @@ -621,7 +621,7 @@ func (b *ValueBinder) Int64s(sourceParam string, dest *[]int64) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInt64s requires parameter value to exist to be bind to int64 slice variable. Returns error when value does not exist +// MustInt64s requires parameter value to exist to bind to int64 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt64s(sourceParam string, dest *[]int64) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -631,7 +631,7 @@ func (b *ValueBinder) Int32s(sourceParam string, dest *[]int32) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInt32s requires parameter value to exist to be bind to int32 slice variable. Returns error when value does not exist +// MustInt32s requires parameter value to exist to bind to int32 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt32s(sourceParam string, dest *[]int32) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -641,7 +641,7 @@ func (b *ValueBinder) Int16s(sourceParam string, dest *[]int16) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInt16s requires parameter value to exist to be bind to int16 slice variable. Returns error when value does not exist +// MustInt16s requires parameter value to exist to bind to int16 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt16s(sourceParam string, dest *[]int16) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -651,7 +651,7 @@ func (b *ValueBinder) Int8s(sourceParam string, dest *[]int8) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInt8s requires parameter value to exist to be bind to int8 slice variable. Returns error when value does not exist +// MustInt8s requires parameter value to exist to bind to int8 slice variable. Returns error when value does not exist func (b *ValueBinder) MustInt8s(sourceParam string, dest *[]int8) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -661,7 +661,7 @@ func (b *ValueBinder) Ints(sourceParam string, dest *[]int) *ValueBinder { return b.intsValue(sourceParam, dest, false) } -// MustInts requires parameter value to exist to be bind to int slice variable. Returns error when value does not exist +// MustInts requires parameter value to exist to bind to int slice variable. Returns error when value does not exist func (b *ValueBinder) MustInts(sourceParam string, dest *[]int) *ValueBinder { return b.intsValue(sourceParam, dest, true) } @@ -671,7 +671,7 @@ func (b *ValueBinder) Uint64(sourceParam string, dest *uint64) *ValueBinder { return b.uintValue(sourceParam, dest, 64, false) } -// MustUint64 requires parameter value to exist to be bind to uint64 variable. Returns error when value does not exist +// MustUint64 requires parameter value to exist to bind to uint64 variable. Returns error when value does not exist func (b *ValueBinder) MustUint64(sourceParam string, dest *uint64) *ValueBinder { return b.uintValue(sourceParam, dest, 64, true) } @@ -681,7 +681,7 @@ func (b *ValueBinder) Uint32(sourceParam string, dest *uint32) *ValueBinder { return b.uintValue(sourceParam, dest, 32, false) } -// MustUint32 requires parameter value to exist to be bind to uint32 variable. Returns error when value does not exist +// MustUint32 requires parameter value to exist to bind to uint32 variable. Returns error when value does not exist func (b *ValueBinder) MustUint32(sourceParam string, dest *uint32) *ValueBinder { return b.uintValue(sourceParam, dest, 32, true) } @@ -691,7 +691,7 @@ func (b *ValueBinder) Uint16(sourceParam string, dest *uint16) *ValueBinder { return b.uintValue(sourceParam, dest, 16, false) } -// MustUint16 requires parameter value to exist to be bind to uint16 variable. Returns error when value does not exist +// MustUint16 requires parameter value to exist to bind to uint16 variable. Returns error when value does not exist func (b *ValueBinder) MustUint16(sourceParam string, dest *uint16) *ValueBinder { return b.uintValue(sourceParam, dest, 16, true) } @@ -701,7 +701,7 @@ func (b *ValueBinder) Uint8(sourceParam string, dest *uint8) *ValueBinder { return b.uintValue(sourceParam, dest, 8, false) } -// MustUint8 requires parameter value to exist to be bind to uint8 variable. Returns error when value does not exist +// MustUint8 requires parameter value to exist to bind to uint8 variable. Returns error when value does not exist func (b *ValueBinder) MustUint8(sourceParam string, dest *uint8) *ValueBinder { return b.uintValue(sourceParam, dest, 8, true) } @@ -711,7 +711,7 @@ func (b *ValueBinder) Byte(sourceParam string, dest *byte) *ValueBinder { return b.uintValue(sourceParam, dest, 8, false) } -// MustByte requires parameter value to exist to be bind to byte variable. Returns error when value does not exist +// MustByte requires parameter value to exist to bind to byte variable. Returns error when value does not exist func (b *ValueBinder) MustByte(sourceParam string, dest *byte) *ValueBinder { return b.uintValue(sourceParam, dest, 8, true) } @@ -721,7 +721,7 @@ func (b *ValueBinder) Uint(sourceParam string, dest *uint) *ValueBinder { return b.uintValue(sourceParam, dest, 0, false) } -// MustUint requires parameter value to exist to be bind to uint variable. Returns error when value does not exist +// MustUint requires parameter value to exist to bind to uint variable. Returns error when value does not exist func (b *ValueBinder) MustUint(sourceParam string, dest *uint) *ValueBinder { return b.uintValue(sourceParam, dest, 0, true) } @@ -849,7 +849,7 @@ func (b *ValueBinder) Uint64s(sourceParam string, dest *[]uint64) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUint64s requires parameter value to exist to be bind to uint64 slice variable. Returns error when value does not exist +// MustUint64s requires parameter value to exist to bind to uint64 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint64s(sourceParam string, dest *[]uint64) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -859,7 +859,7 @@ func (b *ValueBinder) Uint32s(sourceParam string, dest *[]uint32) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUint32s requires parameter value to exist to be bind to uint32 slice variable. Returns error when value does not exist +// MustUint32s requires parameter value to exist to bind to uint32 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint32s(sourceParam string, dest *[]uint32) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -869,7 +869,7 @@ func (b *ValueBinder) Uint16s(sourceParam string, dest *[]uint16) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUint16s requires parameter value to exist to be bind to uint16 slice variable. Returns error when value does not exist +// MustUint16s requires parameter value to exist to bind to uint16 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint16s(sourceParam string, dest *[]uint16) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -879,7 +879,7 @@ func (b *ValueBinder) Uint8s(sourceParam string, dest *[]uint8) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUint8s requires parameter value to exist to be bind to uint8 slice variable. Returns error when value does not exist +// MustUint8s requires parameter value to exist to bind to uint8 slice variable. Returns error when value does not exist func (b *ValueBinder) MustUint8s(sourceParam string, dest *[]uint8) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -889,7 +889,7 @@ func (b *ValueBinder) Uints(sourceParam string, dest *[]uint) *ValueBinder { return b.uintsValue(sourceParam, dest, false) } -// MustUints requires parameter value to exist to be bind to uint slice variable. Returns error when value does not exist +// MustUints requires parameter value to exist to bind to uint slice variable. Returns error when value does not exist func (b *ValueBinder) MustUints(sourceParam string, dest *[]uint) *ValueBinder { return b.uintsValue(sourceParam, dest, true) } @@ -899,7 +899,7 @@ func (b *ValueBinder) Bool(sourceParam string, dest *bool) *ValueBinder { return b.boolValue(sourceParam, dest, false) } -// MustBool requires parameter value to exist to be bind to bool variable. Returns error when value does not exist +// MustBool requires parameter value to exist to bind to bool variable. Returns error when value does not exist func (b *ValueBinder) MustBool(sourceParam string, dest *bool) *ValueBinder { return b.boolValue(sourceParam, dest, true) } @@ -964,7 +964,7 @@ func (b *ValueBinder) Bools(sourceParam string, dest *[]bool) *ValueBinder { return b.boolsValue(sourceParam, dest, false) } -// MustBools requires parameter values to exist to be bind to slice of bool variables. Returns error when values does not exist +// MustBools requires parameter values to exist to bind to slice of bool variables. Returns error when values does not exist func (b *ValueBinder) MustBools(sourceParam string, dest *[]bool) *ValueBinder { return b.boolsValue(sourceParam, dest, true) } @@ -974,7 +974,7 @@ func (b *ValueBinder) Float64(sourceParam string, dest *float64) *ValueBinder { return b.floatValue(sourceParam, dest, 64, false) } -// MustFloat64 requires parameter value to exist to be bind to float64 variable. Returns error when value does not exist +// MustFloat64 requires parameter value to exist to bind to float64 variable. Returns error when value does not exist func (b *ValueBinder) MustFloat64(sourceParam string, dest *float64) *ValueBinder { return b.floatValue(sourceParam, dest, 64, true) } @@ -984,7 +984,7 @@ func (b *ValueBinder) Float32(sourceParam string, dest *float32) *ValueBinder { return b.floatValue(sourceParam, dest, 32, false) } -// MustFloat32 requires parameter value to exist to be bind to float32 variable. Returns error when value does not exist +// MustFloat32 requires parameter value to exist to bind to float32 variable. Returns error when value does not exist func (b *ValueBinder) MustFloat32(sourceParam string, dest *float32) *ValueBinder { return b.floatValue(sourceParam, dest, 32, true) } @@ -1069,7 +1069,7 @@ func (b *ValueBinder) Float64s(sourceParam string, dest *[]float64) *ValueBinder return b.floatsValue(sourceParam, dest, false) } -// MustFloat64s requires parameter values to exist to be bind to slice of float64 variables. Returns error when values does not exist +// MustFloat64s requires parameter values to exist to bind to slice of float64 variables. Returns error when values does not exist func (b *ValueBinder) MustFloat64s(sourceParam string, dest *[]float64) *ValueBinder { return b.floatsValue(sourceParam, dest, true) } @@ -1079,7 +1079,7 @@ func (b *ValueBinder) Float32s(sourceParam string, dest *[]float32) *ValueBinder return b.floatsValue(sourceParam, dest, false) } -// MustFloat32s requires parameter values to exist to be bind to slice of float32 variables. Returns error when values does not exist +// MustFloat32s requires parameter values to exist to bind to slice of float32 variables. Returns error when values does not exist func (b *ValueBinder) MustFloat32s(sourceParam string, dest *[]float32) *ValueBinder { return b.floatsValue(sourceParam, dest, true) } @@ -1089,7 +1089,7 @@ func (b *ValueBinder) Time(sourceParam string, dest *time.Time, layout string) * return b.time(sourceParam, dest, layout, false) } -// MustTime requires parameter value to exist to be bind to time.Time variable. Returns error when value does not exist +// MustTime requires parameter value to exist to bind to time.Time variable. Returns error when value does not exist func (b *ValueBinder) MustTime(sourceParam string, dest *time.Time, layout string) *ValueBinder { return b.time(sourceParam, dest, layout, true) } @@ -1120,7 +1120,7 @@ func (b *ValueBinder) Times(sourceParam string, dest *[]time.Time, layout string return b.times(sourceParam, dest, layout, false) } -// MustTimes requires parameter values to exist to be bind to slice of time.Time variables. Returns error when values does not exist +// MustTimes requires parameter values to exist to bind to slice of time.Time variables. Returns error when values does not exist func (b *ValueBinder) MustTimes(sourceParam string, dest *[]time.Time, layout string) *ValueBinder { return b.times(sourceParam, dest, layout, true) } @@ -1161,7 +1161,7 @@ func (b *ValueBinder) Duration(sourceParam string, dest *time.Duration) *ValueBi return b.duration(sourceParam, dest, false) } -// MustDuration requires parameter value to exist to be bind to time.Duration variable. Returns error when value does not exist +// MustDuration requires parameter value to exist to bind to time.Duration variable. Returns error when value does not exist func (b *ValueBinder) MustDuration(sourceParam string, dest *time.Duration) *ValueBinder { return b.duration(sourceParam, dest, true) } @@ -1192,7 +1192,7 @@ func (b *ValueBinder) Durations(sourceParam string, dest *[]time.Duration) *Valu return b.durationsValue(sourceParam, dest, false) } -// MustDurations requires parameter values to exist to be bind to slice of time.Duration variables. Returns error when values does not exist +// MustDurations requires parameter values to exist to bind to slice of time.Duration variables. Returns error when values does not exist func (b *ValueBinder) MustDurations(sourceParam string, dest *[]time.Duration) *ValueBinder { return b.durationsValue(sourceParam, dest, true) } @@ -1241,7 +1241,7 @@ func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder return b.unixTime(sourceParam, dest, false, time.Second) } -// MustUnixTime requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding +// MustUnixTime requires parameter value to exist to bind to time.Duration variable (in local time corresponding // to the given Unix time). Returns error when value does not exist. // // Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 @@ -1252,7 +1252,7 @@ func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBi return b.unixTime(sourceParam, dest, true, time.Second) } -// UnixTimeMilli binds parameter to time.Time variable (in local Time corresponding to the given Unix time in millisecond precision). +// UnixTimeMilli binds parameter to time.Time variable (in local time corresponding to the given Unix time in millisecond precision). // // Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 // @@ -1262,7 +1262,7 @@ func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueB return b.unixTime(sourceParam, dest, false, time.Millisecond) } -// MustUnixTimeMilli requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding +// MustUnixTimeMilli requires parameter value to exist to bind to time.Duration variable (in local time corresponding // to the given Unix time in millisecond precision). Returns error when value does not exist. // // Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 @@ -1273,7 +1273,7 @@ func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *Va return b.unixTime(sourceParam, dest, true, time.Millisecond) } -// UnixTimeNano binds parameter to time.Time variable (in local Time corresponding to the given Unix time in nanosecond precision). +// UnixTimeNano binds parameter to time.Time variable (in local time corresponding to the given Unix time in nanosecond precision). // // Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 // Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00 @@ -1286,7 +1286,7 @@ func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBi return b.unixTime(sourceParam, dest, false, time.Nanosecond) } -// MustUnixTimeNano requires parameter value to exist to be bind to time.Duration variable (in local Time corresponding +// MustUnixTimeNano requires parameter value to exist to bind to time.Duration variable (in local Time corresponding // to the given Unix time value in nano second precision). Returns error when value does not exist. // // Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 From 572466d92889a5c946885ec90d5a94d7ad25b0a3 Mon Sep 17 00:00:00 2001 From: gemaizi <864321211@qq.com> Date: Mon, 21 Mar 2022 23:45:06 +0800 Subject: [PATCH 219/446] Fix body_limit middleware unit test --- middleware/body_limit_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index 0e8642a06..8981534d4 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -33,12 +33,13 @@ func TestBodyLimit(t *testing.T) { assert.Equal(hw, rec.Body.Bytes()) } - // Based on content read (overlimit) + // Based on content length (overlimit) he := BodyLimit("2B")(h)(c).(*echo.HTTPError) assert.Equal(http.StatusRequestEntityTooLarge, he.Code) // Based on content read (within limit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) if assert.NoError(BodyLimit("2M")(h)(c)) { @@ -48,6 +49,7 @@ func TestBodyLimit(t *testing.T) { // Based on content read (overlimit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) he = BodyLimit("2B")(h)(c).(*echo.HTTPError) From a987b6577c5ade3d4cd3ece29db43487e975b597 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Fri, 29 Apr 2022 21:57:14 +0300 Subject: [PATCH 220/446] Update Github CI flow to use Go 1.18, bump actions versions --- .github/workflows/echo.yml | 59 ++++++++++++-------------------------- Makefile | 6 ++-- 2 files changed, 22 insertions(+), 43 deletions(-) diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 266406664..69535f09c 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -19,6 +19,7 @@ on: - '_fixture/**' - '.github/**' - 'codecov.yml' + workflow_dispatch: jobs: test: @@ -27,33 +28,22 @@ jobs: os: [ubuntu-latest, macos-latest, windows-latest] # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy # Echo tests with last four major releases - go: [1.14, 1.15, 1.16, 1.17] + go: [1.16, 1.17, 1.18] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: - - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v1 - with: - go-version: ${{ matrix.go }} - - - name: Set GOPATH and PATH - run: | - echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV - echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH - shell: bash - - - name: Set build variables - run: | - echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV - echo "GO111MODULE=on" >> $GITHUB_ENV - - name: Checkout Code - uses: actions/checkout@v1 + uses: actions/checkout@v3 with: ref: ${{ github.ref }} + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + - name: Install Dependencies - run: go get -v golang.org/x/lint/golint + run: go install golang.org/x/lint/golint@latest - name: Run Tests run: | @@ -61,7 +51,7 @@ jobs: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... - name: Upload coverage to Codecov - if: success() && matrix.go == 1.17 && matrix.os == 'ubuntu-latest' + if: success() && matrix.go == 1.18 && matrix.os == 'ubuntu-latest' uses: codecov/codecov-action@v1 with: token: @@ -71,39 +61,28 @@ jobs: strategy: matrix: os: [ubuntu-latest] - go: [1.17] + go: [1.18] name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: - - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v1 - with: - go-version: ${{ matrix.go }} - - - name: Set GOPATH and PATH - run: | - echo "GOPATH=$(dirname $GITHUB_WORKSPACE)" >> $GITHUB_ENV - echo "$(dirname $GITHUB_WORKSPACE)/bin" >> $GITHUB_PATH - shell: bash - - - name: Set build variables - run: | - echo "GOPROXY=https://proxy.golang.org" >> $GITHUB_ENV - echo "GO111MODULE=on" >> $GITHUB_ENV - - name: Checkout Code (Previous) - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: ref: ${{ github.base_ref }} path: previous - name: Checkout Code (New) - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: path: new + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + - name: Install Dependencies - run: go get -v golang.org/x/perf/cmd/benchstat + run: go install golang.org/x/perf/cmd/benchstat@latest - name: Run Benchmark (Previous) run: | diff --git a/Makefile b/Makefile index 48061f7e2..a6c4aaa90 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ tag: check: lint vet race ## Check project init: - @go get -u golang.org/x/lint/golint + @go install golang.org/x/lint/golint@latest lint: ## Lint the files @golint -set_exit_status ${PKG_LIST} @@ -29,6 +29,6 @@ benchmark: ## Run benchmarks help: ## Display this help screen @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -goversion ?= "1.15" -test_version: ## Run tests inside Docker with given version (defaults to 1.15 oldest supported). Example: make test_version goversion=1.15 +goversion ?= "1.16" +test_version: ## Run tests inside Docker with given version (defaults to 1.15 oldest supported). Example: make test_version goversion=1.16 @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" From 2e02ce3dd88f4404c87e8a3a410ae8676fad6521 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 21 May 2022 19:27:22 +0300 Subject: [PATCH 221/446] Timeout mw: fix datarace in tests when we are getting data from buffer. Run each test in their own server so multiple tests cases will not cause datarace getting data out of logger buffer. --- middleware/timeout_test.go | 62 ++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index 7fb802a9a..bba48a80f 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -362,40 +362,38 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { }, } - e := echo.New() - - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) - - // NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first - // FIXME: I have no idea how to fix this without adding mutexes. - e.Use(TimeoutWithConfig(TimeoutConfig{ - Timeout: 15 * time.Millisecond, - })) - e.Use(Logger()) - e.Use(Recover()) - - e.GET("/", func(c echo.Context) error { - var delay time.Duration - if err := echo.QueryParamsBinder(c).Duration("delay", &delay).BindError(); err != nil { - return err - } - if delay > 0 { - time.Sleep(delay) - } - return c.JSON(http.StatusTeapot, map[string]string{"message": "OK"}) - }) - - server, addr, err := startServer(e) - if err != nil { - assert.NoError(t, err) - return - } - defer server.Close() - for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - buf.Reset() // this is design this can not be run in parallel + e := echo.New() + + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) + + // NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first + // FIXME: I have no idea how to fix this without adding mutexes. + e.Use(TimeoutWithConfig(TimeoutConfig{ + Timeout: 15 * time.Millisecond, + })) + e.Use(Logger()) + e.Use(Recover()) + + e.GET("/", func(c echo.Context) error { + var delay time.Duration + if err := echo.QueryParamsBinder(c).Duration("delay", &delay).BindError(); err != nil { + return err + } + if delay > 0 { + time.Sleep(delay) + } + return c.JSON(http.StatusTeapot, map[string]string{"message": "OK"}) + }) + + server, addr, err := startServer(e) + if err != nil { + assert.NoError(t, err) + return + } + defer server.Close() res, err := http.Get(fmt.Sprintf("http://%v%v", addr, tc.whenPath)) if err != nil { From 28797c761df73cef962bbe92395089b60275680a Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 21 May 2022 20:58:15 +0300 Subject: [PATCH 222/446] Timeout mw: fix datarace in tests when we are getting data from buffer (in test) and writing to logger at the same time. --- middleware/timeout_test.go | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index bba48a80f..6da6a3866 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -12,6 +12,7 @@ import ( "net/url" "reflect" "strings" + "sync" "testing" "time" @@ -366,7 +367,7 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { t.Run(tc.name, func(t *testing.T) { e := echo.New() - buf := new(bytes.Buffer) + buf := new(coroutineSafeBuffer) e.Logger.SetOutput(buf) // NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first @@ -419,6 +420,36 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { } } +// as we are spawning multiple coroutines - one for http server, one for request, one by timeout middleware, one by testcase +// we are accessing logger (writing/reading) from multiple coroutines and causing dataraces (most often reported on macos) +// we could be writing to logger in logger middleware and at the same time our tests is getting logger buffer contents +// in testcase coroutine. +type coroutineSafeBuffer struct { + bytes.Buffer + lock sync.RWMutex +} + +func (b *coroutineSafeBuffer) Write(p []byte) (n int, err error) { + b.lock.Lock() + defer b.lock.Unlock() + + return b.Buffer.Write(p) +} + +func (b *coroutineSafeBuffer) Bytes() []byte { + b.lock.RLock() + defer b.lock.RUnlock() + + return b.Buffer.Bytes() +} + +func (b *coroutineSafeBuffer) String() string { + b.lock.RLock() + defer b.lock.RUnlock() + + return b.Buffer.String() +} + func startServer(e *echo.Echo) (*http.Server, string, error) { l, err := net.Listen("tcp", ":0") if err != nil { From d5f883707bc2cce801e261959c7a8dd5f111f702 Mon Sep 17 00:00:00 2001 From: Martti T Date: Sun, 22 May 2022 00:21:50 +0300 Subject: [PATCH 223/446] =?UTF-8?q?Timeout=20mw:=20rework=20how=20test=20w?= =?UTF-8?q?aits=20for=20timeout.=20Using=20sleep=20as=20delay=20i=E2=80=A6?= =?UTF-8?q?=20(#2187)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Timeout mw: rework how test waits for timeout. Using sleep as delay is problematic when CI worker is slower than usual. --- middleware/timeout_test.go | 49 ++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index 6da6a3866..56eb7bc74 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -2,6 +2,7 @@ package middleware import ( "bytes" + "context" "errors" "fmt" "io/ioutil" @@ -328,12 +329,13 @@ func TestTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) { func TestTimeoutWithFullEchoStack(t *testing.T) { // test timeout with full http server stack running, do see what http.Server.ErrorLog contains var testCases = []struct { - name string - whenPath string - expectStatusCode int - expectResponse string - expectLogContains []string - expectLogNotContains []string + name string + whenPath string + whenForceHandlerTimeout bool + expectStatusCode int + expectResponse string + expectLogContains []string + expectLogNotContains []string }{ { name: "404 - write response in global error handler", @@ -352,14 +354,15 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { expectLogContains: []string{`"status":418,"error":"",`}, }, { - name: "503 - handler timeouts, write response in timeout middleware", - whenPath: "/?delay=50ms", - expectResponse: "Timeout

Timeout

", - expectStatusCode: http.StatusServiceUnavailable, + name: "503 - handler timeouts, write response in timeout middleware", + whenForceHandlerTimeout: true, + whenPath: "/", + expectResponse: "Timeout

Timeout

", + expectStatusCode: http.StatusServiceUnavailable, expectLogNotContains: []string{ "echo:http: superfluous response.WriteHeader call from", - "{", // means that logger was not called. }, + expectLogContains: []string{"http: Handler timeout"}, }, } @@ -371,21 +374,18 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { e.Logger.SetOutput(buf) // NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first - // FIXME: I have no idea how to fix this without adding mutexes. e.Use(TimeoutWithConfig(TimeoutConfig{ Timeout: 15 * time.Millisecond, })) e.Use(Logger()) e.Use(Recover()) + wg := sync.WaitGroup{} + if tc.whenForceHandlerTimeout { + wg.Add(1) // make `wg.Wait()` block until we release it with `wg.Done()` + } e.GET("/", func(c echo.Context) error { - var delay time.Duration - if err := echo.QueryParamsBinder(c).Duration("delay", &delay).BindError(); err != nil { - return err - } - if delay > 0 { - time.Sleep(delay) - } + wg.Wait() return c.JSON(http.StatusTeapot, map[string]string{"message": "OK"}) }) @@ -401,6 +401,13 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { assert.NoError(t, err) return } + if tc.whenForceHandlerTimeout { + wg.Done() + // shutdown waits for server to shutdown. this way we wait logger mw to be executed + ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) + defer cancel() + server.Shutdown(ctx) + } assert.Equal(t, tc.expectStatusCode, res.StatusCode) if body, err := ioutil.ReadAll(res.Body); err == nil { @@ -411,10 +418,10 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { logged := buf.String() for _, subStr := range tc.expectLogContains { - assert.True(t, strings.Contains(logged, subStr)) + assert.True(t, strings.Contains(logged, subStr), "expected logs to contain: %v, logged: '%v'", subStr, logged) } for _, subStr := range tc.expectLogNotContains { - assert.False(t, strings.Contains(logged, subStr)) + assert.False(t, strings.Contains(logged, subStr), "expected logs not to contain: %v, logged: '%v'", subStr, logged) } }) } From b0453b98e0508cde2cf8915f647c6167e01a0683 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?welling=20guzm=C3=A1n?= <1531291+wellingguzman@users.noreply.github.com> Date: Fri, 27 May 2022 18:44:51 +0200 Subject: [PATCH 224/446] fix: basic auth invalid base64 string (#2191) * fix: basic auth returns 400 on invalid base64 string --- middleware/basic_auth.go | 6 +++++- middleware/basic_auth_test.go | 6 ++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 8cf1ed9fc..52ef1042f 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "strconv" "strings" + "net/http" "github.com/labstack/echo/v4" ) @@ -74,10 +75,13 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { l := len(basic) if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { + // Invalid base64 shouldn't be treated as error + // instead should be treated as invalid client input b, err := base64.StdEncoding.DecodeString(auth[l+1:]) if err != nil { - return err + return echo.NewHTTPError(http.StatusBadRequest).SetInternal(err) } + cred := string(b) for i := 0; i < len(cred); i++ { if cred[i] == ':' { diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 76039db0a..4c355aa16 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -58,6 +58,12 @@ func TestBasicAuth(t *testing.T) { assert.Equal(http.StatusUnauthorized, he.Code) assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate)) + // Invalid base64 string + auth = basic + " invalidString" + req.Header.Set(echo.HeaderAuthorization, auth) + he = h(c).(*echo.HTTPError) + assert.Equal(http.StatusBadRequest, he.Code) + // Missing Authorization header req.Header.Del(echo.HeaderAuthorization) he = h(c).(*echo.HTTPError) From 0644cd6ecdef4473c38c9f298a15512a47c8db42 Mon Sep 17 00:00:00 2001 From: lkeix <53435330+lkeix@users.noreply.github.com> Date: Sat, 28 May 2022 02:15:58 +0900 Subject: [PATCH 225/446] fix: duplicated findStaticChild process at findChildWithLabel (#2176) --- router.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/router.go b/router.go index a1de2d6e3..b5e50d94f 100644 --- a/router.go +++ b/router.go @@ -333,10 +333,8 @@ func (n *node) findStaticChild(l byte) *node { } func (n *node) findChildWithLabel(l byte) *node { - for _, c := range n.staticChildren { - if c.label == l { - return c - } + if c := n.findStaticChild(l); c != nil { + return c } if l == paramLabel { return n.paramChild From ddb66e1ba272fe7580dae6f8543763bc4a760fbd Mon Sep 17 00:00:00 2001 From: moznion Date: Mon, 4 Jul 2022 21:57:39 -0700 Subject: [PATCH 226/446] Add logger middleware template variables: `${time_unix_milli}` and `${time_unix_micro}` (#2206) This patch introduces two template variables `${time_unix_milli}` and `${time_unix_micro}` into the logger middleware. Currently, there is no way to interpolate that UNIX milli and micro seconds timestamp in a log entry, and go 1.17 or later runtime supports the utility functions `time#UnixMilli()` and `time#UnixMicro()` so this patch adds them as well. see also: https://github.com/golang/go/issues/44196 Signed-off-by: moznion --- middleware/logger.go | 8 +++++++ middleware/logger_test.go | 47 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/middleware/logger.go b/middleware/logger.go index 9baac4769..a21df8f39 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -23,6 +23,8 @@ type ( // Tags to construct the logger format. // // - time_unix + // - time_unix_milli + // - time_unix_micro // - time_unix_nano // - time_rfc3339 // - time_rfc3339_nano @@ -126,6 +128,12 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { switch tag { case "time_unix": return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10)) + case "time_unix_milli": + // go 1.17 or later, it supports time#UnixMilli() + return buf.WriteString(strconv.FormatInt(time.Now().UnixNano()/1000000, 10)) + case "time_unix_micro": + // go 1.17 or later, it supports time#UnixMicro() + return buf.WriteString(strconv.FormatInt(time.Now().UnixNano()/1000, 10)) case "time_unix_nano": return buf.WriteString(strconv.FormatInt(time.Now().UnixNano(), 10)) case "time_rfc3339": diff --git a/middleware/logger_test.go b/middleware/logger_test.go index 394f62712..ab889bfda 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strconv" "strings" "testing" "time" @@ -244,3 +245,49 @@ func BenchmarkLoggerWithConfig_withMapFields(b *testing.B) { buf.Reset() } } + +func TestLoggerTemplateWithTimeUnixMilli(t *testing.T) { + buf := new(bytes.Buffer) + + e := echo.New() + e.Use(LoggerWithConfig(LoggerConfig{ + Format: `${time_unix_milli}`, + Output: buf, + })) + + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + unixMillis, err := strconv.ParseInt(buf.String(), 10, 64) + assert.NoError(t, err) + assert.WithinDuration(t, time.Unix(unixMillis/1000, 0), time.Now(), 3*time.Second) +} + +func TestLoggerTemplateWithTimeUnixMicro(t *testing.T) { + buf := new(bytes.Buffer) + + e := echo.New() + e.Use(LoggerWithConfig(LoggerConfig{ + Format: `${time_unix_micro}`, + Output: buf, + })) + + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + unixMicros, err := strconv.ParseInt(buf.String(), 10, 64) + assert.NoError(t, err) + assert.WithinDuration(t, time.Unix(unixMicros/1000000, 0), time.Now(), 3*time.Second) +} From 9bf1e3c8ce6de029186945e4ee82f89f2e2cc661 Mon Sep 17 00:00:00 2001 From: Artem Iurin Date: Mon, 11 Jul 2022 20:25:41 +0300 Subject: [PATCH 227/446] Allow different param names in different methods with same path scheme (#2209) * Change methodHandler element type to methodContext Signed-off-by: ortyomka * Allow different param names in the smae path with different methods Signed-off-by: ortyomka * Rename methodContext to routeMethod Add paramsCount in each node for perfomance Signed-off-by: ortyomka * Add backtracking to nearest path Signed-off-by: ortyomka * Remove params from NotAllowed Signed-off-by: ortyomka --- router.go | 210 ++++++++++++++++++++++++++----------------------- router_test.go | 29 ++++++- 2 files changed, 141 insertions(+), 98 deletions(-) diff --git a/router.go b/router.go index b5e50d94f..6a8615d83 100644 --- a/router.go +++ b/router.go @@ -19,30 +19,35 @@ type ( prefix string parent *node staticChildren children - ppath string - pnames []string - methodHandler *methodHandler + originalPath string + methods *routeMethods paramChild *node anyChild *node + paramsCount int // isLeaf indicates that node does not have child routes isLeaf bool // isHandler indicates that node has at least one handler registered to it isHandler bool } - kind uint8 - children []*node - methodHandler struct { - connect HandlerFunc - delete HandlerFunc - get HandlerFunc - head HandlerFunc - options HandlerFunc - patch HandlerFunc - post HandlerFunc - propfind HandlerFunc - put HandlerFunc - trace HandlerFunc - report HandlerFunc + kind uint8 + children []*node + routeMethod struct { + ppath string + pnames []string + handler HandlerFunc + } + routeMethods struct { + connect *routeMethod + delete *routeMethod + get *routeMethod + head *routeMethod + options *routeMethod + patch *routeMethod + post *routeMethod + propfind *routeMethod + put *routeMethod + trace *routeMethod + report *routeMethod allowHeader string } ) @@ -56,7 +61,7 @@ const ( anyLabel = byte('*') ) -func (m *methodHandler) isHandler() bool { +func (m *routeMethods) isHandler() bool { return m.connect != nil || m.delete != nil || m.get != nil || @@ -70,7 +75,7 @@ func (m *methodHandler) isHandler() bool { m.report != nil } -func (m *methodHandler) updateAllowHeader() { +func (m *routeMethods) updateAllowHeader() { buf := new(bytes.Buffer) buf.WriteString(http.MethodOptions) @@ -119,7 +124,7 @@ func (m *methodHandler) updateAllowHeader() { func NewRouter(e *Echo) *Router { return &Router{ tree: &node{ - methodHandler: new(methodHandler), + methods: new(routeMethods), }, routes: map[string]*Route{}, echo: e, @@ -153,7 +158,7 @@ func (r *Router) Add(method, path string, h HandlerFunc) { } j := i + 1 - r.insert(method, path[:i], nil, staticKind, "", nil) + r.insert(method, path[:i], staticKind, routeMethod{}) for ; i < lcpIndex && path[i] != '/'; i++ { } @@ -163,23 +168,23 @@ func (r *Router) Add(method, path string, h HandlerFunc) { if i == lcpIndex { // path node is last fragment of route path. ie. `/users/:id` - r.insert(method, path[:i], h, paramKind, ppath, pnames) + r.insert(method, path[:i], paramKind, routeMethod{ppath, pnames, h}) } else { - r.insert(method, path[:i], nil, paramKind, "", nil) + r.insert(method, path[:i], paramKind, routeMethod{}) } } else if path[i] == '*' { - r.insert(method, path[:i], nil, staticKind, "", nil) + r.insert(method, path[:i], staticKind, routeMethod{}) pnames = append(pnames, "*") - r.insert(method, path[:i+1], h, anyKind, ppath, pnames) + r.insert(method, path[:i+1], anyKind, routeMethod{ppath, pnames, h}) } } - r.insert(method, path, h, staticKind, ppath, pnames) + r.insert(method, path, staticKind, routeMethod{ppath, pnames, h}) } -func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string, pnames []string) { +func (r *Router) insert(method, path string, t kind, rm routeMethod) { // Adjust max param - paramLen := len(pnames) + paramLen := len(rm.pnames) if *r.echo.maxParam < paramLen { *r.echo.maxParam = paramLen } @@ -207,11 +212,11 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string // At root node currentNode.label = search[0] currentNode.prefix = search - if h != nil { + if rm.handler != nil { currentNode.kind = t - currentNode.addHandler(method, h) - currentNode.ppath = ppath - currentNode.pnames = pnames + currentNode.addMethod(method, &rm) + currentNode.paramsCount = len(rm.pnames) + currentNode.originalPath = rm.ppath } currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil } else if lcpLen < prefixLen { @@ -221,9 +226,9 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string currentNode.prefix[lcpLen:], currentNode, currentNode.staticChildren, - currentNode.methodHandler, - currentNode.ppath, - currentNode.pnames, + currentNode.originalPath, + currentNode.methods, + currentNode.paramsCount, currentNode.paramChild, currentNode.anyChild, ) @@ -243,9 +248,9 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string currentNode.label = currentNode.prefix[0] currentNode.prefix = currentNode.prefix[:lcpLen] currentNode.staticChildren = nil - currentNode.methodHandler = new(methodHandler) - currentNode.ppath = "" - currentNode.pnames = nil + currentNode.originalPath = "" + currentNode.methods = new(routeMethods) + currentNode.paramsCount = 0 currentNode.paramChild = nil currentNode.anyChild = nil currentNode.isLeaf = false @@ -257,13 +262,19 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string if lcpLen == searchLen { // At parent node currentNode.kind = t - currentNode.addHandler(method, h) - currentNode.ppath = ppath - currentNode.pnames = pnames + if rm.handler != nil { + currentNode.addMethod(method, &rm) + currentNode.paramsCount = len(rm.pnames) + currentNode.originalPath = rm.ppath + } } else { // Create child node - n = newNode(t, search[lcpLen:], currentNode, nil, new(methodHandler), ppath, pnames, nil, nil) - n.addHandler(method, h) + n = newNode(t, search[lcpLen:], currentNode, nil, "", new(routeMethods), 0, nil, nil) + if rm.handler != nil { + n.addMethod(method, &rm) + n.paramsCount = len(rm.pnames) + n.originalPath = rm.ppath + } // Only Static children could reach here currentNode.addStaticChild(n) } @@ -277,8 +288,12 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string continue } // Create child node - n := newNode(t, search, currentNode, nil, new(methodHandler), ppath, pnames, nil, nil) - n.addHandler(method, h) + n := newNode(t, search, currentNode, nil, rm.ppath, new(routeMethods), 0, nil, nil) + if rm.handler != nil { + n.addMethod(method, &rm) + n.paramsCount = len(rm.pnames) + } + switch t { case staticKind: currentNode.addStaticChild(n) @@ -290,28 +305,26 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil } else { // Node already exists - if h != nil { - currentNode.addHandler(method, h) - currentNode.ppath = ppath - if len(currentNode.pnames) == 0 { // Issue #729 - currentNode.pnames = pnames - } + if rm.handler != nil { + currentNode.addMethod(method, &rm) + currentNode.paramsCount = len(rm.pnames) + currentNode.originalPath = rm.ppath } } return } } -func newNode(t kind, pre string, p *node, sc children, mh *methodHandler, ppath string, pnames []string, paramChildren, anyChildren *node) *node { +func newNode(t kind, pre string, p *node, sc children, originalPath string, mh *routeMethods, paramsCount int, paramChildren, anyChildren *node) *node { return &node{ kind: t, label: pre[0], prefix: pre, parent: p, staticChildren: sc, - ppath: ppath, - pnames: pnames, - methodHandler: mh, + originalPath: originalPath, + methods: mh, + paramsCount: paramsCount, paramChild: paramChildren, anyChild: anyChildren, isLeaf: sc == nil && paramChildren == nil && anyChildren == nil, @@ -345,64 +358,60 @@ func (n *node) findChildWithLabel(l byte) *node { return nil } -func (n *node) addHandler(method string, h HandlerFunc) { +func (n *node) addMethod(method string, h *routeMethod) { switch method { case http.MethodConnect: - n.methodHandler.connect = h + n.methods.connect = h case http.MethodDelete: - n.methodHandler.delete = h + n.methods.delete = h case http.MethodGet: - n.methodHandler.get = h + n.methods.get = h case http.MethodHead: - n.methodHandler.head = h + n.methods.head = h case http.MethodOptions: - n.methodHandler.options = h + n.methods.options = h case http.MethodPatch: - n.methodHandler.patch = h + n.methods.patch = h case http.MethodPost: - n.methodHandler.post = h + n.methods.post = h case PROPFIND: - n.methodHandler.propfind = h + n.methods.propfind = h case http.MethodPut: - n.methodHandler.put = h + n.methods.put = h case http.MethodTrace: - n.methodHandler.trace = h + n.methods.trace = h case REPORT: - n.methodHandler.report = h + n.methods.report = h } - n.methodHandler.updateAllowHeader() - if h != nil { - n.isHandler = true - } else { - n.isHandler = n.methodHandler.isHandler() - } + n.methods.updateAllowHeader() + n.isHandler = true } -func (n *node) findHandler(method string) HandlerFunc { +func (n *node) findMethod(method string) *routeMethod { switch method { case http.MethodConnect: - return n.methodHandler.connect + return n.methods.connect case http.MethodDelete: - return n.methodHandler.delete + return n.methods.delete case http.MethodGet: - return n.methodHandler.get + return n.methods.get case http.MethodHead: - return n.methodHandler.head + return n.methods.head case http.MethodOptions: - return n.methodHandler.options + return n.methods.options case http.MethodPatch: - return n.methodHandler.patch + return n.methods.patch case http.MethodPost: - return n.methodHandler.post + return n.methods.post case PROPFIND: - return n.methodHandler.propfind + return n.methods.propfind case http.MethodPut: - return n.methodHandler.put + return n.methods.put case http.MethodTrace: - return n.methodHandler.trace + return n.methods.trace case REPORT: - return n.methodHandler.report + return n.methods.report default: return nil } @@ -433,7 +442,7 @@ func (r *Router) Find(method, path string, c Context) { var ( previousBestMatchNode *node - matchedHandler HandlerFunc + matchedRouteMethod *routeMethod // search stores the remaining path to check for match. By each iteration we move from start of path to end of the path // and search value gets shorter and shorter. search = path @@ -529,8 +538,8 @@ func (r *Router) Find(method, path string, c Context) { if previousBestMatchNode == nil { previousBestMatchNode = currentNode } - if h := currentNode.findHandler(method); h != nil { - matchedHandler = h + if h := currentNode.findMethod(method); h != nil { + matchedRouteMethod = h break } } @@ -569,7 +578,8 @@ func (r *Router) Find(method, path string, c Context) { if child := currentNode.anyChild; child != nil { // If any node is found, use remaining path for paramValues currentNode = child - paramValues[len(currentNode.pnames)-1] = search + paramValues[currentNode.paramsCount-1] = search + // update indexes/search in case we need to backtrack when no handler match is found paramIndex++ searchIndex += +len(search) @@ -580,8 +590,8 @@ func (r *Router) Find(method, path string, c Context) { if previousBestMatchNode == nil { previousBestMatchNode = currentNode } - if h := currentNode.findHandler(method); h != nil { - matchedHandler = h + if h := currentNode.findMethod(method); h != nil { + matchedRouteMethod = h break } } @@ -604,22 +614,28 @@ func (r *Router) Find(method, path string, c Context) { return // nothing matched at all } - if matchedHandler != nil { - ctx.handler = matchedHandler + var rPath string + var rPNames []string + if matchedRouteMethod != nil { + ctx.handler = matchedRouteMethod.handler + rPath = matchedRouteMethod.ppath + rPNames = matchedRouteMethod.pnames } else { // use previous match as basis. although we have no matching handler we have path match. // so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404) currentNode = previousBestMatchNode + rPath = currentNode.originalPath + rPNames = nil // no params here ctx.handler = NotFoundHandler if currentNode.isHandler { - ctx.Set(ContextKeyHeaderAllow, currentNode.methodHandler.allowHeader) + ctx.Set(ContextKeyHeaderAllow, currentNode.methods.allowHeader) ctx.handler = MethodNotAllowedHandler if method == http.MethodOptions { - ctx.handler = optionsMethodHandler(currentNode.methodHandler.allowHeader) + ctx.handler = optionsMethodHandler(currentNode.methods.allowHeader) } } } - ctx.path = currentNode.ppath - ctx.pnames = currentNode.pnames + ctx.path = rPath + ctx.pnames = rPNames } diff --git a/router_test.go b/router_test.go index 457566b90..8645a26c1 100644 --- a/router_test.go +++ b/router_test.go @@ -2318,6 +2318,33 @@ func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) { } } +// Issue #1726 +func TestRouterDifferentParamsInPath(t *testing.T) { + e := New() + r := e.router + r.Add(http.MethodPut, "/*", func(Context) error { + return nil + }) + r.Add(http.MethodPut, "/users/:vid/files/:gid", func(Context) error { + return nil + }) + r.Add(http.MethodGet, "/users/:uid/files/:fid", func(Context) error { + return nil + }) + c := e.NewContext(nil, nil).(*context) + + r.Find(http.MethodGet, "/users/1/files/2", c) + assert.Equal(t, "1", c.Param("uid")) + assert.Equal(t, "2", c.Param("fid")) + + r.Find(http.MethodGet, "/users/1/shouldBacktrackToFirstAnyRouteAnd405", c) + assert.Equal(t, "/*", c.Path()) + + r.Find(http.MethodPut, "/users/3/files/4", c) + assert.Equal(t, "3", c.Param("vid")) + assert.Equal(t, "4", c.Param("gid")) +} + func TestRouterHandleMethodOptions(t *testing.T) { e := New() r := e.router @@ -2380,7 +2407,7 @@ func TestRouterHandleMethodOptions(t *testing.T) { assert.NoError(t, err) assert.Equal(t, tc.expectStatus, rec.Code) } - assert.Equal(t, tc.expectAllowHeader, c.Response().Header().Get("Allow")) + assert.Equal(t, tc.expectAllowHeader, c.Response().Header().Get(HeaderAllow)) }) } } From 690e3392d984dcbdb9f41a7915ddc0d383311974 Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 12 Jul 2022 21:53:41 +0300 Subject: [PATCH 228/446] Add support for registering handlers for 404 routes (#2217) --- echo.go | 13 ++++ echo_test.go | 64 +++++++++++++++++ group.go | 7 ++ group_test.go | 65 +++++++++++++++++ router.go | 59 +++++++++++----- router_test.go | 185 +++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 375 insertions(+), 18 deletions(-) diff --git a/echo.go b/echo.go index 8829619c7..5b10d586e 100644 --- a/echo.go +++ b/echo.go @@ -183,6 +183,8 @@ const ( PROPFIND = "PROPFIND" // REPORT Method can be used to get information about a resource, see rfc 3253 REPORT = "REPORT" + // RouteNotFound is special method type for routes handling "route not found" (404) cases + RouteNotFound = "echo_route_not_found" ) // Headers @@ -480,6 +482,16 @@ func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { return e.Add(http.MethodTrace, path, h, m...) } +// RouteNotFound registers a special-case route which is executed when no other route is found (i.e. HTTP 404 cases) +// for current request URL. +// Path supports static and named/any parameters just like other http method is defined. Generally path is ended with +// wildcard/match-any character (`/*`, `/download/*` etc). +// +// Example: `e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })` +func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { + return e.Add(RouteNotFound, path, h, m...) +} + // Any registers a new route for all HTTP methods and path with matching handler // in the router with optional route-level middleware. func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { @@ -515,6 +527,7 @@ func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route { func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { name := handlerName(handler) router := e.findRouter(host) + // FIXME: when handler+middleware are both nil ... make it behave like handler removal router.Add(method, path, func(c Context) error { h := applyMiddleware(handler, middleware...) return h(c) diff --git a/echo_test.go b/echo_test.go index 0e1e42be0..64796b3b5 100644 --- a/echo_test.go +++ b/echo_test.go @@ -766,6 +766,70 @@ func TestEchoNotFound(t *testing.T) { assert.Equal(t, http.StatusNotFound, rec.Code) } +func TestEcho_RouteNotFound(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectCode int + }{ + { + name: "404, route to static not found handler /a/c/xx", + whenURL: "/a/c/xx", + expectRoute: "GET /a/c/xx", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to path param not found handler /a/:file", + whenURL: "/a/echo.exe", + expectRoute: "GET /a/:file", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to any not found handler /*", + whenURL: "/b/echo.exe", + expectRoute: "GET /*", + expectCode: http.StatusNotFound, + }, + { + name: "200, route /a/c/df to /a/c/df", + whenURL: "/a/c/df", + expectRoute: "GET /a/c/df", + expectCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + okHandler := func(c Context) error { + return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) + } + notFoundHandler := func(c Context) error { + return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) + } + + e.GET("/", okHandler) + e.GET("/a/c/df", okHandler) + e.GET("/a/b*", okHandler) + e.PUT("/*", okHandler) + + e.RouteNotFound("/a/c/xx", notFoundHandler) // static + e.RouteNotFound("/a/:file", notFoundHandler) // param + e.RouteNotFound("/*", notFoundHandler) // any + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectRoute, rec.Body.String()) + }) + } +} + func TestEchoMethodNotAllowed(t *testing.T) { e := New() diff --git a/group.go b/group.go index bba470ce8..28ce0dd9a 100644 --- a/group.go +++ b/group.go @@ -107,6 +107,13 @@ func (g *Group) File(path, file string) { g.file(path, file, g.GET) } +// RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group. +// +// Example: `g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })` +func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { + return g.Add(RouteNotFound, path, h, m...) +} + // Add implements `Echo#Add()` for sub-routes within the Group. func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { // Combine into a new slice to avoid accidentally passing the same slice for diff --git a/group_test.go b/group_test.go index c51fd91eb..24f191677 100644 --- a/group_test.go +++ b/group_test.go @@ -119,3 +119,68 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { assert.Equal(t, "/*", m) } + +func TestGroup_RouteNotFound(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectCode int + }{ + { + name: "404, route to static not found handler /group/a/c/xx", + whenURL: "/group/a/c/xx", + expectRoute: "GET /group/a/c/xx", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to path param not found handler /group/a/:file", + whenURL: "/group/a/echo.exe", + expectRoute: "GET /group/a/:file", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to any not found handler /group/*", + whenURL: "/group/b/echo.exe", + expectRoute: "GET /group/*", + expectCode: http.StatusNotFound, + }, + { + name: "200, route /group/a/c/df to /group/a/c/df", + whenURL: "/group/a/c/df", + expectRoute: "GET /group/a/c/df", + expectCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + g := e.Group("/group") + + okHandler := func(c Context) error { + return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) + } + notFoundHandler := func(c Context) error { + return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) + } + + g.GET("/", okHandler) + g.GET("/a/c/df", okHandler) + g.GET("/a/b*", okHandler) + g.PUT("/*", okHandler) + + g.RouteNotFound("/a/c/xx", notFoundHandler) // static + g.RouteNotFound("/a/:file", notFoundHandler) // param + g.RouteNotFound("/*", notFoundHandler) // any + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectRoute, rec.Body.String()) + }) + } +} diff --git a/router.go b/router.go index 6a8615d83..74bc7659b 100644 --- a/router.go +++ b/router.go @@ -28,6 +28,9 @@ type ( isLeaf bool // isHandler indicates that node has at least one handler registered to it isHandler bool + + // notFoundHandler is handler registered with RouteNotFound method and is executed for 404 cases + notFoundHandler *routeMethod } kind uint8 children []*node @@ -73,6 +76,7 @@ func (m *routeMethods) isHandler() bool { m.put != nil || m.trace != nil || m.report != nil + // RouteNotFound/404 is not considered as a handler } func (m *routeMethods) updateAllowHeader() { @@ -382,6 +386,9 @@ func (n *node) addMethod(method string, h *routeMethod) { n.methods.trace = h case REPORT: n.methods.report = h + case RouteNotFound: + n.notFoundHandler = h + return // RouteNotFound/404 is not considered as a handler so no further logic needs to be executed } n.methods.updateAllowHeader() @@ -412,7 +419,7 @@ func (n *node) findMethod(method string) *routeMethod { return n.methods.trace case REPORT: return n.methods.report - default: + default: // RouteNotFound/404 is not considered as a handler return nil } } @@ -515,7 +522,7 @@ func (r *Router) Find(method, path string, c Context) { // No matching prefix, let's backtrack to the first possible alternative node of the decision path nk, ok := backtrackToNextNodeKind(staticKind) if !ok { - return // No other possibilities on the decision path + return // No other possibilities on the decision path, handler will be whatever context is reset to. } else if nk == paramKind { goto Param // NOTE: this case (backtracking from static node to previous any node) can not happen by current any matching logic. Any node is end of search currently @@ -531,15 +538,21 @@ func (r *Router) Find(method, path string, c Context) { search = search[lcpLen:] searchIndex = searchIndex + lcpLen - // Finish routing if no remaining search and we are on a node with handler and matching method type - if search == "" && currentNode.isHandler { - // check if current node has handler registered for http method we are looking for. we store currentNode as - // best matching in case we do no find no more routes matching this path+method - if previousBestMatchNode == nil { - previousBestMatchNode = currentNode - } - if h := currentNode.findMethod(method); h != nil { - matchedRouteMethod = h + // Finish routing if is no request path remaining to search + if search == "" { + // in case of node that is handler we have exact method type match or something for 405 to use + if currentNode.isHandler { + // check if current node has handler registered for http method we are looking for. we store currentNode as + // best matching in case we do no find no more routes matching this path+method + if previousBestMatchNode == nil { + previousBestMatchNode = currentNode + } + if h := currentNode.findMethod(method); h != nil { + matchedRouteMethod = h + break + } + } else if currentNode.notFoundHandler != nil { + matchedRouteMethod = currentNode.notFoundHandler break } } @@ -559,7 +572,8 @@ func (r *Router) Find(method, path string, c Context) { i := 0 l := len(search) if currentNode.isLeaf { - // when param node does not have any children then param node should act similarly to any node - consider all remaining search as match + // when param node does not have any children (path param is last piece of route path) then param node should + // act similarly to any node - consider all remaining search as match i = l } else { for ; i < l && search[i] != '/'; i++ { @@ -585,13 +599,16 @@ func (r *Router) Find(method, path string, c Context) { searchIndex += +len(search) search = "" - // check if current node has handler registered for http method we are looking for. we store currentNode as - // best matching in case we do no find no more routes matching this path+method + if h := currentNode.findMethod(method); h != nil { + matchedRouteMethod = h + break + } + // we store currentNode as best matching in case we do not find more routes matching this path+method. Needed for 405 if previousBestMatchNode == nil { previousBestMatchNode = currentNode } - if h := currentNode.findMethod(method); h != nil { - matchedRouteMethod = h + if currentNode.notFoundHandler != nil { + matchedRouteMethod = currentNode.notFoundHandler break } } @@ -614,12 +631,14 @@ func (r *Router) Find(method, path string, c Context) { return // nothing matched at all } + // matchedHandler could be method+path handler that we matched or notFoundHandler from node with matching path + // user provided not found (404) handler has priority over generic method not found (405) handler or global 404 handler var rPath string var rPNames []string if matchedRouteMethod != nil { - ctx.handler = matchedRouteMethod.handler rPath = matchedRouteMethod.ppath rPNames = matchedRouteMethod.pnames + ctx.handler = matchedRouteMethod.handler } else { // use previous match as basis. although we have no matching handler we have path match. // so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404) @@ -628,7 +647,11 @@ func (r *Router) Find(method, path string, c Context) { rPath = currentNode.originalPath rPNames = nil // no params here ctx.handler = NotFoundHandler - if currentNode.isHandler { + if currentNode.notFoundHandler != nil { + rPath = currentNode.notFoundHandler.ppath + rPNames = currentNode.notFoundHandler.pnames + ctx.handler = currentNode.notFoundHandler.handler + } else if currentNode.isHandler { ctx.Set(ContextKeyHeaderAllow, currentNode.methods.allowHeader) ctx.handler = MethodNotAllowedHandler if method == http.MethodOptions { diff --git a/router_test.go b/router_test.go index 8645a26c1..34f325d33 100644 --- a/router_test.go +++ b/router_test.go @@ -1101,6 +1101,191 @@ func TestRouterBacktrackingFromMultipleParamKinds(t *testing.T) { } } +func TestNotFoundRouteAnyKind(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectID int + expectParam map[string]string + }{ + { + name: "route not existent /xx to not found handler /*", + whenURL: "/xx", + expectRoute: "/*", + expectID: 4, + expectParam: map[string]string{"*": "xx"}, + }, + { + name: "route not existent /a/xx to not found handler /a/*", + whenURL: "/a/xx", + expectRoute: "/a/*", + expectID: 5, + expectParam: map[string]string{"*": "xx"}, + }, + { + name: "route not existent /a/c/dxxx to not found handler /a/c/d*", + whenURL: "/a/c/dxxx", + expectRoute: "/a/c/d*", + expectID: 6, + expectParam: map[string]string{"*": "xxx"}, + }, + { + name: "route /a/c/df to /a/c/df", + whenURL: "/a/c/df", + expectRoute: "/a/c/df", + expectID: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/", handlerHelper("ID", 0)) + r.Add(http.MethodGet, "/a/c/df", handlerHelper("ID", 1)) + r.Add(http.MethodGet, "/a/b*", handlerHelper("ID", 2)) + r.Add(http.MethodPut, "/*", handlerHelper("ID", 3)) + + r.Add(RouteNotFound, "/a/c/d*", handlerHelper("ID", 6)) + r.Add(RouteNotFound, "/a/*", handlerHelper("ID", 5)) + r.Add(RouteNotFound, "/*", handlerHelper("ID", 4)) + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, tc.whenURL, c) + + c.handler(c) + + testValue, _ := c.Get("ID").(int) + assert.Equal(t, tc.expectID, testValue) + assert.Equal(t, tc.expectRoute, c.Path()) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } +} + +func TestNotFoundRouteParamKind(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectID int + expectParam map[string]string + }{ + { + name: "route not existent /xx to not found handler /:file", + whenURL: "/xx", + expectRoute: "/:file", + expectID: 4, + expectParam: map[string]string{"file": "xx"}, + }, + { + name: "route not existent /a/xx to not found handler /a/:file", + whenURL: "/a/xx", + expectRoute: "/a/:file", + expectID: 5, + expectParam: map[string]string{"file": "xx"}, + }, + { + name: "route not existent /a/c/dxxx to not found handler /a/c/d:file", + whenURL: "/a/c/dxxx", + expectRoute: "/a/c/d:file", + expectID: 6, + expectParam: map[string]string{"file": "xxx"}, + }, + { + name: "route /a/c/df to /a/c/df", + whenURL: "/a/c/df", + expectRoute: "/a/c/df", + expectID: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/", handlerHelper("ID", 0)) + r.Add(http.MethodGet, "/a/c/df", handlerHelper("ID", 1)) + r.Add(http.MethodGet, "/a/b*", handlerHelper("ID", 2)) + r.Add(http.MethodPut, "/*", handlerHelper("ID", 3)) + + r.Add(RouteNotFound, "/a/c/d:file", handlerHelper("ID", 6)) + r.Add(RouteNotFound, "/a/:file", handlerHelper("ID", 5)) + r.Add(RouteNotFound, "/:file", handlerHelper("ID", 4)) + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, tc.whenURL, c) + + c.handler(c) + + testValue, _ := c.Get("ID").(int) + assert.Equal(t, tc.expectID, testValue) + assert.Equal(t, tc.expectRoute, c.Path()) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } +} + +func TestNotFoundRouteStaticKind(t *testing.T) { + // note: static not found handler is quite silly thing to have but we still support it + var testCases = []struct { + name string + whenURL string + expectRoute interface{} + expectID int + expectParam map[string]string + }{ + { + name: "route not existent / to not found handler /", + whenURL: "/", + expectRoute: "/", + expectID: 3, + expectParam: map[string]string{}, + }, + { + name: "route /a to /a", + whenURL: "/a", + expectRoute: "/a", + expectID: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodPut, "/", handlerHelper("ID", 0)) + r.Add(http.MethodGet, "/a", handlerHelper("ID", 1)) + r.Add(http.MethodPut, "/*", handlerHelper("ID", 2)) + + r.Add(RouteNotFound, "/", handlerHelper("ID", 3)) + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodGet, tc.whenURL, c) + + c.handler(c) + + testValue, _ := c.Get("ID").(int) + assert.Equal(t, tc.expectID, testValue) + assert.Equal(t, tc.expectRoute, c.Path()) + for param, expectedValue := range tc.expectParam { + assert.Equal(t, expectedValue, c.Param(param)) + } + checkUnusedParamValues(t, c, tc.expectParam) + }) + } +} + // Issue #1509 func TestRouterParamStaticConflict(t *testing.T) { e := New() From 70acd57105b15b2db18c730b52083a76d05babf7 Mon Sep 17 00:00:00 2001 From: Martti T Date: Wed, 13 Jul 2022 08:16:27 +0300 Subject: [PATCH 229/446] Fix case when routeNotFound handler is lost when new route is added to the router (#2219) --- router.go | 51 ++++++++++++++++++++++++++++++++++---------------- router_test.go | 37 ++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 16 deletions(-) diff --git a/router.go b/router.go index 74bc7659b..90102a294 100644 --- a/router.go +++ b/router.go @@ -224,7 +224,12 @@ func (r *Router) insert(method, path string, t kind, rm routeMethod) { } currentNode.isLeaf = currentNode.staticChildren == nil && currentNode.paramChild == nil && currentNode.anyChild == nil } else if lcpLen < prefixLen { - // Split node + // Split node into two before we insert new node. + // This happens when we are inserting path that is submatch of any existing inserted paths. + // For example, we have node `/test` and now are about to insert `/te/*`. In that case + // 1. overlapping part is `/te` that is used as parent node + // 2. `st` is part from existing node that is not matching - it gets its own node (child to `/te`) + // 3. `/*` is the new part we are about to insert (child to `/te`) n := newNode( currentNode.kind, currentNode.prefix[lcpLen:], @@ -235,6 +240,7 @@ func (r *Router) insert(method, path string, t kind, rm routeMethod) { currentNode.paramsCount, currentNode.paramChild, currentNode.anyChild, + currentNode.notFoundHandler, ) // Update parent path for all children to new node for _, child := range currentNode.staticChildren { @@ -259,6 +265,7 @@ func (r *Router) insert(method, path string, t kind, rm routeMethod) { currentNode.anyChild = nil currentNode.isLeaf = false currentNode.isHandler = false + currentNode.notFoundHandler = nil // Only Static children could reach here currentNode.addStaticChild(n) @@ -273,7 +280,7 @@ func (r *Router) insert(method, path string, t kind, rm routeMethod) { } } else { // Create child node - n = newNode(t, search[lcpLen:], currentNode, nil, "", new(routeMethods), 0, nil, nil) + n = newNode(t, search[lcpLen:], currentNode, nil, "", new(routeMethods), 0, nil, nil, nil) if rm.handler != nil { n.addMethod(method, &rm) n.paramsCount = len(rm.pnames) @@ -292,7 +299,7 @@ func (r *Router) insert(method, path string, t kind, rm routeMethod) { continue } // Create child node - n := newNode(t, search, currentNode, nil, rm.ppath, new(routeMethods), 0, nil, nil) + n := newNode(t, search, currentNode, nil, rm.ppath, new(routeMethods), 0, nil, nil, nil) if rm.handler != nil { n.addMethod(method, &rm) n.paramsCount = len(rm.pnames) @@ -319,20 +326,32 @@ func (r *Router) insert(method, path string, t kind, rm routeMethod) { } } -func newNode(t kind, pre string, p *node, sc children, originalPath string, mh *routeMethods, paramsCount int, paramChildren, anyChildren *node) *node { +func newNode( + t kind, + pre string, + p *node, + sc children, + originalPath string, + methods *routeMethods, + paramsCount int, + paramChildren, + anyChildren *node, + notFoundHandler *routeMethod, +) *node { return &node{ - kind: t, - label: pre[0], - prefix: pre, - parent: p, - staticChildren: sc, - originalPath: originalPath, - methods: mh, - paramsCount: paramsCount, - paramChild: paramChildren, - anyChild: anyChildren, - isLeaf: sc == nil && paramChildren == nil && anyChildren == nil, - isHandler: mh.isHandler(), + kind: t, + label: pre[0], + prefix: pre, + parent: p, + staticChildren: sc, + originalPath: originalPath, + methods: methods, + paramsCount: paramsCount, + paramChild: paramChildren, + anyChild: anyChildren, + isLeaf: sc == nil && paramChildren == nil && anyChildren == nil, + isHandler: methods.isHandler(), + notFoundHandler: notFoundHandler, } } diff --git a/router_test.go b/router_test.go index 34f325d33..1b0c409b6 100644 --- a/router_test.go +++ b/router_test.go @@ -1286,6 +1286,43 @@ func TestNotFoundRouteStaticKind(t *testing.T) { } } +func TestRouter_notFoundRouteWithNodeSplitting(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/test*", handlerHelper("ID", 0)) + r.Add(RouteNotFound, "/*", handlerHelper("ID", 1)) + r.Add(RouteNotFound, "/test", handlerHelper("ID", 2)) + + // Tree before: + // 1 `/` + // 1.1 `*` (any) ID=1 + // 1.2 `test` (static) ID=2 + // 1.2.1 `*` (any) ID=0 + + // node with path `test` has routeNotFound handler from previous Add call. Now when we insert `/te/st*` into router tree + // This means that node `test` is split into `te` and `st` nodes and new node `/st*` is inserted. + // On that split `/test` routeNotFound handler must not be lost. + r.Add(http.MethodGet, "/te/st*", handlerHelper("ID", 3)) + // Tree after: + // 1 `/` + // 1.1 `*` (any) ID=1 + // 1.2 `te` (static) + // 1.2.1 `st` (static) ID=2 + // 1.2.1.1 `*` (any) ID=0 + // 1.2.2 `/st` (static) + // 1.2.2.1 `*` (any) ID=3 + + c := e.NewContext(nil, nil).(*context) + r.Find(http.MethodPut, "/test", c) + + c.handler(c) + + testValue, _ := c.Get("ID").(int) + assert.Equal(t, 2, testValue) + assert.Equal(t, "/test", c.Path()) +} + // Issue #1509 func TestRouterParamStaticConflict(t *testing.T) { e := New() From a9879ffa6b6fe73e43a0e062884f83eb959e6c1a Mon Sep 17 00:00:00 2001 From: Daniel Price Date: Thu, 21 Jul 2022 17:40:44 +0000 Subject: [PATCH 230/446] Middlewares should use errors.As() instead of type assertion on HTTPError - Helps consumers who want to wrap HTTPError, and other use cases --- middleware/request_logger.go | 6 ++++-- middleware/static.go | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/middleware/request_logger.go b/middleware/request_logger.go index 1b3e3eaad..7a4d9822e 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -2,9 +2,10 @@ package middleware import ( "errors" - "github.com/labstack/echo/v4" "net/http" "time" + + "github.com/labstack/echo/v4" ) // Example for `fmt.Printf` @@ -264,7 +265,8 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.LogStatus { v.Status = res.Status if err != nil { - if httpErr, ok := err.(*echo.HTTPError); ok { + var httpErr *echo.HTTPError + if errors.As(err, &httpErr) { v.Status = httpErr.Code } } diff --git a/middleware/static.go b/middleware/static.go index 0106f7ce2..27ccf4117 100644 --- a/middleware/static.go +++ b/middleware/static.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "fmt" "html/template" "net/http" @@ -196,8 +197,8 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { return err } - he, ok := err.(*echo.HTTPError) - if !(ok && config.HTML5 && he.Code == http.StatusNotFound) { + var he *echo.HTTPError + if !(errors.As(err, &he) && config.HTML5 && he.Code == http.StatusNotFound) { return err } From 61422dd7de9b0359708ff56b67099b91b5954c31 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 6 Aug 2022 23:29:21 +0300 Subject: [PATCH 231/446] Update CI-flow (Go 1.19 +deps) --- .github/workflows/echo.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 69535f09c..db60f7f84 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -28,7 +28,7 @@ jobs: os: [ubuntu-latest, macos-latest, windows-latest] # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy # Echo tests with last four major releases - go: [1.16, 1.17, 1.18] + go: [1.16, 1.17, 1.18, 1.19] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: @@ -51,8 +51,8 @@ jobs: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... - name: Upload coverage to Codecov - if: success() && matrix.go == 1.18 && matrix.os == 'ubuntu-latest' - uses: codecov/codecov-action@v1 + if: success() && matrix.go == 1.19 && matrix.os == 'ubuntu-latest' + uses: codecov/codecov-action@v3 with: token: fail_ci_if_error: false @@ -61,7 +61,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - go: [1.18] + go: [1.19] name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: From a327884b682cc1458580ac583e04dc610429b6fc Mon Sep 17 00:00:00 2001 From: go-woo Date: Mon, 8 Aug 2022 15:18:59 +0800 Subject: [PATCH 232/446] add:README.md-Third-party middlewares-github.com/go-woo/protoc-gen-echo --- README.md | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 8b2321f05..17e6ed934 100644 --- a/README.md +++ b/README.md @@ -93,15 +93,16 @@ func hello(c echo.Context) error { # Third-party middlewares -| Repository | Description | -|------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | (by Echo team) [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | -| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | -| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | -| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | -| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | -| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | -| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. +| Repository | Description | +|------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | (by Echo team) [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | +| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | +| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | +| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | +| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | +| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | +| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. | +| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code | Please send a PR to add your own library here. From cba12a570e8caa1fafabc2d41afc97bd7a83d758 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 6 Aug 2022 23:25:43 +0300 Subject: [PATCH 233/446] Allow arbitrary HTTP method types to be added as routes --- echo.go | 5 +++- router.go | 19 ++++++++++-- router_test.go | 80 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 3 deletions(-) diff --git a/echo.go b/echo.go index 5b10d586e..5738578df 100644 --- a/echo.go +++ b/echo.go @@ -492,8 +492,11 @@ func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *R return e.Add(RouteNotFound, path, h, m...) } -// Any registers a new route for all HTTP methods and path with matching handler +// Any registers a new route for all HTTP methods (supported by Echo) and path with matching handler // in the router with optional route-level middleware. +// +// Note: this method only adds specific set of supported HTTP methods as handler and is not true +// "catch-any-arbitrary-method" way of matching requests. func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { routes := make([]*Route, len(methods)) for i, m := range methods { diff --git a/router.go b/router.go index 90102a294..23c5bd3ba 100644 --- a/router.go +++ b/router.go @@ -51,6 +51,7 @@ type ( put *routeMethod trace *routeMethod report *routeMethod + anyOther map[string]*routeMethod allowHeader string } ) @@ -75,7 +76,8 @@ func (m *routeMethods) isHandler() bool { m.propfind != nil || m.put != nil || m.trace != nil || - m.report != nil + m.report != nil || + len(m.anyOther) != 0 // RouteNotFound/404 is not considered as a handler } @@ -121,6 +123,10 @@ func (m *routeMethods) updateAllowHeader() { if m.report != nil { buf.WriteString(", REPORT") } + for method := range m.anyOther { // for simplicity, we use map and therefore order is not deterministic here + buf.WriteString(", ") + buf.WriteString(method) + } m.allowHeader = buf.String() } @@ -408,6 +414,15 @@ func (n *node) addMethod(method string, h *routeMethod) { case RouteNotFound: n.notFoundHandler = h return // RouteNotFound/404 is not considered as a handler so no further logic needs to be executed + default: + if n.methods.anyOther == nil { + n.methods.anyOther = make(map[string]*routeMethod) + } + if h.handler == nil { + delete(n.methods.anyOther, method) + } else { + n.methods.anyOther[method] = h + } } n.methods.updateAllowHeader() @@ -439,7 +454,7 @@ func (n *node) findMethod(method string) *routeMethod { case REPORT: return n.methods.report default: // RouteNotFound/404 is not considered as a handler - return nil + return n.methods.anyOther[method] } } diff --git a/router_test.go b/router_test.go index 1b0c409b6..a95421011 100644 --- a/router_test.go +++ b/router_test.go @@ -716,6 +716,67 @@ func TestRouterParam(t *testing.T) { } } +func TestRouter_addAndMatchAllSupportedMethods(t *testing.T) { + var testCases = []struct { + name string + givenNoAddRoute bool + whenMethod string + expectPath string + expectError string + }{ + {name: "ok, CONNECT", whenMethod: http.MethodConnect}, + {name: "ok, DELETE", whenMethod: http.MethodDelete}, + {name: "ok, GET", whenMethod: http.MethodGet}, + {name: "ok, HEAD", whenMethod: http.MethodHead}, + {name: "ok, OPTIONS", whenMethod: http.MethodOptions}, + {name: "ok, PATCH", whenMethod: http.MethodPatch}, + {name: "ok, POST", whenMethod: http.MethodPost}, + {name: "ok, PROPFIND", whenMethod: PROPFIND}, + {name: "ok, PUT", whenMethod: http.MethodPut}, + {name: "ok, TRACE", whenMethod: http.MethodTrace}, + {name: "ok, REPORT", whenMethod: REPORT}, + {name: "ok, NON_TRADITIONAL_METHOD", whenMethod: "NON_TRADITIONAL_METHOD"}, + { + name: "ok, NOT_EXISTING_METHOD", + whenMethod: "NOT_EXISTING_METHOD", + givenNoAddRoute: true, + expectPath: "/*", + expectError: "code=405, message=Method Not Allowed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + e.GET("/*", handlerFunc) + + if !tc.givenNoAddRoute { + e.Add(tc.whenMethod, "/my/*", handlerFunc) + } + + req := httptest.NewRequest(tc.whenMethod, "/my/some-url", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec).(*context) + + e.router.Find(tc.whenMethod, "/my/some-url", c) + err := c.handler(c) + + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + expectPath := "/my/*" + if tc.expectPath != "" { + expectPath = tc.expectPath + } + assert.Equal(t, expectPath, c.Path()) + }) + } +} + func TestMethodNotAllowedAndNotFound(t *testing.T) { e := New() r := e.router @@ -2634,6 +2695,25 @@ func TestRouterHandleMethodOptions(t *testing.T) { } } +func TestRouterAllowHeaderForAnyOtherMethodType(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/users", handlerFunc) + r.Add("COPY", "/users", handlerFunc) + r.Add("LOCK", "/users", handlerFunc) + + req := httptest.NewRequest("TEST", "/users", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec).(*context) + + r.Find("TEST", "/users", c) + err := c.handler(c) + + assert.EqualError(t, err, "code=405, message=Method Not Allowed") + assert.ElementsMatch(t, []string{"COPY", "GET", "LOCK", "OPTIONS"}, strings.Split(c.Response().Header().Get(HeaderAllow), ", ")) +} + func benchmarkRouterRoutes(b *testing.B, routes []*Route, routesToFind []*Route) { e := New() r := e.router From d48197db7af19becf2363496493ed0e2a8d1caea Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Wed, 10 Aug 2022 22:49:43 +0300 Subject: [PATCH 234/446] Changelog for 4.8.0 --- CHANGELOG.md | 31 +++++++++++++++++++++++++++++++ echo.go | 2 +- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ba75d71f6..2fcb2ff7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,36 @@ # Changelog +## v4.8.0 - 2022-08-10 + +**Most notable things** + +You can now add any arbitrary HTTP method type as a route [#2237](https://github.com/labstack/echo/pull/2237) +```go +e.Add("COPY", "/*", func(c echo.Context) error + return c.String(http.StatusOK, "OK COPY") +}) +``` + +You can add custom 404 handler for specific paths [#2217](https://github.com/labstack/echo/pull/2217) +```go +e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) }) + +g := e.Group("/images") +g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) }) +``` + +**Enhancements** + +* Add new value binding methods (UnixTimeMilli,TextUnmarshaler,JSONUnmarshaler) to Valuebinder [#2127](https://github.com/labstack/echo/pull/2127) +* Refactor: body_limit middleware unit test [#2145](https://github.com/labstack/echo/pull/2145) +* Refactor: Timeout mw: rework how test waits for timeout. [#2187](https://github.com/labstack/echo/pull/2187) +* BasicAuth middleware returns 500 InternalServerError on invalid base64 strings but should return 400 [#2191](https://github.com/labstack/echo/pull/2191) +* Refactor: duplicated findStaticChild process at findChildWithLabel [#2176](https://github.com/labstack/echo/pull/2176) +* Allow different param names in different methods with same path scheme [#2209](https://github.com/labstack/echo/pull/2209) +* Add support for registering handlers for different 404 routes [#2217](https://github.com/labstack/echo/pull/2217) +* Middlewares should use errors.As() instead of type assertion on HTTPError [#2227](https://github.com/labstack/echo/pull/2227) +* Allow arbitrary HTTP method types to be added as routes [#2237](https://github.com/labstack/echo/pull/2237) + ## v4.7.2 - 2022-03-16 **Fixes** diff --git a/echo.go b/echo.go index 5738578df..291c4047e 100644 --- a/echo.go +++ b/echo.go @@ -248,7 +248,7 @@ const ( const ( // Version of Echo - Version = "4.7.2" + Version = "4.8.0" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From fb57d96a6dc0e35c7016153ac6d028ced1ca0e41 Mon Sep 17 00:00:00 2001 From: Kamandlou Date: Fri, 19 Aug 2022 21:36:40 +0430 Subject: [PATCH 235/446] replace GET constance with stdlib constance --- bind_test.go | 6 +++--- context_test.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bind_test.go b/bind_test.go index 4ed8dbb50..a7801da92 100644 --- a/bind_test.go +++ b/bind_test.go @@ -330,7 +330,7 @@ func TestBindUnmarshalParam(t *testing.T) { func TestBindUnmarshalText(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) + req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { @@ -406,7 +406,7 @@ func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) { func TestBindUnmarshalTextPtr(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil) + req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { @@ -462,7 +462,7 @@ func TestBindbindData(t *testing.T) { func TestBindParam(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) c.SetPath("/users/:id/:name") diff --git a/context_test.go b/context_test.go index a8b9a9946..377a740e4 100644 --- a/context_test.go +++ b/context_test.go @@ -728,7 +728,7 @@ func TestContext_QueryString(t *testing.T) { queryString := "query=string&var=val" - req := httptest.NewRequest(GET, "/?"+queryString, nil) + req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil) c := e.NewContext(req, nil) testify.Equal(t, queryString, c.QueryString()) @@ -739,7 +739,7 @@ func TestContext_Request(t *testing.T) { testify.Nil(t, c.Request()) - req := httptest.NewRequest(GET, "/path", nil) + req := httptest.NewRequest(http.MethodGet, "/path", nil) c.SetRequest(req) testify.Equal(t, req, c.Request()) From 534bbb81e3a13f04c7c1513d23659ae599e421c0 Mon Sep 17 00:00:00 2001 From: Kamandlou Date: Fri, 19 Aug 2022 21:38:38 +0430 Subject: [PATCH 236/446] replace POST constance with stdlib constance --- bind_test.go | 2 +- context_test.go | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bind_test.go b/bind_test.go index a7801da92..822008153 100644 --- a/bind_test.go +++ b/bind_test.go @@ -492,7 +492,7 @@ func TestBindParam(t *testing.T) { // Bind something with param and post data payload body := bytes.NewBufferString(`{ "name": "Jon Snow" }`) e2 := New() - req2 := httptest.NewRequest(POST, "/", body) + req2 := httptest.NewRequest(http.MethodPost, "/", body) req2.Header.Set(HeaderContentType, MIMEApplicationJSON) rec2 := httptest.NewRecorder() diff --git a/context_test.go b/context_test.go index 377a740e4..b25e11c60 100644 --- a/context_test.go +++ b/context_test.go @@ -32,7 +32,7 @@ var testUser = user{1, "Jon Snow"} func BenchmarkAllocJSONP(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) @@ -46,7 +46,7 @@ func BenchmarkAllocJSONP(b *testing.B) { func BenchmarkAllocJSON(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) @@ -60,7 +60,7 @@ func BenchmarkAllocJSON(b *testing.B) { func BenchmarkAllocXML(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) @@ -849,7 +849,7 @@ func TestContext_IsWebSocket(t *testing.T) { func TestContext_Bind(t *testing.T) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) c := e.NewContext(req, nil) u := new(user) From d77e8c09b21bb23fa8d2dc0250c998d0da0815fa Mon Sep 17 00:00:00 2001 From: Mojtaba Arezoumand Date: Thu, 1 Sep 2022 12:21:55 +0430 Subject: [PATCH 237/446] Added ErrorHandler and ErrorHandlerWithContext in CSRF middleware (#2257) * feat: add error handler to csrf middleware Co-authored-by: Mojtaba Arezoomand --- middleware/csrf.go | 18 ++++++++++++++++-- middleware/csrf_test.go | 22 ++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/middleware/csrf.go b/middleware/csrf.go index 61299f5ca..ea90fdba7 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -61,7 +61,13 @@ type ( // Indicates SameSite mode of the CSRF cookie. // Optional. Default value SameSiteDefaultMode. CookieSameSite http.SameSite `yaml:"cookie_same_site"` + + // ErrorHandler defines a function which is executed for returning custom errors. + ErrorHandler CSRFErrorHandler } + + // CSRFErrorHandler is a function which is executed for creating custom errors. + CSRFErrorHandler func(err error, c echo.Context) error ) // ErrCSRFInvalid is returned when CSRF check fails @@ -154,8 +160,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { lastTokenErr = ErrCSRFInvalid } } + var finalErr error if lastTokenErr != nil { - return lastTokenErr + finalErr = lastTokenErr } else if lastExtractorErr != nil { // ugly part to preserve backwards compatible errors. someone could rely on them if lastExtractorErr == errQueryExtractorValueMissing { @@ -167,7 +174,14 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { } else { lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error()) } - return lastExtractorErr + finalErr = lastExtractorErr + } + + if finalErr != nil { + if config.ErrorHandler != nil { + return config.ErrorHandler(finalErr, c) + } + return finalErr } } diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 9aff82a98..6bccdbe4d 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -358,3 +358,25 @@ func TestCSRFConfig_skipper(t *testing.T) { }) } } + +func TestCSRFErrorHandling(t *testing.T) { + cfg := CSRFConfig{ + ErrorHandler: func(err error, c echo.Context) error { + return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed") + }, + } + + e := echo.New() + e.POST("/", func(c echo.Context) error { + return c.String(http.StatusNotImplemented, "should not end up here") + }) + + e.Use(CSRFWithConfig(cfg)) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.Equal(t, "{\"message\":\"error_handler_executed\"}\n", res.Body.String()) +} From 0ac4d74402391912ff6da733bb09fd4c3980b4e1 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 4 Sep 2022 22:44:32 +0300 Subject: [PATCH 238/446] Fix #2259 open redirect vulnerability in echo.StaticDirectoryHandler (used by e.Static, e.StaticFs etc) remove pre Go1.16 and after differences --- binder_go1.15_test.go | 265 ------------------ binder_test.go | 222 ++++++++++++++- context_fs.go | 42 ++- context_fs_go1.16.go | 52 ---- ...xt_fs_go1.16_test.go => context_fs_test.go | 3 - echo_fs.go | 179 +++++++++--- echo_fs_go1.16.go | 169 ----------- echo_fs_go1.16_test.go => echo_fs_test.go | 12 +- group_fs.go | 31 +- group_fs_go1.16.go | 33 --- group_fs_go1.16_test.go => group_fs_test.go | 3 - middleware/static_1_16_test.go | 106 ------- middleware/static_test.go | 103 +++++++ 13 files changed, 534 insertions(+), 686 deletions(-) delete mode 100644 binder_go1.15_test.go delete mode 100644 context_fs_go1.16.go rename context_fs_go1.16_test.go => context_fs_test.go (98%) delete mode 100644 echo_fs_go1.16.go rename echo_fs_go1.16_test.go => echo_fs_test.go (95%) delete mode 100644 group_fs_go1.16.go rename group_fs_go1.16_test.go => group_fs_test.go (98%) delete mode 100644 middleware/static_1_16_test.go diff --git a/binder_go1.15_test.go b/binder_go1.15_test.go deleted file mode 100644 index 018628c3a..000000000 --- a/binder_go1.15_test.go +++ /dev/null @@ -1,265 +0,0 @@ -// +build go1.15 - -package echo - -/** - Since version 1.15 time.Time and time.Duration error message pattern has changed (values are wrapped now in \"\") - So pre 1.15 these tests fail with similar error: - - expected: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param" - actual : "code=400, message=failed to bind field value to Duration, internal=time: invalid duration nope, field=param" -*/ - -import ( - "errors" - "github.com/stretchr/testify/assert" - "io" - "net/http" - "net/http/httptest" - "testing" - "time" -) - -func createTestContext15(URL string, body io.Reader, pathParams map[string]string) Context { - e := New() - req := httptest.NewRequest(http.MethodGet, URL, body) - if body != nil { - req.Header.Set(HeaderContentType, MIMEApplicationJSON) - } - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - if len(pathParams) > 0 { - names := make([]string, 0) - values := make([]string, 0) - for name, value := range pathParams { - names = append(names, name) - values = append(values, value) - } - c.SetParamNames(names...) - c.SetParamValues(values...) - } - - return c -} - -func TestValueBinder_TimeError(t *testing.T) { - var testCases = []struct { - name string - givenFailFast bool - givenBindErrors []error - whenURL string - whenMust bool - whenLayout string - expectValue time.Time - expectError string - }{ - { - name: "nok, conversion fails, value is not changed", - whenURL: "/search?param=nope¶m=100", - expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", - }, - { - name: "nok (must), conversion fails, value is not changed", - whenMust: true, - whenURL: "/search?param=nope¶m=100", - expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := createTestContext15(tc.whenURL, nil, nil) - b := QueryParamsBinder(c).FailFast(tc.givenFailFast) - if tc.givenFailFast { - b.errors = []error{errors.New("previous error")} - } - - dest := time.Time{} - var err error - if tc.whenMust { - err = b.MustTime("param", &dest, tc.whenLayout).BindError() - } else { - err = b.Time("param", &dest, tc.whenLayout).BindError() - } - - assert.Equal(t, tc.expectValue, dest) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValueBinder_TimesError(t *testing.T) { - var testCases = []struct { - name string - givenFailFast bool - givenBindErrors []error - whenURL string - whenMust bool - whenLayout string - expectValue []time.Time - expectError string - }{ - { - name: "nok, fail fast without binding value", - givenFailFast: true, - whenURL: "/search?param=1¶m=100", - expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param", - }, - { - name: "nok, conversion fails, value is not changed", - whenURL: "/search?param=nope¶m=100", - expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", - }, - { - name: "nok (must), conversion fails, value is not changed", - whenMust: true, - whenURL: "/search?param=nope¶m=100", - expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := createTestContext15(tc.whenURL, nil, nil) - b := QueryParamsBinder(c).FailFast(tc.givenFailFast) - b.errors = tc.givenBindErrors - - layout := time.RFC3339 - if tc.whenLayout != "" { - layout = tc.whenLayout - } - - var dest []time.Time - var err error - if tc.whenMust { - err = b.MustTimes("param", &dest, layout).BindError() - } else { - err = b.Times("param", &dest, layout).BindError() - } - - assert.Equal(t, tc.expectValue, dest) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValueBinder_DurationError(t *testing.T) { - var testCases = []struct { - name string - givenFailFast bool - givenBindErrors []error - whenURL string - whenMust bool - expectValue time.Duration - expectError string - }{ - { - name: "nok, conversion fails, value is not changed", - whenURL: "/search?param=nope¶m=100", - expectValue: 0, - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", - }, - { - name: "nok (must), conversion fails, value is not changed", - whenMust: true, - whenURL: "/search?param=nope¶m=100", - expectValue: 0, - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := createTestContext15(tc.whenURL, nil, nil) - b := QueryParamsBinder(c).FailFast(tc.givenFailFast) - if tc.givenFailFast { - b.errors = []error{errors.New("previous error")} - } - - var dest time.Duration - var err error - if tc.whenMust { - err = b.MustDuration("param", &dest).BindError() - } else { - err = b.Duration("param", &dest).BindError() - } - - assert.Equal(t, tc.expectValue, dest) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValueBinder_DurationsError(t *testing.T) { - var testCases = []struct { - name string - givenFailFast bool - givenBindErrors []error - whenURL string - whenMust bool - expectValue []time.Duration - expectError string - }{ - { - name: "nok, fail fast without binding value", - givenFailFast: true, - whenURL: "/search?param=1¶m=100", - expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, internal=time: missing unit in duration \"1\", field=param", - }, - { - name: "nok, conversion fails, value is not changed", - whenURL: "/search?param=nope¶m=100", - expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", - }, - { - name: "nok (must), conversion fails, value is not changed", - whenMust: true, - whenURL: "/search?param=nope¶m=100", - expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - c := createTestContext15(tc.whenURL, nil, nil) - b := QueryParamsBinder(c).FailFast(tc.givenFailFast) - b.errors = tc.givenBindErrors - - var dest []time.Duration - var err error - if tc.whenMust { - err = b.MustDurations("param", &dest).BindError() - } else { - err = b.Durations("param", &dest).BindError() - } - - assert.Equal(t, tc.expectValue, dest) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - }) - } -} diff --git a/binder_test.go b/binder_test.go index 910bbfc50..0b27cae64 100644 --- a/binder_test.go +++ b/binder_test.go @@ -1,4 +1,3 @@ -// run tests as external package to get real feel for API package echo import ( @@ -3029,3 +3028,224 @@ func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) { } } } + +func TestValueBinder_TimeError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + whenLayout string + expectValue time.Time + expectError string + }{ + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustTime("param", &dest, tc.whenLayout).BindError() + } else { + err = b.Time("param", &dest, tc.whenLayout).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_TimesError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + whenLayout string + expectValue []time.Time + expectError string + }{ + { + name: "nok, fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + layout := time.RFC3339 + if tc.whenLayout != "" { + layout = tc.whenLayout + } + + var dest []time.Time + var err error + if tc.whenMust { + err = b.MustTimes("param", &dest, layout).BindError() + } else { + err = b.Times("param", &dest, layout).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_DurationError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue time.Duration + expectError string + }{ + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: 0, + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: 0, + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest time.Duration + var err error + if tc.whenMust { + err = b.MustDuration("param", &dest).BindError() + } else { + err = b.Duration("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_DurationsError(t *testing.T) { + var testCases = []struct { + name string + givenFailFast bool + givenBindErrors []error + whenURL string + whenMust bool + expectValue []time.Duration + expectError string + }{ + { + name: "nok, fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, internal=time: missing unit in duration \"1\", field=param", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + var dest []time.Duration + var err error + if tc.whenMust { + err = b.MustDurations("param", &dest).BindError() + } else { + err = b.Durations("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/context_fs.go b/context_fs.go index 11ee84bcd..1038f892e 100644 --- a/context_fs.go +++ b/context_fs.go @@ -1,33 +1,49 @@ -//go:build !go1.16 -// +build !go1.16 - package echo import ( + "errors" + "io" + "io/fs" "net/http" - "os" "path/filepath" ) -func (c *context) File(file string) (err error) { - f, err := os.Open(file) +func (c *context) File(file string) error { + return fsFile(c, file, c.echo.Filesystem) +} + +// FileFS serves file from given file system. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (c *context) FileFS(file string, filesystem fs.FS) error { + return fsFile(c, file, filesystem) +} + +func fsFile(c Context, file string, filesystem fs.FS) error { + f, err := filesystem.Open(file) if err != nil { - return NotFoundHandler(c) + return ErrNotFound } defer f.Close() fi, _ := f.Stat() if fi.IsDir() { - file = filepath.Join(file, indexPage) - f, err = os.Open(file) + file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect. + f, err = filesystem.Open(file) if err != nil { - return NotFoundHandler(c) + return ErrNotFound } defer f.Close() if fi, err = f.Stat(); err != nil { - return + return err } } - http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f) - return + ff, ok := f.(io.ReadSeeker) + if !ok { + return errors.New("file does not implement io.ReadSeeker") + } + http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) + return nil } diff --git a/context_fs_go1.16.go b/context_fs_go1.16.go deleted file mode 100644 index c1c724afd..000000000 --- a/context_fs_go1.16.go +++ /dev/null @@ -1,52 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "errors" - "io" - "io/fs" - "net/http" - "path/filepath" -) - -func (c *context) File(file string) error { - return fsFile(c, file, c.echo.Filesystem) -} - -// FileFS serves file from given file system. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (c *context) FileFS(file string, filesystem fs.FS) error { - return fsFile(c, file, filesystem) -} - -func fsFile(c Context, file string, filesystem fs.FS) error { - f, err := filesystem.Open(file) - if err != nil { - return ErrNotFound - } - defer f.Close() - - fi, _ := f.Stat() - if fi.IsDir() { - file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect. - f, err = filesystem.Open(file) - if err != nil { - return ErrNotFound - } - defer f.Close() - if fi, err = f.Stat(); err != nil { - return err - } - } - ff, ok := f.(io.ReadSeeker) - if !ok { - return errors.New("file does not implement io.ReadSeeker") - } - http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) - return nil -} diff --git a/context_fs_go1.16_test.go b/context_fs_test.go similarity index 98% rename from context_fs_go1.16_test.go rename to context_fs_test.go index 027d1c483..51346c956 100644 --- a/context_fs_go1.16_test.go +++ b/context_fs_test.go @@ -1,6 +1,3 @@ -//go:build go1.16 -// +build go1.16 - package echo import ( diff --git a/echo_fs.go b/echo_fs.go index c3790545a..b8526da9e 100644 --- a/echo_fs.go +++ b/echo_fs.go @@ -1,62 +1,175 @@ -//go:build !go1.16 -// +build !go1.16 - package echo import ( + "fmt" + "io/fs" "net/http" "net/url" "os" "path/filepath" + "runtime" + "strings" ) type filesystem struct { + // Filesystem is file system used by Static and File handlers to access files. + // Defaults to os.DirFS(".") + // + // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary + // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths + // including `assets/images` as their prefix. + Filesystem fs.FS } func createFilesystem() filesystem { - return filesystem{} + return filesystem{ + Filesystem: newDefaultFS(), + } } -// Static registers a new route with path prefix to serve static files from the -// provided root directory. -func (e *Echo) Static(prefix, root string) *Route { - if root == "" { - root = "." // For security we want to restrict to CWD. - } - return e.static(prefix, root, e.GET) +// Static registers a new route with path prefix to serve static files from the provided root directory. +func (e *Echo) Static(pathPrefix, fsRoot string) *Route { + subFs := MustSubFS(e.Filesystem, fsRoot) + return e.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(subFs, false), + ) } -func (common) static(prefix, root string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route) *Route { - h := func(c Context) error { - p, err := url.PathUnescape(c.Param("*")) - if err != nil { - return err +// StaticFS registers a new route with path prefix to serve static files from the provided file system. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route { + return e.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(filesystem, false), + ) +} + +// StaticDirectoryHandler creates handler function to serve files from provided file system +// When disablePathUnescaping is set then file name from path is not unescaped and is served as is. +func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc { + return func(c Context) error { + p := c.Param("*") + if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice + tmpPath, err := url.PathUnescape(p) + if err != nil { + return fmt.Errorf("failed to unescape path variable: %w", err) + } + p = tmpPath } - name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security - fi, err := os.Stat(name) + // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid + name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/"))) + fi, err := fs.Stat(fileSystem, name) if err != nil { - // The access path does not exist - return NotFoundHandler(c) + return ErrNotFound } // If the request is for a directory and does not end with "/" p = c.Request().URL.Path // path must not be empty. - if fi.IsDir() && p[len(p)-1] != '/' { + if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { // Redirect to ends with "/" - return c.Redirect(http.StatusMovedPermanently, p+"/") + return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/")) } - return c.File(name) - } - // Handle added routes based on trailing slash: - // /prefix => exact route "/prefix" + any route "/prefix/*" - // /prefix/ => only any route "/prefix/*" - if prefix != "" { - if prefix[len(prefix)-1] == '/' { - // Only add any route for intentional trailing slash - return get(prefix+"*", h) + return fsFile(c, name, fileSystem) + } +} + +// FileFS registers a new route with path to serve file from the provided file system. +func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { + return e.GET(path, StaticFileHandler(file, filesystem), m...) +} + +// StaticFileHandler creates handler function to serve file from provided file system +func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { + return func(c Context) error { + return fsFile(c, file, filesystem) + } +} + +// defaultFS exists to preserve pre v4.7.0 behaviour where files were open by `os.Open`. +// v4.7 introduced `echo.Filesystem` field which is Go1.16+ `fs.Fs` interface. +// Difference between `os.Open` and `fs.Open` is that FS does not allow opening path that start with `.`, `..` or `/` +// etc. For example previously you could have `../images` in your application but `fs := os.DirFS("./")` would not +// allow you to use `fs.Open("../images")` and this would break all old applications that rely on being able to +// traverse up from current executable run path. +// NB: private because you really should use fs.FS implementation instances +type defaultFS struct { + prefix string + fs fs.FS +} + +func newDefaultFS() *defaultFS { + dir, _ := os.Getwd() + return &defaultFS{ + prefix: dir, + fs: nil, + } +} + +func (fs defaultFS) Open(name string) (fs.File, error) { + if fs.fs == nil { + return os.Open(name) + } + return fs.fs.Open(name) +} + +func subFS(currentFs fs.FS, root string) (fs.FS, error) { + root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows + if dFS, ok := currentFs.(*defaultFS); ok { + // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. + // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we + // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs + if isRelativePath(root) { + root = filepath.Join(dFS.prefix, root) } - get(prefix, h) + return &defaultFS{ + prefix: root, + fs: os.DirFS(root), + }, nil + } + return fs.Sub(currentFs, root) +} + +func isRelativePath(path string) bool { + if path == "" { + return true + } + if path[0] == '/' { + return false + } + if runtime.GOOS == "windows" && strings.IndexByte(path, ':') != -1 { + // https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file?redirectedfrom=MSDN#file_and_directory_names + // https://docs.microsoft.com/en-us/dotnet/standard/io/file-path-formats + return false + } + return true +} + +// MustSubFS creates sub FS from current filesystem or panic on failure. +// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. +// +// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with +// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to +// create sub fs which uses necessary prefix for directory path. +func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { + subFs, err := subFS(currentFs, fsRoot) + if err != nil { + panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err)) + } + return subFs +} + +func sanitizeURI(uri string) string { + // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri + // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash + if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') { + uri = "/" + strings.TrimLeft(uri, `/\`) } - return get(prefix+"/*", h) + return uri } diff --git a/echo_fs_go1.16.go b/echo_fs_go1.16.go deleted file mode 100644 index eb17768ab..000000000 --- a/echo_fs_go1.16.go +++ /dev/null @@ -1,169 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "fmt" - "io/fs" - "net/http" - "net/url" - "os" - "path/filepath" - "runtime" - "strings" -) - -type filesystem struct { - // Filesystem is file system used by Static and File handlers to access files. - // Defaults to os.DirFS(".") - // - // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary - // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths - // including `assets/images` as their prefix. - Filesystem fs.FS -} - -func createFilesystem() filesystem { - return filesystem{ - Filesystem: newDefaultFS(), - } -} - -// Static registers a new route with path prefix to serve static files from the provided root directory. -func (e *Echo) Static(pathPrefix, fsRoot string) *Route { - subFs := MustSubFS(e.Filesystem, fsRoot) - return e.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(subFs, false), - ) -} - -// StaticFS registers a new route with path prefix to serve static files from the provided file system. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route { - return e.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(filesystem, false), - ) -} - -// StaticDirectoryHandler creates handler function to serve files from provided file system -// When disablePathUnescaping is set then file name from path is not unescaped and is served as is. -func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc { - return func(c Context) error { - p := c.Param("*") - if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice - tmpPath, err := url.PathUnescape(p) - if err != nil { - return fmt.Errorf("failed to unescape path variable: %w", err) - } - p = tmpPath - } - - // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid - name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/"))) - fi, err := fs.Stat(fileSystem, name) - if err != nil { - return ErrNotFound - } - - // If the request is for a directory and does not end with "/" - p = c.Request().URL.Path // path must not be empty. - if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { - // Redirect to ends with "/" - return c.Redirect(http.StatusMovedPermanently, p+"/") - } - return fsFile(c, name, fileSystem) - } -} - -// FileFS registers a new route with path to serve file from the provided file system. -func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { - return e.GET(path, StaticFileHandler(file, filesystem), m...) -} - -// StaticFileHandler creates handler function to serve file from provided file system -func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { - return func(c Context) error { - return fsFile(c, file, filesystem) - } -} - -// defaultFS exists to preserve pre v4.7.0 behaviour where files were open by `os.Open`. -// v4.7 introduced `echo.Filesystem` field which is Go1.16+ `fs.Fs` interface. -// Difference between `os.Open` and `fs.Open` is that FS does not allow opening path that start with `.`, `..` or `/` -// etc. For example previously you could have `../images` in your application but `fs := os.DirFS("./")` would not -// allow you to use `fs.Open("../images")` and this would break all old applications that rely on being able to -// traverse up from current executable run path. -// NB: private because you really should use fs.FS implementation instances -type defaultFS struct { - prefix string - fs fs.FS -} - -func newDefaultFS() *defaultFS { - dir, _ := os.Getwd() - return &defaultFS{ - prefix: dir, - fs: nil, - } -} - -func (fs defaultFS) Open(name string) (fs.File, error) { - if fs.fs == nil { - return os.Open(name) - } - return fs.fs.Open(name) -} - -func subFS(currentFs fs.FS, root string) (fs.FS, error) { - root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows - if dFS, ok := currentFs.(*defaultFS); ok { - // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. - // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we - // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs - if isRelativePath(root) { - root = filepath.Join(dFS.prefix, root) - } - return &defaultFS{ - prefix: root, - fs: os.DirFS(root), - }, nil - } - return fs.Sub(currentFs, root) -} - -func isRelativePath(path string) bool { - if path == "" { - return true - } - if path[0] == '/' { - return false - } - if runtime.GOOS == "windows" && strings.IndexByte(path, ':') != -1 { - // https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file?redirectedfrom=MSDN#file_and_directory_names - // https://docs.microsoft.com/en-us/dotnet/standard/io/file-path-formats - return false - } - return true -} - -// MustSubFS creates sub FS from current filesystem or panic on failure. -// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. -// -// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with -// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to -// create sub fs which uses necessary prefix for directory path. -func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { - subFs, err := subFS(currentFs, fsRoot) - if err != nil { - panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err)) - } - return subFs -} diff --git a/echo_fs_go1.16_test.go b/echo_fs_test.go similarity index 95% rename from echo_fs_go1.16_test.go rename to echo_fs_test.go index 07e516555..eb072a28d 100644 --- a/echo_fs_go1.16_test.go +++ b/echo_fs_test.go @@ -1,6 +1,3 @@ -//go:build go1.16 -// +build go1.16 - package echo import ( @@ -139,6 +136,15 @@ func TestEcho_StaticFS(t *testing.T) { expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", }, + { + name: "open redirect vulnerability", + givenPrefix: "/", + givenFs: os.DirFS("_fixture/"), + whenURL: "/open.redirect.hackercom%2f..", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/open.redirect.hackercom/../", // location starting with `//open` would be very bad + expectBodyStartsWith: "", + }, } for _, tc := range testCases { diff --git a/group_fs.go b/group_fs.go index 0a1ce4a94..aedc4c6a9 100644 --- a/group_fs.go +++ b/group_fs.go @@ -1,9 +1,30 @@ -//go:build !go1.16 -// +build !go1.16 - package echo +import ( + "io/fs" + "net/http" +) + // Static implements `Echo#Static()` for sub-routes within the Group. -func (g *Group) Static(prefix, root string) { - g.static(prefix, root, g.GET) +func (g *Group) Static(pathPrefix, fsRoot string) { + subFs := MustSubFS(g.echo.Filesystem, fsRoot) + g.StaticFS(pathPrefix, subFs) +} + +// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) { + g.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(filesystem, false), + ) +} + +// FileFS implements `Echo#FileFS()` for sub-routes within the Group. +func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { + return g.GET(path, StaticFileHandler(file, filesystem), m...) } diff --git a/group_fs_go1.16.go b/group_fs_go1.16.go deleted file mode 100644 index 2ba52b5e2..000000000 --- a/group_fs_go1.16.go +++ /dev/null @@ -1,33 +0,0 @@ -//go:build go1.16 -// +build go1.16 - -package echo - -import ( - "io/fs" - "net/http" -) - -// Static implements `Echo#Static()` for sub-routes within the Group. -func (g *Group) Static(pathPrefix, fsRoot string) { - subFs := MustSubFS(g.echo.Filesystem, fsRoot) - g.StaticFS(pathPrefix, subFs) -} - -// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) { - g.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(filesystem, false), - ) -} - -// FileFS implements `Echo#FileFS()` for sub-routes within the Group. -func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { - return g.GET(path, StaticFileHandler(file, filesystem), m...) -} diff --git a/group_fs_go1.16_test.go b/group_fs_test.go similarity index 98% rename from group_fs_go1.16_test.go rename to group_fs_test.go index d0caa33db..958d9efb1 100644 --- a/group_fs_go1.16_test.go +++ b/group_fs_test.go @@ -1,6 +1,3 @@ -//go:build go1.16 -// +build go1.16 - package echo import ( diff --git a/middleware/static_1_16_test.go b/middleware/static_1_16_test.go deleted file mode 100644 index 53e02f742..000000000 --- a/middleware/static_1_16_test.go +++ /dev/null @@ -1,106 +0,0 @@ -// +build go1.16 - -package middleware - -import ( - "io/fs" - "net/http" - "net/http/httptest" - "os" - "testing" - "testing/fstest" - - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" -) - -func TestStatic_CustomFS(t *testing.T) { - var testCases = []struct { - name string - filesystem fs.FS - root string - whenURL string - expectContains string - expectCode int - }{ - { - name: "ok, serve index with Echo message", - whenURL: "/", - filesystem: os.DirFS("../_fixture"), - expectCode: http.StatusOK, - expectContains: "Echo", - }, - - { - name: "ok, serve index with Echo message", - whenURL: "/_fixture/", - filesystem: os.DirFS(".."), - expectCode: http.StatusOK, - expectContains: "Echo", - }, - { - name: "ok, serve file from map fs", - whenURL: "/file.txt", - filesystem: fstest.MapFS{ - "file.txt": &fstest.MapFile{Data: []byte("file.txt is ok")}, - }, - expectCode: http.StatusOK, - expectContains: "file.txt is ok", - }, - { - name: "nok, missing file in map fs", - whenURL: "/file.txt", - expectCode: http.StatusNotFound, - filesystem: fstest.MapFS{ - "file2.txt": &fstest.MapFile{Data: []byte("file2.txt is ok")}, - }, - }, - { - name: "nok, file is not a subpath of root", - whenURL: `/../../secret.txt`, - root: "/nested/folder", - filesystem: fstest.MapFS{ - "secret.txt": &fstest.MapFile{Data: []byte("this is a secret")}, - }, - expectCode: http.StatusNotFound, - }, - { - name: "nok, backslash is forbidden", - whenURL: `/..\..\secret.txt`, - expectCode: http.StatusNotFound, - root: "/nested/folder", - filesystem: fstest.MapFS{ - "secret.txt": &fstest.MapFile{Data: []byte("this is a secret")}, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - - config := StaticConfig{ - Root: ".", - Filesystem: http.FS(tc.filesystem), - } - - if tc.root != "" { - config.Root = tc.root - } - - middlewareFunc := StaticWithConfig(config) - e.Use(middlewareFunc) - - req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectCode, rec.Code) - if tc.expectContains != "" { - responseBody := rec.Body.String() - assert.Contains(t, responseBody, tc.expectContains) - } - }) - } -} diff --git a/middleware/static_test.go b/middleware/static_test.go index af6641f66..f26d97a95 100644 --- a/middleware/static_test.go +++ b/middleware/static_test.go @@ -1,10 +1,13 @@ package middleware import ( + "io/fs" "net/http" "net/http/httptest" + "os" "strings" "testing" + "testing/fstest" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" @@ -207,6 +210,15 @@ func TestStatic_GroupWithStatic(t *testing.T) { expectHeaderLocation: "/group/folder/", expectBodyStartsWith: "", }, + { + name: "Directory redirect", + givenPrefix: "/", + givenRoot: "../_fixture", + whenURL: "/group/folder%2f..", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/group/folder/../", + expectBodyStartsWith: "", + }, { name: "Prefixed directory 404 (request URL without slash)", givenGroup: "_fixture", @@ -306,3 +318,94 @@ func TestStatic_GroupWithStatic(t *testing.T) { }) } } + +func TestStatic_CustomFS(t *testing.T) { + var testCases = []struct { + name string + filesystem fs.FS + root string + whenURL string + expectContains string + expectCode int + }{ + { + name: "ok, serve index with Echo message", + whenURL: "/", + filesystem: os.DirFS("../_fixture"), + expectCode: http.StatusOK, + expectContains: "Echo", + }, + + { + name: "ok, serve index with Echo message", + whenURL: "/_fixture/", + filesystem: os.DirFS(".."), + expectCode: http.StatusOK, + expectContains: "Echo", + }, + { + name: "ok, serve file from map fs", + whenURL: "/file.txt", + filesystem: fstest.MapFS{ + "file.txt": &fstest.MapFile{Data: []byte("file.txt is ok")}, + }, + expectCode: http.StatusOK, + expectContains: "file.txt is ok", + }, + { + name: "nok, missing file in map fs", + whenURL: "/file.txt", + expectCode: http.StatusNotFound, + filesystem: fstest.MapFS{ + "file2.txt": &fstest.MapFile{Data: []byte("file2.txt is ok")}, + }, + }, + { + name: "nok, file is not a subpath of root", + whenURL: `/../../secret.txt`, + root: "/nested/folder", + filesystem: fstest.MapFS{ + "secret.txt": &fstest.MapFile{Data: []byte("this is a secret")}, + }, + expectCode: http.StatusNotFound, + }, + { + name: "nok, backslash is forbidden", + whenURL: `/..\..\secret.txt`, + expectCode: http.StatusNotFound, + root: "/nested/folder", + filesystem: fstest.MapFS{ + "secret.txt": &fstest.MapFile{Data: []byte("this is a secret")}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + config := StaticConfig{ + Root: ".", + Filesystem: http.FS(tc.filesystem), + } + + if tc.root != "" { + config.Root = tc.root + } + + middlewareFunc := StaticWithConfig(config) + e.Use(middlewareFunc) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + if tc.expectContains != "" { + responseBody := rec.Body.String() + assert.Contains(t, responseBody, tc.expectContains) + } + }) + } +} From 16d3b65eb09664fe95dc6e4e7e2a082010d0eb68 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 4 Sep 2022 22:58:19 +0300 Subject: [PATCH 239/446] Changelog for 4.9.0 --- CHANGELOG.md | 12 ++++++++++++ echo.go | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2fcb2ff7f..e8f42200e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## v4.9.0 - 2022-09-04 + +**Security** + +* Fix open redirect vulnerability in handlers serving static directories (e.Static, e.StaticFs, echo.StaticDirectoryHandler) [#2260](https://github.com/labstack/echo/pull/2260) + +**Enhancements** + +* Allow configuring ErrorHandler in CSRF middleware [#2257](https://github.com/labstack/echo/pull/2257) +* Replace HTTP method constants in tests with stdlib constants [#2247](https://github.com/labstack/echo/pull/2247) + + ## v4.8.0 - 2022-08-10 **Most notable things** diff --git a/echo.go b/echo.go index 291c4047e..5ae8a1424 100644 --- a/echo.go +++ b/echo.go @@ -248,7 +248,7 @@ const ( const ( // Version of Echo - Version = "4.8.0" + Version = "4.9.0" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 50e7e569f0660b2b02ef457687b451072c106c40 Mon Sep 17 00:00:00 2001 From: Daniel Price Date: Mon, 12 Sep 2022 18:53:44 +0000 Subject: [PATCH 240/446] Improve CORS documentation * Provide links to further reading * Provide security warnings * Document undocumented wildcard feature * Update to go-1.19 style links --- middleware/cors.go | 88 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 68 insertions(+), 20 deletions(-) diff --git a/middleware/cors.go b/middleware/cors.go index 16259512a..25cf983a7 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -15,46 +15,85 @@ type ( // Skipper defines a function to skip middleware. Skipper Skipper - // AllowOrigin defines a list of origins that may access the resource. + // AllowOrigins determines the value of the Access-Control-Allow-Origin + // response header. This header defines a list of origins that may access the + // resource. The wildcard characters '*' and '?' are supported and are + // converted to regex fragments '.*' and '.' accordingly. + // + // Security: use extreme caution when handling the origin, and carefully + // validate any logic. Remember that attackers may register hostile domain names. + // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // // Optional. Default value []string{"*"}. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin AllowOrigins []string `yaml:"allow_origins"` // AllowOriginFunc is a custom function to validate the origin. It takes the // origin as an argument and returns true if allowed or false otherwise. If // an error is returned, it is returned by the handler. If this option is // set, AllowOrigins is ignored. + // + // Security: use extreme caution when handling the origin, and carefully + // validate any logic. Remember that attackers may register hostile domain names. + // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // // Optional. AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` - // AllowMethods defines a list methods allowed when accessing the resource. - // This is used in response to a preflight request. + // AllowMethods determines the value of the Access-Control-Allow-Methods + // response header. This header specified the list of methods allowed when + // accessing the resource. This is used in response to a preflight request. + // // Optional. Default value DefaultCORSConfig.AllowMethods. - // If `allowMethods` is left empty will fill for preflight request `Access-Control-Allow-Methods` header value + // If `allowMethods` is left empty, this middleware will fill for preflight + // request `Access-Control-Allow-Methods` header value // from `Allow` header that echo.Router set into context. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods AllowMethods []string `yaml:"allow_methods"` - // AllowHeaders defines a list of request headers that can be used when - // making the actual request. This is in response to a preflight request. + // AllowHeaders determines the value of the Access-Control-Allow-Headers + // response header. This header is used in response to a preflight request to + // indicate which HTTP headers can be used when making the actual request. + // // Optional. Default value []string{}. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers AllowHeaders []string `yaml:"allow_headers"` - // AllowCredentials indicates whether or not the response to the request - // can be exposed when the credentials flag is true. When used as part of - // a response to a preflight request, this indicates whether or not the - // actual request can be made using credentials. - // Optional. Default value false. + // AllowCredentials determines the value of the + // Access-Control-Allow-Credentials response header. This header indicates + // whether or not the response to the request can be exposed when the + // credentials mode (Request.credentials) is true. When used as part of a + // response to a preflight request, this indicates whether or not the actual + // request can be made using credentials. See also + // [MDN: Access-Control-Allow-Credentials]. + // + // Optional. Default value false, in which case the header is not set. + // // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. - // See http://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // See "Exploiting CORS misconfigurations for Bitcoins and bounties", + // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials AllowCredentials bool `yaml:"allow_credentials"` - // ExposeHeaders defines a whitelist headers that clients are allowed to - // access. - // Optional. Default value []string{}. + // ExposeHeaders determines the value of Access-Control-Expose-Headers, which + // defines a list of headers that clients are allowed to access. + // + // Optional. Default value []string{}, in which case the header is not set. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header ExposeHeaders []string `yaml:"expose_headers"` - // MaxAge indicates how long (in seconds) the results of a preflight request - // can be cached. - // Optional. Default value 0. + // MaxAge determines the value of the Access-Control-Max-Age response header. + // This header indicates how long (in seconds) the results of a preflight + // request can be cached. + // + // Optional. Default value 0. The header is set only if MaxAge > 0. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age MaxAge int `yaml:"max_age"` } ) @@ -69,13 +108,22 @@ var ( ) // CORS returns a Cross-Origin Resource Sharing (CORS) middleware. -// See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS +// See also [MDN: Cross-Origin Resource Sharing (CORS)]. +// +// Security: Poorly configured CORS can compromise security because it allows +// relaxation of the browser's Same-Origin policy. See [Exploiting CORS +// misconfigurations for Bitcoins and bounties] and [Portswigger: Cross-origin +// resource sharing (CORS)] for more details. +// +// [MDN: Cross-Origin Resource Sharing (CORS)]: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS +// [Exploiting CORS misconfigurations for Bitcoins and bounties]: https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html +// [Portswigger: Cross-origin resource sharing (CORS)]: https://portswigger.net/web-security/cors func CORS() echo.MiddlewareFunc { return CORSWithConfig(DefaultCORSConfig) } // CORSWithConfig returns a CORS middleware with config. -// See: `CORS()`. +// See: [CORS]. func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { // Defaults if config.Skipper == nil { From 666938e523c62170646fc2320cc7d97bcacdfd6f Mon Sep 17 00:00:00 2001 From: Amir Hossein <77993374+Kamandlou@users.noreply.github.com> Date: Wed, 14 Sep 2022 10:10:39 +0430 Subject: [PATCH 241/446] tests: error handling on closing body (#2254) * tidy up tests --- echo_test.go | 53 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/echo_test.go b/echo_test.go index 64796b3b5..7fd77c836 100644 --- a/echo_test.go +++ b/echo_test.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "errors" "fmt" + "io" "io/ioutil" "net" "net/http" @@ -235,7 +236,12 @@ func TestEchoStaticRedirectIndex(t *testing.T) { addr := e.ListenerAddr().String() if resp, err := http.Get("http://" + addr + "/static"); err == nil { // http.Get follows redirects by default - defer resp.Body.Close() + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + assert.Fail(t, err.Error()) + } + }(resp.Body) assert.Equal(t, http.StatusOK, resp.StatusCode) if body, err := ioutil.ReadAll(resp.Body); err == nil { @@ -380,7 +386,10 @@ func TestEchoWrapHandler(t *testing.T) { c := e.NewContext(req, rec) h := WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - w.Write([]byte("test")) + _, err := w.Write([]byte("test")) + if err != nil { + assert.Fail(t, err.Error()) + } })) if assert.NoError(t, h(c)) { assert.Equal(t, http.StatusOK, rec.Code) @@ -482,16 +491,16 @@ func TestEchoURL(t *testing.T) { g := e.Group("/group") g.GET("/users/:uid/files/:fid", getFile) - assert := assert.New(t) + assertion := assert.New(t) - assert.Equal("/static/file", e.URL(static)) - assert.Equal("/users/:id", e.URL(getUser)) - assert.Equal("/users/1", e.URL(getUser, "1")) - assert.Equal("/users/1", e.URL(getUser, "1")) - assert.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt")) - assert.Equal("/documents/*", e.URL(getAny)) - assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1")) - assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1")) + assertion.Equal("/static/file", e.URL(static)) + assertion.Equal("/users/:id", e.URL(getUser)) + assertion.Equal("/users/1", e.URL(getUser, "1")) + assertion.Equal("/users/1", e.URL(getUser, "1")) + assertion.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt")) + assertion.Equal("/documents/*", e.URL(getAny)) + assertion.Equal("/group/users/1/files/:fid", e.URL(getFile, "1")) + assertion.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1")) } func TestEchoRoutes(t *testing.T) { @@ -598,7 +607,7 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) { } func TestEchoHost(t *testing.T) { - assert := assert.New(t) + assertion := assert.New(t) okHandler := func(c Context) error { return c.String(http.StatusOK, http.StatusText(http.StatusOK)) } teapotHandler := func(c Context) error { return c.String(http.StatusTeapot, http.StatusText(http.StatusTeapot)) } @@ -694,8 +703,8 @@ func TestEchoHost(t *testing.T) { e.ServeHTTP(rec, req) - assert.Equal(tc.expectStatus, rec.Code) - assert.Equal(tc.expectBody, rec.Body.String()) + assertion.Equal(tc.expectStatus, rec.Code) + assertion.Equal(tc.expectBody, rec.Body.String()) }) } } @@ -1231,7 +1240,7 @@ func TestDefaultHTTPErrorHandler(t *testing.T) { e := New() e.Debug = true e.Any("/plain", func(c Context) error { - return errors.New("An error occurred") + return errors.New("an error occurred") }) e.Any("/badrequest", func(c Context) error { return NewHTTPError(http.StatusBadRequest, "Invalid request") @@ -1244,7 +1253,10 @@ func TestDefaultHTTPErrorHandler(t *testing.T) { }) }) e.Any("/early-return", func(c Context) error { - c.String(http.StatusOK, "OK") + err := c.String(http.StatusOK, "OK") + if err != nil { + assert.Fail(t, err.Error()) + } return errors.New("ERROR") }) e.GET("/internal-error", func(c Context) error { @@ -1255,7 +1267,7 @@ func TestDefaultHTTPErrorHandler(t *testing.T) { // With Debug=true plain response contains error message c, b := request(http.MethodGet, "/plain", e) assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\n \"error\": \"An error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", b) + assert.Equal(t, "{\n \"error\": \"an error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", b) // and special handling for HTTPError c, b = request(http.MethodGet, "/badrequest", e) assert.Equal(t, http.StatusBadRequest, c) @@ -1379,7 +1391,12 @@ func TestEchoListenerNetwork(t *testing.T) { assert.NoError(t, err) if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil { - defer resp.Body.Close() + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + assert.Fail(t, err.Error()) + } + }(resp.Body) assert.Equal(t, http.StatusOK, resp.StatusCode) if body, err := ioutil.ReadAll(resp.Body); err == nil { From 79221d91cadcbb59190ca0b334092e605f29c91e Mon Sep 17 00:00:00 2001 From: Martti T Date: Wed, 5 Oct 2022 06:36:12 +0300 Subject: [PATCH 242/446] Update readme about supported Go versions (#2291) --- README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 17e6ed934..509b97351 100644 --- a/README.md +++ b/README.md @@ -11,12 +11,11 @@ ## Supported Go versions +Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions. + As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules). Therefore a Go version capable of understanding /vN suffixed imports is required: -- 1.9.7+ -- 1.10.3+ -- 1.14+ Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended way of using Echo going forward. From 4c44305b23757c65b46cb08894ae536d8bfed5df Mon Sep 17 00:00:00 2001 From: Amir Hossein <77993374+Kamandlou@users.noreply.github.com> Date: Thu, 6 Oct 2022 12:04:00 +0330 Subject: [PATCH 243/446] update tests (#2275) update tests --- bind_test.go | 217 ++++++++++++++++++++++++--------------------------- context.go | 2 +- 2 files changed, 104 insertions(+), 115 deletions(-) diff --git a/bind_test.go b/bind_test.go index 822008153..c35283dcf 100644 --- a/bind_test.go +++ b/bind_test.go @@ -190,44 +190,40 @@ func TestToMultipleFields(t *testing.T) { } func TestBindJSON(t *testing.T) { - assert := assert.New(t) - testBindOkay(assert, strings.NewReader(userJSON), nil, MIMEApplicationJSON) - testBindOkay(assert, strings.NewReader(userJSON), dummyQuery, MIMEApplicationJSON) - testBindArrayOkay(assert, strings.NewReader(usersJSON), nil, MIMEApplicationJSON) - testBindArrayOkay(assert, strings.NewReader(usersJSON), dummyQuery, MIMEApplicationJSON) - testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) - testBindError(assert, strings.NewReader(userJSONInvalidType), MIMEApplicationJSON, &json.UnmarshalTypeError{}) + testBindOkay(t, strings.NewReader(userJSON), nil, MIMEApplicationJSON) + testBindOkay(t, strings.NewReader(userJSON), dummyQuery, MIMEApplicationJSON) + testBindArrayOkay(t, strings.NewReader(usersJSON), nil, MIMEApplicationJSON) + testBindArrayOkay(t, strings.NewReader(usersJSON), dummyQuery, MIMEApplicationJSON) + testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) + testBindError(t, strings.NewReader(userJSONInvalidType), MIMEApplicationJSON, &json.UnmarshalTypeError{}) } func TestBindXML(t *testing.T) { - assert := assert.New(t) - - testBindOkay(assert, strings.NewReader(userXML), nil, MIMEApplicationXML) - testBindOkay(assert, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) - testBindArrayOkay(assert, strings.NewReader(userXML), nil, MIMEApplicationXML) - testBindArrayOkay(assert, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) - testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationXML, errors.New("")) - testBindError(assert, strings.NewReader(userXMLConvertNumberError), MIMEApplicationXML, &strconv.NumError{}) - testBindError(assert, strings.NewReader(userXMLUnsupportedTypeError), MIMEApplicationXML, &xml.SyntaxError{}) - testBindOkay(assert, strings.NewReader(userXML), nil, MIMETextXML) - testBindOkay(assert, strings.NewReader(userXML), dummyQuery, MIMETextXML) - testBindError(assert, strings.NewReader(invalidContent), MIMETextXML, errors.New("")) - testBindError(assert, strings.NewReader(userXMLConvertNumberError), MIMETextXML, &strconv.NumError{}) - testBindError(assert, strings.NewReader(userXMLUnsupportedTypeError), MIMETextXML, &xml.SyntaxError{}) + testBindOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML) + testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) + testBindArrayOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML) + testBindArrayOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) + testBindError(t, strings.NewReader(invalidContent), MIMEApplicationXML, errors.New("")) + testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMEApplicationXML, &strconv.NumError{}) + testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMEApplicationXML, &xml.SyntaxError{}) + testBindOkay(t, strings.NewReader(userXML), nil, MIMETextXML) + testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMETextXML) + testBindError(t, strings.NewReader(invalidContent), MIMETextXML, errors.New("")) + testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMETextXML, &strconv.NumError{}) + testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMETextXML, &xml.SyntaxError{}) } func TestBindForm(t *testing.T) { - assert := assert.New(t) - testBindOkay(assert, strings.NewReader(userForm), nil, MIMEApplicationForm) - testBindOkay(assert, strings.NewReader(userForm), dummyQuery, MIMEApplicationForm) + testBindOkay(t, strings.NewReader(userForm), nil, MIMEApplicationForm) + testBindOkay(t, strings.NewReader(userForm), dummyQuery, MIMEApplicationForm) e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userForm)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(HeaderContentType, MIMEApplicationForm) err := c.Bind(&[]struct{ Field string }{}) - assert.Error(err) + assert.Error(t, err) } func TestBindQueryParams(t *testing.T) { @@ -317,14 +313,13 @@ func TestBindUnmarshalParam(t *testing.T) { err := c.Bind(&result) ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)) - assert := assert.New(t) - if assert.NoError(err) { + if assert.NoError(t, err) { // assert.Equal( Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T) - assert.Equal(ts, result.T) - assert.Equal(StringArray([]string{"one", "two", "three"}), result.SA) - assert.Equal([]Timestamp{ts, ts}, result.TA) - assert.Equal(Struct{""}, result.ST) // child struct does not have a field with matching tag - assert.Equal("baz", result.StWithTag.Foo) // child struct has field with matching tag + assert.Equal(t, ts, result.T) + assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA) + assert.Equal(t, []Timestamp{ts, ts}, result.TA) + assert.Equal(t, Struct{""}, result.ST) // child struct does not have a field with matching tag + assert.Equal(t, "baz", result.StWithTag.Foo) // child struct has field with matching tag } } @@ -426,38 +421,35 @@ func TestBindMultipartForm(t *testing.T) { mw.Close() body := bodyBuffer.Bytes() - assert := assert.New(t) - testBindOkay(assert, bytes.NewReader(body), nil, mw.FormDataContentType()) - testBindOkay(assert, bytes.NewReader(body), dummyQuery, mw.FormDataContentType()) + testBindOkay(t, bytes.NewReader(body), nil, mw.FormDataContentType()) + testBindOkay(t, bytes.NewReader(body), dummyQuery, mw.FormDataContentType()) } func TestBindUnsupportedMediaType(t *testing.T) { - assert := assert.New(t) - testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) + testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) } func TestBindbindData(t *testing.T) { - a := assert.New(t) ts := new(bindTestStruct) b := new(DefaultBinder) err := b.bindData(ts, values, "form") - a.NoError(err) - - a.Equal(0, ts.I) - a.Equal(int8(0), ts.I8) - a.Equal(int16(0), ts.I16) - a.Equal(int32(0), ts.I32) - a.Equal(int64(0), ts.I64) - a.Equal(uint(0), ts.UI) - a.Equal(uint8(0), ts.UI8) - a.Equal(uint16(0), ts.UI16) - a.Equal(uint32(0), ts.UI32) - a.Equal(uint64(0), ts.UI64) - a.Equal(false, ts.B) - a.Equal(float32(0), ts.F32) - a.Equal(float64(0), ts.F64) - a.Equal("", ts.S) - a.Equal("", ts.cantSet) + assert.NoError(t, err) + + assert.Equal(t, 0, ts.I) + assert.Equal(t, int8(0), ts.I8) + assert.Equal(t, int16(0), ts.I16) + assert.Equal(t, int32(0), ts.I32) + assert.Equal(t, int64(0), ts.I64) + assert.Equal(t, uint(0), ts.UI) + assert.Equal(t, uint8(0), ts.UI8) + assert.Equal(t, uint16(0), ts.UI16) + assert.Equal(t, uint32(0), ts.UI32) + assert.Equal(t, uint64(0), ts.UI64) + assert.Equal(t, false, ts.B) + assert.Equal(t, float32(0), ts.F32) + assert.Equal(t, float64(0), ts.F64) + assert.Equal(t, "", ts.S) + assert.Equal(t, "", ts.cantSet) } func TestBindParam(t *testing.T) { @@ -528,7 +520,6 @@ func TestBindUnmarshalTypeError(t *testing.T) { } func TestBindSetWithProperType(t *testing.T) { - assert := assert.New(t) ts := new(bindTestStruct) typ := reflect.TypeOf(ts).Elem() val := reflect.ValueOf(ts).Elem() @@ -543,9 +534,9 @@ func TestBindSetWithProperType(t *testing.T) { } val := values[typeField.Name][0] err := setWithProperType(typeField.Type.Kind(), val, structField) - assert.NoError(err) + assert.NoError(t, err) } - assertBindTestStruct(assert, ts) + assertBindTestStruct(t, ts) type foo struct { Bar bytes.Buffer @@ -553,56 +544,54 @@ func TestBindSetWithProperType(t *testing.T) { v := &foo{} typ = reflect.TypeOf(v).Elem() val = reflect.ValueOf(v).Elem() - assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) + assert.Error(t, setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) } func TestBindSetFields(t *testing.T) { - assert := assert.New(t) ts := new(bindTestStruct) val := reflect.ValueOf(ts).Elem() // Int - if assert.NoError(setIntField("5", 0, val.FieldByName("I"))) { - assert.Equal(5, ts.I) + if assert.NoError(t, setIntField("5", 0, val.FieldByName("I"))) { + assert.Equal(t, 5, ts.I) } - if assert.NoError(setIntField("", 0, val.FieldByName("I"))) { - assert.Equal(0, ts.I) + if assert.NoError(t, setIntField("", 0, val.FieldByName("I"))) { + assert.Equal(t, 0, ts.I) } // Uint - if assert.NoError(setUintField("10", 0, val.FieldByName("UI"))) { - assert.Equal(uint(10), ts.UI) + if assert.NoError(t, setUintField("10", 0, val.FieldByName("UI"))) { + assert.Equal(t, uint(10), ts.UI) } - if assert.NoError(setUintField("", 0, val.FieldByName("UI"))) { - assert.Equal(uint(0), ts.UI) + if assert.NoError(t, setUintField("", 0, val.FieldByName("UI"))) { + assert.Equal(t, uint(0), ts.UI) } // Float - if assert.NoError(setFloatField("15.5", 0, val.FieldByName("F32"))) { - assert.Equal(float32(15.5), ts.F32) + if assert.NoError(t, setFloatField("15.5", 0, val.FieldByName("F32"))) { + assert.Equal(t, float32(15.5), ts.F32) } - if assert.NoError(setFloatField("", 0, val.FieldByName("F32"))) { - assert.Equal(float32(0.0), ts.F32) + if assert.NoError(t, setFloatField("", 0, val.FieldByName("F32"))) { + assert.Equal(t, float32(0.0), ts.F32) } // Bool - if assert.NoError(setBoolField("true", val.FieldByName("B"))) { - assert.Equal(true, ts.B) + if assert.NoError(t, setBoolField("true", val.FieldByName("B"))) { + assert.Equal(t, true, ts.B) } - if assert.NoError(setBoolField("", val.FieldByName("B"))) { - assert.Equal(false, ts.B) + if assert.NoError(t, setBoolField("", val.FieldByName("B"))) { + assert.Equal(t, false, ts.B) } ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T")) - if assert.NoError(err) { - assert.Equal(ok, true) - assert.Equal(Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) + if assert.NoError(t, err) { + assert.Equal(t, ok, true) + assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) } } func BenchmarkBindbindDataWithTags(b *testing.B) { b.ReportAllocs() - assert := assert.New(b) ts := new(bindTestStructWithTags) binder := new(DefaultBinder) var err error @@ -610,29 +599,29 @@ func BenchmarkBindbindDataWithTags(b *testing.B) { for i := 0; i < b.N; i++ { err = binder.bindData(ts, values, "form") } - assert.NoError(err) - assertBindTestStruct(assert, (*bindTestStruct)(ts)) + assert.NoError(b, err) + assertBindTestStruct(b, (*bindTestStruct)(ts)) } -func assertBindTestStruct(a *assert.Assertions, ts *bindTestStruct) { - a.Equal(0, ts.I) - a.Equal(int8(8), ts.I8) - a.Equal(int16(16), ts.I16) - a.Equal(int32(32), ts.I32) - a.Equal(int64(64), ts.I64) - a.Equal(uint(0), ts.UI) - a.Equal(uint8(8), ts.UI8) - a.Equal(uint16(16), ts.UI16) - a.Equal(uint32(32), ts.UI32) - a.Equal(uint64(64), ts.UI64) - a.Equal(true, ts.B) - a.Equal(float32(32.5), ts.F32) - a.Equal(float64(64.5), ts.F64) - a.Equal("test", ts.S) - a.Equal("", ts.GetCantSet()) +func assertBindTestStruct(tb testing.TB, ts *bindTestStruct) { + assert.Equal(tb, 0, ts.I) + assert.Equal(tb, int8(8), ts.I8) + assert.Equal(tb, int16(16), ts.I16) + assert.Equal(tb, int32(32), ts.I32) + assert.Equal(tb, int64(64), ts.I64) + assert.Equal(tb, uint(0), ts.UI) + assert.Equal(tb, uint8(8), ts.UI8) + assert.Equal(tb, uint16(16), ts.UI16) + assert.Equal(tb, uint32(32), ts.UI32) + assert.Equal(tb, uint64(64), ts.UI64) + assert.Equal(tb, true, ts.B) + assert.Equal(tb, float32(32.5), ts.F32) + assert.Equal(tb, float64(64.5), ts.F64) + assert.Equal(tb, "test", ts.S) + assert.Equal(tb, "", ts.GetCantSet()) } -func testBindOkay(assert *assert.Assertions, r io.Reader, query url.Values, ctype string) { +func testBindOkay(t *testing.T, r io.Reader, query url.Values, ctype string) { e := New() path := "/" if len(query) > 0 { @@ -644,13 +633,13 @@ func testBindOkay(assert *assert.Assertions, r io.Reader, query url.Values, ctyp req.Header.Set(HeaderContentType, ctype) u := new(user) err := c.Bind(u) - if assert.NoError(err) { - assert.Equal(1, u.ID) - assert.Equal("Jon Snow", u.Name) + if assert.Equal(t, nil, err) { + assert.Equal(t, 1, u.ID) + assert.Equal(t, "Jon Snow", u.Name) } } -func testBindArrayOkay(assert *assert.Assertions, r io.Reader, query url.Values, ctype string) { +func testBindArrayOkay(t *testing.T, r io.Reader, query url.Values, ctype string) { e := New() path := "/" if len(query) > 0 { @@ -662,14 +651,14 @@ func testBindArrayOkay(assert *assert.Assertions, r io.Reader, query url.Values, req.Header.Set(HeaderContentType, ctype) u := []user{} err := c.Bind(&u) - if assert.NoError(err) { - assert.Equal(1, len(u)) - assert.Equal(1, u[0].ID) - assert.Equal("Jon Snow", u[0].Name) + if assert.NoError(t, err) { + assert.Equal(t, 1, len(u)) + assert.Equal(t, 1, u[0].ID) + assert.Equal(t, "Jon Snow", u[0].Name) } } -func testBindError(assert *assert.Assertions, r io.Reader, ctype string, expectedInternal error) { +func testBindError(t *testing.T, r io.Reader, ctype string, expectedInternal error) { e := New() req := httptest.NewRequest(http.MethodPost, "/", r) rec := httptest.NewRecorder() @@ -681,14 +670,14 @@ func testBindError(assert *assert.Assertions, r io.Reader, ctype string, expecte switch { case strings.HasPrefix(ctype, MIMEApplicationJSON), strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML), strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): - if assert.IsType(new(HTTPError), err) { - assert.Equal(http.StatusBadRequest, err.(*HTTPError).Code) - assert.IsType(expectedInternal, err.(*HTTPError).Internal) + if assert.IsType(t, new(HTTPError), err) { + assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).Code) + assert.IsType(t, expectedInternal, err.(*HTTPError).Internal) } default: - if assert.IsType(new(HTTPError), err) { - assert.Equal(ErrUnsupportedMediaType, err) - assert.IsType(expectedInternal, err.(*HTTPError).Internal) + if assert.IsType(t, new(HTTPError), err) { + assert.Equal(t, ErrUnsupportedMediaType, err) + assert.IsType(t, expectedInternal, err.(*HTTPError).Internal) } } } diff --git a/context.go b/context.go index a4ecfadfc..5567100b9 100644 --- a/context.go +++ b/context.go @@ -181,7 +181,7 @@ type ( // Logger returns the `Logger` instance. Logger() Logger - // Set the logger + // SetLogger Set the logger SetLogger(l Logger) // Echo returns the `Echo` instance. From 1d5f335f4092c6d8778b511d2b75bd09d1f9e61e Mon Sep 17 00:00:00 2001 From: Martti T Date: Wed, 12 Oct 2022 21:47:21 +0300 Subject: [PATCH 244/446] refactor assertions (#2301) --- echo_test.go | 58 +++++++++++++++-------------------- middleware/basic_auth_test.go | 18 +++++------ middleware/body_dump_test.go | 12 +++----- middleware/body_limit_test.go | 18 +++++------ middleware/compress_test.go | 26 ++++++++-------- middleware/decompress_test.go | 18 +++++------ middleware/jwt_test.go | 12 +++----- 7 files changed, 71 insertions(+), 91 deletions(-) diff --git a/echo_test.go b/echo_test.go index 7fd77c836..6bece4fd3 100644 --- a/echo_test.go +++ b/echo_test.go @@ -491,16 +491,14 @@ func TestEchoURL(t *testing.T) { g := e.Group("/group") g.GET("/users/:uid/files/:fid", getFile) - assertion := assert.New(t) - - assertion.Equal("/static/file", e.URL(static)) - assertion.Equal("/users/:id", e.URL(getUser)) - assertion.Equal("/users/1", e.URL(getUser, "1")) - assertion.Equal("/users/1", e.URL(getUser, "1")) - assertion.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt")) - assertion.Equal("/documents/*", e.URL(getAny)) - assertion.Equal("/group/users/1/files/:fid", e.URL(getFile, "1")) - assertion.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1")) + assert.Equal(t, "/static/file", e.URL(static)) + assert.Equal(t, "/users/:id", e.URL(getUser)) + assert.Equal(t, "/users/1", e.URL(getUser, "1")) + assert.Equal(t, "/users/1", e.URL(getUser, "1")) + assert.Equal(t, "/documents/foo.txt", e.URL(getAny, "foo.txt")) + assert.Equal(t, "/documents/*", e.URL(getAny)) + assert.Equal(t, "/group/users/1/files/:fid", e.URL(getFile, "1")) + assert.Equal(t, "/group/users/1/files/1", e.URL(getFile, "1", "1")) } func TestEchoRoutes(t *testing.T) { @@ -607,8 +605,6 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) { } func TestEchoHost(t *testing.T) { - assertion := assert.New(t) - okHandler := func(c Context) error { return c.String(http.StatusOK, http.StatusText(http.StatusOK)) } teapotHandler := func(c Context) error { return c.String(http.StatusTeapot, http.StatusText(http.StatusTeapot)) } acceptHandler := func(c Context) error { return c.String(http.StatusAccepted, http.StatusText(http.StatusAccepted)) } @@ -703,8 +699,8 @@ func TestEchoHost(t *testing.T) { e.ServeHTTP(rec, req) - assertion.Equal(tc.expectStatus, rec.Code) - assertion.Equal(tc.expectBody, rec.Body.String()) + assert.Equal(t, tc.expectStatus, rec.Code) + assert.Equal(t, tc.expectBody, rec.Body.String()) }) } } @@ -1429,8 +1425,6 @@ func TestEchoListenerNetworkInvalid(t *testing.T) { } func TestEchoReverse(t *testing.T) { - assert := assert.New(t) - e := New() dummyHandler := func(Context) error { return nil } @@ -1440,22 +1434,20 @@ func TestEchoReverse(t *testing.T) { e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux" e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*" - assert.Equal("/static", e.Reverse("/static")) - assert.Equal("/static", e.Reverse("/static", "missing param")) - assert.Equal("/static/*", e.Reverse("/static/*")) - assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt")) - - assert.Equal("/params/:foo", e.Reverse("/params/:foo")) - assert.Equal("/params/one", e.Reverse("/params/:foo", "one")) - assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux")) - assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one")) - assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two")) - assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) + assert.Equal(t, "/static", e.Reverse("/static")) + assert.Equal(t, "/static", e.Reverse("/static", "missing param")) + assert.Equal(t, "/static/*", e.Reverse("/static/*")) + assert.Equal(t, "/static/foo.txt", e.Reverse("/static/*", "foo.txt")) + + assert.Equal(t, "/params/:foo", e.Reverse("/params/:foo")) + assert.Equal(t, "/params/one", e.Reverse("/params/:foo", "one")) + assert.Equal(t, "/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux")) + assert.Equal(t, "/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one")) + assert.Equal(t, "/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two")) + assert.Equal(t, "/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) } func TestEchoReverseHandleHostProperly(t *testing.T) { - assert := assert.New(t) - dummyHandler := func(Context) error { return nil } e := New() @@ -1463,10 +1455,10 @@ func TestEchoReverseHandleHostProperly(t *testing.T) { h.GET("/static", dummyHandler).Name = "/static" h.GET("/static/*", dummyHandler).Name = "/static/*" - assert.Equal("/static", e.Reverse("/static")) - assert.Equal("/static", e.Reverse("/static", "missing param")) - assert.Equal("/static/*", e.Reverse("/static/*")) - assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt")) + assert.Equal(t, "/static", e.Reverse("/static")) + assert.Equal(t, "/static", e.Reverse("/static", "missing param")) + assert.Equal(t, "/static/*", e.Reverse("/static/*")) + assert.Equal(t, "/static/foo.txt", e.Reverse("/static/*", "foo.txt")) } func TestEcho_ListenerAddr(t *testing.T) { diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 4c355aa16..20e769214 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -26,12 +26,10 @@ func TestBasicAuth(t *testing.T) { return c.String(http.StatusOK, "test") }) - assert := assert.New(t) - // Valid credentials auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + assert.NoError(t, h(c)) h = BasicAuthWithConfig(BasicAuthConfig{ Skipper: nil, @@ -44,34 +42,34 @@ func TestBasicAuth(t *testing.T) { // Valid credentials auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + assert.NoError(t, h(c)) // Case-insensitive header scheme auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + assert.NoError(t, h(c)) // Invalid credentials auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")) req.Header.Set(echo.HeaderAuthorization, auth) he := h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) - assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate)) + assert.Equal(t, http.StatusUnauthorized, he.Code) + assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate)) // Invalid base64 string auth = basic + " invalidString" req.Header.Set(echo.HeaderAuthorization, auth) he = h(c).(*echo.HTTPError) - assert.Equal(http.StatusBadRequest, he.Code) + assert.Equal(t, http.StatusBadRequest, he.Code) // Missing Authorization header req.Header.Del(echo.HeaderAuthorization) he = h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) + assert.Equal(t, http.StatusUnauthorized, he.Code) // Invalid Authorization header auth = base64.StdEncoding.EncodeToString([]byte("invalid")) req.Header.Set(echo.HeaderAuthorization, auth) he = h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) + assert.Equal(t, http.StatusUnauthorized, he.Code) } diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go index e6e00f726..533971a47 100644 --- a/middleware/body_dump_test.go +++ b/middleware/body_dump_test.go @@ -33,13 +33,11 @@ func TestBodyDump(t *testing.T) { responseBody = string(resBody) }) - assert := assert.New(t) - - if assert.NoError(mw(h)(c)) { - assert.Equal(requestBody, hw) - assert.Equal(responseBody, hw) - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(hw, rec.Body.String()) + if assert.NoError(t, mw(h)(c)) { + assert.Equal(t, requestBody, hw) + assert.Equal(t, responseBody, hw) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.String()) } // Must set default skipper diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index 8981534d4..b891767c5 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -25,26 +25,24 @@ func TestBodyLimit(t *testing.T) { return c.String(http.StatusOK, string(body)) } - assert := assert.New(t) - // Based on content length (within limit) - if assert.NoError(BodyLimit("2M")(h)(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(hw, rec.Body.Bytes()) + if assert.NoError(t, BodyLimit("2M")(h)(c)) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) } // Based on content length (overlimit) he := BodyLimit("2B")(h)(c).(*echo.HTTPError) - assert.Equal(http.StatusRequestEntityTooLarge, he.Code) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) // Based on content read (within limit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) - if assert.NoError(BodyLimit("2M")(h)(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("Hello, World!", rec.Body.String()) + if assert.NoError(t, BodyLimit("2M")(h)(c)) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "Hello, World!", rec.Body.String()) } // Based on content read (overlimit) @@ -53,7 +51,7 @@ func TestBodyLimit(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec) he = BodyLimit("2B")(h)(c).(*echo.HTTPError) - assert.Equal(http.StatusRequestEntityTooLarge, he.Code) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) } func TestBodyLimitReader(t *testing.T) { diff --git a/middleware/compress_test.go b/middleware/compress_test.go index b62bffef5..c8dd2e1f4 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -26,9 +26,7 @@ func TestGzip(t *testing.T) { }) h(c) - assert := assert.New(t) - - assert.Equal("test", rec.Body.String()) + assert.Equal(t, "test", rec.Body.String()) // Gzip req = httptest.NewRequest(http.MethodGet, "/", nil) @@ -36,14 +34,14 @@ func TestGzip(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec) h(c) - assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) - assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) r, err := gzip.NewReader(rec.Body) - if assert.NoError(err) { + if assert.NoError(t, err) { buf := new(bytes.Buffer) defer r.Close() buf.ReadFrom(r) - assert.Equal("test", buf.String()) + assert.Equal(t, "test", buf.String()) } chunkBuf := make([]byte, 5) @@ -63,21 +61,21 @@ func TestGzip(t *testing.T) { c.Response().Flush() // Read the first part of the data - assert.True(rec.Flushed) - assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.True(t, rec.Flushed) + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) r.Reset(rec.Body) _, err = io.ReadFull(r, chunkBuf) - assert.NoError(err) - assert.Equal("test\n", string(chunkBuf)) + assert.NoError(t, err) + assert.Equal(t, "test\n", string(chunkBuf)) // Write and flush the second part of the data c.Response().Write([]byte("test\n")) c.Response().Flush() _, err = io.ReadFull(r, chunkBuf) - assert.NoError(err) - assert.Equal("test\n", string(chunkBuf)) + assert.NoError(t, err) + assert.Equal(t, "test\n", string(chunkBuf)) // Write the final part of the data and return c.Response().Write([]byte("test")) @@ -87,7 +85,7 @@ func TestGzip(t *testing.T) { buf := new(bytes.Buffer) defer r.Close() buf.ReadFrom(r) - assert.Equal("test", buf.String()) + assert.Equal(t, "test", buf.String()) } func TestGzipNoContent(t *testing.T) { diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 51fa6b0f1..42c6250ac 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -28,8 +28,7 @@ func TestDecompress(t *testing.T) { }) h(c) - assert := assert.New(t) - assert.Equal("test", rec.Body.String()) + assert.Equal(t, "test", rec.Body.String()) // Decompress body := `{"name": "echo"}` @@ -39,10 +38,10 @@ func TestDecompress(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec) h(c) - assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := ioutil.ReadAll(req.Body) - assert.NoError(err) - assert.Equal(body, string(b)) + assert.NoError(t, err) + assert.Equal(t, body, string(b)) } func TestDecompressDefaultConfig(t *testing.T) { @@ -57,8 +56,7 @@ func TestDecompressDefaultConfig(t *testing.T) { }) h(c) - assert := assert.New(t) - assert.Equal("test", rec.Body.String()) + assert.Equal(t, "test", rec.Body.String()) // Decompress body := `{"name": "echo"}` @@ -68,10 +66,10 @@ func TestDecompressDefaultConfig(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec) h(c) - assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := ioutil.ReadAll(req.Body) - assert.NoError(err) - assert.Equal(body, string(b)) + assert.NoError(t, err) + assert.Equal(t, body, string(b)) } func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index eee9df966..90e8cad81 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -348,8 +348,6 @@ func TestJWTConfig(t *testing.T) { } func TestJWTwithKID(t *testing.T) { - test := assert.New(t) - e := echo.New() handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") @@ -417,19 +415,19 @@ func TestJWTwithKID(t *testing.T) { if tc.expErrCode != 0 { h := JWTWithConfig(tc.config)(handler) he := h(c).(*echo.HTTPError) - test.Equal(tc.expErrCode, he.Code, tc.info) + assert.Equal(t, tc.expErrCode, he.Code, tc.info) continue } h := JWTWithConfig(tc.config)(handler) - if test.NoError(h(c), tc.info) { + if assert.NoError(t, h(c), tc.info) { user := c.Get("user").(*jwt.Token) switch claims := user.Claims.(type) { case jwt.MapClaims: - test.Equal(claims["name"], "John Doe", tc.info) + assert.Equal(t, claims["name"], "John Doe", tc.info) case *jwtCustomClaims: - test.Equal(claims.Name, "John Doe", tc.info) - test.Equal(claims.Admin, true, tc.info) + assert.Equal(t, claims.Name, "John Doe", tc.info) + assert.Equal(t, claims.Admin, true, tc.info) default: panic("unexpected type of claims") } From 56f63c3036bb4be6a454f5c1ac9efafca0af36f9 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Wed, 12 Oct 2022 21:58:55 +0300 Subject: [PATCH 245/446] bump github.com/labstack/gommon dependency version --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 4de2bdde1..158d23d1d 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.17 require ( github.com/golang-jwt/jwt v3.2.2+incompatible - github.com/labstack/gommon v0.3.1 + github.com/labstack/gommon v0.4.0 github.com/stretchr/testify v1.7.0 github.com/valyala/fasttemplate v1.2.1 golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 diff --git a/go.sum b/go.sum index f66734243..5545943be 100644 --- a/go.sum +++ b/go.sum @@ -3,8 +3,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/labstack/gommon v0.3.1 h1:OomWaJXm7xR6L1HmEtGyQf26TEn7V6X88mktX9kee9o= -github.com/labstack/gommon v0.3.1/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= +github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8= +github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= github.com/mattn/go-colorable v0.1.11 h1:nQ+aFkoE2TMGc0b68U2OKSexC+eq46+XwZzWXHRmPYs= github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= From 8ad22302f2a6b3451f5ecc3aec1b32a2c48412de Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Wed, 12 Oct 2022 22:05:39 +0300 Subject: [PATCH 246/446] Changelog for v4.9.1 --- CHANGELOG.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e8f42200e..8b71fb8e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,19 @@ # Changelog +## v4.9.1 - 2022-10-12 + +**Fixes** + +* Fix logger panicing (when template is set to empty) by bumping dependency version [#2295](https://github.com/labstack/echo/issues/2295) + +**Enhancements** + +* Improve CORS documentation [#2272](https://github.com/labstack/echo/pull/2272) +* Update readme about supported Go versions [#2291](https://github.com/labstack/echo/pull/2291) +* Tests: improve error handling on closing body [#2254](https://github.com/labstack/echo/pull/2254) +* Tests: refactor some of the assertions in tests [#2275](https://github.com/labstack/echo/pull/2275) +* Tests: refactor assertions [#2301](https://github.com/labstack/echo/pull/2301) + ## v4.9.0 - 2022-09-04 **Security** From b02e78ba55d1d642cc045b503c83780eb79943d4 Mon Sep 17 00:00:00 2001 From: Patrick Brueckner Date: Fri, 14 Oct 2022 16:29:14 +0200 Subject: [PATCH 247/446] bump x/text to 0.3.8 see https://go.dev/issue/56152, https://ossindex.sonatype.org/vulnerability/CVE-2022-32149?component-type=golang&component-name=golang.org%2Fx%2Ftext&utm_source=nancy-client&utm_medium=integration&utm_content=1.0.41 --- go.mod | 8 ++++---- go.sum | 8 ++++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index 158d23d1d..e9f611ccf 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,8 @@ require ( github.com/labstack/gommon v0.4.0 github.com/stretchr/testify v1.7.0 github.com/valyala/fasttemplate v1.2.1 - golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 - golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f + golang.org/x/crypto v0.0.0-20221012134737-56aed061732a + golang.org/x/net v0.0.0-20221014081412-f15817d10f9b golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 ) @@ -18,7 +18,7 @@ require ( github.com/mattn/go-isatty v0.0.14 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.0.0-20211103235746-7861aae1554b // indirect - golang.org/x/text v0.3.7 // indirect + golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect + golang.org/x/text v0.3.8 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index 5545943be..57e0ac243 100644 --- a/go.sum +++ b/go.sum @@ -20,9 +20,13 @@ github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52 github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20221012134737-56aed061732a h1:NmSIgad6KjE6VvHciPZuNRTKxGhlPfD6OA87W/PLkqg= +golang.org/x/crypto v0.0.0-20221012134737-56aed061732a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f h1:OfiFi4JbukWwe3lzw+xunroH1mnC1e2Gy5cxNJApiSY= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20221014081412-f15817d10f9b h1:tvrvnPFcdzp294diPnrdZZZ8XUt2Tyj7svb7X52iDuU= +golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -30,11 +34,15 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= From 8f2bf82982a6005447b68bbbae739c6fcced5140 Mon Sep 17 00:00:00 2001 From: Patrick Brueckner Date: Mon, 17 Oct 2022 10:45:27 +0200 Subject: [PATCH 248/446] go mod tidy --- go.sum | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/go.sum b/go.sum index 57e0ac243..f5dbb44fb 100644 --- a/go.sum +++ b/go.sum @@ -18,34 +18,45 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= -golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20221012134737-56aed061732a h1:NmSIgad6KjE6VvHciPZuNRTKxGhlPfD6OA87W/PLkqg= golang.org/x/crypto v0.0.0-20221012134737-56aed061732a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f h1:OfiFi4JbukWwe3lzw+xunroH1mnC1e2Gy5cxNJApiSY= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20221014081412-f15817d10f9b h1:tvrvnPFcdzp294diPnrdZZZ8XUt2Tyj7svb7X52iDuU= golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 0ce73028d0815e0ecec80964cc2da42d98fafa33 Mon Sep 17 00:00:00 2001 From: Hristo Hristov Date: Sat, 29 Oct 2022 21:54:23 +0300 Subject: [PATCH 249/446] [suggestion] Add helper interface for ProxyBalancer interface (#2316) * [suggestion] Add helper interface for ProxyBalancer interface * Update proxy_test.go * addressed code review comments * address pr comments * clean up * return error --- middleware/proxy.go | 17 ++++++++++++- middleware/proxy_test.go | 52 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/middleware/proxy.go b/middleware/proxy.go index 6cfd6731e..d2cd2aa6d 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -72,6 +72,11 @@ type ( Next(echo.Context) *ProxyTarget } + // TargetProvider defines an interface that gives the opportunity for balancer to return custom errors when selecting target. + TargetProvider interface { + NextTarget(echo.Context) (*ProxyTarget, error) + } + commonBalancer struct { targets []*ProxyTarget mutex sync.RWMutex @@ -223,6 +228,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { } } + provider, isTargetProvider := config.Balancer.(TargetProvider) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { if config.Skipper(c) { @@ -231,7 +237,16 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { req := c.Request() res := c.Response() - tgt := config.Balancer.Next(c) + + var tgt *ProxyTarget + if isTargetProvider { + tgt, err = provider.NextTarget(c) + if err != nil { + return err + } + } else { + tgt = config.Balancer.Next(c) + } c.Set(config.ContextKey, tgt) if err := rewriteURL(config.RegexRewrite, req); err != nil { diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 7939fc5c2..0ded50a1f 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -18,7 +18,7 @@ import ( "github.com/stretchr/testify/assert" ) -//Assert expected with url.EscapedPath method to obtain the path. +// Assert expected with url.EscapedPath method to obtain the path. func TestProxy(t *testing.T) { // Setup t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -31,7 +31,6 @@ func TestProxy(t *testing.T) { })) defer t2.Close() url2, _ := url.Parse(t2.URL) - targets := []*ProxyTarget{ { Name: "target 1", @@ -122,6 +121,55 @@ func TestProxy(t *testing.T) { e.ServeHTTP(rec, req) } +type testProvider struct { + *commonBalancer + target *ProxyTarget + err error +} + +func (p *testProvider) Next(c echo.Context) *ProxyTarget { + return &ProxyTarget{} +} + +func (p *testProvider) NextTarget(c echo.Context) (*ProxyTarget, error) { + return p.target, p.err +} + +func TestTargetProvider(t *testing.T) { + t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "target 1") + })) + defer t1.Close() + url1, _ := url.Parse(t1.URL) + + e := echo.New() + tp := &testProvider{commonBalancer: new(commonBalancer)} + tp.target = &ProxyTarget{Name: "target 1", URL: url1} + e.Use(Proxy(tp)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + e.ServeHTTP(rec, req) + body := rec.Body.String() + assert.Equal(t, "target 1", body) +} + +func TestFailNextTarget(t *testing.T) { + url1, err := url.Parse("http://dummy:8080") + assert.Nil(t, err) + + e := echo.New() + tp := &testProvider{commonBalancer: new(commonBalancer)} + tp.target = &ProxyTarget{Name: "target 1", URL: url1} + tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target") + + e.Use(Proxy(tp)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + e.ServeHTTP(rec, req) + body := rec.Body.String() + assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body) +} + func TestProxyRealIPHeader(t *testing.T) { // Setup upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) From b010b69329fa3da8b858d3d8b45cacf61783f369 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Fri, 4 Nov 2022 11:32:53 +0200 Subject: [PATCH 250/446] Bump dependencies and add notes about Go releases we support --- .github/workflows/echo.yml | 6 ++++-- CHANGELOG.md | 12 +++++++++++ Makefile | 4 ++-- go.mod | 20 ++++++++--------- go.sum | 44 ++++++++++++++++++++++---------------- 5 files changed, 54 insertions(+), 32 deletions(-) diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index db60f7f84..c2bd41e1b 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -27,8 +27,10 @@ jobs: matrix: os: [ubuntu-latest, macos-latest, windows-latest] # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy - # Echo tests with last four major releases - go: [1.16, 1.17, 1.18, 1.19] + # Echo tests with last four major releases (unless there are pressing vulnerabilities) + # As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when + # we derive from last four major releases promise. + go: [1.17, 1.18, 1.19] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b71fb8e4..c629350c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## v4.10.0 - 2022-xx-xx + +**Security** + +This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are +several vulnerabilities fixed in these libraries. + +Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise. + + + + ## v4.9.1 - 2022-10-12 **Fixes** diff --git a/Makefile b/Makefile index a6c4aaa90..3b7651983 100644 --- a/Makefile +++ b/Makefile @@ -29,6 +29,6 @@ benchmark: ## Run benchmarks help: ## Display this help screen @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -goversion ?= "1.16" -test_version: ## Run tests inside Docker with given version (defaults to 1.15 oldest supported). Example: make test_version goversion=1.16 +goversion ?= "1.17" +test_version: ## Run tests inside Docker with given version (defaults to 1.17 oldest supported). Example: make test_version goversion=1.17 @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" diff --git a/go.mod b/go.mod index e9f611ccf..73fd6d900 100644 --- a/go.mod +++ b/go.mod @@ -5,20 +5,20 @@ go 1.17 require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/labstack/gommon v0.4.0 - github.com/stretchr/testify v1.7.0 - github.com/valyala/fasttemplate v1.2.1 - golang.org/x/crypto v0.0.0-20221012134737-56aed061732a - golang.org/x/net v0.0.0-20221014081412-f15817d10f9b - golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 + github.com/stretchr/testify v1.8.1 + github.com/valyala/fasttemplate v1.2.2 + golang.org/x/crypto v0.2.0 + golang.org/x/net v0.2.0 + golang.org/x/time v0.2.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-colorable v0.1.11 // indirect - github.com/mattn/go-isatty v0.0.14 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.16 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect - golang.org/x/text v0.3.8 // indirect - gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect + golang.org/x/sys v0.2.0 // indirect + golang.org/x/text v0.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f5dbb44fb..b052ff9d6 100644 --- a/go.sum +++ b/go.sum @@ -5,54 +5,61 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8= github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= -github.com/mattn/go-colorable v0.1.11 h1:nQ+aFkoE2TMGc0b68U2OKSexC+eq46+XwZzWXHRmPYs= github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= +github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20221012134737-56aed061732a h1:NmSIgad6KjE6VvHciPZuNRTKxGhlPfD6OA87W/PLkqg= -golang.org/x/crypto v0.0.0-20221012134737-56aed061732a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.2.0 h1:BRXPfhNivWL5Yq0BGQ39a2sW6t44aODpfxkWjYdzewE= +golang.org/x/crypto v0.2.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.0.0-20221014081412-f15817d10f9b h1:tvrvnPFcdzp294diPnrdZZZ8XUt2Tyj7svb7X52iDuU= -golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU= +golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg= -golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= -golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/time v0.2.0 h1:52I/1L54xyEQAYdtcSuxtiT84KGYTBGXwayxmIpNJhE= +golang.org/x/time v0.2.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= @@ -60,5 +67,6 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 35184a893b98d2e2451d061c210a61e2d50030a0 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 12 Nov 2022 23:14:59 +0200 Subject: [PATCH 251/446] Expose middleware.CreateExtractors function so we can use it from echo-contrib repository --- middleware/csrf.go | 2 +- middleware/extractor.go | 20 ++++++++++++++++++++ middleware/extractor_test.go | 2 +- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/middleware/csrf.go b/middleware/csrf.go index ea90fdba7..8661c9f89 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -119,7 +119,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { config.CookieSecure = true } - extractors, err := createExtractors(config.TokenLookup, "") + extractors, err := CreateExtractors(config.TokenLookup) if err != nil { panic(err) } diff --git a/middleware/extractor.go b/middleware/extractor.go index afdfd8195..5d9cee6d0 100644 --- a/middleware/extractor.go +++ b/middleware/extractor.go @@ -24,6 +24,26 @@ var errFormExtractorValueMissing = errors.New("missing value in the form") // ValuesExtractor defines a function for extracting values (keys/tokens) from the given context. type ValuesExtractor func(c echo.Context) ([]string, error) +// CreateExtractors creates ValuesExtractors from given lookups. +// Lookups is a string in the form of ":" or ":,:" that is used +// to extract key from the request. +// Possible values: +// - "header:" or "header::" +// `` is argument value to cut/trim prefix of the extracted value. This is useful if header +// value has static prefix like `Authorization: ` where part that we +// want to cut is ` ` note the space at the end. +// In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. +// - "query:" +// - "param:" +// - "form:" +// - "cookie:" +// +// Multiple sources example: +// - "header:Authorization,header:X-Api-Key" +func CreateExtractors(lookups string) ([]ValuesExtractor, error) { + return createExtractors(lookups, "") +} + func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, error) { if lookups == "" { return nil, nil diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go index 2e898f541..428c5563e 100644 --- a/middleware/extractor_test.go +++ b/middleware/extractor_test.go @@ -110,7 +110,7 @@ func TestCreateExtractors(t *testing.T) { setPathParams(c, tc.givenPathParams) } - extractors, err := createExtractors(tc.whenLoopups, "") + extractors, err := CreateExtractors(tc.whenLoopups) if tc.expectCreateError != "" { assert.EqualError(t, err, tc.expectCreateError) return From a97d4bfb7b609b7daba06f9cf4d2eb3e95482174 Mon Sep 17 00:00:00 2001 From: lkeix Date: Tue, 25 Oct 2022 23:12:15 +0900 Subject: [PATCH 252/446] fix func(Context) error to HandlerFunc --- echo.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/echo.go b/echo.go index 5ae8a1424..856e6bde5 100644 --- a/echo.go +++ b/echo.go @@ -626,7 +626,7 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Acquire context c := e.pool.Get().(*context) c.Reset(r, w) - var h func(Context) error + var h HandlerFunc if e.premiddleware == nil { e.findRouter(r.Host).Find(r.Method, GetPath(r), c) From fd2b102d3ee1bd538e13418f6382e98cff77fdb0 Mon Sep 17 00:00:00 2001 From: wanghaha-dev Date: Sun, 6 Nov 2022 22:00:51 +0800 Subject: [PATCH 253/446] Modify comment syntax error --- echo.go | 2 +- middleware/slash.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/echo.go b/echo.go index 856e6bde5..068558aba 100644 --- a/echo.go +++ b/echo.go @@ -565,7 +565,7 @@ func (e *Echo) Group(prefix string, m ...MiddlewareFunc) (g *Group) { return } -// URI generates a URI from handler. +// URI generates an URI from handler. func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string { name := handlerName(handler) return e.Reverse(name, params...) diff --git a/middleware/slash.go b/middleware/slash.go index 4188675b0..a3bf807ec 100644 --- a/middleware/slash.go +++ b/middleware/slash.go @@ -33,7 +33,7 @@ func AddTrailingSlash() echo.MiddlewareFunc { return AddTrailingSlashWithConfig(DefaultTrailingSlashConfig) } -// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config. +// AddTrailingSlashWithConfig returns an AddTrailingSlash middleware with config. // See `AddTrailingSlash()`. func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc { // Defaults From be23ab67ccdcb94ff99f49638b252be03ae274a6 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 19 Nov 2022 22:05:53 +0200 Subject: [PATCH 254/446] Add new method HTTPError.WithInternal --- echo.go | 55 ++++++++++++++++++++++++++++++---------------------- echo_test.go | 22 +++++++++++++++++++-- 2 files changed, 52 insertions(+), 25 deletions(-) diff --git a/echo.go b/echo.go index 068558aba..28f6565d7 100644 --- a/echo.go +++ b/echo.go @@ -3,34 +3,34 @@ Package echo implements high performance, minimalist Go web framework. Example: - package main + package main - import ( - "net/http" + import ( + "net/http" - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" - ) + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + ) - // Handler - func hello(c echo.Context) error { - return c.String(http.StatusOK, "Hello, World!") - } + // Handler + func hello(c echo.Context) error { + return c.String(http.StatusOK, "Hello, World!") + } - func main() { - // Echo instance - e := echo.New() + func main() { + // Echo instance + e := echo.New() - // Middleware - e.Use(middleware.Logger()) - e.Use(middleware.Recover()) + // Middleware + e.Use(middleware.Logger()) + e.Use(middleware.Recover()) - // Routes - e.GET("/", hello) + // Routes + e.GET("/", hello) - // Start server - e.Logger.Fatal(e.Start(":1323")) - } + // Start server + e.Logger.Fatal(e.Start(":1323")) + } Learn more at https://echo.labstack.com */ @@ -884,6 +884,15 @@ func (he *HTTPError) SetInternal(err error) *HTTPError { return he } +// WithInternal returns clone of HTTPError with err set to HTTPError.Internal field +func (he *HTTPError) WithInternal(err error) *HTTPError { + return &HTTPError{ + Code: he.Code, + Message: he.Message, + Internal: err, + } +} + // Unwrap satisfies the Go 1.13 error wrapper interface. func (he *HTTPError) Unwrap() error { return he.Internal @@ -913,8 +922,8 @@ func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { // GetPath returns RawPath, if it's empty returns Path from URL // Difference between RawPath and Path is: -// * Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/. -// * RawPath is an optional field which only gets set if the default encoding is different from Path. +// - Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/. +// - RawPath is an optional field which only gets set if the default encoding is different from Path. func GetPath(r *http.Request) string { path := r.URL.RawPath if path == "" { diff --git a/echo_test.go b/echo_test.go index 6bece4fd3..aa2954c92 100644 --- a/echo_test.go +++ b/echo_test.go @@ -1206,13 +1206,22 @@ func TestHTTPError(t *testing.T) { assert.Equal(t, "code=400, message=map[code:12]", err.Error()) }) - t.Run("internal", func(t *testing.T) { + + t.Run("internal and SetInternal", func(t *testing.T) { err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ "code": 12, }) err.SetInternal(errors.New("internal error")) assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error()) }) + + t.Run("internal and WithInternal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + err = err.WithInternal(errors.New("internal error")) + assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error()) + }) } func TestHTTPError_Unwrap(t *testing.T) { @@ -1223,13 +1232,22 @@ func TestHTTPError_Unwrap(t *testing.T) { assert.Nil(t, errors.Unwrap(err)) }) - t.Run("internal", func(t *testing.T) { + + t.Run("unwrap internal and SetInternal", func(t *testing.T) { err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ "code": 12, }) err.SetInternal(errors.New("internal error")) assert.Equal(t, "internal error", errors.Unwrap(err).Error()) }) + + t.Run("unwrap internal and WithInternal", func(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ + "code": 12, + }) + err = err.WithInternal(errors.New("internal error")) + assert.Equal(t, "internal error", errors.Unwrap(err).Error()) + }) } func TestDefaultHTTPErrorHandler(t *testing.T) { From 3c4d3b3083d71aa26ae4db7206091bad02c38a6f Mon Sep 17 00:00:00 2001 From: zeek Date: Mon, 21 Nov 2022 21:29:43 +0900 Subject: [PATCH 255/446] Replace "io/ioutil" "io/ioutil" pakcage has been deprecated since Go 1.16. --- echo.go | 4 ++-- echo_test.go | 17 ++++++++--------- group_test.go | 4 ++-- middleware/body_dump.go | 5 ++--- middleware/body_dump_test.go | 4 ++-- middleware/body_limit_test.go | 10 +++++----- middleware/compress.go | 5 ++--- middleware/compress_test.go | 4 ++-- middleware/decompress_test.go | 12 ++++++------ middleware/proxy_test.go | 4 ++-- middleware/rewrite_test.go | 4 ++-- middleware/timeout_test.go | 4 ++-- 12 files changed, 37 insertions(+), 40 deletions(-) diff --git a/echo.go b/echo.go index 28f6565d7..2b632c980 100644 --- a/echo.go +++ b/echo.go @@ -43,10 +43,10 @@ import ( "errors" "fmt" "io" - "io/ioutil" stdLog "log" "net" "net/http" + "os" "reflect" "runtime" "sync" @@ -700,7 +700,7 @@ func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err erro func filepathOrContent(fileOrContent interface{}) (content []byte, err error) { switch v := fileOrContent.(type) { case string: - return ioutil.ReadFile(v) + return os.ReadFile(v) case []byte: return v, nil default: diff --git a/echo_test.go b/echo_test.go index aa2954c92..b0d1ccd28 100644 --- a/echo_test.go +++ b/echo_test.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "net/http" "net/http/httptest" @@ -244,7 +243,7 @@ func TestEchoStaticRedirectIndex(t *testing.T) { }(resp.Body) assert.Equal(t, http.StatusOK, resp.StatusCode) - if body, err := ioutil.ReadAll(resp.Body); err == nil { + if body, err := io.ReadAll(resp.Body); err == nil { assert.Equal(t, true, strings.HasPrefix(string(body), "")) } else { assert.Fail(t, err.Error()) @@ -1032,9 +1031,9 @@ func TestEchoStartTLSAndStart(t *testing.T) { } func TestEchoStartTLSByteString(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") + cert, err := os.ReadFile("_fixture/certs/cert.pem") require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") + key, err := os.ReadFile("_fixture/certs/key.pem") require.NoError(t, err) testCases := []struct { @@ -1413,7 +1412,7 @@ func TestEchoListenerNetwork(t *testing.T) { }(resp.Body) assert.Equal(t, http.StatusOK, resp.StatusCode) - if body, err := ioutil.ReadAll(resp.Body); err == nil { + if body, err := io.ReadAll(resp.Body); err == nil { assert.Equal(t, "OK", string(body)) } else { assert.Fail(t, err.Error()) @@ -1495,9 +1494,9 @@ func TestEcho_ListenerAddr(t *testing.T) { } func TestEcho_TLSListenerAddr(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") + cert, err := os.ReadFile("_fixture/certs/cert.pem") require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") + key, err := os.ReadFile("_fixture/certs/key.pem") require.NoError(t, err) e := New() @@ -1515,9 +1514,9 @@ func TestEcho_TLSListenerAddr(t *testing.T) { } func TestEcho_StartServer(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") + cert, err := os.ReadFile("_fixture/certs/cert.pem") require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") + key, err := os.ReadFile("_fixture/certs/key.pem") require.NoError(t, err) certs, err := tls.X509KeyPair(cert, key) require.NoError(t, err) diff --git a/group_test.go b/group_test.go index 24f191677..01c304d0c 100644 --- a/group_test.go +++ b/group_test.go @@ -1,9 +1,9 @@ package echo import ( - "io/ioutil" "net/http" "net/http/httptest" + "os" "testing" "github.com/stretchr/testify/assert" @@ -32,7 +32,7 @@ func TestGroupFile(t *testing.T) { e := New() g := e.Group("/group") g.File("/walle", "_fixture/images/walle.png") - expectedData, err := ioutil.ReadFile("_fixture/images/walle.png") + expectedData, err := os.ReadFile("_fixture/images/walle.png") assert.Nil(t, err) req := httptest.NewRequest(http.MethodGet, "/group/walle", nil) rec := httptest.NewRecorder() diff --git a/middleware/body_dump.go b/middleware/body_dump.go index ebd0d0ab2..fa7891b16 100644 --- a/middleware/body_dump.go +++ b/middleware/body_dump.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "io" - "io/ioutil" "net" "net/http" @@ -68,9 +67,9 @@ func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc { // Request reqBody := []byte{} if c.Request().Body != nil { // Read - reqBody, _ = ioutil.ReadAll(c.Request().Body) + reqBody, _ = io.ReadAll(c.Request().Body) } - c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset + c.Request().Body = io.NopCloser(bytes.NewBuffer(reqBody)) // Reset // Response resBody := new(bytes.Buffer) diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go index 533971a47..de1de3356 100644 --- a/middleware/body_dump_test.go +++ b/middleware/body_dump_test.go @@ -2,7 +2,7 @@ package middleware import ( "errors" - "io/ioutil" + "io" "net/http" "net/http/httptest" "strings" @@ -19,7 +19,7 @@ func TestBodyDump(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c echo.Context) error { - body, err := ioutil.ReadAll(c.Request().Body) + body, err := io.ReadAll(c.Request().Body) if err != nil { return err } diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index b891767c5..8ffed55a4 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -2,7 +2,7 @@ package middleware import ( "bytes" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -18,7 +18,7 @@ func TestBodyLimit(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) h := func(c echo.Context) error { - body, err := ioutil.ReadAll(c.Request().Body) + body, err := io.ReadAll(c.Request().Body) if err != nil { return err } @@ -67,18 +67,18 @@ func TestBodyLimitReader(t *testing.T) { } reader := &limitedReader{ BodyLimitConfig: config, - reader: ioutil.NopCloser(bytes.NewReader(hw)), + reader: io.NopCloser(bytes.NewReader(hw)), context: e.NewContext(req, rec), } // read all should return ErrStatusRequestEntityTooLarge - _, err := ioutil.ReadAll(reader) + _, err := io.ReadAll(reader) he := err.(*echo.HTTPError) assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) // reset reader and read two bytes must succeed bt := make([]byte, 2) - reader.Reset(ioutil.NopCloser(bytes.NewReader(hw)), e.NewContext(req, rec)) + reader.Reset(io.NopCloser(bytes.NewReader(hw)), e.NewContext(req, rec)) n, err := reader.Read(bt) assert.Equal(t, 2, n) assert.Equal(t, nil, err) diff --git a/middleware/compress.go b/middleware/compress.go index ac6672e9d..9e5f61069 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -4,7 +4,6 @@ import ( "bufio" "compress/gzip" "io" - "io/ioutil" "net" "net/http" "strings" @@ -89,7 +88,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { // nothing is written to body or error is returned. // See issue #424, #407. res.Writer = rw - w.Reset(ioutil.Discard) + w.Reset(io.Discard) } w.Close() pool.Put(w) @@ -135,7 +134,7 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { func gzipCompressPool(config GzipConfig) sync.Pool { return sync.Pool{ New: func() interface{} { - w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level) + w, err := gzip.NewWriterLevel(io.Discard, config.Level) if err != nil { return err } diff --git a/middleware/compress_test.go b/middleware/compress_test.go index c8dd2e1f4..714548e8b 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -4,9 +4,9 @@ import ( "bytes" "compress/gzip" "io" - "io/ioutil" "net/http" "net/http/httptest" + "os" "testing" "github.com/labstack/echo/v4" @@ -173,7 +173,7 @@ func TestGzipWithStatic(t *testing.T) { r, err := gzip.NewReader(rec.Body) if assert.NoError(t, err) { defer r.Close() - want, err := ioutil.ReadFile("../_fixture/images/walle.png") + want, err := os.ReadFile("../_fixture/images/walle.png") if assert.NoError(t, err) { buf := new(bytes.Buffer) buf.ReadFrom(r) diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 42c6250ac..2e73ba80e 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -4,7 +4,7 @@ import ( "bytes" "compress/gzip" "errors" - "io/ioutil" + "io" "net/http" "net/http/httptest" "strings" @@ -39,7 +39,7 @@ func TestDecompress(t *testing.T) { c = e.NewContext(req, rec) h(c) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) - b, err := ioutil.ReadAll(req.Body) + b, err := io.ReadAll(req.Body) assert.NoError(t, err) assert.Equal(t, body, string(b)) } @@ -67,7 +67,7 @@ func TestDecompressDefaultConfig(t *testing.T) { c = e.NewContext(req, rec) h(c) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) - b, err := ioutil.ReadAll(req.Body) + b, err := io.ReadAll(req.Body) assert.NoError(t, err) assert.Equal(t, body, string(b)) } @@ -82,7 +82,7 @@ func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { e.NewContext(req, rec) e.ServeHTTP(rec, req) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) - b, err := ioutil.ReadAll(req.Body) + b, err := io.ReadAll(req.Body) assert.NoError(t, err) assert.NotEqual(t, b, body) assert.Equal(t, b, gz) @@ -132,7 +132,7 @@ func TestDecompressSkipper(t *testing.T) { c := e.NewContext(req, rec) e.ServeHTTP(rec, req) assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8) - reqBody, err := ioutil.ReadAll(c.Request().Body) + reqBody, err := io.ReadAll(c.Request().Body) assert.NoError(t, err) assert.Equal(t, body, string(reqBody)) } @@ -161,7 +161,7 @@ func TestDecompressPoolError(t *testing.T) { c := e.NewContext(req, rec) e.ServeHTTP(rec, req) assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) - reqBody, err := ioutil.ReadAll(c.Request().Body) + reqBody, err := io.ReadAll(c.Request().Body) assert.NoError(t, err) assert.Equal(t, body, string(reqBody)) assert.Equal(t, rec.Code, http.StatusInternalServerError) diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 0ded50a1f..4b1dbef92 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -4,7 +4,7 @@ import ( "bytes" "context" "fmt" - "io/ioutil" + "io" "net" "net/http" "net/http/httptest" @@ -93,7 +93,7 @@ func TestProxy(t *testing.T) { e.Use(ProxyWithConfig(ProxyConfig{ Balancer: rrb, ModifyResponse: func(res *http.Response) error { - res.Body = ioutil.NopCloser(bytes.NewBuffer([]byte("modified"))) + res.Body = io.NopCloser(bytes.NewBuffer([]byte("modified"))) res.Header.Set("X-Modified", "1") return nil }, diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index 0ac04bb2f..47d707c30 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -1,7 +1,7 @@ package middleware import ( - "io/ioutil" + "io" "net/http" "net/http/httptest" "net/url" @@ -142,7 +142,7 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { assert.Equal(t, http.StatusOK, rec.Code) defer rec.Result().Body.Close() - bodyBytes, _ := ioutil.ReadAll(rec.Result().Body) + bodyBytes, _ := io.ReadAll(rec.Result().Body) assert.Equal(t, "hosts", string(bodyBytes)) } } diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index 56eb7bc74..dbac3fbc3 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -5,7 +5,7 @@ import ( "context" "errors" "fmt" - "io/ioutil" + "io" "log" "net" "net/http" @@ -410,7 +410,7 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { } assert.Equal(t, tc.expectStatusCode, res.StatusCode) - if body, err := ioutil.ReadAll(res.Body); err == nil { + if body, err := io.ReadAll(res.Body); err == nil { assert.Equal(t, tc.expectResponse, string(body)) } else { assert.Fail(t, err.Error()) From a0c211542ccfd9a1bc21dd676dfa4b38140f5e60 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Mon, 21 Nov 2022 16:05:30 +0200 Subject: [PATCH 256/446] Add staticcheck to CI flow --- .github/workflows/echo.yml | 16 +++++++++++----- Makefile | 2 ++ middleware/jwt.go | 2 +- middleware/proxy_test.go | 3 +-- middleware/timeout_test.go | 6 +++--- 5 files changed, 18 insertions(+), 11 deletions(-) diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index c2bd41e1b..e35e7f107 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -44,13 +44,19 @@ jobs: with: go-version: ${{ matrix.go }} - - name: Install Dependencies - run: go install golang.org/x/lint/golint@latest - - name: Run Tests + run: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... + + - name: Install dependencies for checks run: | - golint -set_exit_status ./... - go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... + go install golang.org/x/lint/golint@latest + go install honnef.co/go/tools/cmd/staticcheck@latest + + - name: Run golint + run: golint -set_exit_status ./... + + - name: Run staticcheck + run: staticcheck ./... - name: Upload coverage to Codecov if: success() && matrix.go == 1.19 && matrix.os == 'ubuntu-latest' diff --git a/Makefile b/Makefile index 3b7651983..6aff6a89f 100644 --- a/Makefile +++ b/Makefile @@ -10,8 +10,10 @@ check: lint vet race ## Check project init: @go install golang.org/x/lint/golint@latest + @go install honnef.co/go/tools/cmd/staticcheck@latest lint: ## Lint the files + @staticcheck ${PKG_LIST} @golint -set_exit_status ${PKG_LIST} vet: ## Vet the files diff --git a/middleware/jwt.go b/middleware/jwt.go index bec5167e2..ef6ad4ebc 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -262,7 +262,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { } func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) { - token := new(jwt.Token) + var token *jwt.Token var err error // Issue #647, #656 if _, ok := config.Claims.(jwt.MapClaims); ok { diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 4b1dbef92..a1b7f2cae 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -384,10 +384,9 @@ func TestProxyError(t *testing.T) { e := echo.New() e.Use(Proxy(rb)) req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() // Remote unreachable - rec = httptest.NewRecorder() + rec := httptest.NewRecorder() req.URL.Path = "/api/users" e.ServeHTTP(rec, req) assert.Equal(t, "/api/users", req.URL.Path) diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index dbac3fbc3..6f60753c6 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -129,7 +129,7 @@ func TestTimeoutOnTimeoutRouteErrorHandler(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) - stopChan := make(chan struct{}, 0) + stopChan := make(chan struct{}) err := m(func(c echo.Context) error { <-stopChan return errors.New("error in route after timeout") @@ -245,7 +245,7 @@ func TestTimeoutWithErrorMessage(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) - stopChan := make(chan struct{}, 0) + stopChan := make(chan struct{}) err := m(func(c echo.Context) error { // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output @@ -275,7 +275,7 @@ func TestTimeoutWithDefaultErrorMessage(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) - stopChan := make(chan struct{}, 0) + stopChan := make(chan struct{}) err := m(func(c echo.Context) error { <-stopChan return c.String(http.StatusOK, "Hello, World!") From 36ff0b3fbd9bbf406e91d37ff14c34f3c17c749f Mon Sep 17 00:00:00 2001 From: OHZEKI Naoki <0h23k1.n40k1@gmail.com> Date: Wed, 23 Nov 2022 09:13:50 +0900 Subject: [PATCH 257/446] Replace relative path determination from proprietary to std --- echo_fs.go | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/echo_fs.go b/echo_fs.go index b8526da9e..9f83a0351 100644 --- a/echo_fs.go +++ b/echo_fs.go @@ -7,7 +7,6 @@ import ( "net/url" "os" "path/filepath" - "runtime" "strings" ) @@ -125,7 +124,7 @@ func subFS(currentFs fs.FS, root string) (fs.FS, error) { // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs - if isRelativePath(root) { + if !filepath.IsAbs(root) { root = filepath.Join(dFS.prefix, root) } return &defaultFS{ @@ -136,21 +135,6 @@ func subFS(currentFs fs.FS, root string) (fs.FS, error) { return fs.Sub(currentFs, root) } -func isRelativePath(path string) bool { - if path == "" { - return true - } - if path[0] == '/' { - return false - } - if runtime.GOOS == "windows" && strings.IndexByte(path, ':') != -1 { - // https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file?redirectedfrom=MSDN#file_and_directory_names - // https://docs.microsoft.com/en-us/dotnet/standard/io/file-path-formats - return false - } - return true -} - // MustSubFS creates sub FS from current filesystem or panic on failure. // Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. // From 754479694694f6d0bd280d453c10d63348e4e3b7 Mon Sep 17 00:00:00 2001 From: Wim Date: Thu, 24 Nov 2022 21:17:31 +0100 Subject: [PATCH 258/446] Remove square brackets from ipv6 addresses in XFF (#2182) Remove square brackets from ipv6 addresses in XFF --- context.go | 7 ++++- context_test.go | 36 ++++++++++++++++++++++- ip.go | 7 ++++- ip_test.go | 78 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 125 insertions(+), 3 deletions(-) diff --git a/context.go b/context.go index 5567100b9..df2228a6f 100644 --- a/context.go +++ b/context.go @@ -282,11 +282,16 @@ func (c *context) RealIP() string { if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" { i := strings.IndexAny(ip, ",") if i > 0 { - return strings.TrimSpace(ip[:i]) + xffip := strings.TrimSpace(ip[:i]) + xffip = strings.TrimPrefix(xffip, "[") + xffip = strings.TrimSuffix(xffip, "]") + return xffip } return ip } if ip := c.request.Header.Get(HeaderXRealIP); ip != "" { + ip = strings.TrimPrefix(ip, "[") + ip = strings.TrimSuffix(ip, "]") return ip } ra, _, _ := net.SplitHostPort(c.request.RemoteAddr) diff --git a/context_test.go b/context_test.go index b25e11c60..11a63cfce 100644 --- a/context_test.go +++ b/context_test.go @@ -97,7 +97,6 @@ func (responseWriterErr) Write([]byte) (int, error) { } func (responseWriterErr) WriteHeader(statusCode int) { - } func TestContext(t *testing.T) { @@ -904,6 +903,30 @@ func TestContext_RealIP(t *testing.T) { }, "127.0.0.1", }, + { + &context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348], 2001:db8::1, "}}, + }, + }, + "2001:db8:85a3:8d3:1319:8a2e:370:7348", + }, + { + &context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348],[2001:db8::1]"}}, + }, + }, + "2001:db8:85a3:8d3:1319:8a2e:370:7348", + }, + { + &context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}}, + }, + }, + "2001:db8:85a3:8d3:1319:8a2e:370:7348", + }, { &context{ request: &http.Request{ @@ -914,6 +937,17 @@ func TestContext_RealIP(t *testing.T) { }, "192.168.0.1", }, + { + &context{ + request: &http.Request{ + Header: http.Header{ + "X-Real-Ip": []string{"[2001:db8::1]"}, + }, + }, + }, + "2001:db8::1", + }, + { &context{ request: &http.Request{ diff --git a/ip.go b/ip.go index 46d464cf9..1bcd756ae 100644 --- a/ip.go +++ b/ip.go @@ -227,6 +227,8 @@ func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { return func(req *http.Request) string { realIP := req.Header.Get(HeaderXRealIP) if realIP != "" { + realIP = strings.TrimPrefix(realIP, "[") + realIP = strings.TrimSuffix(realIP, "]") if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) { return realIP } @@ -248,7 +250,10 @@ func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor { } ips := append(strings.Split(strings.Join(xffs, ","), ","), directIP) for i := len(ips) - 1; i >= 0; i-- { - ip := net.ParseIP(strings.TrimSpace(ips[i])) + ips[i] = strings.TrimSpace(ips[i]) + ips[i] = strings.TrimPrefix(ips[i], "[") + ips[i] = strings.TrimSuffix(ips[i], "]") + ip := net.ParseIP(ips[i]) if ip == nil { // Unable to parse IP; cannot trust entire records return directIP diff --git a/ip_test.go b/ip_test.go index 755900d3d..38c4a1cac 100644 --- a/ip_test.go +++ b/ip_test.go @@ -459,6 +459,7 @@ func TestExtractIPDirect(t *testing.T) { func TestExtractIPFromRealIPHeader(t *testing.T) { _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") + _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64") var testCases = []struct { name string @@ -493,6 +494,16 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, expectIP: "203.0.113.1", }, + { + name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"[2001:db8::113:199]"}, // <-- this is untrusted + }, + RemoteAddr: "[2001:db8::113:1]:8080", + }, + expectIP: "2001:db8::113:1", + }, { name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" @@ -506,6 +517,19 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, expectIP: "203.0.113.199", }, + { + name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"[2001:db8::113:199]"}, + }, + RemoteAddr: "[2001:db8::113:1]:8080", + }, + expectIP: "2001:db8::113:199", + }, { name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" @@ -520,6 +544,20 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, expectIP: "203.0.113.199", }, + { + name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"[2001:db8::113:199]"}, + HeaderXForwardedFor: []string{"[2001:db8::113:198], [2001:db8::113:197]"}, // <-- should not affect anything + }, + RemoteAddr: "[2001:db8::113:1]:8080", + }, + expectIP: "2001:db8::113:199", + }, } for _, tc := range testCases { @@ -532,6 +570,7 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { func TestExtractIPFromXFFHeader(t *testing.T) { _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") + _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64") var testCases = []struct { name string @@ -566,6 +605,16 @@ func TestExtractIPFromXFFHeader(t *testing.T) { }, expectIP: "127.0.0.3", }, + { + name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"[fe80::3], [fe80::2], [fe80::1]"}, + }, + RemoteAddr: "[fe80::1]:8080", + }, + expectIP: "fe80::3", + }, { name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr", whenRequest: http.Request{ @@ -576,6 +625,16 @@ func TestExtractIPFromXFFHeader(t *testing.T) { }, expectIP: "203.0.113.1", }, + { + name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"[2001:db8::1]"}, // <-- this is untrusted + }, + RemoteAddr: "[2001:db8::2]:8080", + }, + expectIP: "2001:db8::2", + }, { name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header", givenTrustOptions: []TrustOption{ @@ -595,6 +654,25 @@ func TestExtractIPFromXFFHeader(t *testing.T) { }, expectIP: "203.0.100.100", // this is first trusted IP in XFF chain }, + { + name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header", + givenTrustOptions: []TrustOption{ + TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64" + }, + // from request its seems that request has been proxied through 6 servers. + // 1) 2001:db8:1::1:100 (this is external IP set by 2001:db8:2::100:100 which we do not trust - could be spoofed) + // 2) 2001:db8:2::100:100 (this is outside of our network but set by 2001:db8::113:199 which we trust to set correct IPs) + // 3) 2001:db8::113:199 (we trust, for example maybe our proxy from some other office) + // 4) fd12:3456:789a:1::1 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products) + // 5) fe80::1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing) + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"[2001:db8:1::1:100], [2001:db8:2::100:100], [2001:db8::113:199], [fd12:3456:789a:1::1]"}, + }, + RemoteAddr: "[fe80::1]:8080", // IP of proxy upstream of our APP + }, + expectIP: "2001:db8:2::100:100", // this is first trusted IP in XFF chain + }, } for _, tc := range testCases { From 466bf80e418b3a3a4f9adafa911a51e3ba317db0 Mon Sep 17 00:00:00 2001 From: Martti T Date: Fri, 25 Nov 2022 13:27:52 +0200 Subject: [PATCH 259/446] Add testcases for some BodyLimit middleware configuration options (#2350) * Add testcases for some BodyLimit middleware configuration options --- middleware/body_limit_test.go | 89 +++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index 8ffed55a4..2bfce372a 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -83,3 +83,92 @@ func TestBodyLimitReader(t *testing.T) { assert.Equal(t, 2, n) assert.Equal(t, nil, err) } + +func TestBodyLimitWithConfig_Skipper(t *testing.T) { + e := echo.New() + h := func(c echo.Context) error { + body, err := io.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + mw := BodyLimitWithConfig(BodyLimitConfig{ + Skipper: func(c echo.Context) bool { + return true + }, + Limit: "2B", // if not skipped this limit would make request to fail limit check + }) + + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) +} + +func TestBodyLimitWithConfig(t *testing.T) { + var testCases = []struct { + name string + givenLimit string + whenBody []byte + expectBody []byte + expectError string + }{ + { + name: "ok, body is less than limit", + givenLimit: "10B", + whenBody: []byte("123456789"), + expectBody: []byte("123456789"), + expectError: "", + }, + { + name: "nok, body is more than limit", + givenLimit: "9B", + whenBody: []byte("1234567890"), + expectBody: []byte(nil), + expectError: "code=413, message=Request Entity Too Large", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + h := func(c echo.Context) error { + body, err := io.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + mw := BodyLimitWithConfig(BodyLimitConfig{ + Limit: tc.givenLimit, + }) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(tc.whenBody)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := mw(h)(c) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + // not testing status as middlewares return error instead of committing it and OK cases are anyway 200 + assert.Equal(t, tc.expectBody, rec.Body.Bytes()) + }) + } +} + +func TestBodyLimit_panicOnInvalidLimit(t *testing.T) { + assert.PanicsWithError( + t, + "echo: invalid body-limit=", + func() { BodyLimit("") }, + ) +} From 8d4ac4c907696c1aecb237d088f9b648c00f274b Mon Sep 17 00:00:00 2001 From: Martti T Date: Wed, 30 Nov 2022 15:47:23 +0200 Subject: [PATCH 260/446] Additional configuration options for RequestLogger and Logger middleware (#2341) * Add `middleware.RequestLoggerConfig.HandleError` configuration option to handle error within middleware with global error handler thus setting response status code decided by error handler and not derived from error itself. * Add `middleware.LoggerConfig.CustomTagFunc` so Logger middleware can add custom text to logged row. --- context.go | 6 +- echo.go | 2 +- middleware/logger.go | 11 ++++ middleware/logger_test.go | 22 +++++++ middleware/request_logger.go | 102 ++++++++++++++++++++++-------- middleware/request_logger_test.go | 52 +++++++++++++-- 6 files changed, 162 insertions(+), 33 deletions(-) diff --git a/context.go b/context.go index df2228a6f..b3a7ce8d0 100644 --- a/context.go +++ b/context.go @@ -169,7 +169,11 @@ type ( // Redirect redirects the request to a provided URL with status code. Redirect(code int, url string) error - // Error invokes the registered HTTP error handler. Generally used by middleware. + // Error invokes the registered global HTTP error handler. Generally used by middleware. + // A side-effect of calling global error handler is that now Response has been committed (sent to the client) and + // middlewares up in chain can not change Response status code or Response body anymore. + // + // Avoid using this method in handlers as no middleware will be able to effectively handle errors after that. Error(err error) // Handler returns the matched handler by router. diff --git a/echo.go b/echo.go index 2b632c980..2f54d7711 100644 --- a/echo.go +++ b/echo.go @@ -116,7 +116,7 @@ type ( HandlerFunc func(c Context) error // HTTPErrorHandler is a centralized HTTP error handler. - HTTPErrorHandler func(error, Context) + HTTPErrorHandler func(err error, c Context) // Validator is the interface that wraps the Validate function. Validator interface { diff --git a/middleware/logger.go b/middleware/logger.go index a21df8f39..8bf335ffb 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -47,6 +47,7 @@ type ( // - header: // - query: // - form: + // - custom (see CustomTagFunc field) // // Example "${remote_ip} ${status}" // @@ -56,6 +57,11 @@ type ( // Optional. Default value DefaultLoggerConfig.CustomTimeFormat. CustomTimeFormat string `yaml:"custom_time_format"` + // CustomTagFunc is function called for `${custom}` tag to output user implemented text by writing it to buf. + // Make sure that outputted text creates valid JSON string with other logged tags. + // Optional. + CustomTagFunc func(c echo.Context, buf *bytes.Buffer) (int, error) + // Output is a writer where logs in JSON format are written. // Optional. Default value os.Stdout. Output io.Writer @@ -126,6 +132,11 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { switch tag { + case "custom": + if config.CustomTagFunc == nil { + return 0, nil + } + return config.CustomTagFunc(c, buf) case "time_unix": return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10)) case "time_unix_milli": diff --git a/middleware/logger_test.go b/middleware/logger_test.go index ab889bfda..a0568611e 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -173,6 +173,28 @@ func TestLoggerCustomTimestamp(t *testing.T) { assert.Error(t, err) } +func TestLoggerCustomTagFunc(t *testing.T) { + e := echo.New() + buf := new(bytes.Buffer) + e.Use(LoggerWithConfig(LoggerConfig{ + Format: `{"method":"${method}",${custom}}` + "\n", + CustomTagFunc: func(c echo.Context, buf *bytes.Buffer) (int, error) { + return buf.WriteString(`"tag":"my-value"`) + }, + Output: buf, + })) + + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "custom time stamp test") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, `{"method":"GET","tag":"my-value"}`+"\n", buf.String()) +} + func BenchmarkLoggerWithConfig_withoutMapFields(b *testing.B) { e := echo.New() diff --git a/middleware/request_logger.go b/middleware/request_logger.go index 7a4d9822e..b9e369255 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -10,10 +10,16 @@ import ( // Example for `fmt.Printf` // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ -// LogStatus: true, -// LogURI: true, +// LogStatus: true, +// LogURI: true, +// LogError: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { -// fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status) +// if v.Error == nil { +// fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status) +// } else { +// fmt.Printf("REQUEST_ERROR: uri: %v, status: %v, err: %v\n", v.URI, v.Status, v.Error) +// } // return nil // }, // })) @@ -21,14 +27,23 @@ import ( // Example for Zerolog (https://github.com/rs/zerolog) // logger := zerolog.New(os.Stdout) // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ -// LogURI: true, -// LogStatus: true, +// LogURI: true, +// LogStatus: true, +// LogError: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { -// logger.Info(). -// Str("URI", v.URI). -// Int("status", v.Status). -// Msg("request") -// +// if v.Error == nil { +// logger.Info(). +// Str("URI", v.URI). +// Int("status", v.Status). +// Msg("request") +// } else { +// logger.Error(). +// Err(v.Error). +// Str("URI", v.URI). +// Int("status", v.Status). +// Msg("request error") +// } // return nil // }, // })) @@ -36,29 +51,47 @@ import ( // Example for Zap (https://github.com/uber-go/zap) // logger, _ := zap.NewProduction() // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ -// LogURI: true, -// LogStatus: true, +// LogURI: true, +// LogStatus: true, +// LogError: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code // LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { -// logger.Info("request", -// zap.String("URI", v.URI), -// zap.Int("status", v.Status), -// ) -// +// if v.Error == nil { +// logger.Info("request", +// zap.String("URI", v.URI), +// zap.Int("status", v.Status), +// ) +// } else { +// logger.Error("request error", +// zap.String("URI", v.URI), +// zap.Int("status", v.Status), +// zap.Error(v.Error), +// ) +// } // return nil // }, // })) // // Example for Logrus (https://github.com/sirupsen/logrus) -// log := logrus.New() +// log := logrus.New() // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ -// LogURI: true, -// LogStatus: true, -// LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error { -// log.WithFields(logrus.Fields{ -// "URI": values.URI, -// "status": values.Status, -// }).Info("request") -// +// LogURI: true, +// LogStatus: true, +// LogError: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code +// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { +// if v.Error == nil { +// log.WithFields(logrus.Fields{ +// "URI": v.URI, +// "status": v.Status, +// }).Info("request") +// } else { +// log.WithFields(logrus.Fields{ +// "URI": v.URI, +// "status": v.Status, +// "error": v.Error, +// }).Error("request error") +// } // return nil // }, // })) @@ -74,6 +107,13 @@ type RequestLoggerConfig struct { // Mandatory. LogValuesFunc func(c echo.Context, v RequestLoggerValues) error + // HandleError instructs logger to call global error handler when next middleware/handler returns an error. + // This is useful when you have custom error handler that can decide to use different status codes. + // + // A side-effect of calling global error handler is that now Response has been committed and sent to the client + // and middlewares up in chain can not change Response status code or response body. + HandleError bool + // LogLatency instructs logger to record duration it took to execute rest of the handler chain (next(c) call). LogLatency bool // LogProtocol instructs logger to extract request protocol (i.e. `HTTP/1.1` or `HTTP/2`) @@ -217,6 +257,9 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { config.BeforeNextFunc(c) } err := next(c) + if config.HandleError { + c.Error(err) + } v := RequestLoggerValues{ StartTime: start, @@ -264,7 +307,9 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { } if config.LogStatus { v.Status = res.Status - if err != nil { + if err != nil && !config.HandleError { + // this block should not be executed in case of HandleError=true as the global error handler will decide + // the status code. In that case status code could be different from what err contains. var httpErr *echo.HTTPError if errors.As(err, &httpErr) { v.Status = httpErr.Code @@ -310,6 +355,9 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { return errOnLog } + // in case of HandleError=true we are returning the error that we already have handled with global error handler + // this is deliberate as this error could be useful for upstream middlewares and default global error handler + // will ignore that error when it bubbles up in middleware chain. return err } }, nil diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go index 5118b1216..51d617abb 100644 --- a/middleware/request_logger_test.go +++ b/middleware/request_logger_test.go @@ -103,12 +103,12 @@ func TestRequestLogger_beforeNextFunc(t *testing.T) { func TestRequestLogger_logError(t *testing.T) { e := echo.New() - var expect RequestLoggerValues + var actual RequestLoggerValues e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ LogError: true, LogStatus: true, LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { - expect = values + actual = values return nil }, })) @@ -123,8 +123,52 @@ func TestRequestLogger_logError(t *testing.T) { e.ServeHTTP(rec, req) assert.Equal(t, http.StatusNotAcceptable, rec.Code) - assert.Equal(t, http.StatusNotAcceptable, expect.Status) - assert.EqualError(t, expect.Error, "code=406, message=nope") + assert.Equal(t, http.StatusNotAcceptable, actual.Status) + assert.EqualError(t, actual.Error, "code=406, message=nope") +} + +func TestRequestLogger_HandleError(t *testing.T) { + e := echo.New() + + var actual RequestLoggerValues + e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + timeNow: func() time.Time { + return time.Unix(1631045377, 0).UTC() + }, + HandleError: true, + LogError: true, + LogStatus: true, + LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + actual = values + return nil + }, + })) + + // to see if "HandleError" works we create custom error handler that uses its own status codes + e.HTTPErrorHandler = func(err error, c echo.Context) { + if c.Response().Committed { + return + } + c.JSON(http.StatusTeapot, "custom error handler") + } + + e.GET("/test", func(c echo.Context) error { + return echo.NewHTTPError(http.StatusForbidden, "nope") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + + expect := RequestLoggerValues{ + StartTime: time.Unix(1631045377, 0).UTC(), + Status: http.StatusTeapot, + Error: echo.NewHTTPError(http.StatusForbidden, "nope"), + } + assert.Equal(t, expect, actual) } func TestRequestLogger_LogValuesFuncError(t *testing.T) { From 135c511f5dd1b6cecbbc632d01f2640075e2c4d4 Mon Sep 17 00:00:00 2001 From: Kanji Yomoda Date: Mon, 5 Dec 2022 03:38:45 +0900 Subject: [PATCH 261/446] Add request route with "route" tag to logger middleware (#2162) --- middleware/logger.go | 3 +++ middleware/logger_test.go | 9 +++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/middleware/logger.go b/middleware/logger.go index 8bf335ffb..7958d873b 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -35,6 +35,7 @@ type ( // - host // - method // - path + // - route // - protocol // - referer // - user_agent @@ -173,6 +174,8 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { p = "/" } return buf.WriteString(p) + case "route": + return buf.WriteString(c.Path()) case "protocol": return buf.WriteString(req.Proto) case "referer": diff --git a/middleware/logger_test.go b/middleware/logger_test.go index a0568611e..9f35a70bc 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -92,17 +92,17 @@ func TestLoggerTemplate(t *testing.T) { e.Use(LoggerWithConfig(LoggerConfig{ Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` + `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + - `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` + + `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "route":"${route}", "referer":"${referer}",` + `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}", "protocol":"${protocol}"` + `"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n", Output: buf, })) - e.GET("/", func(c echo.Context) error { + e.GET("/users/:id", func(c echo.Context) error { return c.String(http.StatusOK, "Header Logged") }) - req := httptest.NewRequest(http.MethodGet, "/?username=apagano-param&password=secret", nil) + req := httptest.NewRequest(http.MethodGet, "/users/1?username=apagano-param&password=secret", nil) req.RequestURI = "/" req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") req.Header.Add("Referer", "google.com") @@ -127,7 +127,8 @@ func TestLoggerTemplate(t *testing.T) { "hexvalue": false, "GET": true, "127.0.0.1": true, - "\"path\":\"/\"": true, + "\"path\":\"/users/1\"": true, + "\"route\":\"/users/:id\"": true, "\"uri\":\"/\"": true, "\"status\":200": true, "\"bytes_in\":0": true, From 40eb889d14001640d8e9c48f22db8388dcf0feb5 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 8 Dec 2022 19:37:15 +0200 Subject: [PATCH 262/446] build: harden echo.yml permissions Signed-off-by: Alex --- .github/workflows/echo.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index e35e7f107..7a2db7a9a 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -21,6 +21,9 @@ on: - 'codecov.yml' workflow_dispatch: +permissions: + contents: read # to fetch code (actions/checkout) + jobs: test: strategy: From bc75cc2b17254ef70f9d0dab263a5fb81ae8520e Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Tue, 13 Dec 2022 10:38:20 +0200 Subject: [PATCH 263/446] Add govulncheck to CI and bump dependencies. Refactor GitHub workflows. --- .github/workflows/checks.yml | 48 ++++++++++++++++++++++++++++++++++++ .github/workflows/echo.yml | 42 +++++++------------------------ go.mod | 6 ++--- go.sum | 10 +++++--- 4 files changed, 67 insertions(+), 39 deletions(-) create mode 100644 .github/workflows/checks.yml diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml new file mode 100644 index 000000000..907b2858a --- /dev/null +++ b/.github/workflows/checks.yml @@ -0,0 +1,48 @@ +name: Run checks + +on: + push: + branches: + - master + pull_request: + branches: + - master + workflow_dispatch: + +permissions: + contents: read # to fetch code (actions/checkout) + +env: + # run static analysis only with the latest Go version + LATEST_GO_VERSION: 1.19 + +jobs: + check: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@v3 + + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v3 + with: + go-version: ${{ env.LATEST_GO_VERSION }} + check-latest: true + + - name: Run golint + run: | + go install golang.org/x/lint/golint@latest + golint -set_exit_status ./... + + - name: Run staticcheck + run: | + go install honnef.co/go/tools/cmd/staticcheck@latest + staticcheck ./... + + - name: Run govulncheck + run: | + go version + go install golang.org/x/vuln/cmd/govulncheck@latest + govulncheck ./... + + diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 7a2db7a9a..e41c80ab7 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -4,26 +4,18 @@ on: push: branches: - master - paths: - - '**.go' - - 'go.*' - - '_fixture/**' - - '.github/**' - - 'codecov.yml' pull_request: branches: - master - paths: - - '**.go' - - 'go.*' - - '_fixture/**' - - '.github/**' - - 'codecov.yml' workflow_dispatch: permissions: contents: read # to fetch code (actions/checkout) +env: + # run coverage and benchmarks only with the latest Go version + LATEST_GO_VERSION: 1.19 + jobs: test: strategy: @@ -39,8 +31,6 @@ jobs: steps: - name: Checkout Code uses: actions/checkout@v3 - with: - ref: ${{ github.ref }} - name: Set up Go ${{ matrix.go }} uses: actions/setup-go@v3 @@ -50,31 +40,17 @@ jobs: - name: Run Tests run: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... - - name: Install dependencies for checks - run: | - go install golang.org/x/lint/golint@latest - go install honnef.co/go/tools/cmd/staticcheck@latest - - - name: Run golint - run: golint -set_exit_status ./... - - - name: Run staticcheck - run: staticcheck ./... - - name: Upload coverage to Codecov - if: success() && matrix.go == 1.19 && matrix.os == 'ubuntu-latest' + if: success() && matrix.go == env.LATEST_GO_VERSION && matrix.os == 'ubuntu-latest' uses: codecov/codecov-action@v3 with: token: fail_ci_if_error: false + benchmark: needs: test - strategy: - matrix: - os: [ubuntu-latest] - go: [1.19] - name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }} - runs-on: ${{ matrix.os }} + name: Benchmark comparison + runs-on: ubuntu-latest steps: - name: Checkout Code (Previous) uses: actions/checkout@v3 @@ -90,7 +66,7 @@ jobs: - name: Set up Go ${{ matrix.go }} uses: actions/setup-go@v3 with: - go-version: ${{ matrix.go }} + go-version: ${{ env.LATEST_GO_VERSION }} - name: Install Dependencies run: go install golang.org/x/perf/cmd/benchstat@latest diff --git a/go.mod b/go.mod index 73fd6d900..3b833310d 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/stretchr/testify v1.8.1 github.com/valyala/fasttemplate v1.2.2 golang.org/x/crypto v0.2.0 - golang.org/x/net v0.2.0 + golang.org/x/net v0.4.0 golang.org/x/time v0.2.0 ) @@ -18,7 +18,7 @@ require ( github.com/mattn/go-isatty v0.0.16 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.2.0 // indirect - golang.org/x/text v0.4.0 // indirect + golang.org/x/sys v0.3.0 // indirect + golang.org/x/text v0.5.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index b052ff9d6..825c35155 100644 --- a/go.sum +++ b/go.sum @@ -35,8 +35,9 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= +golang.org/x/net v0.4.0 h1:Q5QPcMlvfxFTAPV0+07Xz/MpK9NTXu2VDUuy0FeMfaU= +golang.org/x/net v0.4.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -48,16 +49,19 @@ golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= +golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= +golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= +golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.2.0 h1:52I/1L54xyEQAYdtcSuxtiT84KGYTBGXwayxmIpNJhE= golang.org/x/time v0.2.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= From 895121d178019edbce8826cb13abe3be3af64124 Mon Sep 17 00:00:00 2001 From: yagikota Date: Sat, 17 Dec 2022 18:16:00 +0900 Subject: [PATCH 264/446] Fix rate limiter docs (#2366) * Improve wording for the comment of Burst * Improve rate limiter docs --- middleware/rate_limiter.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index be2b348db..f7fae83c6 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -155,7 +155,7 @@ type ( RateLimiterMemoryStore struct { visitors map[string]*Visitor mutex sync.Mutex - rate rate.Limit //for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. + rate rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. burst int expiresIn time.Duration @@ -170,15 +170,16 @@ type ( /* NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with -the provided rate (as req/s). The provided rate less than 1 will be treated as zero. +the provided rate (as req/s). for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. Burst and ExpiresIn will be set to default values. +Note that if the provided rate is a float number and Burst is zero, Burst will be treated as the rounded down value of the rate. + Example (with 20 requests/sec): limiterStore := middleware.NewRateLimiterMemoryStore(20) - */ func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) { return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ @@ -188,7 +189,7 @@ func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) /* NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore -with the provided configuration. Rate must be provided. Burst will be set to the value of +with the provided configuration. Rate must be provided. Burst will be set to the rounded down value of the configured rate if not provided or set to 0. The build-in memory store is usually capable for modest loads. For higher loads other @@ -225,7 +226,7 @@ func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (s // RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore type RateLimiterMemoryStoreConfig struct { Rate rate.Limit // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. - Burst int // Burst additionally allows a number of requests to pass when rate limit is reached + Burst int // Burst is maximum number of requests to pass at the same moment. It additionally allows a number of requests to pass when rate limit is reached. ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up } From f1cf1ec930e388333798c133629afa18dad00241 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 12 Nov 2022 18:35:19 +0200 Subject: [PATCH 265/446] Fix adding route with host overwrites default host route with same method+path in list of routes. --- echo.go | 44 +++------------- echo_test.go | 88 +++++++++++++++++++++++++------- router.go | 46 +++++++++++++++++ router_test.go | 136 +++++++++++++++++++++++++++++++++++++++---------- 4 files changed, 232 insertions(+), 82 deletions(-) diff --git a/echo.go b/echo.go index 2f54d7711..fc2f556ef 100644 --- a/echo.go +++ b/echo.go @@ -37,7 +37,6 @@ Learn more at https://echo.labstack.com package echo import ( - "bytes" stdContext "context" "crypto/tls" "errors" @@ -528,20 +527,13 @@ func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route { } func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - name := handlerName(handler) router := e.findRouter(host) - // FIXME: when handler+middleware are both nil ... make it behave like handler removal - router.Add(method, path, func(c Context) error { + //FIXME: when handler+middleware are both nil ... make it behave like handler removal + name := handlerName(handler) + return router.add(method, path, name, func(c Context) error { h := applyMiddleware(handler, middleware...) return h(c) }) - r := &Route{ - Method: method, - Path: path, - Name: name, - } - e.router.routes[method+path] = r - return r } // Add registers a new route for an HTTP method and path with matching handler @@ -578,35 +570,13 @@ func (e *Echo) URL(h HandlerFunc, params ...interface{}) string { // Reverse generates an URL from route name and provided parameters. func (e *Echo) Reverse(name string, params ...interface{}) string { - uri := new(bytes.Buffer) - ln := len(params) - n := 0 - for _, r := range e.router.routes { - if r.Name == name { - for i, l := 0, len(r.Path); i < l; i++ { - if (r.Path[i] == ':' || r.Path[i] == '*') && n < ln { - for ; i < l && r.Path[i] != '/'; i++ { - } - uri.WriteString(fmt.Sprintf("%v", params[n])) - n++ - } - if i < l { - uri.WriteByte(r.Path[i]) - } - } - break - } - } - return uri.String() + return e.router.Reverse(name, params...) } -// Routes returns the registered routes. +// Routes returns the registered routes for default router. +// In case when Echo serves multiple hosts/domains use `e.Routers()["domain2.site"].Routes()` to get specific host routes. func (e *Echo) Routes() []*Route { - routes := make([]*Route, 0, len(e.router.routes)) - for _, v := range e.router.routes { - routes = append(routes, v) - } - return routes + return e.router.Routes() } // AcquireContext returns an empty `Context` instance from the pool. diff --git a/echo_test.go b/echo_test.go index b0d1ccd28..250396928 100644 --- a/echo_test.go +++ b/echo_test.go @@ -530,9 +530,9 @@ func TestEchoRoutes(t *testing.T) { } } -func TestEchoRoutesHandleHostsProperly(t *testing.T) { +func TestEchoRoutesHandleAdditionalHosts(t *testing.T) { e := New() - h := e.Host("route.com") + domain2Router := e.Host("domain2.router.com") routes := []*Route{ {http.MethodGet, "/users/:user/events", ""}, {http.MethodGet, "/users/:user/events/public", ""}, @@ -540,24 +540,61 @@ func TestEchoRoutesHandleHostsProperly(t *testing.T) { {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, } for _, r := range routes { - h.Add(r.Method, r.Path, func(c Context) error { + domain2Router.Add(r.Method, r.Path, func(c Context) error { return c.String(http.StatusOK, "OK") }) } + e.Add(http.MethodGet, "/api", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) - if assert.Equal(t, len(routes), len(e.Routes())) { - for _, r := range e.Routes() { - found := false - for _, rr := range routes { - if r.Method == rr.Method && r.Path == rr.Path { - found = true - break - } + domain2Routes := e.Routers()["domain2.router.com"].Routes() + + assert.Len(t, domain2Routes, len(routes)) + for _, r := range domain2Routes { + found := false + for _, rr := range routes { + if r.Method == rr.Method && r.Path == rr.Path { + found = true + break } - if !found { - t.Errorf("Route %s %s not found", r.Method, r.Path) + } + if !found { + t.Errorf("Route %s %s not found", r.Method, r.Path) + } + } +} + +func TestEchoRoutesHandleDefaultHost(t *testing.T) { + e := New() + routes := []*Route{ + {http.MethodGet, "/users/:user/events", ""}, + {http.MethodGet, "/users/:user/events/public", ""}, + {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, + {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, + } + for _, r := range routes { + e.Add(r.Method, r.Path, func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + } + e.Host("subdomain.mysite.site").Add(http.MethodGet, "/api", func(c Context) error { + return c.String(http.StatusOK, "OK") + }) + + defaultRouterRoutes := e.Routes() + assert.Len(t, defaultRouterRoutes, len(routes)) + for _, r := range defaultRouterRoutes { + found := false + for _, rr := range routes { + if r.Method == rr.Method && r.Path == rr.Path { + found = true + break } } + if !found { + t.Errorf("Route %s %s not found", r.Method, r.Path) + } } } @@ -1468,14 +1505,27 @@ func TestEchoReverseHandleHostProperly(t *testing.T) { dummyHandler := func(Context) error { return nil } e := New() + + // routes added to the default router are different form different hosts + e.GET("/static", dummyHandler).Name = "default-host /static" + e.GET("/static/*", dummyHandler).Name = "xxx" + + // different host h := e.Host("the_host") - h.GET("/static", dummyHandler).Name = "/static" - h.GET("/static/*", dummyHandler).Name = "/static/*" + h.GET("/static", dummyHandler).Name = "host2 /static" + h.GET("/static/v2/*", dummyHandler).Name = "xxx" + + assert.Equal(t, "/static", e.Reverse("default-host /static")) + // when actual route does not have params and we provide some to Reverse we should get that route url back + assert.Equal(t, "/static", e.Reverse("default-host /static", "missing param")) + + host2Router := e.Routers()["the_host"] + assert.Equal(t, "/static", host2Router.Reverse("host2 /static")) + assert.Equal(t, "/static", host2Router.Reverse("host2 /static", "missing param")) + + assert.Equal(t, "/static/v2/*", host2Router.Reverse("xxx")) + assert.Equal(t, "/static/v2/foo.txt", host2Router.Reverse("xxx", "foo.txt")) - assert.Equal(t, "/static", e.Reverse("/static")) - assert.Equal(t, "/static", e.Reverse("/static", "missing param")) - assert.Equal(t, "/static/*", e.Reverse("/static/*")) - assert.Equal(t, "/static/foo.txt", e.Reverse("/static/*", "foo.txt")) } func TestEcho_ListenerAddr(t *testing.T) { diff --git a/router.go b/router.go index 23c5bd3ba..86a986a29 100644 --- a/router.go +++ b/router.go @@ -2,6 +2,7 @@ package echo import ( "bytes" + "fmt" "net/http" ) @@ -141,6 +142,51 @@ func NewRouter(e *Echo) *Router { } } +// Routes returns the registered routes. +func (r *Router) Routes() []*Route { + routes := make([]*Route, 0, len(r.routes)) + for _, v := range r.routes { + routes = append(routes, v) + } + return routes +} + +// Reverse generates an URL from route name and provided parameters. +func (r *Router) Reverse(name string, params ...interface{}) string { + uri := new(bytes.Buffer) + ln := len(params) + n := 0 + for _, route := range r.routes { + if route.Name == name { + for i, l := 0, len(route.Path); i < l; i++ { + if (route.Path[i] == ':' || route.Path[i] == '*') && n < ln { + for ; i < l && route.Path[i] != '/'; i++ { + } + uri.WriteString(fmt.Sprintf("%v", params[n])) + n++ + } + if i < l { + uri.WriteByte(route.Path[i]) + } + } + break + } + } + return uri.String() +} + +func (r *Router) add(method, path, name string, h HandlerFunc) *Route { + r.Add(method, path, h) + + route := &Route{ + Method: method, + Path: path, + Name: name, + } + r.routes[method+path] = route + return route +} + // Add registers a new route for method and path with matching handler. func (r *Router) Add(method, path string, h HandlerFunc) { // Validate path diff --git a/router_test.go b/router_test.go index a95421011..825170a3f 100644 --- a/router_test.go +++ b/router_test.go @@ -914,19 +914,22 @@ func TestRouterParamWithSlash(t *testing.T) { // Searching route for "/a/c/f" should match "/a/*/f" // When route `4) /a/*/f` is not added then request for "/a/c/f" should match "/:e/c/f" // -// +----------+ -// +-----+ "/" root +--------------------+--------------------------+ -// | +----------+ | | -// | | | -// +-------v-------+ +---v---------+ +-------v---+ -// | "a/" (static) +---------------+ | ":" (param) | | "*" (any) | -// +-+----------+--+ | +-----------+-+ +-----------+ -// | | | | +// +----------+ +// +-----+ "/" root +--------------------+--------------------------+ +// | +----------+ | | +// | | | +// +-------v-------+ +---v---------+ +-------v---+ +// | "a/" (static) +---------------+ | ":" (param) | | "*" (any) | +// +-+----------+--+ | +-----------+-+ +-----------+ +// | | | | +// // +---------------v+ +-- ---v------+ +------v----+ +-----v-----------+ // | "c/d" (static) | | ":" (param) | | "*" (any) | | "/c/f" (static) | // +---------+------+ +--------+----+ +----------++ +-----------------+ -// | | | -// | | | +// +// | | | +// | | | +// // +---------v----+ +------v--------+ +------v--------+ // | "f" (static) | | "/c" (static) | | "/f" (static) | // +--------------+ +---------------+ +---------------+ @@ -998,22 +1001,22 @@ func TestRouteMultiLevelBacktracking(t *testing.T) { // // Request for "/a/c/f" should match "/:e/c/f" // -// +-0,7--------+ -// | "/" (root) |----------------------------------+ -// +------------+ | -// | | | -// | | | -// +-1,6-----------+ | | +-8-----------+ +------v----+ -// | "a/" (static) +<--+ +--------->+ ":" (param) | | "*" (any) | -// +---------------+ +-------------+ +-----------+ -// | | | -// +-2--------v-----+ +v-3,5--------+ +-9------v--------+ -// | "c/d" (static) | | ":" (param) | | "/c/f" (static) | -// +----------------+ +-------------+ +-----------------+ -// | -// +-4--v----------+ -// | "/c" (static) | -// +---------------+ +// +-0,7--------+ +// | "/" (root) |----------------------------------+ +// +------------+ | +// | | | +// | | | +// +-1,6-----------+ | | +-8-----------+ +------v----+ +// | "a/" (static) +<--+ +--------->+ ":" (param) | | "*" (any) | +// +---------------+ +-------------+ +-----------+ +// | | | +// +-2--------v-----+ +v-3,5--------+ +-9------v--------+ +// | "c/d" (static) | | ":" (param) | | "/c/f" (static) | +// +----------------+ +-------------+ +-----------------+ +// | +// +-4--v----------+ +// | "/c" (static) | +// +---------------+ func TestRouteMultiLevelBacktracking2(t *testing.T) { e := New() r := e.router @@ -2695,6 +2698,87 @@ func TestRouterHandleMethodOptions(t *testing.T) { } } +func TestRouter_Routes(t *testing.T) { + type rr struct { + method string + path string + name string + } + var testCases = []struct { + name string + givenRoutes []rr + expect []rr + }{ + { + name: "ok, multiple", + givenRoutes: []rr{ + {method: http.MethodGet, path: "/static", name: "/static"}, + {method: http.MethodGet, path: "/static/*", name: "/static/*"}, + }, + expect: []rr{ + {method: http.MethodGet, path: "/static", name: "/static"}, + {method: http.MethodGet, path: "/static/*", name: "/static/*"}, + }, + }, + { + name: "ok, no routes", + givenRoutes: []rr{}, + expect: []rr{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dummyHandler := func(Context) error { return nil } + + e := New() + route := e.router + + for _, tmp := range tc.givenRoutes { + route.add(tmp.method, tmp.path, tmp.name, dummyHandler) + } + + // Add does not add route. because of backwards compatibility we can not change this method signature + route.Add("LOCK", "/users", handlerFunc) + + result := route.Routes() + assert.Len(t, result, len(tc.expect)) + for _, r := range result { + for _, tmp := range tc.expect { + if tmp.name == r.Name { + assert.Equal(t, tmp.method, r.Method) + assert.Equal(t, tmp.path, r.Path) + } + } + } + }) + } +} + +func TestRouter_Reverse(t *testing.T) { + e := New() + r := e.router + dummyHandler := func(Context) error { return nil } + + r.add(http.MethodGet, "/static", "/static", dummyHandler) + r.add(http.MethodGet, "/static/*", "/static/*", dummyHandler) + r.add(http.MethodGet, "/params/:foo", "/params/:foo", dummyHandler) + r.add(http.MethodGet, "/params/:foo/bar/:qux", "/params/:foo/bar/:qux", dummyHandler) + r.add(http.MethodGet, "/params/:foo/bar/:qux/*", "/params/:foo/bar/:qux/*", dummyHandler) + + assert.Equal(t, "/static", r.Reverse("/static")) + assert.Equal(t, "/static", r.Reverse("/static", "missing param")) + assert.Equal(t, "/static/*", r.Reverse("/static/*")) + assert.Equal(t, "/static/foo.txt", r.Reverse("/static/*", "foo.txt")) + + assert.Equal(t, "/params/:foo", r.Reverse("/params/:foo")) + assert.Equal(t, "/params/one", r.Reverse("/params/:foo", "one")) + assert.Equal(t, "/params/:foo/bar/:qux", r.Reverse("/params/:foo/bar/:qux")) + assert.Equal(t, "/params/one/bar/:qux", r.Reverse("/params/:foo/bar/:qux", "one")) + assert.Equal(t, "/params/one/bar/two", r.Reverse("/params/:foo/bar/:qux", "one", "two")) + assert.Equal(t, "/params/one/bar/two/three", r.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) +} + func TestRouterAllowHeaderForAnyOtherMethodType(t *testing.T) { e := New() r := e.router From 45402bb393fa37386fcc5f9127eaff9dc18f67c6 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 12 Nov 2022 19:38:02 +0200 Subject: [PATCH 266/446] Add echo.OnAddRouteHandler field. As name says - this handler is called when new route is registered. --- echo.go | 39 ++++++++++++++++++++++++++++----------- echo_test.go | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 11 deletions(-) diff --git a/echo.go b/echo.go index fc2f556ef..e3e1c0370 100644 --- a/echo.go +++ b/echo.go @@ -61,20 +61,28 @@ import ( type ( // Echo is the top-level framework instance. + // + // Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these + // fields from handlers/middlewares and changing field values at the same time leads to data-races. + // Same rule applies to adding new routes after server has been started - Adding a route is not Goroutine safe action. Echo struct { filesystem common // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get // listener address info (on which interface/port was listener binded) without having data races. - startupMutex sync.RWMutex + startupMutex sync.RWMutex + colorer *color.Color + + // premiddleware are middlewares that are run before routing is done. In case pre-middleware returns an error router + // will not be called at all and execution ends up in global error handler. + premiddleware []MiddlewareFunc + middleware []MiddlewareFunc + maxParam *int + router *Router + routers map[string]*Router + pool sync.Pool + StdLogger *stdLog.Logger - colorer *color.Color - premiddleware []MiddlewareFunc - middleware []MiddlewareFunc - maxParam *int - router *Router - routers map[string]*Router - pool sync.Pool Server *http.Server TLSServer *http.Server Listener net.Listener @@ -92,6 +100,9 @@ type ( Logger Logger IPExtractor IPExtractor ListenerNetwork string + + // OnAddRouteHandler is called when Echo adds new route to specific host router. + OnAddRouteHandler func(host string, route Route, handler HandlerFunc, middleware []MiddlewareFunc) } // Route contains a handler and information for matching against requests. @@ -526,14 +537,20 @@ func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route { return e.file(path, file, e.GET, m...) } -func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { +func (e *Echo) add(host, method, path string, handler HandlerFunc, middlewares ...MiddlewareFunc) *Route { router := e.findRouter(host) //FIXME: when handler+middleware are both nil ... make it behave like handler removal name := handlerName(handler) - return router.add(method, path, name, func(c Context) error { - h := applyMiddleware(handler, middleware...) + route := router.add(method, path, name, func(c Context) error { + h := applyMiddleware(handler, middlewares...) return h(c) }) + + if e.OnAddRouteHandler != nil { + e.OnAddRouteHandler(host, *route, handler, middlewares) + } + + return route } // Add registers a new route for an HTTP method and path with matching handler diff --git a/echo_test.go b/echo_test.go index 250396928..2f66c8c6c 100644 --- a/echo_test.go +++ b/echo_test.go @@ -1478,6 +1478,44 @@ func TestEchoListenerNetworkInvalid(t *testing.T) { assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323")) } +func TestEcho_OnAddRouteHandler(t *testing.T) { + type rr struct { + host string + route Route + handler HandlerFunc + middleware []MiddlewareFunc + } + dummyHandler := func(Context) error { return nil } + e := New() + + added := make([]rr, 0) + e.OnAddRouteHandler = func(host string, route Route, handler HandlerFunc, middleware []MiddlewareFunc) { + added = append(added, rr{ + host: host, + route: route, + handler: handler, + middleware: middleware, + }) + } + + e.GET("/static", NotFoundHandler) + e.Host("domain.site").GET("/static/*", dummyHandler, func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + return next(c) + } + }) + + assert.Len(t, added, 2) + + assert.Equal(t, "", added[0].host) + assert.Equal(t, Route{Method: http.MethodGet, Path: "/static", Name: "github.com/labstack/echo/v4.glob..func1"}, added[0].route) + assert.Len(t, added[0].middleware, 0) + + assert.Equal(t, "domain.site", added[1].host) + assert.Equal(t, Route{Method: http.MethodGet, Path: "/static/*", Name: "github.com/labstack/echo/v4.TestEcho_OnAddRouteHandler.func1"}, added[1].route) + assert.Len(t, added[1].middleware, 1) +} + func TestEchoReverse(t *testing.T) { e := New() dummyHandler := func(Context) error { return nil } From 0056cc8ec00f393d68fd742ea8c1a3c3053d80db Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 3 Dec 2022 19:34:21 +0200 Subject: [PATCH 267/446] Improve comments wording --- echo.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/echo.go b/echo.go index e3e1c0370..e3e7b2fe0 100644 --- a/echo.go +++ b/echo.go @@ -64,7 +64,7 @@ type ( // // Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these // fields from handlers/middlewares and changing field values at the same time leads to data-races. - // Same rule applies to adding new routes after server has been started - Adding a route is not Goroutine safe action. + // Adding new routes after the server has been started is also not safe! Echo struct { filesystem common @@ -73,8 +73,8 @@ type ( startupMutex sync.RWMutex colorer *color.Color - // premiddleware are middlewares that are run before routing is done. In case pre-middleware returns an error router - // will not be called at all and execution ends up in global error handler. + // premiddleware are middlewares that are run before routing is done. In case a pre-middleware returns + // an error the router is not executed and the request will end up in the global error handler. premiddleware []MiddlewareFunc middleware []MiddlewareFunc maxParam *int From a69727e2b95ed346dbdcaf2f194aedab41baa779 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Tue, 27 Dec 2022 21:08:37 +0200 Subject: [PATCH 268/446] Mark JWT middleware deprecated --- middleware/jwt.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/middleware/jwt.go b/middleware/jwt.go index ef6ad4ebc..bd628264e 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -154,6 +154,8 @@ var ( // // See: https://jwt.io/introduction // See `JWTConfig.TokenLookup` +// +// Deprecated: Please use https://github.com/labstack/echo-jwt instead func JWT(key interface{}) echo.MiddlewareFunc { c := DefaultJWTConfig c.SigningKey = key @@ -162,6 +164,8 @@ func JWT(key interface{}) echo.MiddlewareFunc { // JWTWithConfig returns a JWT auth middleware with config. // See: `JWT()`. +// +// Deprecated: Please use https://github.com/labstack/echo-jwt instead func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { // Defaults if config.Skipper == nil { From f36d5662fbb1850f03c9ac78f02a699a492ecc2d Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Tue, 27 Dec 2022 21:09:32 +0200 Subject: [PATCH 269/446] Changelog for 4.10.0 --- CHANGELOG.md | 33 +++++++++++++++++++++++++++++---- echo.go | 2 +- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c629350c0..c1c3c1074 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,15 +1,40 @@ # Changelog -## v4.10.0 - 2022-xx-xx +## v4.10.0 - 2022-12-27 **Security** -This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are -several vulnerabilities fixed in these libraries. +* We are deprecating JWT middleware in this repository. Please use https://github.com/labstack/echo-jwt instead. -Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise. + JWT middleware is moved to separate repository to allow us to bump/upgrade version of JWT implementation (`github.com/golang-jwt/jwt`) we are using +which we can not do in Echo core because this would break backwards compatibility guarantees we try to maintain. +* This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are + several vulnerabilities fixed in these libraries. + Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise. + + +**Enhancements** + +* Bump x/text to 0.3.8 [#2305](https://github.com/labstack/echo/pull/2305) +* Bump dependencies and add notes about Go releases we support [#2336](https://github.com/labstack/echo/pull/2336) +* Add helper interface for ProxyBalancer interface [#2316](https://github.com/labstack/echo/pull/2316) +* Expose `middleware.CreateExtractors` function so we can use it from echo-contrib repository [#2338](https://github.com/labstack/echo/pull/2338) +* Refactor func(Context) error to HandlerFunc [#2315](https://github.com/labstack/echo/pull/2315) +* Improve function comments [#2329](https://github.com/labstack/echo/pull/2329) +* Add new method HTTPError.WithInternal [#2340](https://github.com/labstack/echo/pull/2340) +* Replace io/ioutil package usages [#2342](https://github.com/labstack/echo/pull/2342) +* Add staticcheck to CI flow [#2343](https://github.com/labstack/echo/pull/2343) +* Replace relative path determination from proprietary to std [#2345](https://github.com/labstack/echo/pull/2345) +* Remove square brackets from ipv6 addresses in XFF (X-Forwarded-For header) [#2182](https://github.com/labstack/echo/pull/2182) +* Add testcases for some BodyLimit middleware configuration options [#2350](https://github.com/labstack/echo/pull/2350) +* Additional configuration options for RequestLogger and Logger middleware [#2341](https://github.com/labstack/echo/pull/2341) +* Add route to request log [#2162](https://github.com/labstack/echo/pull/2162) +* GitHub Workflows security hardening [#2358](https://github.com/labstack/echo/pull/2358) +* Add govulncheck to CI and bump dependencies [#2362](https://github.com/labstack/echo/pull/2362) +* Fix rate limiter docs [#2366](https://github.com/labstack/echo/pull/2366) +* Refactor how `e.Routes()` work and introduce `e.OnAddRouteHandler` callback [#2337](https://github.com/labstack/echo/pull/2337) ## v4.9.1 - 2022-10-12 diff --git a/echo.go b/echo.go index e3e7b2fe0..f6d89b966 100644 --- a/echo.go +++ b/echo.go @@ -258,7 +258,7 @@ const ( const ( // Version of Echo - Version = "4.9.0" + Version = "4.10.0" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 24a30611dfc07e427dc771a16ef9bb0dd94c4c2e Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Mon, 2 Jan 2023 21:39:15 +0200 Subject: [PATCH 270/446] Add new JWT repository to the README --- README.md | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 509b97351..a388adc48 100644 --- a/README.md +++ b/README.md @@ -11,12 +11,12 @@ ## Supported Go versions -Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions. +Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with +older versions. As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules). Therefore a Go version capable of understanding /vN suffixed imports is required: - Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended way of using Echo going forward. @@ -90,18 +90,29 @@ func hello(c echo.Context) error { } ``` -# Third-party middlewares - -| Repository | Description | -|------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | (by Echo team) [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | -| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | -| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | -| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | -| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | -| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | -| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. | -| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code | +# Official middleware repositories + +Following list of middleware is maintained by Echo team. + +| Repository | Description | +|------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware | +| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | + +# Third-party middleware repositories + +Be careful when adding 3rd party middleware. Echo teams does not have time or manpower to guarantee safety and quality +of middlewares in this list. + +| Repository | Description | +|------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [deepmap/oapi-codegen](https://github.com/deepmap/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | +| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | +| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | +| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | +| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | +| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. | +| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code | Please send a PR to add your own library here. From 08093a4a1dbdcc90c2f6659ef02300d6eccef7f1 Mon Sep 17 00:00:00 2001 From: Brie Taylor Date: Fri, 27 Jan 2023 12:58:54 -0800 Subject: [PATCH 271/446] Return an empty string for ctx.path if there is no registered path --- router.go | 1 - router_test.go | 12 ++++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/router.go b/router.go index 86a986a29..597660d39 100644 --- a/router.go +++ b/router.go @@ -524,7 +524,6 @@ func optionsMethodHandler(allowMethods string) func(c Context) error { // - Return it `Echo#ReleaseContext()`. func (r *Router) Find(method, path string, c Context) { ctx := c.(*context) - ctx.path = path currentNode := r.tree // Current node as root var ( diff --git a/router_test.go b/router_test.go index 825170a3f..619cce092 100644 --- a/router_test.go +++ b/router_test.go @@ -674,6 +674,18 @@ func TestRouterStatic(t *testing.T) { assert.Equal(t, path, c.Get("path")) } +func TestRouterNoRoutablePath(t *testing.T) { + e := New() + r := e.router + c := e.NewContext(nil, nil).(*context) + + r.Find(http.MethodGet, "/notfound", c) + c.handler(c) + + // No routable path, don't set Path. + assert.Equal(t, "", c.Path()) +} + func TestRouterParam(t *testing.T) { e := New() r := e.router From 82a964c657e26b68998393d3e7291f1a474447f8 Mon Sep 17 00:00:00 2001 From: Hakan Kutluay <77051856+hakankutluay@users.noreply.github.com> Date: Wed, 1 Feb 2023 23:38:20 +0300 Subject: [PATCH 272/446] Add context timeout middleware (#2380) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add context timeout middleware Co-authored-by: Erhan Akpınar Co-authored-by: @erhanakp --- middleware/context_timeout.go | 72 +++++++++ middleware/context_timeout_test.go | 226 +++++++++++++++++++++++++++++ 2 files changed, 298 insertions(+) create mode 100644 middleware/context_timeout.go create mode 100644 middleware/context_timeout_test.go diff --git a/middleware/context_timeout.go b/middleware/context_timeout.go new file mode 100644 index 000000000..be260e188 --- /dev/null +++ b/middleware/context_timeout.go @@ -0,0 +1,72 @@ +package middleware + +import ( + "context" + "errors" + "time" + + "github.com/labstack/echo/v4" +) + +// ContextTimeoutConfig defines the config for ContextTimeout middleware. +type ContextTimeoutConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // ErrorHandler is a function when error aries in middeware execution. + ErrorHandler func(err error, c echo.Context) error + + // Timeout configures a timeout for the middleware, defaults to 0 for no timeout + Timeout time.Duration +} + +// ContextTimeout returns a middleware which returns error (503 Service Unavailable error) to client +// when underlying method returns context.DeadlineExceeded error. +func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc { + return ContextTimeoutWithConfig(ContextTimeoutConfig{Timeout: timeout}) +} + +// ContextTimeoutWithConfig returns a Timeout middleware with config. +func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc { + mw, err := config.ToMiddleware() + if err != nil { + panic(err) + } + return mw +} + +// ToMiddleware converts Config to middleware. +func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) { + if config.Timeout == 0 { + return nil, errors.New("timeout must be set") + } + if config.Skipper == nil { + config.Skipper = DefaultSkipper + } + if config.ErrorHandler == nil { + config.ErrorHandler = func(err error, c echo.Context) error { + if err != nil && errors.Is(err, context.DeadlineExceeded) { + return echo.ErrServiceUnavailable.WithInternal(err) + } + return err + } + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + + timeoutContext, cancel := context.WithTimeout(c.Request().Context(), config.Timeout) + defer cancel() + + c.SetRequest(c.Request().WithContext(timeoutContext)) + + if err := next(c); err != nil { + return config.ErrorHandler(err, c) + } + return nil + } + }, nil +} diff --git a/middleware/context_timeout_test.go b/middleware/context_timeout_test.go new file mode 100644 index 000000000..605ca8e65 --- /dev/null +++ b/middleware/context_timeout_test.go @@ -0,0 +1,226 @@ +package middleware + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestContextTimeoutSkipper(t *testing.T) { + t.Parallel() + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + Skipper: func(context echo.Context) bool { + return true + }, + Timeout: 10 * time.Millisecond, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil { + return err + } + + return errors.New("response from handler") + })(c) + + // if not skipped we would have not returned error due context timeout logic + assert.EqualError(t, err, "response from handler") +} + +func TestContextTimeoutWithTimeout0(t *testing.T) { + t.Parallel() + assert.Panics(t, func() { + ContextTimeout(time.Duration(0)) + }) +} + +func TestContextTimeoutErrorOutInHandler(t *testing.T) { + t.Parallel() + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + // Timeout has to be defined or the whole flow for timeout middleware will be skipped + Timeout: 10 * time.Millisecond, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + rec.Code = 1 // we want to be sure that even 200 will not be sent + err := m(func(c echo.Context) error { + // this error must not be written to the client response. Middlewares upstream of timeout middleware must be able + // to handle returned error and this can be done only then handler has not yet committed (written status code) + // the response. + return echo.NewHTTPError(http.StatusTeapot, "err") + })(c) + + assert.Error(t, err) + assert.EqualError(t, err, "code=418, message=err") + assert.Equal(t, 1, rec.Code) + assert.Equal(t, "", rec.Body.String()) +} + +func TestContextTimeoutSuccessfulRequest(t *testing.T) { + t.Parallel() + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + // Timeout has to be defined or the whole flow for timeout middleware will be skipped + Timeout: 10 * time.Millisecond, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + return c.JSON(http.StatusCreated, map[string]string{"data": "ok"}) + })(c) + + assert.NoError(t, err) + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Equal(t, "{\"data\":\"ok\"}\n", rec.Body.String()) +} + +func TestContextTimeoutTestRequestClone(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode())) + req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"}) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + // Timeout has to be defined or the whole flow for timeout middleware will be skipped + Timeout: 1 * time.Second, + }) + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + // Cookie test + cookie, err := c.Request().Cookie("cookie") + if assert.NoError(t, err) { + assert.EqualValues(t, "cookie", cookie.Name) + assert.EqualValues(t, "value", cookie.Value) + } + + // Form values + if assert.NoError(t, c.Request().ParseForm()) { + assert.EqualValues(t, "value", c.Request().FormValue("form")) + } + + // Query string + assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0]) + return nil + })(c) + + assert.NoError(t, err) +} + +func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) { + t.Parallel() + + timeout := 10 * time.Millisecond + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + Timeout: timeout, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil { + return err + } + return c.String(http.StatusOK, "Hello, World!") + })(c) + + assert.IsType(t, &echo.HTTPError{}, err) + assert.Error(t, err) + assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code) + assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message) +} + +func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) { + t.Parallel() + + timeoutErrorHandler := func(err error, c echo.Context) error { + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return &echo.HTTPError{ + Code: http.StatusServiceUnavailable, + Message: "Timeout! change me", + } + } + return err + } + return nil + } + + timeout := 10 * time.Millisecond + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + Timeout: timeout, + ErrorHandler: timeoutErrorHandler, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c echo.Context) error { + // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) + // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output + // difference over 500microseconds (0.5millisecond) response seems to be reliable + + if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil { + return err + } + + // The Request Context should have a Deadline set by http.ContextTimeoutHandler + if _, ok := c.Request().Context().Deadline(); !ok { + assert.Fail(t, "No timeout set on Request Context") + } + return c.String(http.StatusOK, "Hello, World!") + })(c) + + assert.IsType(t, &echo.HTTPError{}, err) + assert.Error(t, err) + assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code) + assert.Equal(t, "Timeout! change me", err.(*echo.HTTPError).Message) +} + +func sleepWithContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + + defer func() { + _ = timer.Stop() + }() + + select { + case <-ctx.Done(): + return context.DeadlineExceeded + case <-timer.C: + return nil + } +} From 6b09f3ffeb5085bf23a3e0749155752f574c331b Mon Sep 17 00:00:00 2001 From: Roman Garanin Date: Tue, 7 Feb 2023 21:59:38 +0100 Subject: [PATCH 273/446] Update link to jaegertracing Added https:// prefix, without it github markdown rendering does strange things --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a388adc48..fe78b6ed1 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ Following list of middleware is maintained by Echo team. | Repository | Description | |------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware | -| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | +| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](https://github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | # Third-party middleware repositories From 45da0f888b8d642125b860af1c996a71f3f50bec Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 19 Feb 2023 10:14:05 +0200 Subject: [PATCH 274/446] remove .travis.yml --- .travis.yml | 21 --------------------- 1 file changed, 21 deletions(-) delete mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 67d45ad78..000000000 --- a/.travis.yml +++ /dev/null @@ -1,21 +0,0 @@ -arch: - - amd64 - - ppc64le - -language: go -go: - - 1.14.x - - 1.15.x - - tip -env: - - GO111MODULE=on -install: - - go get -v golang.org/x/lint/golint -script: - - golint -set_exit_status ./... - - go test -race -coverprofile=coverage.txt -covermode=atomic ./... -after_success: - - bash <(curl -s https://codecov.io/bash) -matrix: - allow_failures: - - go: tip From a3998ac96ad155e132e08bdae67f26a379f99385 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 19 Feb 2023 10:16:07 +0200 Subject: [PATCH 275/446] Upgrade deps due to the latest golang.org/x/net vulnerability --- .github/workflows/checks.yml | 2 +- .github/workflows/echo.yml | 4 ++-- go.mod | 12 ++++++------ go.sum | 30 ++++++++++++++---------------- 4 files changed, 23 insertions(+), 25 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 907b2858a..d2d3386c4 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -14,7 +14,7 @@ permissions: env: # run static analysis only with the latest Go version - LATEST_GO_VERSION: 1.19 + LATEST_GO_VERSION: "1.20" jobs: check: diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index e41c80ab7..e06183d5e 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -14,7 +14,7 @@ permissions: env: # run coverage and benchmarks only with the latest Go version - LATEST_GO_VERSION: 1.19 + LATEST_GO_VERSION: "1.20" jobs: test: @@ -25,7 +25,7 @@ jobs: # Echo tests with last four major releases (unless there are pressing vulnerabilities) # As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when # we derive from last four major releases promise. - go: [1.17, 1.18, 1.19] + go: ["1.18", "1.19", "1.20"] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: diff --git a/go.mod b/go.mod index 3b833310d..265b0aafc 100644 --- a/go.mod +++ b/go.mod @@ -7,18 +7,18 @@ require ( github.com/labstack/gommon v0.4.0 github.com/stretchr/testify v1.8.1 github.com/valyala/fasttemplate v1.2.2 - golang.org/x/crypto v0.2.0 - golang.org/x/net v0.4.0 - golang.org/x/time v0.2.0 + golang.org/x/crypto v0.6.0 + golang.org/x/net v0.7.0 + golang.org/x/time v0.3.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.16 // indirect + github.com/mattn/go-isatty v0.0.17 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.3.0 // indirect - golang.org/x/text v0.5.0 // indirect + golang.org/x/sys v0.5.0 // indirect + golang.org/x/text v0.7.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 825c35155..79ff318c5 100644 --- a/go.sum +++ b/go.sum @@ -9,8 +9,9 @@ github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= -github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= +github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -29,15 +30,15 @@ github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.2.0 h1:BRXPfhNivWL5Yq0BGQ39a2sW6t44aODpfxkWjYdzewE= -golang.org/x/crypto v0.2.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= +golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= +golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= -golang.org/x/net v0.4.0 h1:Q5QPcMlvfxFTAPV0+07Xz/MpK9NTXu2VDUuy0FeMfaU= -golang.org/x/net v0.4.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -49,21 +50,18 @@ golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= -golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= -golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= -golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/time v0.2.0 h1:52I/1L54xyEQAYdtcSuxtiT84KGYTBGXwayxmIpNJhE= -golang.org/x/time v0.2.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= From 2c25767e45bdcb881645ebb7f962c4f3c2adc20c Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 19 Feb 2023 10:38:34 +0200 Subject: [PATCH 276/446] remediate flaky timeout tests --- middleware/timeout_test.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index 6f60753c6..98d96baef 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -375,7 +375,7 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { // NOTE: timeout middleware is first as it changes Response.Writer and causes data race for logger middleware if it is not first e.Use(TimeoutWithConfig(TimeoutConfig{ - Timeout: 15 * time.Millisecond, + Timeout: 100 * time.Millisecond, })) e.Use(Logger()) e.Use(Recover()) @@ -403,8 +403,13 @@ func TestTimeoutWithFullEchoStack(t *testing.T) { } if tc.whenForceHandlerTimeout { wg.Done() + // extremely short periods are not reliable for tests when it comes to goroutines. We can not guarantee in which + // order scheduler decides do execute: 1) request goroutine, 2) timeout timer goroutine. + // most of the time we get result we expect but Mac OS seems to be quite flaky + time.Sleep(50 * time.Millisecond) + // shutdown waits for server to shutdown. this way we wait logger mw to be executed - ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() server.Shutdown(ctx) } From b888a30fe394deeeb14e18226be51b5928115dd3 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 19 Feb 2023 21:04:05 +0200 Subject: [PATCH 277/446] Changelog for v4.10.1 --- CHANGELOG.md | 15 +++++++++++++++ echo.go | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c1c3c1074..28b6f8653 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,20 @@ # Changelog +## v4.10.1 - 2023-02-19 + +**Security** + +* Upgrade deps due to the latest golang.org/x/net vulnerability [#2402](https://github.com/labstack/echo/pull/2402) + + +**Enhancements** + +* Add new JWT repository to the README [#2377](https://github.com/labstack/echo/pull/2377) +* Return an empty string for ctx.path if there is no registered path [#2385](https://github.com/labstack/echo/pull/2385) +* Add context timeout middleware [#2380](https://github.com/labstack/echo/pull/2380) +* Update link to jaegertracing [#2394](https://github.com/labstack/echo/pull/2394) + + ## v4.10.0 - 2022-12-27 **Security** diff --git a/echo.go b/echo.go index f6d89b966..7199c45ac 100644 --- a/echo.go +++ b/echo.go @@ -258,7 +258,7 @@ const ( const ( // Version of Echo - Version = "4.10.0" + Version = "4.10.1" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 04ba8e2f9d3f39d7c05f3f0340d27ccec6535e7f Mon Sep 17 00:00:00 2001 From: Ara Park Date: Wed, 22 Feb 2023 06:32:11 +0900 Subject: [PATCH 278/446] Add more http error values (#2277) * Add more HTTP error constants --- echo.go | 65 +++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 47 insertions(+), 18 deletions(-) diff --git a/echo.go b/echo.go index 7199c45ac..ef99c22d7 100644 --- a/echo.go +++ b/echo.go @@ -291,24 +291,53 @@ var ( // Errors var ( - ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) - ErrNotFound = NewHTTPError(http.StatusNotFound) - ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) - ErrForbidden = NewHTTPError(http.StatusForbidden) - ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) - ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) - ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests) - ErrBadRequest = NewHTTPError(http.StatusBadRequest) - ErrBadGateway = NewHTTPError(http.StatusBadGateway) - ErrInternalServerError = NewHTTPError(http.StatusInternalServerError) - ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout) - ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable) - ErrValidatorNotRegistered = errors.New("validator not registered") - ErrRendererNotRegistered = errors.New("renderer not registered") - ErrInvalidRedirectCode = errors.New("invalid redirect status code") - ErrCookieNotFound = errors.New("cookie not found") - ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") - ErrInvalidListenerNetwork = errors.New("invalid listener network") + ErrBadRequest = NewHTTPError(http.StatusBadRequest) // HTTP 400 Bad Request + ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) // HTTP 401 Unauthorized + ErrPaymentRequired = NewHTTPError(http.StatusPaymentRequired) // HTTP 402 Payment Required + ErrForbidden = NewHTTPError(http.StatusForbidden) // HTTP 403 Forbidden + ErrNotFound = NewHTTPError(http.StatusNotFound) // HTTP 404 Not Found + ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) // HTTP 405 Method Not Allowed + ErrNotAcceptable = NewHTTPError(http.StatusNotAcceptable) // HTTP 406 Not Acceptable + ErrProxyAuthRequired = NewHTTPError(http.StatusProxyAuthRequired) // HTTP 407 Proxy AuthRequired + ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout) // HTTP 408 Request Timeout + ErrConflict = NewHTTPError(http.StatusConflict) // HTTP 409 Conflict + ErrGone = NewHTTPError(http.StatusGone) // HTTP 410 Gone + ErrLengthRequired = NewHTTPError(http.StatusLengthRequired) // HTTP 411 Length Required + ErrPreconditionFailed = NewHTTPError(http.StatusPreconditionFailed) // HTTP 412 Precondition Failed + ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) // HTTP 413 Payload Too Large + ErrRequestURITooLong = NewHTTPError(http.StatusRequestURITooLong) // HTTP 414 URI Too Long + ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) // HTTP 415 Unsupported Media Type + ErrRequestedRangeNotSatisfiable = NewHTTPError(http.StatusRequestedRangeNotSatisfiable) // HTTP 416 Range Not Satisfiable + ErrExpectationFailed = NewHTTPError(http.StatusExpectationFailed) // HTTP 417 Expectation Failed + ErrTeapot = NewHTTPError(http.StatusTeapot) // HTTP 418 I'm a teapot + ErrMisdirectedRequest = NewHTTPError(http.StatusMisdirectedRequest) // HTTP 421 Misdirected Request + ErrUnprocessableEntity = NewHTTPError(http.StatusUnprocessableEntity) // HTTP 422 Unprocessable Entity + ErrLocked = NewHTTPError(http.StatusLocked) // HTTP 423 Locked + ErrFailedDependency = NewHTTPError(http.StatusFailedDependency) // HTTP 424 Failed Dependency + ErrTooEarly = NewHTTPError(http.StatusTooEarly) // HTTP 425 Too Early + ErrUpgradeRequired = NewHTTPError(http.StatusUpgradeRequired) // HTTP 426 Upgrade Required + ErrPreconditionRequired = NewHTTPError(http.StatusPreconditionRequired) // HTTP 428 Precondition Required + ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests) // HTTP 429 Too Many Requests + ErrRequestHeaderFieldsTooLarge = NewHTTPError(http.StatusRequestHeaderFieldsTooLarge) // HTTP 431 Request Header Fields Too Large + ErrUnavailableForLegalReasons = NewHTTPError(http.StatusUnavailableForLegalReasons) // HTTP 451 Unavailable For Legal Reasons + ErrInternalServerError = NewHTTPError(http.StatusInternalServerError) // HTTP 500 Internal Server Error + ErrNotImplemented = NewHTTPError(http.StatusNotImplemented) // HTTP 501 Not Implemented + ErrBadGateway = NewHTTPError(http.StatusBadGateway) // HTTP 502 Bad Gateway + ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable) // HTTP 503 Service Unavailable + ErrGatewayTimeout = NewHTTPError(http.StatusGatewayTimeout) // HTTP 504 Gateway Timeout + ErrHTTPVersionNotSupported = NewHTTPError(http.StatusHTTPVersionNotSupported) // HTTP 505 HTTP Version Not Supported + ErrVariantAlsoNegotiates = NewHTTPError(http.StatusVariantAlsoNegotiates) // HTTP 506 Variant Also Negotiates + ErrInsufficientStorage = NewHTTPError(http.StatusInsufficientStorage) // HTTP 507 Insufficient Storage + ErrLoopDetected = NewHTTPError(http.StatusLoopDetected) // HTTP 508 Loop Detected + ErrNotExtended = NewHTTPError(http.StatusNotExtended) // HTTP 510 Not Extended + ErrNetworkAuthenticationRequired = NewHTTPError(http.StatusNetworkAuthenticationRequired) // HTTP 511 Network Authentication Required + + ErrValidatorNotRegistered = errors.New("validator not registered") + ErrRendererNotRegistered = errors.New("renderer not registered") + ErrInvalidRedirectCode = errors.New("invalid redirect status code") + ErrCookieNotFound = errors.New("cookie not found") + ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") + ErrInvalidListenerNetwork = errors.New("invalid listener network") ) // Error handlers From 7c7531002d4fb5fd2fc573a5e32f6482cd54f153 Mon Sep 17 00:00:00 2001 From: Martti T Date: Wed, 22 Feb 2023 00:00:52 +0200 Subject: [PATCH 279/446] Clean on go1.20 (#2406) * Fix tests failing on Go 1.20 on Windows. Clean works differently on 1.20. Use path.Clean instead with some workaround related to errors. --- middleware/context_timeout_test.go | 11 +++++------ middleware/static.go | 28 +++++++++++----------------- middleware/static_other.go | 12 ++++++++++++ middleware/static_windows.go | 23 +++++++++++++++++++++++ 4 files changed, 51 insertions(+), 23 deletions(-) create mode 100644 middleware/static_other.go create mode 100644 middleware/static_windows.go diff --git a/middleware/context_timeout_test.go b/middleware/context_timeout_test.go index 605ca8e65..24c6203e7 100644 --- a/middleware/context_timeout_test.go +++ b/middleware/context_timeout_test.go @@ -148,7 +148,7 @@ func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) { c := e.NewContext(req, rec) err := m(func(c echo.Context) error { - if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil { + if err := sleepWithContext(c.Request().Context(), time.Duration(80*time.Millisecond)); err != nil { return err } return c.String(http.StatusOK, "Hello, World!") @@ -176,7 +176,7 @@ func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) { return nil } - timeout := 10 * time.Millisecond + timeout := 50 * time.Millisecond m := ContextTimeoutWithConfig(ContextTimeoutConfig{ Timeout: timeout, ErrorHandler: timeoutErrorHandler, @@ -189,11 +189,10 @@ func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) { c := e.NewContext(req, rec) err := m(func(c echo.Context) error { - // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) - // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output - // difference over 500microseconds (0.5millisecond) response seems to be reliable + // NOTE: Very short periods are not reliable for tests due to Go routine scheduling and the unpredictable order + // for 1) request and 2) time goroutine. For most OS this works as expected, but MacOS seems most flaky. - if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil { + if err := sleepWithContext(c.Request().Context(), 100*time.Millisecond); err != nil { return err } diff --git a/middleware/static.go b/middleware/static.go index 27ccf4117..24a5f59b9 100644 --- a/middleware/static.go +++ b/middleware/static.go @@ -8,7 +8,6 @@ import ( "net/url" "os" "path" - "path/filepath" "strings" "github.com/labstack/echo/v4" @@ -157,9 +156,9 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { } // Index template - t, err := template.New("index").Parse(html) - if err != nil { - panic(fmt.Sprintf("echo: %v", err)) + t, tErr := template.New("index").Parse(html) + if tErr != nil { + panic(fmt.Errorf("echo: %w", tErr)) } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -176,7 +175,7 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { if err != nil { return } - name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security + name := path.Join(config.Root, path.Clean("/"+p)) // "/"+ for security if config.IgnoreBase { routePath := path.Base(strings.TrimRight(c.Path(), "/*")) @@ -187,12 +186,14 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { } } - file, err := openFile(config.Filesystem, name) + file, err := config.Filesystem.Open(name) if err != nil { - if !os.IsNotExist(err) { + if !isIgnorableOpenFileError(err) { return err } + // file with that path did not exist, so we continue down in middleware/handler chain, hoping that we end up in + // handler that is meant to handle this request if err = next(c); err == nil { return err } @@ -202,7 +203,7 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { return err } - file, err = openFile(config.Filesystem, filepath.Join(config.Root, config.Index)) + file, err = config.Filesystem.Open(path.Join(config.Root, config.Index)) if err != nil { return err } @@ -216,15 +217,13 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { } if info.IsDir() { - index, err := openFile(config.Filesystem, filepath.Join(name, config.Index)) + index, err := config.Filesystem.Open(path.Join(name, config.Index)) if err != nil { if config.Browse { return listDir(t, name, file, c.Response()) } - if os.IsNotExist(err) { - return next(c) - } + return next(c) } defer index.Close() @@ -242,11 +241,6 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { } } -func openFile(fs http.FileSystem, name string) (http.File, error) { - pathWithSlashes := filepath.ToSlash(name) - return fs.Open(pathWithSlashes) -} - func serveFile(c echo.Context, file http.File, info os.FileInfo) error { http.ServeContent(c.Response(), c.Request(), info.Name(), info.ModTime(), file) return nil diff --git a/middleware/static_other.go b/middleware/static_other.go new file mode 100644 index 000000000..0337b22af --- /dev/null +++ b/middleware/static_other.go @@ -0,0 +1,12 @@ +//go:build !windows + +package middleware + +import ( + "os" +) + +// We ignore these errors as there could be handler that matches request path. +func isIgnorableOpenFileError(err error) bool { + return os.IsNotExist(err) +} diff --git a/middleware/static_windows.go b/middleware/static_windows.go new file mode 100644 index 000000000..0ab119859 --- /dev/null +++ b/middleware/static_windows.go @@ -0,0 +1,23 @@ +package middleware + +import ( + "os" +) + +// We ignore these errors as there could be handler that matches request path. +// +// As of Go 1.20 filepath.Clean has different behaviour on OS related filesystems so we need to use path.Clean +// on Windows which has some caveats. The Open methods might return different errors than earlier versions and +// as of 1.20 path checks are more strict on the provided path and considers [UNC](https://en.wikipedia.org/wiki/Path_(computing)#UNC) +// paths with missing host etc parts as invalid. Previously it would result you `fs.ErrNotExist`. +// +// For 1.20@Windows we need to treat those errors the same as `fs.ErrNotExists` so we can continue handling +// errors in the middleware/handler chain. Otherwise we might end up with status 500 instead of finding a route +// or return 404 not found. +func isIgnorableOpenFileError(err error) bool { + if os.IsNotExist(err) { + return true + } + errTxt := err.Error() + return errTxt == "http: invalid or unsafe file path" || errTxt == "invalid path" +} From ef4aea97ef344bf0f61ba3b50844987b7dac8169 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Tue, 21 Feb 2023 12:20:30 +0200 Subject: [PATCH 280/446] use different variable name so returned function would not accidentally be able to use it in future and cause data race --- middleware/csrf.go | 6 +++--- middleware/jwt.go | 6 +++--- middleware/key_auth.go | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/middleware/csrf.go b/middleware/csrf.go index 8661c9f89..6899700c7 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -119,9 +119,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { config.CookieSecure = true } - extractors, err := CreateExtractors(config.TokenLookup) - if err != nil { - panic(err) + extractors, cErr := CreateExtractors(config.TokenLookup) + if cErr != nil { + panic(cErr) } return func(next echo.HandlerFunc) echo.HandlerFunc { diff --git a/middleware/jwt.go b/middleware/jwt.go index bd628264e..bc318c976 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -196,9 +196,9 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { config.ParseTokenFunc = config.defaultParseToken } - extractors, err := createExtractors(config.TokenLookup, config.AuthScheme) - if err != nil { - panic(err) + extractors, cErr := createExtractors(config.TokenLookup, config.AuthScheme) + if cErr != nil { + panic(cErr) } if len(config.TokenLookupFuncs) > 0 { extractors = append(config.TokenLookupFuncs, extractors...) diff --git a/middleware/key_auth.go b/middleware/key_auth.go index e8a6b0853..f6fcc5d69 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -108,9 +108,9 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { panic("echo: key-auth middleware requires a validator function") } - extractors, err := createExtractors(config.KeyLookup, config.AuthScheme) - if err != nil { - panic(err) + extractors, cErr := createExtractors(config.KeyLookup, config.AuthScheme) + if cErr != nil { + panic(cErr) } return func(next echo.HandlerFunc) echo.HandlerFunc { From f909660bb9fa0fed50a897a5169422e3bd92106b Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Tue, 21 Feb 2023 12:21:49 +0200 Subject: [PATCH 281/446] Add middleware.CORSConfig.UnsafeWildcardOriginWithAllowCredentials to make UNSAFE usages of wildcard origin + allow cretentials less likely. --- middleware/cors.go | 11 +- middleware/cors_test.go | 282 ++++++++++++++++++++++++++-------------- 2 files changed, 193 insertions(+), 100 deletions(-) diff --git a/middleware/cors.go b/middleware/cors.go index 25cf983a7..149de347a 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -79,6 +79,15 @@ type ( // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials AllowCredentials bool `yaml:"allow_credentials"` + // UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials + // flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header. + // + // This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties) + // attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject. + // + // Optional. Default value is false. + UnsafeWildcardOriginWithAllowCredentials bool `yaml:"unsafe_wildcard_origin_with_allow_credentials"` + // ExposeHeaders determines the value of Access-Control-Expose-Headers, which // defines a list of headers that clients are allowed to access. // @@ -203,7 +212,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } else { // Check allowed origins for _, o := range config.AllowOrigins { - if o == "*" && config.AllowCredentials { + if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials { allowOrigin = origin break } diff --git a/middleware/cors_test.go b/middleware/cors_test.go index daadbab6e..c1bb91eb3 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -11,106 +11,190 @@ import ( ) func TestCORS(t *testing.T) { - e := echo.New() + var testCases = []struct { + name string + givenMW echo.MiddlewareFunc + whenMethod string + whenHeaders map[string]string + expectHeaders map[string]string + notExpectHeaders map[string]string + }{ + { + name: "ok, wildcard origin", + whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"}, + expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "*"}, + }, + { + name: "ok, wildcard AllowedOrigin with no Origin header in request", + notExpectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: ""}, + }, + { + name: "ok, specific AllowOrigins and AllowCredentials", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, + AllowCredentials: true, + MaxAge: 3600, + }), + whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"}, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "localhost", + echo.HeaderAccessControlAllowCredentials: "true", + }, + }, + { + name: "ok, preflight request with matching origin for `AllowOrigins`", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, + AllowCredentials: true, + MaxAge: 3600, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "localhost", + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + echo.HeaderAccessControlAllowCredentials: "true", + echo.HeaderAccessControlMaxAge: "3600", + }, + }, + { + name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"*"}, + AllowCredentials: true, + MaxAge: 3600, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "*", // Note: browsers will ignore and complain about responses having `*` + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + echo.HeaderAccessControlAllowCredentials: "true", + echo.HeaderAccessControlMaxAge: "3600", + }, + }, + { + name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` false", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"*"}, + AllowCredentials: false, // important for this testcase + MaxAge: 3600, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "*", + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + echo.HeaderAccessControlMaxAge: "3600", + }, + notExpectHeaders: map[string]string{ + echo.HeaderAccessControlAllowCredentials: "", + }, + }, + { + name: "ok, INSECURE preflight request with wildcard `AllowOrigins` and `AllowCredentials` true", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"*"}, + AllowCredentials: true, + UnsafeWildcardOriginWithAllowCredentials: true, // important for this testcase + MaxAge: 3600, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "localhost", // This could end up as cross-origin attack + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + echo.HeaderAccessControlAllowCredentials: "true", + echo.HeaderAccessControlMaxAge: "3600", + }, + }, + { + name: "ok, preflight request with Access-Control-Request-Headers", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"*"}, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + echo.HeaderAccessControlRequestHeaders: "Special-Request-Header", + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "*", + echo.HeaderAccessControlAllowHeaders: "Special-Request-Header", + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + }, + }, + { + name: "ok, preflight request with `AllowOrigins` which allow all subdomains aaa with *", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"http://*.example.com"}, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{echo.HeaderOrigin: "http://aaa.example.com"}, + expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://aaa.example.com"}, + }, + { + name: "ok, preflight request with `AllowOrigins` which allow all subdomains bbb with *", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"http://*.example.com"}, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{echo.HeaderOrigin: "http://bbb.example.com"}, + expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://bbb.example.com"}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + mw := CORS() + if tc.givenMW != nil { + mw = tc.givenMW + } + h := mw(func(c echo.Context) error { + return nil + }) - // Wildcard origin - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := CORS()(echo.NotFoundHandler) - req.Header.Set(echo.HeaderOrigin, "localhost") - h(c) - assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - - // Wildcard AllowedOrigin with no Origin header in request - req = httptest.NewRequest(http.MethodGet, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = CORS()(echo.NotFoundHandler) - h(c) - assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) - - // Allow origins - req = httptest.NewRequest(http.MethodGet, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"localhost"}, - AllowCredentials: true, - MaxAge: 3600, - })(echo.NotFoundHandler) - req.Header.Set(echo.HeaderOrigin, "localhost") - h(c) - assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) - - // Preflight request - req = httptest.NewRequest(http.MethodOptions, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - req.Header.Set(echo.HeaderOrigin, "localhost") - req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) - cors := CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"localhost"}, - AllowCredentials: true, - MaxAge: 3600, - }) - h = cors(echo.NotFoundHandler) - h(c) - assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) - assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) - assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge)) - - // Preflight request with `AllowOrigins` * - req = httptest.NewRequest(http.MethodOptions, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - req.Header.Set(echo.HeaderOrigin, "localhost") - req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) - cors = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"*"}, - AllowCredentials: true, - MaxAge: 3600, - }) - h = cors(echo.NotFoundHandler) - h(c) - assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) - assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) - assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge)) - - // Preflight request with Access-Control-Request-Headers - req = httptest.NewRequest(http.MethodOptions, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - req.Header.Set(echo.HeaderOrigin, "localhost") - req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) - req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header") - cors = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"*"}, - }) - h = cors(echo.NotFoundHandler) - h(c) - assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders)) - assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) - - // Preflight request with `AllowOrigins` which allow all subdomains with * - req = httptest.NewRequest(http.MethodOptions, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - req.Header.Set(echo.HeaderOrigin, "http://aaa.example.com") - cors = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"http://*.example.com"}, - }) - h = cors(echo.NotFoundHandler) - h(c) - assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - - req.Header.Set(echo.HeaderOrigin, "http://bbb.example.com") - h(c) - assert.Equal(t, "http://bbb.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + req := httptest.NewRequest(method, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + for k, v := range tc.whenHeaders { + req.Header.Set(k, v) + } + + err := h(c) + + assert.NoError(t, err) + header := rec.Header() + for k, v := range tc.expectHeaders { + assert.Equal(t, v, header.Get(k), "header: `%v` should be `%v`", k, v) + } + for k, v := range tc.notExpectHeaders { + if v == "" { + assert.Len(t, header.Values(k), 0, "header: `%v` should not be set", k) + } else { + assert.NotEqual(t, v, header.Get(k), "header: `%v` should not be `%v`", k, v) + } + } + }) + } } func Test_allowOriginScheme(t *testing.T) { From 47844c9b7f83e5bf4efbe1f449bf2a155f465da8 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Wed, 22 Feb 2023 00:55:31 +0200 Subject: [PATCH 282/446] Changelog for v4.10.2 --- CHANGELOG.md | 12 ++++++++++++ echo.go | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 28b6f8653..831842497 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## v4.10.2 - 2023-02-22 + +**Security** + +* `filepath.Clean` behaviour has changed in Go 1.20 - adapt to it [#2406](https://github.com/labstack/echo/pull/2406) +* Add `middleware.CORSConfig.UnsafeWildcardOriginWithAllowCredentials` to make UNSAFE usages of wildcard origin + allow cretentials less likely [#2405](https://github.com/labstack/echo/pull/2405) + +**Enhancements** + +* Add more HTTP error values [#2277](https://github.com/labstack/echo/pull/2277) + + ## v4.10.1 - 2023-02-19 **Security** diff --git a/echo.go b/echo.go index ef99c22d7..085a3a7f2 100644 --- a/echo.go +++ b/echo.go @@ -258,7 +258,7 @@ const ( const ( // Version of Echo - Version = "4.10.1" + Version = "4.10.2" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 1e575b7b56d7f1478d889bbd7464f124efe9bc1e Mon Sep 17 00:00:00 2001 From: Omkar <42245836+Omkar-C@users.noreply.github.com> Date: Fri, 24 Feb 2023 16:39:40 +0530 Subject: [PATCH 283/446] Added a optional config variable to disable centralized error handler in recovery middleware (#2410) Added a config variable to disable centralized error handler in recovery middleware --- middleware/recover.go | 15 +++++++++++++-- middleware/recover_test.go | 23 ++++++++++++++++++++++- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/middleware/recover.go b/middleware/recover.go index 7b6128533..36d41aa64 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -38,6 +38,11 @@ type ( // LogErrorFunc defines a function for custom logging in the middleware. // If it's set you don't need to provide LogLevel for config. LogErrorFunc LogErrorFunc + + // DisableErrorHandler disables the call to centralized HTTPErrorHandler. + // The recovered error is then passed back to upstream middleware, instead of swallowing the error. + // Optional. Default value false. + DisableErrorHandler bool `yaml:"disable_error_handler"` } ) @@ -50,6 +55,7 @@ var ( DisablePrintStack: false, LogLevel: 0, LogErrorFunc: nil, + DisableErrorHandler: false, } ) @@ -71,7 +77,7 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c echo.Context) (returnErr error) { if config.Skipper(c) { return next(c) } @@ -113,7 +119,12 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { c.Logger().Print(msg) } } - c.Error(err) + + if(!config.DisableErrorHandler) { + c.Error(err) + } else { + returnErr = err + } } }() return next(c) diff --git a/middleware/recover_test.go b/middleware/recover_test.go index b27f3b41c..3e0d35d79 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -23,7 +23,8 @@ func TestRecover(t *testing.T) { h := Recover()(echo.HandlerFunc(func(c echo.Context) error { panic("test") })) - h(c) + err := h(c) + assert.NoError(t, err) assert.Equal(t, http.StatusInternalServerError, rec.Code) assert.Contains(t, buf.String(), "PANIC RECOVER") } @@ -163,3 +164,23 @@ func TestRecoverWithConfig_LogErrorFunc(t *testing.T) { assert.Contains(t, output, `"level":"ERROR"`) }) } + +func TestRecoverWithDisabled_ErrorHandler(t *testing.T) { + e := echo.New() + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + config := DefaultRecoverConfig + config.DisableErrorHandler = true + h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { + panic("test") + })) + err := h(c) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Contains(t, buf.String(), "PANIC RECOVER") + assert.EqualError(t, err, "test") +} From 5b36ce36127b2c011e6f0b905958d2544eef8820 Mon Sep 17 00:00:00 2001 From: Becir Basic Date: Fri, 24 Feb 2023 19:32:41 +0100 Subject: [PATCH 284/446] Fixes the concurrency issue of calling the `Next()` proxy target on RRB (#2409) * Fixes the concurrency issue of calling the `Next()` proxy target on round robin balancer - fixed concurrency issue in `AddTarget()` - moved `rand.New()` to the random balancer initializer func. - internal code reorganized eliminating unnecessary pointer redirection - employing `sync.Mutex` instead of `RWMutex` which brings additional overhead of tracking readers and writers. No need for that since the guarded code has no long-running operations, hence no realistic congestion. - added additional guards without which the code would otherwise panic (e.g., the case where a random value is calculation when targets list is empty) - added descriptions for func return values, what to expect in which case. - Improve code test coverage --------- Co-authored-by: Becir Basic --- middleware/proxy.go | 59 +++++++++++++++++++++++++++------------- middleware/proxy_test.go | 15 ++++++++-- 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/middleware/proxy.go b/middleware/proxy.go index d2cd2aa6d..74f49de8a 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -12,7 +12,6 @@ import ( "regexp" "strings" "sync" - "sync/atomic" "time" "github.com/labstack/echo/v4" @@ -79,19 +78,20 @@ type ( commonBalancer struct { targets []*ProxyTarget - mutex sync.RWMutex + mutex sync.Mutex } // RandomBalancer implements a random load balancing technique. randomBalancer struct { - *commonBalancer + commonBalancer random *rand.Rand } // RoundRobinBalancer implements a round-robin load balancing technique. roundRobinBalancer struct { - *commonBalancer - i uint32 + commonBalancer + // tracking the index on `targets` slice for the next `*ProxyTarget` to be used + i int } ) @@ -143,32 +143,37 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { // NewRandomBalancer returns a random proxy balancer. func NewRandomBalancer(targets []*ProxyTarget) ProxyBalancer { - b := &randomBalancer{commonBalancer: new(commonBalancer)} + b := randomBalancer{} b.targets = targets - return b + b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) + return &b } // NewRoundRobinBalancer returns a round-robin proxy balancer. func NewRoundRobinBalancer(targets []*ProxyTarget) ProxyBalancer { - b := &roundRobinBalancer{commonBalancer: new(commonBalancer)} + b := roundRobinBalancer{} b.targets = targets - return b + return &b } -// AddTarget adds an upstream target to the list. +// AddTarget adds an upstream target to the list and returns `true`. +// +// However, if a target with the same name already exists then the operation is aborted returning `false`. func (b *commonBalancer) AddTarget(target *ProxyTarget) bool { + b.mutex.Lock() + defer b.mutex.Unlock() for _, t := range b.targets { if t.Name == target.Name { return false } } - b.mutex.Lock() - defer b.mutex.Unlock() b.targets = append(b.targets, target) return true } -// RemoveTarget removes an upstream target from the list. +// RemoveTarget removes an upstream target from the list by name. +// +// Returns `true` on success, `false` if no target with the name is found. func (b *commonBalancer) RemoveTarget(name string) bool { b.mutex.Lock() defer b.mutex.Unlock() @@ -182,20 +187,36 @@ func (b *commonBalancer) RemoveTarget(name string) bool { } // Next randomly returns an upstream target. +// +// Note: `nil` is returned in case upstream target list is empty. func (b *randomBalancer) Next(c echo.Context) *ProxyTarget { - if b.random == nil { - b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) + b.mutex.Lock() + defer b.mutex.Unlock() + if len(b.targets) == 0 { + return nil + } else if len(b.targets) == 1 { + return b.targets[0] } - b.mutex.RLock() - defer b.mutex.RUnlock() return b.targets[b.random.Intn(len(b.targets))] } // Next returns an upstream target using round-robin technique. +// +// Note: `nil` is returned in case upstream target list is empty. func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget { - b.i = b.i % uint32(len(b.targets)) + b.mutex.Lock() + defer b.mutex.Unlock() + if len(b.targets) == 0 { + return nil + } else if len(b.targets) == 1 { + return b.targets[0] + } + // reset the index if out of bounds + if b.i >= len(b.targets) { + b.i = 0 + } t := b.targets[b.i] - atomic.AddUint32(&b.i, 1) + b.i++ return t } diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index a1b7f2cae..122dddeba 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -122,7 +122,7 @@ func TestProxy(t *testing.T) { } type testProvider struct { - *commonBalancer + commonBalancer target *ProxyTarget err error } @@ -143,7 +143,7 @@ func TestTargetProvider(t *testing.T) { url1, _ := url.Parse(t1.URL) e := echo.New() - tp := &testProvider{commonBalancer: new(commonBalancer)} + tp := &testProvider{} tp.target = &ProxyTarget{Name: "target 1", URL: url1} e.Use(Proxy(tp)) rec := httptest.NewRecorder() @@ -158,7 +158,7 @@ func TestFailNextTarget(t *testing.T) { assert.Nil(t, err) e := echo.New() - tp := &testProvider{commonBalancer: new(commonBalancer)} + tp := &testProvider{} tp.target = &ProxyTarget{Name: "target 1", URL: url1} tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target") @@ -422,3 +422,12 @@ func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { timeoutStop.Done() assert.Equal(t, 499, rec.Code) } + +// Assert balancer with empty targets does return `nil` on `Next()` +func TestProxyBalancerWithNoTargets(t *testing.T) { + rb := NewRandomBalancer(nil) + assert.Nil(t, rb.Next(nil)) + + rrb := NewRoundRobinBalancer([]*ProxyTarget{}) + assert.Nil(t, rrb.Next(nil)) +} From ec642f7df11b7e0d5231af0b35d12438bf48498e Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Thu, 23 Feb 2023 21:02:12 +0200 Subject: [PATCH 285/446] Fix group.RouteNotFound not working when group has attached middlewares --- group.go | 10 +++++--- group_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 4 deletions(-) diff --git a/group.go b/group.go index 28ce0dd9a..749a5caab 100644 --- a/group.go +++ b/group.go @@ -23,10 +23,12 @@ func (g *Group) Use(middleware ...MiddlewareFunc) { if len(g.middleware) == 0 { return } - // Allow all requests to reach the group as they might get dropped if router - // doesn't find a match, making none of the group middleware process. - g.Any("", NotFoundHandler) - g.Any("/*", NotFoundHandler) + // group level middlewares are different from Echo `Pre` and `Use` middlewares (those are global). Group level middlewares + // are only executed if they are added to the Router with route. + // So we register catch all route (404 is a safe way to emulate route match) for this group and now during routing the + // Router would find route to match our request path and therefore guarantee the middleware(s) will get executed. + g.RouteNotFound("", NotFoundHandler) + g.RouteNotFound("/*", NotFoundHandler) } // CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. diff --git a/group_test.go b/group_test.go index 01c304d0c..d22f564b0 100644 --- a/group_test.go +++ b/group_test.go @@ -184,3 +184,73 @@ func TestGroup_RouteNotFound(t *testing.T) { }) } } + +func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) { + var testCases = []struct { + name string + givenCustom404 bool + whenURL string + expectBody interface{} + expectCode int + }{ + { + name: "ok, custom 404 handler is called with middleware", + givenCustom404: true, + whenURL: "/group/test3", + expectBody: "GET /group/*", + expectCode: http.StatusNotFound, + }, + { + name: "ok, default group 404 handler is called with middleware", + givenCustom404: false, + whenURL: "/group/test3", + expectBody: "{\"message\":\"Not Found\"}\n", + expectCode: http.StatusNotFound, + }, + { + name: "ok, (no slash) default group 404 handler is called with middleware", + givenCustom404: false, + whenURL: "/group", + expectBody: "{\"message\":\"Not Found\"}\n", + expectCode: http.StatusNotFound, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + okHandler := func(c Context) error { + return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) + } + notFoundHandler := func(c Context) error { + return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) + } + + e := New() + e.GET("/test1", okHandler) + e.RouteNotFound("/*", notFoundHandler) + + g := e.Group("/group") + g.GET("/test1", okHandler) + + middlewareCalled := false + g.Use(func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + middlewareCalled = true + return next(c) + } + }) + if tc.givenCustom404 { + g.RouteNotFound("/*", notFoundHandler) + } + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.True(t, middlewareCalled) + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectBody, rec.Body.String()) + }) + } +} From f22ba6725c66896efdef029aaeb3bfc471f171c3 Mon Sep 17 00:00:00 2001 From: ivansmaliakou Date: Wed, 15 Mar 2023 22:50:00 +0100 Subject: [PATCH 286/446] documentation: changed description for `Bind()` method of `Context interface`. Because `Bind()`` binds not only request body, but also path and query params --- context.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/context.go b/context.go index b3a7ce8d0..27da28a9c 100644 --- a/context.go +++ b/context.go @@ -100,8 +100,8 @@ type ( // Set saves data in the context. Set(key string, val interface{}) - // Bind binds the request body into provided type `i`. The default binder - // does it based on Content-Type header. + // Bind binds path params, query params and the request body into provided type `i`. The default binder + // binds body based on Content-Type header. Bind(i interface{}) error // Validate validates provided `i`. It is usually called after `Context#Bind()`. From c0bc886b78b8214cdbd79953899338a37658dd48 Mon Sep 17 00:00:00 2001 From: imxyb Date: Tue, 28 Mar 2023 16:42:55 +0800 Subject: [PATCH 287/446] refactor: use strings.ReplaceAll directly --- middleware/cors.go | 4 ++-- middleware/middleware.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/middleware/cors.go b/middleware/cors.go index 149de347a..6ddb540af 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -150,8 +150,8 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { allowOriginPatterns := []string{} for _, origin := range config.AllowOrigins { pattern := regexp.QuoteMeta(origin) - pattern = strings.Replace(pattern, "\\*", ".*", -1) - pattern = strings.Replace(pattern, "\\?", ".", -1) + pattern = strings.ReplaceAll(pattern, "\\*", ".*") + pattern = strings.ReplaceAll(pattern, "\\?", ".") pattern = "^" + pattern + "$" allowOriginPatterns = append(allowOriginPatterns, pattern) } diff --git a/middleware/middleware.go b/middleware/middleware.go index f250ca49a..664f71f45 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -38,9 +38,9 @@ func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string { rulesRegex := map[*regexp.Regexp]string{} for k, v := range rewrite { k = regexp.QuoteMeta(k) - k = strings.Replace(k, `\*`, "(.*?)", -1) + k = strings.ReplaceAll(k, `\*`, "(.*?)") if strings.HasPrefix(k, `\^`) { - k = strings.Replace(k, `\^`, "^", -1) + k = strings.ReplaceAll(k, `\^`, "^") } k = k + "$" rulesRegex[regexp.MustCompile(k)] = v From a7802ea523e56c79a1b9e9620c48d68bcff5212e Mon Sep 17 00:00:00 2001 From: imxyb Date: Tue, 28 Mar 2023 17:25:11 +0800 Subject: [PATCH 288/446] add supprt for go1.20 http.rwUnwrapper --- response.go | 7 +++++++ response_test.go | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/response.go b/response.go index 84f7c9e7e..d9c9aa6e0 100644 --- a/response.go +++ b/response.go @@ -94,6 +94,13 @@ func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { return r.Writer.(http.Hijacker).Hijack() } +// Unwrap returns the original http.ResponseWriter. +// ResponseController can be used to access the original http.ResponseWriter. +// See [https://go.dev/blog/go1.20] +func (r *Response) Unwrap() http.ResponseWriter { + return r.Writer +} + func (r *Response) reset(w http.ResponseWriter) { r.beforeFuncs = nil r.afterFuncs = nil diff --git a/response_test.go b/response_test.go index d95e079f9..e4fd636d8 100644 --- a/response_test.go +++ b/response_test.go @@ -72,3 +72,11 @@ func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) { assert.Equal(t, http.StatusOK, rec.Code) } + +func TestResponse_Unwrap(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + res := &Response{echo: e, Writer: rec} + + assert.Equal(t, rec, res.Unwrap()) +} From de1c798143d316b94331a0c0e15d2d519d94aad2 Mon Sep 17 00:00:00 2001 From: Simba Peng <1531315@qq.com> Date: Fri, 7 Apr 2023 16:00:17 +0800 Subject: [PATCH 289/446] Check whether is nil before invoking centralized error handling. --- middleware/recover.go | 15 ++++++++------- middleware/request_logger.go | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/middleware/recover.go b/middleware/recover.go index 36d41aa64..0466cfe56 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -37,6 +37,7 @@ type ( // LogErrorFunc defines a function for custom logging in the middleware. // If it's set you don't need to provide LogLevel for config. + // If this function returns nil, the centralized HTTPErrorHandler will not be called. LogErrorFunc LogErrorFunc // DisableErrorHandler disables the call to centralized HTTPErrorHandler. @@ -49,12 +50,12 @@ type ( var ( // DefaultRecoverConfig is the default Recover middleware config. DefaultRecoverConfig = RecoverConfig{ - Skipper: DefaultSkipper, - StackSize: 4 << 10, // 4 KB - DisableStackAll: false, - DisablePrintStack: false, - LogLevel: 0, - LogErrorFunc: nil, + Skipper: DefaultSkipper, + StackSize: 4 << 10, // 4 KB + DisableStackAll: false, + DisablePrintStack: false, + LogLevel: 0, + LogErrorFunc: nil, DisableErrorHandler: false, } ) @@ -120,7 +121,7 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { } } - if(!config.DisableErrorHandler) { + if err != nil && !config.DisableErrorHandler { c.Error(err) } else { returnErr = err diff --git a/middleware/request_logger.go b/middleware/request_logger.go index b9e369255..8e312e8d8 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -257,7 +257,7 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { config.BeforeNextFunc(c) } err := next(c) - if config.HandleError { + if err != nil && config.HandleError { c.Error(err) } From 7d54690cdc4be1effb746ee60d950f937aa9e897 Mon Sep 17 00:00:00 2001 From: Mihard Date: Sun, 16 Apr 2023 20:13:47 +0200 Subject: [PATCH 290/446] Proper colon support in reverse (#2416) * Adds support of the escaped colon in echo.Reverse --------- Co-authored-by: Mihard --- echo_test.go | 109 ++++++++++++++++++++++++++++++++++++++++++--------- router.go | 7 +++- 2 files changed, 96 insertions(+), 20 deletions(-) diff --git a/echo_test.go b/echo_test.go index 2f66c8c6c..eab25db33 100644 --- a/echo_test.go +++ b/echo_test.go @@ -1517,26 +1517,97 @@ func TestEcho_OnAddRouteHandler(t *testing.T) { } func TestEchoReverse(t *testing.T) { - e := New() - dummyHandler := func(Context) error { return nil } + var testCases = []struct { + name string + whenRouteName string + whenParams []interface{} + expect string + }{ + { + name: "ok,static with no params", + whenRouteName: "/static", + expect: "/static", + }, + { + name: "ok,static with non existent param", + whenRouteName: "/static", + whenParams: []interface{}{"missing param"}, + expect: "/static", + }, + { + name: "ok, wildcard with no params", + whenRouteName: "/static/*", + expect: "/static/*", + }, + { + name: "ok, wildcard with params", + whenRouteName: "/static/*", + whenParams: []interface{}{"foo.txt"}, + expect: "/static/foo.txt", + }, + { + name: "ok, single param without param", + whenRouteName: "/params/:foo", + expect: "/params/:foo", + }, + { + name: "ok, single param with param", + whenRouteName: "/params/:foo", + whenParams: []interface{}{"one"}, + expect: "/params/one", + }, + { + name: "ok, multi param without params", + whenRouteName: "/params/:foo/bar/:qux", + expect: "/params/:foo/bar/:qux", + }, + { + name: "ok, multi param with one param", + whenRouteName: "/params/:foo/bar/:qux", + whenParams: []interface{}{"one"}, + expect: "/params/one/bar/:qux", + }, + { + name: "ok, multi param with all params", + whenRouteName: "/params/:foo/bar/:qux", + whenParams: []interface{}{"one", "two"}, + expect: "/params/one/bar/two", + }, + { + name: "ok, multi param + wildcard with all params", + whenRouteName: "/params/:foo/bar/:qux/*", + whenParams: []interface{}{"one", "two", "three"}, + expect: "/params/one/bar/two/three", + }, + { + name: "ok, backslash is not escaped", + whenRouteName: "/backslash", + whenParams: []interface{}{"test"}, + expect: `/a\b/test`, + }, + { + name: "ok, escaped colon verbs", + whenRouteName: "/params:customVerb", + whenParams: []interface{}{"PATCH"}, + expect: `/params:PATCH`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + dummyHandler := func(Context) error { return nil } + + e.GET("/static", dummyHandler).Name = "/static" + e.GET("/static/*", dummyHandler).Name = "/static/*" + e.GET("/params/:foo", dummyHandler).Name = "/params/:foo" + e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux" + e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*" + e.GET("/a\\b/:x", dummyHandler).Name = "/backslash" + e.GET("/params\\::customVerb", dummyHandler).Name = "/params:customVerb" - e.GET("/static", dummyHandler).Name = "/static" - e.GET("/static/*", dummyHandler).Name = "/static/*" - e.GET("/params/:foo", dummyHandler).Name = "/params/:foo" - e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux" - e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*" - - assert.Equal(t, "/static", e.Reverse("/static")) - assert.Equal(t, "/static", e.Reverse("/static", "missing param")) - assert.Equal(t, "/static/*", e.Reverse("/static/*")) - assert.Equal(t, "/static/foo.txt", e.Reverse("/static/*", "foo.txt")) - - assert.Equal(t, "/params/:foo", e.Reverse("/params/:foo")) - assert.Equal(t, "/params/one", e.Reverse("/params/:foo", "one")) - assert.Equal(t, "/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux")) - assert.Equal(t, "/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one")) - assert.Equal(t, "/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two")) - assert.Equal(t, "/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three")) + assert.Equal(t, tc.expect, e.Reverse(tc.whenRouteName, tc.whenParams...)) + }) + } } func TestEchoReverseHandleHostProperly(t *testing.T) { diff --git a/router.go b/router.go index 597660d39..50a6385ab 100644 --- a/router.go +++ b/router.go @@ -159,7 +159,12 @@ func (r *Router) Reverse(name string, params ...interface{}) string { for _, route := range r.routes { if route.Name == name { for i, l := 0, len(route.Path); i < l; i++ { - if (route.Path[i] == ':' || route.Path[i] == '*') && n < ln { + hasBackslash := route.Path[i] == '\\' + if hasBackslash && i+1 < l && route.Path[i+1] == ':' { + i++ // backslash before colon escapes that colon. in that case skip backslash + } + if n < ln && (route.Path[i] == '*' || (!hasBackslash && route.Path[i] == ':')) { + // in case of `*` wildcard or `:` (unescaped colon) param we replace everything till next slash or end of path for ; i < l && route.Path[i] != '/'; i++ { } uri.WriteString(fmt.Sprintf("%v", params[n])) From 0d47b7e6a93dfc9778a028c1a33d53d7ca52748e Mon Sep 17 00:00:00 2001 From: cui fliter Date: Fri, 21 Apr 2023 16:06:57 +0800 Subject: [PATCH 291/446] fix misuses of a vs an Signed-off-by: cui fliter --- echo.go | 2 +- router.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/echo.go b/echo.go index 085a3a7f2..9028b7a71 100644 --- a/echo.go +++ b/echo.go @@ -614,7 +614,7 @@ func (e *Echo) URL(h HandlerFunc, params ...interface{}) string { return e.URI(h, params...) } -// Reverse generates an URL from route name and provided parameters. +// Reverse generates a URL from route name and provided parameters. func (e *Echo) Reverse(name string, params ...interface{}) string { return e.router.Reverse(name, params...) } diff --git a/router.go b/router.go index 50a6385ab..ee6f3fa48 100644 --- a/router.go +++ b/router.go @@ -151,7 +151,7 @@ func (r *Router) Routes() []*Route { return routes } -// Reverse generates an URL from route name and provided parameters. +// Reverse generates a URL from route name and provided parameters. func (r *Router) Reverse(name string, params ...interface{}) string { uri := new(bytes.Buffer) ln := len(params) From deb17d2388a74cd4133f46c2dedfb7601da1db0a Mon Sep 17 00:00:00 2001 From: Samuel Berthe Date: Sun, 30 Apr 2023 22:39:52 +0200 Subject: [PATCH 292/446] Doc: adding slog.Handler for Echo logging --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index fe78b6ed1..ea8f30f64 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,7 @@ of middlewares in this list. | [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | | [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | | [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | +| [github.com/samber/slog-echo](https://github.com/samber/slog-echo) | Go [slog](https://pkg.go.dev/golang.org/x/exp/slog) logging library wrapper for Echo logger interface. | | [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | | [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. | | [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code | From 0ae74648b9045eac5b5978061044c314a6fcd63a Mon Sep 17 00:00:00 2001 From: mikemherron <15673068+mikemherron@users.noreply.github.com> Date: Fri, 12 May 2023 18:36:24 +0100 Subject: [PATCH 293/446] Support retries of failed proxy requests (#2414) Support retries of failed proxy requests --- middleware/proxy.go | 162 +++++++++++++++----- middleware/proxy_test.go | 316 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 437 insertions(+), 41 deletions(-) diff --git a/middleware/proxy.go b/middleware/proxy.go index 74f49de8a..e4f98d9ed 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -29,6 +29,33 @@ type ( // Required. Balancer ProxyBalancer + // RetryCount defines the number of times a failed proxied request should be retried + // using the next available ProxyTarget. Defaults to 0, meaning requests are never retried. + RetryCount int + + // RetryFilter defines a function used to determine if a failed request to a + // ProxyTarget should be retried. The RetryFilter will only be called when the number + // of previous retries is less than RetryCount. If the function returns true, the + // request will be retried. The provided error indicates the reason for the request + // failure. When the ProxyTarget is unavailable, the error will be an instance of + // echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error + // will indicate an internal error in the Proxy middleware. When a RetryFilter is not + // specified, all requests that fail with http.StatusBadGateway will be retried. A custom + // RetryFilter can be provided to only retry specific requests. Note that RetryFilter is + // only called when the request to the target fails, or an internal error in the Proxy + // middleware has occurred. Successful requests that return a non-200 response code cannot + // be retried. + RetryFilter func(c echo.Context, e error) bool + + // ErrorHandler defines a function which can be used to return custom errors from + // the Proxy middleware. ErrorHandler is only invoked when there has been + // either an internal error in the Proxy middleware or the ProxyTarget is + // unavailable. Due to the way requests are proxied, ErrorHandler is not invoked + // when a ProxyTarget returns a non-200 response. In these cases, the response + // is already written so errors cannot be modified. ErrorHandler is only + // invoked after all retry attempts have been exhausted. + ErrorHandler func(c echo.Context, err error) error + // Rewrite defines URL path rewrite rules. The values captured in asterisk can be // retrieved by index e.g. $1, $2 and so on. // Examples: @@ -71,7 +98,8 @@ type ( Next(echo.Context) *ProxyTarget } - // TargetProvider defines an interface that gives the opportunity for balancer to return custom errors when selecting target. + // TargetProvider defines an interface that gives the opportunity for balancer + // to return custom errors when selecting target. TargetProvider interface { NextTarget(echo.Context) (*ProxyTarget, error) } @@ -107,14 +135,14 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { in, _, err := c.Response().Hijack() if err != nil { - c.Set("_error", fmt.Sprintf("proxy raw, hijack error=%v, url=%s", t.URL, err)) + c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL)) return } defer in.Close() out, err := net.Dial("tcp", t.URL.Host) if err != nil { - c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err))) + c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL))) return } defer out.Close() @@ -122,7 +150,7 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { // Write header err = r.Write(out) if err != nil { - c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err))) + c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", err, t.URL))) return } @@ -136,7 +164,7 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { go cp(in, out) err = <-errCh if err != nil && err != io.EOF { - c.Set("_error", fmt.Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err)) + c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err, t.URL)) } }) } @@ -200,7 +228,12 @@ func (b *randomBalancer) Next(c echo.Context) *ProxyTarget { return b.targets[b.random.Intn(len(b.targets))] } -// Next returns an upstream target using round-robin technique. +// Next returns an upstream target using round-robin technique. In the case +// where a previously failed request is being retried, the round-robin +// balancer will attempt to use the next target relative to the original +// request. If the list of targets held by the balancer is modified while a +// failed request is being retried, it is possible that the balancer will +// return the original failed target. // // Note: `nil` is returned in case upstream target list is empty. func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget { @@ -211,13 +244,29 @@ func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget { } else if len(b.targets) == 1 { return b.targets[0] } - // reset the index if out of bounds - if b.i >= len(b.targets) { - b.i = 0 + + var i int + const lastIdxKey = "_round_robin_last_index" + // This request is a retry, start from the index of the previous + // target to ensure we don't attempt to retry the request with + // the same failed target + if c.Get(lastIdxKey) != nil { + i = c.Get(lastIdxKey).(int) + i++ + if i >= len(b.targets) { + i = 0 + } + } else { + // This is a first time request, use the global index + if b.i >= len(b.targets) { + b.i = 0 + } + i = b.i + b.i++ } - t := b.targets[b.i] - b.i++ - return t + + c.Set(lastIdxKey, i) + return b.targets[i] } // Proxy returns a Proxy middleware. @@ -232,14 +281,26 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc { // ProxyWithConfig returns a Proxy middleware with config. // See: `Proxy()` func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { + if config.Balancer == nil { + panic("echo: proxy middleware requires balancer") + } // Defaults if config.Skipper == nil { config.Skipper = DefaultProxyConfig.Skipper } - if config.Balancer == nil { - panic("echo: proxy middleware requires balancer") + if config.RetryFilter == nil { + config.RetryFilter = func(c echo.Context, e error) bool { + if httpErr, ok := e.(*echo.HTTPError); ok { + return httpErr.Code == http.StatusBadGateway + } + return false + } + } + if config.ErrorHandler == nil { + config.ErrorHandler = func(c echo.Context, err error) error { + return err + } } - if config.Rewrite != nil { if config.RegexRewrite == nil { config.RegexRewrite = make(map[*regexp.Regexp]string) @@ -250,28 +311,17 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { } provider, isTargetProvider := config.Balancer.(TargetProvider) + return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c echo.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() res := c.Response() - - var tgt *ProxyTarget - if isTargetProvider { - tgt, err = provider.NextTarget(c) - if err != nil { - return err - } - } else { - tgt = config.Balancer.Next(c) - } - c.Set(config.ContextKey, tgt) - if err := rewriteURL(config.RegexRewrite, req); err != nil { - return err + return config.ErrorHandler(c, err) } // Fix header @@ -287,19 +337,49 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { req.Header.Set(echo.HeaderXForwardedFor, c.RealIP()) } - // Proxy - switch { - case c.IsWebSocket(): - proxyRaw(tgt, c).ServeHTTP(res, req) - case req.Header.Get(echo.HeaderAccept) == "text/event-stream": - default: - proxyHTTP(tgt, c, config).ServeHTTP(res, req) - } - if e, ok := c.Get("_error").(error); ok { - err = e - } + retries := config.RetryCount + for { + var tgt *ProxyTarget + var err error + if isTargetProvider { + tgt, err = provider.NextTarget(c) + if err != nil { + return config.ErrorHandler(c, err) + } + } else { + tgt = config.Balancer.Next(c) + } - return + c.Set(config.ContextKey, tgt) + + //If retrying a failed request, clear any previous errors from + //context here so that balancers have the option to check for + //errors that occurred using previous target + if retries < config.RetryCount { + c.Set("_error", nil) + } + + // Proxy + switch { + case c.IsWebSocket(): + proxyRaw(tgt, c).ServeHTTP(res, req) + case req.Header.Get(echo.HeaderAccept) == "text/event-stream": + default: + proxyHTTP(tgt, c, config).ServeHTTP(res, req) + } + + err, hasError := c.Get("_error").(error) + if !hasError { + return nil + } + + retry := retries > 0 && config.RetryFilter(c, err) + if !retry { + return config.ErrorHandler(c, err) + } + + retries-- + } } } } diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 122dddeba..1b5ba6cbe 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -3,6 +3,7 @@ package middleware import ( "bytes" "context" + "errors" "fmt" "io" "net" @@ -393,6 +394,321 @@ func TestProxyError(t *testing.T) { assert.Equal(t, http.StatusBadGateway, rec.Code) } +func TestProxyRetries(t *testing.T) { + + newServer := func(res int) (*url.URL, *httptest.Server) { + server := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(res) + }), + ) + targetURL, _ := url.Parse(server.URL) + return targetURL, server + } + + targetURL, server := newServer(http.StatusOK) + defer server.Close() + goodTarget := &ProxyTarget{ + Name: "Good", + URL: targetURL, + } + + targetURL, server = newServer(http.StatusBadRequest) + defer server.Close() + goodTargetWith40X := &ProxyTarget{ + Name: "Good with 40X", + URL: targetURL, + } + + targetURL, _ = url.Parse("http://127.0.0.1:27121") + badTarget := &ProxyTarget{ + Name: "Bad", + URL: targetURL, + } + + alwaysRetryFilter := func(c echo.Context, e error) bool { return true } + neverRetryFilter := func(c echo.Context, e error) bool { return false } + + testCases := []struct { + name string + retryCount int + retryFilters []func(c echo.Context, e error) bool + targets []*ProxyTarget + expectedResponse int + }{ + { + name: "retry count 0 does not attempt retry on fail", + targets: []*ProxyTarget{ + badTarget, + goodTarget, + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 1 does not attempt retry on success", + retryCount: 1, + targets: []*ProxyTarget{ + goodTarget, + }, + expectedResponse: http.StatusOK, + }, + { + name: "retry count 1 does retry on handler return true", + retryCount: 1, + retryFilters: []func(c echo.Context, e error) bool{ + alwaysRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + goodTarget, + }, + expectedResponse: http.StatusOK, + }, + { + name: "retry count 1 does not retry on handler return false", + retryCount: 1, + retryFilters: []func(c echo.Context, e error) bool{ + neverRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + goodTarget, + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 2 returns error when no more retries left", + retryCount: 2, + retryFilters: []func(c echo.Context, e error) bool{ + alwaysRetryFilter, + alwaysRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + badTarget, + badTarget, + goodTarget, //Should never be reached as only 2 retries + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 2 returns error when retries left but handler returns false", + retryCount: 3, + retryFilters: []func(c echo.Context, e error) bool{ + alwaysRetryFilter, + alwaysRetryFilter, + neverRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + badTarget, + badTarget, + goodTarget, //Should never be reached as retry handler returns false on 2nd check + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 3 succeeds", + retryCount: 3, + retryFilters: []func(c echo.Context, e error) bool{ + alwaysRetryFilter, + alwaysRetryFilter, + alwaysRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + badTarget, + badTarget, + goodTarget, + }, + expectedResponse: http.StatusOK, + }, + { + name: "40x responses are not retried", + retryCount: 1, + targets: []*ProxyTarget{ + goodTargetWith40X, + goodTarget, + }, + expectedResponse: http.StatusBadRequest, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + retryFilterCall := 0 + retryFilter := func(c echo.Context, e error) bool { + if len(tc.retryFilters) == 0 { + assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall)) + } + + retryFilterCall++ + + nextRetryFilter := tc.retryFilters[0] + tc.retryFilters = tc.retryFilters[1:] + + return nextRetryFilter(c, e) + } + + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Balancer: NewRoundRobinBalancer(tc.targets), + RetryCount: tc.retryCount, + RetryFilter: retryFilter, + }, + )) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedResponse, rec.Code) + if len(tc.retryFilters) > 0 { + assert.FailNow(t, fmt.Sprintf("expected %d more retry handler calls", len(tc.retryFilters))) + } + }) + } +} + +func TestProxyRetryWithBackendTimeout(t *testing.T) { + + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.ResponseHeaderTimeout = time.Millisecond * 500 + + timeoutBackend := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(1 * time.Second) + w.WriteHeader(404) + }), + ) + defer timeoutBackend.Close() + + timeoutTargetURL, _ := url.Parse(timeoutBackend.URL) + goodBackend := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + }), + ) + defer goodBackend.Close() + + goodTargetURL, _ := url.Parse(goodBackend.URL) + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Transport: transport, + Balancer: NewRoundRobinBalancer([]*ProxyTarget{ + { + Name: "Timeout", + URL: timeoutTargetURL, + }, + { + Name: "Good", + URL: goodTargetURL, + }, + }), + RetryCount: 1, + }, + )) + + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, 200, rec.Code) + }() + } + + wg.Wait() + +} + +func TestProxyErrorHandler(t *testing.T) { + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + goodURL, _ := url.Parse(server.URL) + defer server.Close() + goodTarget := &ProxyTarget{ + Name: "Good", + URL: goodURL, + } + + badURL, _ := url.Parse("http://127.0.0.1:27121") + badTarget := &ProxyTarget{ + Name: "Bad", + URL: badURL, + } + + transformedError := errors.New("a new error") + + testCases := []struct { + name string + target *ProxyTarget + errorHandler func(c echo.Context, e error) error + expectFinalError func(t *testing.T, err error) + }{ + { + name: "Error handler not invoked when request success", + target: goodTarget, + errorHandler: func(c echo.Context, e error) error { + assert.FailNow(t, "error handler should not be invoked") + return e + }, + }, + { + name: "Error handler invoked when request fails", + target: badTarget, + errorHandler: func(c echo.Context, e error) error { + httpErr, ok := e.(*echo.HTTPError) + assert.True(t, ok, "expected http error to be passed to handler") + assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler") + return transformedError + }, + expectFinalError: func(t *testing.T, err error) { + assert.Equal(t, transformedError, err, "transformed error not returned from proxy") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Balancer: NewRoundRobinBalancer([]*ProxyTarget{tc.target}), + ErrorHandler: tc.errorHandler, + }, + )) + + errorHandlerCalled := false + e.HTTPErrorHandler = func(err error, c echo.Context) { + errorHandlerCalled = true + tc.expectFinalError(t, err) + e.DefaultHTTPErrorHandler(err, c) + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + if !errorHandlerCalled && tc.expectFinalError != nil { + t.Fatalf("error handler was not called") + } + + }) + } +} + func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { var timeoutStop sync.WaitGroup timeoutStop.Add(1) From 8e425c04311cc1efb896e7d5a7d7cbcafbf03a60 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Fri, 12 May 2023 20:14:59 +0300 Subject: [PATCH 294/446] gofmt fixes to comments --- bind.go | 2 +- binder.go | 16 ++++++++-------- middleware/basic_auth.go | 2 +- middleware/decompress.go | 6 +++--- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/bind.go b/bind.go index c841ca010..374a2aec5 100644 --- a/bind.go +++ b/bind.go @@ -114,7 +114,7 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { // Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body. // For example a request URL `&id=1&lang=en` with body `{"id":100,"lang":"de"}` would lead to precedence issues. // The HTTP method check restores pre-v4.1.11 behavior to avoid these problems (see issue #1670) - method := c.Request().Method + method := c.Request().Method if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead { if err = b.BindQueryParams(c, i); err != nil { return err diff --git a/binder.go b/binder.go index 5a6cf9d9b..29cceca0b 100644 --- a/binder.go +++ b/binder.go @@ -1236,7 +1236,7 @@ func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]tim // Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, false, time.Second) } @@ -1247,7 +1247,7 @@ func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder // Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, true, time.Second) } @@ -1257,7 +1257,7 @@ func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBi // Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, false, time.Millisecond) } @@ -1268,7 +1268,7 @@ func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueB // Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, true, time.Millisecond) } @@ -1280,8 +1280,8 @@ func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *Va // Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal -// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, false, time.Nanosecond) } @@ -1294,8 +1294,8 @@ func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBi // Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00 // // Note: -// * time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal -// * Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { return b.unixTime(sourceParam, dest, true, time.Nanosecond) } diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 52ef1042f..f9e8caafe 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -2,9 +2,9 @@ package middleware import ( "encoding/base64" + "net/http" "strconv" "strings" - "net/http" "github.com/labstack/echo/v4" ) diff --git a/middleware/decompress.go b/middleware/decompress.go index 88ec70982..a73c9738b 100644 --- a/middleware/decompress.go +++ b/middleware/decompress.go @@ -20,7 +20,7 @@ type ( } ) -//GZIPEncoding content-encoding header if set to "gzip", decompress body contents. +// GZIPEncoding content-encoding header if set to "gzip", decompress body contents. const GZIPEncoding string = "gzip" // Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers @@ -44,12 +44,12 @@ func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { return sync.Pool{New: func() interface{} { return new(gzip.Reader) }} } -//Decompress decompresses request body based if content encoding type is set to "gzip" with default config +// Decompress decompresses request body based if content encoding type is set to "gzip" with default config func Decompress() echo.MiddlewareFunc { return DecompressWithConfig(DefaultDecompressConfig) } -//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config +// DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { // Defaults if config.Skipper == nil { From fbfe2167f1d20a2febe59770ca0500652df6c27e Mon Sep 17 00:00:00 2001 From: Martin Desrumaux <9059840+gnuletik@users.noreply.github.com> Date: Mon, 29 May 2023 22:26:53 +0200 Subject: [PATCH 295/446] fix(DefaultHTTPErrorHandler): return error message when message is an error (#2456) * fix(DefaultHTTPErrorHandler): return error message when message is an error --- echo.go | 9 ++- echo_test.go | 192 +++++++++++++++++++++++++++++++++++---------------- 2 files changed, 141 insertions(+), 60 deletions(-) diff --git a/echo.go b/echo.go index 9028b7a71..e21635466 100644 --- a/echo.go +++ b/echo.go @@ -39,6 +39,7 @@ package echo import ( stdContext "context" "crypto/tls" + "encoding/json" "errors" "fmt" "io" @@ -438,12 +439,18 @@ func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { // Issue #1426 code := he.Code message := he.Message - if m, ok := he.Message.(string); ok { + + switch m := he.Message.(type) { + case string: if e.Debug { message = Map{"message": m, "error": err.Error()} } else { message = Map{"message": m} } + case json.Marshaler: + // do nothing - this type knows how to format itself to JSON + case error: + message = Map{"message": m.Error()} } // Send response diff --git a/echo_test.go b/echo_test.go index eab25db33..a352e4026 100644 --- a/echo_test.go +++ b/echo_test.go @@ -1286,67 +1286,141 @@ func TestHTTPError_Unwrap(t *testing.T) { }) } +type customError struct { + s string +} + +func (ce *customError) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`{"x":"%v"}`, ce.s)), nil +} + +func (ce *customError) Error() string { + return ce.s +} + func TestDefaultHTTPErrorHandler(t *testing.T) { - e := New() - e.Debug = true - e.Any("/plain", func(c Context) error { - return errors.New("an error occurred") - }) - e.Any("/badrequest", func(c Context) error { - return NewHTTPError(http.StatusBadRequest, "Invalid request") - }) - e.Any("/servererror", func(c Context) error { - return NewHTTPError(http.StatusInternalServerError, map[string]interface{}{ - "code": 33, - "message": "Something bad happened", - "error": "stackinfo", - }) - }) - e.Any("/early-return", func(c Context) error { - err := c.String(http.StatusOK, "OK") - if err != nil { - assert.Fail(t, err.Error()) - } - return errors.New("ERROR") - }) - e.GET("/internal-error", func(c Context) error { - err := errors.New("internal error message body") - return NewHTTPError(http.StatusBadRequest).SetInternal(err) - }) + var testCases = []struct { + name string + givenDebug bool + whenPath string + expectCode int + expectBody string + }{ + { + name: "with Debug=true plain response contains error message", + givenDebug: true, + whenPath: "/plain", + expectCode: http.StatusInternalServerError, + expectBody: "{\n \"error\": \"an error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", + }, + { + name: "with Debug=true special handling for HTTPError", + givenDebug: true, + whenPath: "/badrequest", + expectCode: http.StatusBadRequest, + expectBody: "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n", + }, + { + name: "with Debug=true complex errors are serialized to pretty JSON", + givenDebug: true, + whenPath: "/servererror", + expectCode: http.StatusInternalServerError, + expectBody: "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", + }, + { + name: "with Debug=true if the body is already set HTTPErrorHandler should not add anything to response body", + givenDebug: true, + whenPath: "/early-return", + expectCode: http.StatusOK, + expectBody: "OK", + }, + { + name: "with Debug=true internal error should be reflected in the message", + givenDebug: true, + whenPath: "/internal-error", + expectCode: http.StatusBadRequest, + expectBody: "{\n \"error\": \"code=400, message=Bad Request, internal=internal error message body\",\n \"message\": \"Bad Request\"\n}\n", + }, + { + name: "with Debug=false the error response is shortened", + whenPath: "/plain", + expectCode: http.StatusInternalServerError, + expectBody: "{\"message\":\"Internal Server Error\"}\n", + }, + { + name: "with Debug=false the error response is shortened", + whenPath: "/badrequest", + expectCode: http.StatusBadRequest, + expectBody: "{\"message\":\"Invalid request\"}\n", + }, + { + name: "with Debug=false No difference for error response with non plain string errors", + whenPath: "/servererror", + expectCode: http.StatusInternalServerError, + expectBody: "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n", + }, + { + name: "with Debug=false when httpError contains an error", + whenPath: "/error-in-httperror", + expectCode: http.StatusBadRequest, + expectBody: "{\"message\":\"error in httperror\"}\n", + }, + { + name: "with Debug=false when httpError contains an error", + whenPath: "/customerror-in-httperror", + expectCode: http.StatusBadRequest, + expectBody: "{\"x\":\"custom error msg\"}\n", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Debug = tc.givenDebug // With Debug=true plain response contains error message - // With Debug=true plain response contains error message - c, b := request(http.MethodGet, "/plain", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\n \"error\": \"an error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", b) - // and special handling for HTTPError - c, b = request(http.MethodGet, "/badrequest", e) - assert.Equal(t, http.StatusBadRequest, c) - assert.Equal(t, "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n", b) - // complex errors are serialized to pretty JSON - c, b = request(http.MethodGet, "/servererror", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", b) - // if the body is already set HTTPErrorHandler should not add anything to response body - c, b = request(http.MethodGet, "/early-return", e) - assert.Equal(t, http.StatusOK, c) - assert.Equal(t, "OK", b) - // internal error should be reflected in the message - c, b = request(http.MethodGet, "/internal-error", e) - assert.Equal(t, http.StatusBadRequest, c) - assert.Equal(t, "{\n \"error\": \"code=400, message=Bad Request, internal=internal error message body\",\n \"message\": \"Bad Request\"\n}\n", b) - - e.Debug = false - // With Debug=false the error response is shortened - c, b = request(http.MethodGet, "/plain", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\"message\":\"Internal Server Error\"}\n", b) - c, b = request(http.MethodGet, "/badrequest", e) - assert.Equal(t, http.StatusBadRequest, c) - assert.Equal(t, "{\"message\":\"Invalid request\"}\n", b) - // No difference for error response with non plain string errors - c, b = request(http.MethodGet, "/servererror", e) - assert.Equal(t, http.StatusInternalServerError, c) - assert.Equal(t, "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n", b) + e.Any("/plain", func(c Context) error { + return errors.New("an error occurred") + }) + + e.Any("/badrequest", func(c Context) error { // and special handling for HTTPError + return NewHTTPError(http.StatusBadRequest, "Invalid request") + }) + + e.Any("/servererror", func(c Context) error { // complex errors are serialized to pretty JSON + return NewHTTPError(http.StatusInternalServerError, map[string]interface{}{ + "code": 33, + "message": "Something bad happened", + "error": "stackinfo", + }) + }) + + // if the body is already set HTTPErrorHandler should not add anything to response body + e.Any("/early-return", func(c Context) error { + err := c.String(http.StatusOK, "OK") + if err != nil { + assert.Fail(t, err.Error()) + } + return errors.New("ERROR") + }) + + // internal error should be reflected in the message + e.GET("/internal-error", func(c Context) error { + err := errors.New("internal error message body") + return NewHTTPError(http.StatusBadRequest).SetInternal(err) + }) + + e.GET("/error-in-httperror", func(c Context) error { + return NewHTTPError(http.StatusBadRequest, errors.New("error in httperror")) + }) + + e.GET("/customerror-in-httperror", func(c Context) error { + return NewHTTPError(http.StatusBadRequest, &customError{s: "custom error msg"}) + }) + + c, b := request(http.MethodGet, tc.whenPath, e) + assert.Equal(t, tc.expectCode, c) + assert.Equal(t, tc.expectBody, b) + }) + } } func TestEchoClose(t *testing.T) { From 42f07ed880400b8bb80906dfec8138c572748ae8 Mon Sep 17 00:00:00 2001 From: Ingo Oppermann Date: Wed, 31 May 2023 07:53:33 +0200 Subject: [PATCH 296/446] gzip response only if it exceeds a minimal length (#2267) * gzip response only if it exceeds a minimal length If the response is too short, e.g. a few bytes, compressing the response makes it even larger. The new parameter MinLength to the GzipConfig struct allows to set a threshold (in bytes) as of which response size the compression should be applied. If the response is shorter, no compression will be applied. --- middleware/compress.go | 91 ++++++++++++++++++++++++++-- middleware/compress_test.go | 117 ++++++++++++++++++++++++++++++++++++ 2 files changed, 202 insertions(+), 6 deletions(-) diff --git a/middleware/compress.go b/middleware/compress.go index 9e5f61069..cbe29fc32 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -2,6 +2,7 @@ package middleware import ( "bufio" + "bytes" "compress/gzip" "io" "net" @@ -21,12 +22,30 @@ type ( // Gzip compression level. // Optional. Default value -1. Level int `yaml:"level"` + + // Length threshold before gzip compression is applied. + // Optional. Default value 0. + // + // Most of the time you will not need to change the default. Compressing + // a short response might increase the transmitted data because of the + // gzip format overhead. Compressing the response will also consume CPU + // and time on the server and the client (for decompressing). Depending on + // your use case such a threshold might be useful. + // + // See also: + // https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits + MinLength int } gzipResponseWriter struct { io.Writer http.ResponseWriter - wroteBody bool + wroteHeader bool + wroteBody bool + minLength int + minLengthExceeded bool + buffer *bytes.Buffer + code int } ) @@ -37,8 +56,9 @@ const ( var ( // DefaultGzipConfig is the default Gzip middleware config. DefaultGzipConfig = GzipConfig{ - Skipper: DefaultSkipper, - Level: -1, + Skipper: DefaultSkipper, + Level: -1, + MinLength: 0, } ) @@ -58,8 +78,12 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { if config.Level == 0 { config.Level = DefaultGzipConfig.Level } + if config.MinLength < 0 { + config.MinLength = DefaultGzipConfig.MinLength + } pool := gzipCompressPool(config) + bpool := bufferPool() return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -70,7 +94,6 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { res := c.Response() res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) { - res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 i := pool.Get() w, ok := i.(*gzip.Writer) if !ok { @@ -78,7 +101,11 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { } rw := res.Writer w.Reset(rw) - grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw} + + buf := bpool.Get().(*bytes.Buffer) + buf.Reset() + + grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf} defer func() { if !grw.wroteBody { if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { @@ -89,8 +116,17 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { // See issue #424, #407. res.Writer = rw w.Reset(io.Discard) + } else if !grw.minLengthExceeded { + // Write uncompressed response + res.Writer = rw + if grw.wroteHeader { + grw.ResponseWriter.WriteHeader(grw.code) + } + grw.buffer.WriteTo(rw) + w.Reset(io.Discard) } w.Close() + bpool.Put(buf) pool.Put(w) }() res.Writer = grw @@ -102,7 +138,11 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { func (w *gzipResponseWriter) WriteHeader(code int) { w.Header().Del(echo.HeaderContentLength) // Issue #444 - w.ResponseWriter.WriteHeader(code) + + w.wroteHeader = true + + // Delay writing of the header until we know if we'll actually compress the response + w.code = code } func (w *gzipResponseWriter) Write(b []byte) (int, error) { @@ -110,10 +150,40 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) { w.Header().Set(echo.HeaderContentType, http.DetectContentType(b)) } w.wroteBody = true + + if !w.minLengthExceeded { + n, err := w.buffer.Write(b) + + if w.buffer.Len() >= w.minLength { + w.minLengthExceeded = true + + // The minimum length is exceeded, add Content-Encoding header and write the header + w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 + if w.wroteHeader { + w.ResponseWriter.WriteHeader(w.code) + } + + return w.Writer.Write(w.buffer.Bytes()) + } + + return n, err + } + return w.Writer.Write(b) } func (w *gzipResponseWriter) Flush() { + if !w.minLengthExceeded { + // Enforce compression because we will not know how much more data will come + w.minLengthExceeded = true + w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 + if w.wroteHeader { + w.ResponseWriter.WriteHeader(w.code) + } + + w.Writer.Write(w.buffer.Bytes()) + } + w.Writer.(*gzip.Writer).Flush() if flusher, ok := w.ResponseWriter.(http.Flusher); ok { flusher.Flush() @@ -142,3 +212,12 @@ func gzipCompressPool(config GzipConfig) sync.Pool { }, } } + +func bufferPool() sync.Pool { + return sync.Pool{ + New: func() interface{} { + b := &bytes.Buffer{} + return b + }, + } +} diff --git a/middleware/compress_test.go b/middleware/compress_test.go index 714548e8b..e43e2d633 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -88,6 +88,123 @@ func TestGzip(t *testing.T) { assert.Equal(t, "test", buf.String()) } +func TestGzipWithMinLength(t *testing.T) { + assert := assert.New(t) + + e := echo.New() + // Minimal response length + e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) + e.GET("/", func(c echo.Context) error { + c.Response().Write([]byte("foobarfoobar")) + return nil + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + r, err := gzip.NewReader(rec.Body) + if assert.NoError(err) { + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal("foobarfoobar", buf.String()) + } +} + +func TestGzipWithMinLengthTooShort(t *testing.T) { + assert := assert.New(t) + + e := echo.New() + // Minimal response length + e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) + e.GET("/", func(c echo.Context) error { + c.Response().Write([]byte("test")) + return nil + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal("", rec.Header().Get(echo.HeaderContentEncoding)) + assert.Contains(rec.Body.String(), "test") +} + +func TestGzipWithMinLengthChunked(t *testing.T) { + assert := assert.New(t) + + e := echo.New() + + // Gzip chunked + chunkBuf := make([]byte, 5) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + + var r *gzip.Reader = nil + + c := e.NewContext(req, rec) + GzipWithConfig(GzipConfig{MinLength: 10})(func(c echo.Context) error { + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Transfer-Encoding", "chunked") + + // Write and flush the first part of the data + c.Response().Write([]byte("test\n")) + c.Response().Flush() + + // Read the first part of the data + assert.True(rec.Flushed) + assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + + var err error + r, err = gzip.NewReader(rec.Body) + assert.NoError(err) + + _, err = io.ReadFull(r, chunkBuf) + assert.NoError(err) + assert.Equal("test\n", string(chunkBuf)) + + // Write and flush the second part of the data + c.Response().Write([]byte("test\n")) + c.Response().Flush() + + _, err = io.ReadFull(r, chunkBuf) + assert.NoError(err) + assert.Equal("test\n", string(chunkBuf)) + + // Write the final part of the data and return + c.Response().Write([]byte("test")) + return nil + })(c) + + assert.NotNil(r) + + buf := new(bytes.Buffer) + + buf.ReadFrom(r) + assert.Equal("test", buf.String()) + + r.Close() +} + +func TestGzipWithMinLengthNoContent(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c echo.Context) error { + return c.NoContent(http.StatusNoContent) + }) + if assert.NoError(t, h(c)) { + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) + assert.Equal(t, 0, len(rec.Body.Bytes())) + } +} + func TestGzipNoContent(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) From 44ead54c8c99850dbdeaac16842ff0fe7e5dbeb5 Mon Sep 17 00:00:00 2001 From: bahdanmelchankatote <124774625+bahdanmelchankatote@users.noreply.github.com> Date: Mon, 10 Jul 2023 12:24:39 +0300 Subject: [PATCH 297/446] Upgrade packages (#2475) --- go.mod | 10 +++++----- go.sum | 25 +++++++++++++++++++------ 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/go.mod b/go.mod index 265b0aafc..fe2fd4e54 100644 --- a/go.mod +++ b/go.mod @@ -7,18 +7,18 @@ require ( github.com/labstack/gommon v0.4.0 github.com/stretchr/testify v1.8.1 github.com/valyala/fasttemplate v1.2.2 - golang.org/x/crypto v0.6.0 - golang.org/x/net v0.7.0 + golang.org/x/crypto v0.11.0 + golang.org/x/net v0.12.0 golang.org/x/time v0.3.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.17 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.5.0 // indirect - golang.org/x/text v0.7.0 // indirect + golang.org/x/sys v0.10.0 // indirect + golang.org/x/text v0.11.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 79ff318c5..41490b181 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27k github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -30,17 +32,20 @@ github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= -golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= +golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= +golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -50,21 +55,29 @@ golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= +golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From 1ee8e22faa4ee7ada2dd6927665113ac8a35e62f Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 11 Jul 2023 23:36:05 +0300 Subject: [PATCH 298/446] do not use global timeNow variables (#2477) --- middleware/rate_limiter.go | 21 ++++++++++----------- middleware/rate_limiter_test.go | 13 +++++-------- middleware/request_logger.go | 2 +- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index f7fae83c6..1d24df52a 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -160,6 +160,8 @@ type ( burst int expiresIn time.Duration lastCleanup time.Time + + timeNow func() time.Time } // Visitor signifies a unique user's limiter details Visitor struct { @@ -219,7 +221,8 @@ func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (s store.burst = int(config.Rate) } store.visitors = make(map[string]*Visitor) - store.lastCleanup = now() + store.timeNow = time.Now + store.lastCleanup = store.timeNow() return } @@ -244,12 +247,13 @@ func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { limiter.Limiter = rate.NewLimiter(store.rate, store.burst) store.visitors[identifier] = limiter } - limiter.lastSeen = now() - if now().Sub(store.lastCleanup) > store.expiresIn { + now := store.timeNow() + limiter.lastSeen = now + if now.Sub(store.lastCleanup) > store.expiresIn { store.cleanupStaleVisitors() } store.mutex.Unlock() - return limiter.AllowN(now(), 1), nil + return limiter.AllowN(store.timeNow(), 1), nil } /* @@ -258,14 +262,9 @@ of users who haven't visited again after the configured expiry time has elapsed */ func (store *RateLimiterMemoryStore) cleanupStaleVisitors() { for id, visitor := range store.visitors { - if now().Sub(visitor.lastSeen) > store.expiresIn { + if store.timeNow().Sub(visitor.lastSeen) > store.expiresIn { delete(store.visitors, id) } } - store.lastCleanup = now() + store.lastCleanup = store.timeNow() } - -/* -actual time method which is mocked in test file -*/ -var now = time.Now diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go index 89d9a6edc..0f7c9141d 100644 --- a/middleware/rate_limiter_test.go +++ b/middleware/rate_limiter_test.go @@ -2,7 +2,6 @@ package middleware import ( "errors" - "fmt" "math/rand" "net/http" "net/http/httptest" @@ -340,7 +339,7 @@ func TestRateLimiterMemoryStore_Allow(t *testing.T) { for i, tc := range testCases { t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond) - now = func() time.Time { + inMemoryStore.timeNow = func() time.Time { return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond) } allowed, _ := inMemoryStore.Allow(tc.id) @@ -350,24 +349,22 @@ func TestRateLimiterMemoryStore_Allow(t *testing.T) { func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) - now = time.Now - fmt.Println(now()) inMemoryStore.visitors = map[string]*Visitor{ "A": { Limiter: rate.NewLimiter(1, 3), - lastSeen: now(), + lastSeen: time.Now(), }, "B": { Limiter: rate.NewLimiter(1, 3), - lastSeen: now().Add(-1 * time.Minute), + lastSeen: time.Now().Add(-1 * time.Minute), }, "C": { Limiter: rate.NewLimiter(1, 3), - lastSeen: now().Add(-5 * time.Minute), + lastSeen: time.Now().Add(-5 * time.Minute), }, "D": { Limiter: rate.NewLimiter(1, 3), - lastSeen: now().Add(-10 * time.Minute), + lastSeen: time.Now().Add(-10 * time.Minute), }, } diff --git a/middleware/request_logger.go b/middleware/request_logger.go index 8e312e8d8..ce76230c7 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -225,7 +225,7 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultSkipper } - now = time.Now + now := time.Now if config.timeNow != nil { now = config.timeNow } From ac7a9621a17a875d95eee089b555518772854989 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Fri, 14 Jul 2023 08:58:51 +0300 Subject: [PATCH 299/446] bump version to 4.10.0 --- echo.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/echo.go b/echo.go index e21635466..e91226ed6 100644 --- a/echo.go +++ b/echo.go @@ -259,7 +259,7 @@ const ( const ( // Version of Echo - Version = "4.10.2" + Version = "4.11.0" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 60af056959d5ddfb0e8db8dec2f597d72e27a58b Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Fri, 14 Jul 2023 08:59:02 +0300 Subject: [PATCH 300/446] Changelog for v4.11.0 --- CHANGELOG.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 831842497..8c405e205 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,31 @@ # Changelog +## v4.11.0 - 2023-07-14 + + +**Fixes** + +* Fixes the proxy middleware concurrency issue of calling the Next() proxy target on Round Robin Balancer [#2409](https://github.com/labstack/echo/pull/2409) +* Fix `group.RouteNotFound` not working when group has attached middlewares [#2411](https://github.com/labstack/echo/pull/2411) +* Fix global error handler return error message when message is an error [#2456](https://github.com/labstack/echo/pull/2456) +* Do not use global timeNow variables [#2477](https://github.com/labstack/echo/pull/2477) + + +**Enhancements** + +* Added a optional config variable to disable centralized error handler in recovery middleware [#2410](https://github.com/labstack/echo/pull/2410) +* refactor: use `strings.ReplaceAll` directly [#2424](https://github.com/labstack/echo/pull/2424) +* Add support for Go1.20 `http.rwUnwrapper` to Response struct [#2425](https://github.com/labstack/echo/pull/2425) +* Check whether is nil before invoking centralized error handling [#2429](https://github.com/labstack/echo/pull/2429) +* Proper colon support in `echo.Reverse` method [#2416](https://github.com/labstack/echo/pull/2416) +* Fix misuses of a vs an in documentation comments [#2436](https://github.com/labstack/echo/pull/2436) +* Add link to slog.Handler library for Echo logging into README.md [#2444](https://github.com/labstack/echo/pull/2444) +* In proxy middleware Support retries of failed proxy requests [#2414](https://github.com/labstack/echo/pull/2414) +* gofmt fixes to comments [#2452](https://github.com/labstack/echo/pull/2452) +* gzip response only if it exceeds a minimal length [#2267](https://github.com/labstack/echo/pull/2267) +* Upgrade packages [#2475](https://github.com/labstack/echo/pull/2475) + + ## v4.10.2 - 2023-02-22 **Security** From 130be0742560d4e7502537077a2667b4fe5adce8 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 16 Jul 2023 20:15:24 +0300 Subject: [PATCH 301/446] fix gzip not sending response code for no content responses (404, 301/302 redirects etc) --- middleware/compress.go | 6 +++++ middleware/compress_test.go | 52 +++++++++++++++++++++++-------------- 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/middleware/compress.go b/middleware/compress.go index cbe29fc32..3e9bd3201 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -107,10 +107,16 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf} defer func() { + // There are different reasons for cases when we have not yet written response to the client and now need to do so. + // a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now. + // b) body is shorter than our minimum length threshold and being buffered currently and needs to be written if !grw.wroteBody { if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { res.Header().Del(echo.HeaderContentEncoding) } + if grw.wroteHeader { + rw.WriteHeader(grw.code) + } // We have to reset response to it's pristine state when // nothing is written to body or error is returned. // See issue #424, #407. diff --git a/middleware/compress_test.go b/middleware/compress_test.go index e43e2d633..0ed16c813 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -89,8 +89,6 @@ func TestGzip(t *testing.T) { } func TestGzipWithMinLength(t *testing.T) { - assert := assert.New(t) - e := echo.New() // Minimal response length e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) @@ -103,19 +101,17 @@ func TestGzipWithMinLength(t *testing.T) { req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) r, err := gzip.NewReader(rec.Body) - if assert.NoError(err) { + if assert.NoError(t, err) { buf := new(bytes.Buffer) defer r.Close() buf.ReadFrom(r) - assert.Equal("foobarfoobar", buf.String()) + assert.Equal(t, "foobarfoobar", buf.String()) } } func TestGzipWithMinLengthTooShort(t *testing.T) { - assert := assert.New(t) - e := echo.New() // Minimal response length e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) @@ -127,13 +123,29 @@ func TestGzipWithMinLengthTooShort(t *testing.T) { req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal("", rec.Header().Get(echo.HeaderContentEncoding)) - assert.Contains(rec.Body.String(), "test") + assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding)) + assert.Contains(t, rec.Body.String(), "test") } -func TestGzipWithMinLengthChunked(t *testing.T) { - assert := assert.New(t) +func TestGzipWithResponseWithoutBody(t *testing.T) { + e := echo.New() + + e.Use(Gzip()) + e.GET("/", func(c echo.Context) error { + return c.Redirect(http.StatusMovedPermanently, "http://localhost") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusMovedPermanently, rec.Code) + assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding)) +} + +func TestGzipWithMinLengthChunked(t *testing.T) { e := echo.New() // Gzip chunked @@ -155,36 +167,36 @@ func TestGzipWithMinLengthChunked(t *testing.T) { c.Response().Flush() // Read the first part of the data - assert.True(rec.Flushed) - assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.True(t, rec.Flushed) + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) var err error r, err = gzip.NewReader(rec.Body) - assert.NoError(err) + assert.NoError(t, err) _, err = io.ReadFull(r, chunkBuf) - assert.NoError(err) - assert.Equal("test\n", string(chunkBuf)) + assert.NoError(t, err) + assert.Equal(t, "test\n", string(chunkBuf)) // Write and flush the second part of the data c.Response().Write([]byte("test\n")) c.Response().Flush() _, err = io.ReadFull(r, chunkBuf) - assert.NoError(err) - assert.Equal("test\n", string(chunkBuf)) + assert.NoError(t, err) + assert.Equal(t, "test\n", string(chunkBuf)) // Write the final part of the data and return c.Response().Write([]byte("test")) return nil })(c) - assert.NotNil(r) + assert.NotNil(t, r) buf := new(bytes.Buffer) buf.ReadFrom(r) - assert.Equal("test", buf.String()) + assert.Equal(t, "test", buf.String()) r.Close() } From a2e7085094bda23a674c887f0e93f4a15245c439 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sun, 16 Jul 2023 20:36:11 +0300 Subject: [PATCH 302/446] Changelog for v4.11.1 --- CHANGELOG.md | 7 +++++++ echo.go | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c405e205..fef7bb987 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## v4.11.1 - 2023-07-16 + +**Fixes** + +* Fix `Gzip` middleware not sending response code for no content responses (404, 301/302 redirects etc) [#2481](https://github.com/labstack/echo/pull/2481) + + ## v4.11.0 - 2023-07-14 diff --git a/echo.go b/echo.go index e91226ed6..22a5b7af9 100644 --- a/echo.go +++ b/echo.go @@ -259,7 +259,7 @@ const ( const ( // Version of Echo - Version = "4.11.0" + Version = "4.11.1" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 18d32589cdf962ac188a8c6a655ae973d17508c4 Mon Sep 17 00:00:00 2001 From: Vishal Rana <314036+vishr@users.noreply.github.com> Date: Tue, 18 Jul 2023 08:51:02 -0700 Subject: [PATCH 303/446] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ea8f30f64..c24a40c87 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge) [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4) [![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo) -[![Build Status](http://img.shields.io/travis/labstack/echo.svg?style=flat-square)](https://travis-ci.org/labstack/echo) +![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/labstack/echo/echo.yml?style=flat-square) [![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo) [![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://github.com/labstack/echo/discussions) [![Twitter](https://img.shields.io/badge/twitter-@labstack-55acee.svg?style=flat-square)](https://twitter.com/labstack) From 4598a4a7458f69f2532ec825411967f6e82adfbd Mon Sep 17 00:00:00 2001 From: Vishal Rana <314036+vishr@users.noreply.github.com> Date: Tue, 18 Jul 2023 09:20:05 -0700 Subject: [PATCH 304/446] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c24a40c87..18accea75 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge) [![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4) [![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo) -![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/labstack/echo/echo.yml?style=flat-square) +[![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/labstack/echo/echo.yml?style=flat-square)](https://github.com/labstack/echo/actions) [![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo) [![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://github.com/labstack/echo/discussions) [![Twitter](https://img.shields.io/badge/twitter-@labstack-55acee.svg?style=flat-square)](https://twitter.com/labstack) From 3f8ae15b57624dcd04bac482e454c9b665476d9f Mon Sep 17 00:00:00 2001 From: Mobina Noori <91049843+mobinanoorii@users.noreply.github.com> Date: Fri, 21 Jul 2023 11:37:25 +0330 Subject: [PATCH 305/446] delete unused context in body_limit.go (#2483) * delete unused context in body_limit.go --------- Co-authored-by: mobinanoori018 --- middleware/body_limit.go | 10 ++++------ middleware/body_limit_test.go | 6 +----- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/middleware/body_limit.go b/middleware/body_limit.go index b436bd595..99e3ac547 100644 --- a/middleware/body_limit.go +++ b/middleware/body_limit.go @@ -23,9 +23,8 @@ type ( limitedReader struct { BodyLimitConfig - reader io.ReadCloser - read int64 - context echo.Context + reader io.ReadCloser + read int64 } ) @@ -80,7 +79,7 @@ func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { // Based on content read r := pool.Get().(*limitedReader) - r.Reset(req.Body, c) + r.Reset(req.Body) defer pool.Put(r) req.Body = r @@ -102,9 +101,8 @@ func (r *limitedReader) Close() error { return r.reader.Close() } -func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) { +func (r *limitedReader) Reset(reader io.ReadCloser) { r.reader = reader - r.context = context r.read = 0 } diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index 2bfce372a..0fd66ee0f 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -56,9 +56,6 @@ func TestBodyLimit(t *testing.T) { func TestBodyLimitReader(t *testing.T) { hw := []byte("Hello, World!") - e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) - rec := httptest.NewRecorder() config := BodyLimitConfig{ Skipper: DefaultSkipper, @@ -68,7 +65,6 @@ func TestBodyLimitReader(t *testing.T) { reader := &limitedReader{ BodyLimitConfig: config, reader: io.NopCloser(bytes.NewReader(hw)), - context: e.NewContext(req, rec), } // read all should return ErrStatusRequestEntityTooLarge @@ -78,7 +74,7 @@ func TestBodyLimitReader(t *testing.T) { // reset reader and read two bytes must succeed bt := make([]byte, 2) - reader.Reset(io.NopCloser(bytes.NewReader(hw)), e.NewContext(req, rec)) + reader.Reset(io.NopCloser(bytes.NewReader(hw))) n, err := reader.Read(bt) assert.Equal(t, 2, n) assert.Equal(t, nil, err) From 626f13e33830665e08d9d40e333dd13d9de8e672 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Fri, 21 Jul 2023 09:49:27 +0300 Subject: [PATCH 306/446] CSRF/RequestID mw: switch math/random usage to crypto/random --- middleware/csrf.go | 4 ++-- middleware/csrf_test.go | 3 +-- middleware/rate_limiter_test.go | 3 +-- middleware/request_id.go | 5 ++--- middleware/util.go | 17 +++++++++++++++++ middleware/util_test.go | 24 ++++++++++++++++++++++++ 6 files changed, 47 insertions(+), 9 deletions(-) diff --git a/middleware/csrf.go b/middleware/csrf.go index 6899700c7..adf12210b 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -6,7 +6,6 @@ import ( "time" "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" ) type ( @@ -103,6 +102,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.TokenLength == 0 { config.TokenLength = DefaultCSRFConfig.TokenLength } + if config.TokenLookup == "" { config.TokenLookup = DefaultCSRFConfig.TokenLookup } @@ -132,7 +132,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { token := "" if k, err := c.Cookie(config.CookieName); err != nil { - token = random.String(config.TokenLength) // Generate token + token = randomString(config.TokenLength) } else { token = k.Value // Reuse token } diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 6bccdbe4d..6b20297ee 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -8,7 +8,6 @@ import ( "testing" "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" "github.com/stretchr/testify/assert" ) @@ -233,7 +232,7 @@ func TestCSRF(t *testing.T) { assert.Error(t, h(c)) // Valid CSRF token - token := random.String(32) + token := randomString(32) req.Header.Set(echo.HeaderCookie, "_csrf="+token) req.Header.Set(echo.HeaderXCSRFToken, token) if assert.NoError(t, h(c)) { diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go index 0f7c9141d..f66961fe2 100644 --- a/middleware/rate_limiter_test.go +++ b/middleware/rate_limiter_test.go @@ -10,7 +10,6 @@ import ( "time" "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" "github.com/stretchr/testify/assert" "golang.org/x/time/rate" ) @@ -410,7 +409,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) { func generateAddressList(count int) []string { addrs := make([]string, count) for i := 0; i < count; i++ { - addrs[i] = random.String(15) + addrs[i] = randomString(15) } return addrs } diff --git a/middleware/request_id.go b/middleware/request_id.go index 8c5ff6605..e29c8f50d 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -2,7 +2,6 @@ package middleware import ( "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" ) type ( @@ -12,7 +11,7 @@ type ( Skipper Skipper // Generator defines a function to generate an ID. - // Optional. Default value random.String(32). + // Optional. Defaults to generator for random string of length 32. Generator func() string // RequestIDHandler defines a function which is executed for a request id. @@ -73,5 +72,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { } func generator() string { - return random.String(32) + return randomString(32) } diff --git a/middleware/util.go b/middleware/util.go index ab951a0e9..aa34d78f3 100644 --- a/middleware/util.go +++ b/middleware/util.go @@ -1,6 +1,8 @@ package middleware import ( + "crypto/rand" + "fmt" "strings" ) @@ -52,3 +54,18 @@ func matchSubdomain(domain, pattern string) bool { } return false } + +func randomString(length uint8) string { + charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + + bytes := make([]byte, length) + _, err := rand.Read(bytes) + if err != nil { + // we are out of random. let the request fail + panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err)) + } + for i, b := range bytes { + bytes[i] = charset[b%byte(len(charset))] + } + return string(bytes) +} diff --git a/middleware/util_test.go b/middleware/util_test.go index df1d26295..7562d4a5f 100644 --- a/middleware/util_test.go +++ b/middleware/util_test.go @@ -93,3 +93,27 @@ func Test_matchSubdomain(t *testing.T) { assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern)) } } + +func TestRandomString(t *testing.T) { + var testCases = []struct { + name string + whenLength uint8 + expect string + }{ + { + name: "ok, 16", + whenLength: 16, + }, + { + name: "ok, 32", + whenLength: 32, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + uid := randomString(tc.whenLength) + assert.Len(t, uid, int(tc.whenLength)) + }) + } +} From b3ec8e0fdd9d904aa5b1b95479da20c4961a59eb Mon Sep 17 00:00:00 2001 From: Trim21 Date: Sat, 22 Jul 2023 12:08:34 +0800 Subject: [PATCH 307/446] fix(sec): `randomString` bias (#2492) * fix(sec): `randomString` bias when using bytes vs int64 * use pooled buffed random reader --- middleware/util.go | 45 +++++++++++++++++++++++++++++++---------- middleware/util_test.go | 29 ++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 11 deletions(-) diff --git a/middleware/util.go b/middleware/util.go index aa34d78f3..0aa0420fc 100644 --- a/middleware/util.go +++ b/middleware/util.go @@ -1,9 +1,11 @@ package middleware import ( + "bufio" "crypto/rand" - "fmt" + "io" "strings" + "sync" ) func matchScheme(domain, pattern string) bool { @@ -55,17 +57,38 @@ func matchSubdomain(domain, pattern string) bool { return false } +// https://tip.golang.org/doc/go1.19#:~:text=Read%20no%20longer%20buffers%20random%20data%20obtained%20from%20the%20operating%20system%20between%20calls +var randomReaderPool = sync.Pool{New: func() interface{} { + return bufio.NewReader(rand.Reader) +}} + +const randomStringCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +const randomStringCharsetLen = 52 // len(randomStringCharset) +const randomStringMaxByte = 255 - (256 % randomStringCharsetLen) + func randomString(length uint8) string { - charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + reader := randomReaderPool.Get().(*bufio.Reader) + defer randomReaderPool.Put(reader) - bytes := make([]byte, length) - _, err := rand.Read(bytes) - if err != nil { - // we are out of random. let the request fail - panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err)) - } - for i, b := range bytes { - bytes[i] = charset[b%byte(len(charset))] + b := make([]byte, length) + r := make([]byte, length+(length/4)) // perf: avoid read from rand.Reader many times + var i uint8 = 0 + + for { + _, err := io.ReadFull(reader, r) + if err != nil { + panic("unexpected error happened when reading from bufio.NewReader(crypto/rand.Reader)") + } + for _, rb := range r { + if rb > randomStringMaxByte { + // Skip this number to avoid bias. + continue + } + b[i] = randomStringCharset[rb%randomStringCharsetLen] + i++ + if i == length { + return string(b) + } + } } - return string(bytes) } diff --git a/middleware/util_test.go b/middleware/util_test.go index 7562d4a5f..d0f20bba6 100644 --- a/middleware/util_test.go +++ b/middleware/util_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_matchScheme(t *testing.T) { @@ -117,3 +118,31 @@ func TestRandomString(t *testing.T) { }) } } + +func TestRandomStringBias(t *testing.T) { + t.Parallel() + const slen = 33 + const loop = 100000 + + counts := make(map[rune]int) + var count int64 + + for i := 0; i < loop; i++ { + s := randomString(slen) + require.Equal(t, slen, len(s)) + for _, b := range s { + counts[b]++ + count++ + } + } + + require.Equal(t, randomStringCharsetLen, len(counts)) + + avg := float64(count) / float64(len(counts)) + for k, n := range counts { + diff := float64(n) / avg + if diff < 0.95 || diff > 1.05 { + t.Errorf("Bias on '%c': expected average %f, got %d", k, avg, n) + } + } +} From e6b96f8873fed46e71e0d34cddb81c533167f954 Mon Sep 17 00:00:00 2001 From: Trim21 Date: Sun, 23 Jul 2023 04:47:35 +0800 Subject: [PATCH 308/446] docs: add comments to util.go `randomString` (#2494) * Update util.go --- middleware/util.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/middleware/util.go b/middleware/util.go index 0aa0420fc..4d2d172fc 100644 --- a/middleware/util.go +++ b/middleware/util.go @@ -74,6 +74,12 @@ func randomString(length uint8) string { r := make([]byte, length+(length/4)) // perf: avoid read from rand.Reader many times var i uint8 = 0 + // security note: + // we can't just simply do b[i]=randomStringCharset[rb%len(randomStringCharset)], + // len(len(randomStringCharset)) is 52, and rb is [0, 255], 256 = 52 * 4 + 48. + // make the first 48 characters more possibly to be generated then others. + // So we have to skip bytes when rb > randomStringMaxByte + for { _, err := io.ReadFull(reader, r) if err != nil { From 77d5ae6a9173d89c49e008607d08df7ba41336f0 Mon Sep 17 00:00:00 2001 From: Martti T Date: Sat, 12 Aug 2023 09:01:30 +0300 Subject: [PATCH 309/446] Use Go 1.21 in CI (#2505) --- .github/workflows/checks.yml | 4 ++-- .github/workflows/echo.yml | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index d2d3386c4..440f0ec52 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -14,7 +14,7 @@ permissions: env: # run static analysis only with the latest Go version - LATEST_GO_VERSION: "1.20" + LATEST_GO_VERSION: "1.21" jobs: check: @@ -24,7 +24,7 @@ jobs: uses: actions/checkout@v3 - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: go-version: ${{ env.LATEST_GO_VERSION }} check-latest: true diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index e06183d5e..c240dd0c5 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -14,7 +14,7 @@ permissions: env: # run coverage and benchmarks only with the latest Go version - LATEST_GO_VERSION: "1.20" + LATEST_GO_VERSION: "1.21" jobs: test: @@ -25,7 +25,7 @@ jobs: # Echo tests with last four major releases (unless there are pressing vulnerabilities) # As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when # we derive from last four major releases promise. - go: ["1.18", "1.19", "1.20"] + go: ["1.18", "1.19", "1.20", "1.21"] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: @@ -33,7 +33,7 @@ jobs: uses: actions/checkout@v3 - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} @@ -64,7 +64,7 @@ jobs: path: new - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: go-version: ${{ env.LATEST_GO_VERSION }} From 3950c444b726c1de9131d4dee4c9ae708768f26c Mon Sep 17 00:00:00 2001 From: eiei114 <60887155+eiei114@users.noreply.github.com> Date: Thu, 14 Sep 2023 04:41:58 +0900 Subject: [PATCH 310/446] fix some typos (#2511) --- middleware/context_timeout.go | 2 +- middleware/proxy_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/middleware/context_timeout.go b/middleware/context_timeout.go index be260e188..1937693f1 100644 --- a/middleware/context_timeout.go +++ b/middleware/context_timeout.go @@ -13,7 +13,7 @@ type ContextTimeoutConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper - // ErrorHandler is a function when error aries in middeware execution. + // ErrorHandler is a function when error aries in middleware execution. ErrorHandler func(err error, c echo.Context) error // Timeout configures a timeout for the middleware, defaults to 0 for no timeout diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 1b5ba6cbe..415d68e77 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -188,7 +188,7 @@ func TestProxyRealIPHeader(t *testing.T) { tests := []*struct { hasRealIPheader bool hasIPExtractor bool - extectedXRealIP string + expectedXRealIP string }{ {false, false, remoteAddrIP}, {false, true, extractedRealIP}, @@ -210,7 +210,7 @@ func TestProxyRealIPHeader(t *testing.T) { e.IPExtractor = nil } e.ServeHTTP(rec, req) - assert.Equal(t, tt.extectedXRealIP, req.Header.Get(echo.HeaderXRealIP), "hasRealIPheader: %t / hasIPExtractor: %t", tt.hasRealIPheader, tt.hasIPExtractor) + assert.Equal(t, tt.expectedXRealIP, req.Header.Get(echo.HeaderXRealIP), "hasRealIPheader: %t / hasIPExtractor: %t", tt.hasRealIPheader, tt.hasIPExtractor) } } From 4bc3e475e3137b6402933eec5e6fde641e0d2320 Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 19 Sep 2023 08:24:47 +0300 Subject: [PATCH 311/446] cors middleware: allow sending `Access-Control-Max-Age: 0` value with config.MaxAge being negative number. (#2518) --- middleware/cors.go | 11 ++++++--- middleware/cors_test.go | 53 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/middleware/cors.go b/middleware/cors.go index 6ddb540af..10504359f 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -99,8 +99,9 @@ type ( // MaxAge determines the value of the Access-Control-Max-Age response header. // This header indicates how long (in seconds) the results of a preflight // request can be cached. + // The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response. // - // Optional. Default value 0. The header is set only if MaxAge > 0. + // Optional. Default value 0 - meaning header is not sent. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age MaxAge int `yaml:"max_age"` @@ -159,7 +160,11 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { allowMethods := strings.Join(config.AllowMethods, ",") allowHeaders := strings.Join(config.AllowHeaders, ",") exposeHeaders := strings.Join(config.ExposeHeaders, ",") - maxAge := strconv.Itoa(config.MaxAge) + + maxAge := "0" + if config.MaxAge > 0 { + maxAge = strconv.Itoa(config.MaxAge) + } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -282,7 +287,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { res.Header().Set(echo.HeaderAccessControlAllowHeaders, h) } } - if config.MaxAge > 0 { + if config.MaxAge != 0 { res.Header().Set(echo.HeaderAccessControlMaxAge, maxAge) } return c.NoContent(http.StatusNoContent) diff --git a/middleware/cors_test.go b/middleware/cors_test.go index c1bb91eb3..797600c5c 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -60,6 +60,59 @@ func TestCORS(t *testing.T) { echo.HeaderAccessControlMaxAge: "3600", }, }, + { + name: "ok, preflight request when `Access-Control-Max-Age` is set", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, + AllowCredentials: true, + MaxAge: 1, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlMaxAge: "1", + }, + }, + { + name: "ok, preflight request when `Access-Control-Max-Age` is set to 0 - not to cache response", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, + AllowCredentials: true, + MaxAge: -1, // forces `Access-Control-Max-Age: 0` + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlMaxAge: "0", + }, + }, + { + name: "ok, CORS check are skipped", + givenMW: CORSWithConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, + AllowCredentials: true, + Skipper: func(c echo.Context) bool { + return true + }, + }), + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + notExpectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "localhost", + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + echo.HeaderAccessControlAllowCredentials: "true", + echo.HeaderAccessControlMaxAge: "3600", + }, + }, { name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true", givenMW: CORSWithConfig(CORSConfig{ From 5780908c7cb110a8c4d56a62e32dc5cbc030a5ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0tefan=20Baebler?= Date: Wed, 11 Oct 2023 06:14:52 +0200 Subject: [PATCH 312/446] Fix CVE-2023-39325 / CVE-2023-44487 (#2527) Bump golang.org/x/net from v0.12.0 to v0.17.0 Related: * https://github.com/golang/go/issues/63417 * https://www.cve.org/CVERecord?id=CVE-2023-44487 --- go.mod | 8 ++++---- go.sum | 20 +++++++++----------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index fe2fd4e54..960b1ab7f 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,8 @@ require ( github.com/labstack/gommon v0.4.0 github.com/stretchr/testify v1.8.1 github.com/valyala/fasttemplate v1.2.2 - golang.org/x/crypto v0.11.0 - golang.org/x/net v0.12.0 + golang.org/x/crypto v0.14.0 + golang.org/x/net v0.17.0 golang.org/x/time v0.3.0 ) @@ -18,7 +18,7 @@ require ( github.com/mattn/go-isatty v0.0.19 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.10.0 // indirect - golang.org/x/text v0.11.0 // indirect + golang.org/x/sys v0.13.0 // indirect + golang.org/x/text v0.13.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 41490b181..b40dfd062 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,6 @@ github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxec github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= -github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -32,8 +30,8 @@ github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= -golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -41,8 +39,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= -golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -58,20 +56,20 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= -golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= +golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= -golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= From 89ae0e5f2ca6d01665255fd2e479ba98ab5ff4c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0tefan=20Baebler?= Date: Wed, 11 Oct 2023 06:47:09 +0200 Subject: [PATCH 313/446] Bump dependancies (#2522) Bump: * golang.org/x/net v0.12.0 -> v0.15.0 * golang.org/x/crypto v0.11.0 -> v0.13.0 * github.com/stretchr/testify v1.8.1 -> v1.8.4 go mod tidy --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 960b1ab7f..367dcb8cb 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.17 require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/labstack/gommon v0.4.0 - github.com/stretchr/testify v1.8.1 + github.com/stretchr/testify v1.8.4 github.com/valyala/fasttemplate v1.2.2 golang.org/x/crypto v0.14.0 golang.org/x/net v0.17.0 diff --git a/go.sum b/go.sum index b40dfd062..5b8ba6bcb 100644 --- a/go.sum +++ b/go.sum @@ -20,8 +20,8 @@ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpE github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= From 98a523756d875bc13475bcb6237f09e771cbe321 Mon Sep 17 00:00:00 2001 From: Martti T Date: Wed, 11 Oct 2023 08:32:23 +0300 Subject: [PATCH 314/446] Changelog for v4.11.2 (#2529) --- CHANGELOG.md | 16 ++++++++++++++++ echo.go | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fef7bb987..40016c9ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,21 @@ # Changelog +## v4.11.2 - 2023-10-11 + +**Security** + +* Bump golang.org/x/net to prevent CVE-2023-39325 / CVE-2023-44487 HTTP/2 Rapid Reset Attack [#2527](https://github.com/labstack/echo/pull/2527) +* fix(sec): randomString bias introduced by #2490 [#2492](https://github.com/labstack/echo/pull/2492) +* CSRF/RequestID mw: switch math/random usage to crypto/random [#2490](https://github.com/labstack/echo/pull/2490) + +**Enhancements** + +* Delete unused context in body_limit.go [#2483](https://github.com/labstack/echo/pull/2483) +* Use Go 1.21 in CI [#2505](https://github.com/labstack/echo/pull/2505) +* Fix some typos [#2511](https://github.com/labstack/echo/pull/2511) +* Allow CORS middleware to send Access-Control-Max-Age: 0 [#2518](https://github.com/labstack/echo/pull/2518) +* Bump dependancies [#2522](https://github.com/labstack/echo/pull/2522) + ## v4.11.1 - 2023-07-16 **Fixes** diff --git a/echo.go b/echo.go index 22a5b7af9..8bdf97539 100644 --- a/echo.go +++ b/echo.go @@ -259,7 +259,7 @@ const ( const ( // Version of Echo - Version = "4.11.1" + Version = "4.11.2" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 69a0de84158fd7cad326599d145c2248bcc15a69 Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 24 Oct 2023 21:12:13 +0300 Subject: [PATCH 315/446] Mark unmarshallable yaml struct tags as ignored (#2536) --- middleware/cors.go | 2 +- middleware/rewrite.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/middleware/cors.go b/middleware/cors.go index 10504359f..7ace2f224 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -39,7 +39,7 @@ type ( // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // // Optional. - AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` + AllowOriginFunc func(origin string) (bool, error) `yaml:"-"` // AllowMethods determines the value of the Access-Control-Allow-Methods // response header. This header specified the list of methods allowed when diff --git a/middleware/rewrite.go b/middleware/rewrite.go index e5b0a6b56..2090eac04 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -27,7 +27,7 @@ type ( // Example: // "^/old/[0.9]+/": "/new", // "^/api/.+?/(.*)": "/v2/$1", - RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"` + RegexRules map[*regexp.Regexp]string `yaml:"-"` } ) From c7d6d4373fdfbef5d6f44df0a8ef410c198420ee Mon Sep 17 00:00:00 2001 From: Kai Ratzeburg Date: Sun, 5 Nov 2023 17:01:01 +0100 Subject: [PATCH 316/446] proxy middleware: reuse echo request context (#2537) --- middleware/proxy.go | 4 +++ middleware/proxy_test.go | 60 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/middleware/proxy.go b/middleware/proxy.go index e4f98d9ed..16b00d645 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -359,6 +359,10 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { c.Set("_error", nil) } + // This is needed for ProxyConfig.ModifyResponse and/or ProxyConfig.Transport to be able to process the Request + // that Balancer may have replaced with c.SetRequest. + req = c.Request() + // Proxy switch { case c.IsWebSocket(): diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 415d68e77..1c93ba031 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -747,3 +747,63 @@ func TestProxyBalancerWithNoTargets(t *testing.T) { rrb := NewRoundRobinBalancer([]*ProxyTarget{}) assert.Nil(t, rrb.Next(nil)) } + +type testContextKey string + +type customBalancer struct { + target *ProxyTarget +} + +func (b *customBalancer) AddTarget(target *ProxyTarget) bool { + return false +} + +func (b *customBalancer) RemoveTarget(name string) bool { + return false +} + +func (b *customBalancer) Next(c echo.Context) *ProxyTarget { + ctx := context.WithValue(c.Request().Context(), testContextKey("FROM_BALANCER"), "CUSTOM_BALANCER") + c.SetRequest(c.Request().WithContext(ctx)) + return b.target +} + +func TestModifyResponseUseContext(t *testing.T) { + server := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }), + ) + defer server.Close() + + targetURL, _ := url.Parse(server.URL) + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Balancer: &customBalancer{ + target: &ProxyTarget{ + Name: "tst", + URL: targetURL, + }, + }, + RetryCount: 1, + ModifyResponse: func(res *http.Response) error { + val := res.Request.Context().Value(testContextKey("FROM_BALANCER")) + if valStr, ok := val.(string); ok { + res.Header.Set("FROM_BALANCER", valStr) + } + return nil + }, + }, + )) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "OK", rec.Body.String()) + assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER")) +} From 50ebcd8d7c17457489df7bcbbcaa3745c687fd32 Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 7 Nov 2023 13:40:22 +0200 Subject: [PATCH 317/446] refactor context tests to be separate functions (#2540) --- binder.go | 2 +- context_test.go | 770 +++++++++++++++++++++++++++--------------------- json_test.go | 38 ++- 3 files changed, 448 insertions(+), 362 deletions(-) diff --git a/binder.go b/binder.go index 29cceca0b..8e7b81413 100644 --- a/binder.go +++ b/binder.go @@ -1323,7 +1323,7 @@ func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExi case time.Second: *dest = time.Unix(n, 0) case time.Millisecond: - *dest = time.Unix(n/1e3, (n%1e3)*1e6) // TODO: time.UnixMilli(n) exists since Go1.17 switch to that when min version allows + *dest = time.UnixMilli(n) case time.Nanosecond: *dest = time.Unix(0, n) } diff --git a/context_test.go b/context_test.go index 11a63cfce..85b221446 100644 --- a/context_test.go +++ b/context_test.go @@ -19,7 +19,7 @@ import ( "time" "github.com/labstack/gommon/log" - testify "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" ) type ( @@ -85,303 +85,401 @@ func (t *Template) Render(w io.Writer, name string, data interface{}, c Context) return t.templates.ExecuteTemplate(w, name, data) } -type responseWriterErr struct { -} - -func (responseWriterErr) Header() http.Header { - return http.Header{} -} +func TestContextEcho(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() -func (responseWriterErr) Write([]byte) (int, error) { - return 0, errors.New("err") -} + c := e.NewContext(req, rec).(*context) -func (responseWriterErr) WriteHeader(statusCode int) { + assert.Equal(t, e, c.Echo()) } -func TestContext(t *testing.T) { +func TestContextRequest(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() + c := e.NewContext(req, rec).(*context) - assert := testify.New(t) + assert.NotNil(t, c.Request()) + assert.Equal(t, req, c.Request()) +} - // Echo - assert.Equal(e, c.Echo()) +func TestContextResponse(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() - // Request - assert.NotNil(c.Request()) + c := e.NewContext(req, rec).(*context) - // Response - assert.NotNil(c.Response()) + assert.NotNil(t, c.Response()) +} - //-------- - // Render - //-------- +func TestContextRenderTemplate(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec).(*context) tmpl := &Template{ templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")), } c.echo.Renderer = tmpl err := c.Render(http.StatusOK, "hello", "Jon Snow") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("Hello, Jon Snow!", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "Hello, Jon Snow!", rec.Body.String()) } +} + +func TestContextRenderErrorsOnNoRenderer(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec).(*context) c.echo.Renderer = nil - err = c.Render(http.StatusOK, "hello", "Jon Snow") - assert.Error(err) - - // JSON - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON+"\n", rec.Body.String()) - } - - // JSON with "?pretty" - req = httptest.NewRequest(http.MethodGet, "/?pretty", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSONPretty+"\n", rec.Body.String()) - } - req = httptest.NewRequest(http.MethodGet, "/", nil) // reset - - // JSONPretty - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSONPretty+"\n", rec.Body.String()) - } - - // JSON (error) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.JSON(http.StatusOK, make(chan bool)) - assert.Error(err) - - // JSONP - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + assert.Error(t, c.Render(http.StatusOK, "hello", "Jon Snow")) +} + +func TestContextJSON(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + c := e.NewContext(req, rec).(*context) + + err := c.JSON(http.StatusOK, user{1, "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON+"\n", rec.Body.String()) + } +} + +func TestContextJSONErrorsOut(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + c := e.NewContext(req, rec).(*context) + + err := c.JSON(http.StatusOK, make(chan bool)) + assert.EqualError(t, err, "json: unsupported type: chan bool") +} + +func TestContextJSONPrettyURL(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + + err := c.JSON(http.StatusOK, user{1, "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) + } +} + +func TestContextJSONPretty(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + err := c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) + } +} + +func TestContextJSONWithEmptyIntent(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + u := user{1, "Jon Snow"} + emptyIndent := "" + buf := new(bytes.Buffer) + + enc := json.NewEncoder(buf) + enc.SetIndent(emptyIndent, emptyIndent) + _ = enc.Encode(u) + err := c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, buf.String(), rec.Body.String()) + } +} + +func TestContextJSONP(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + callback := "callback" - err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(callback+"("+userJSON+"\n);", rec.Body.String()) - } - - // XML - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.XML(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXML, rec.Body.String()) - } - - // XML with "?pretty" - req = httptest.NewRequest(http.MethodGet, "/?pretty", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.XML(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) - } - req = httptest.NewRequest(http.MethodGet, "/", nil) - - // XML (error) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.XML(http.StatusOK, make(chan bool)) - assert.Error(err) - - // XML response write error - c = e.NewContext(req, rec).(*context) - c.response.Writer = responseWriterErr{} - err = c.XML(0, 0) - testify.Error(t, err) - - // XMLPretty - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) - } - - t.Run("empty indent", func(t *testing.T) { - var ( - u = user{1, "Jon Snow"} - buf = new(bytes.Buffer) - emptyIndent = "" - ) - - t.Run("json", func(t *testing.T) { - buf.Reset() - assert := testify.New(t) - - // New JSONBlob with empty indent - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - enc := json.NewEncoder(buf) - enc.SetIndent(emptyIndent, emptyIndent) - err = enc.Encode(u) - err = c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(buf.String(), rec.Body.String()) - } - }) + err := c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String()) + } +} - t.Run("xml", func(t *testing.T) { - buf.Reset() - assert := testify.New(t) - - // New XMLBlob with empty indent - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - enc := xml.NewEncoder(buf) - enc.Indent(emptyIndent, emptyIndent) - err = enc.Encode(u) - err = c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+buf.String(), rec.Body.String()) - } - }) - }) +func TestContextJSONBlob(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) - // Legacy JSONBlob - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) data, err := json.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + assert.NoError(t, err) err = c.JSONBlob(http.StatusOK, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON, rec.Body.String()) - } - - // Legacy JSONPBlob - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - callback = "callback" - data, err = json.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON, rec.Body.String()) + } +} + +func TestContextJSONPBlob(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + callback := "callback" + data, err := json.Marshal(user{1, "Jon Snow"}) + assert.NoError(t, err) err = c.JSONPBlob(http.StatusOK, callback, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(callback+"("+userJSON+");", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, callback+"("+userJSON+");", rec.Body.String()) + } +} + +func TestContextXML(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + err := c.XML(http.StatusOK, user{1, "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXML, rec.Body.String()) + } +} + +func TestContextXMLPrettyURL(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + + err := c.XML(http.StatusOK, user{1, "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String()) } +} - // Legacy XMLBlob - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - data, err = xml.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) +func TestContextXMLPretty(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + err := c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String()) + } +} + +func TestContextXMLBlob(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + data, err := xml.Marshal(user{1, "Jon Snow"}) + assert.NoError(t, err) err = c.XMLBlob(http.StatusOK, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXML, rec.Body.String()) - } - - // String - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.String(http.StatusOK, "Hello, World!") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal("Hello, World!", rec.Body.String()) - } - - // HTML - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.HTML(http.StatusOK, "Hello, World!") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal("Hello, World!", rec.Body.String()) - } - - // Stream - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXML, rec.Body.String()) + } +} + +func TestContextXMLWithEmptyIntent(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + u := user{1, "Jon Snow"} + emptyIndent := "" + buf := new(bytes.Buffer) + + enc := xml.NewEncoder(buf) + enc.Indent(emptyIndent, emptyIndent) + _ = enc.Encode(u) + err := c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+buf.String(), rec.Body.String()) + } +} + +type responseWriterErr struct { +} + +func (responseWriterErr) Header() http.Header { + return http.Header{} +} + +func (responseWriterErr) Write([]byte) (int, error) { + return 0, errors.New("responseWriterErr") +} + +func (responseWriterErr) WriteHeader(statusCode int) { +} + +func TestContextXMLError(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + c.response.Writer = responseWriterErr{} + + err := c.XML(http.StatusOK, make(chan bool)) + assert.EqualError(t, err, "responseWriterErr") +} + +func TestContextString(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + + err := c.String(http.StatusOK, "Hello, World!") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hello, World!", rec.Body.String()) + } +} + +func TestContextHTML(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + + err := c.HTML(http.StatusOK, "Hello, World!") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hello, World!", rec.Body.String()) + } +} + +func TestContextStream(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + r := strings.NewReader("response from a stream") - err = c.Stream(http.StatusOK, "application/octet-stream", r) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("application/octet-stream", rec.Header().Get(HeaderContentType)) - assert.Equal("response from a stream", rec.Body.String()) - } - - // Attachment - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.Attachment("_fixture/images/walle.png", "walle.png") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(219885, rec.Body.Len()) - } - - // Inline - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.Inline("_fixture/images/walle.png", "walle.png") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(219885, rec.Body.Len()) - } - - // NoContent - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + err := c.Stream(http.StatusOK, "application/octet-stream", r) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType)) + assert.Equal(t, "response from a stream", rec.Body.String()) + } +} + +func TestContextAttachment(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + + err := c.Attachment("_fixture/images/walle.png", "walle.png") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) + assert.Equal(t, 219885, rec.Body.Len()) + } +} + +func TestContextInline(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + + err := c.Inline("_fixture/images/walle.png", "walle.png") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) + assert.Equal(t, 219885, rec.Body.Len()) + } +} + +func TestContextNoContent(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) + c.NoContent(http.StatusOK) - assert.Equal(http.StatusOK, rec.Code) + assert.Equal(t, http.StatusOK, rec.Code) +} + +func TestContextError(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec).(*context) - // Error - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) c.Error(errors.New("error")) - assert.Equal(http.StatusInternalServerError, rec.Code) + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.True(t, c.Response().Committed) +} + +func TestContextReset(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) - // Reset c.SetParamNames("foo") c.SetParamValues("bar") c.Set("foe", "ban") c.query = url.Values(map[string][]string{"fon": {"baz"}}) + c.Reset(req, httptest.NewRecorder()) - assert.Equal(0, len(c.ParamValues())) - assert.Equal(0, len(c.ParamNames())) - assert.Equal(0, len(c.store)) - assert.Equal("", c.Path()) - assert.Equal(0, len(c.QueryParams())) + + assert.Len(t, c.ParamValues(), 0) + assert.Len(t, c.ParamNames(), 0) + assert.Len(t, c.Path(), 0) + assert.Len(t, c.QueryParams(), 0) + assert.Len(t, c.store, 0) } func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { @@ -391,11 +489,10 @@ func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { c := e.NewContext(req, rec).(*context) err := c.JSON(http.StatusCreated, user{1, "Jon Snow"}) - assert := testify.New(t) - if assert.NoError(err) { - assert.Equal(http.StatusCreated, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON+"\n", rec.Body.String()) } } @@ -406,9 +503,8 @@ func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) { c := e.NewContext(req, rec).(*context) err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()}) - assert := testify.New(t) - if assert.Error(err) { - assert.False(c.response.Committed) + if assert.Error(t, err) { + assert.False(t, c.response.Committed) } } @@ -422,22 +518,20 @@ func TestContextCookie(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) - assert := testify.New(t) - // Read single cookie, err := c.Cookie("theme") - if assert.NoError(err) { - assert.Equal("theme", cookie.Name) - assert.Equal("light", cookie.Value) + if assert.NoError(t, err) { + assert.Equal(t, "theme", cookie.Name) + assert.Equal(t, "light", cookie.Value) } // Read multiple for _, cookie := range c.Cookies() { switch cookie.Name { case "theme": - assert.Equal("light", cookie.Value) + assert.Equal(t, "light", cookie.Value) case "user": - assert.Equal("Jon Snow", cookie.Value) + assert.Equal(t, "Jon Snow", cookie.Value) } } @@ -452,11 +546,11 @@ func TestContextCookie(t *testing.T) { HttpOnly: true, } c.SetCookie(cookie) - assert.Contains(rec.Header().Get(HeaderSetCookie), "SSID") - assert.Contains(rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") - assert.Contains(rec.Header().Get(HeaderSetCookie), "labstack.com") - assert.Contains(rec.Header().Get(HeaderSetCookie), "Secure") - assert.Contains(rec.Header().Get(HeaderSetCookie), "HttpOnly") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly") } func TestContextPath(t *testing.T) { @@ -469,14 +563,12 @@ func TestContextPath(t *testing.T) { c := e.NewContext(nil, nil) r.Find(http.MethodGet, "/users/1", c) - assert := testify.New(t) - - assert.Equal("/users/:id", c.Path()) + assert.Equal(t, "/users/:id", c.Path()) r.Add(http.MethodGet, "/users/:uid/files/:fid", handler) c = e.NewContext(nil, nil) r.Find(http.MethodGet, "/users/1/files/1", c) - assert.Equal("/users/:uid/files/:fid", c.Path()) + assert.Equal(t, "/users/:uid/files/:fid", c.Path()) } func TestContextPathParam(t *testing.T) { @@ -486,15 +578,15 @@ func TestContextPathParam(t *testing.T) { // ParamNames c.SetParamNames("uid", "fid") - testify.EqualValues(t, []string{"uid", "fid"}, c.ParamNames()) + assert.EqualValues(t, []string{"uid", "fid"}, c.ParamNames()) // ParamValues c.SetParamValues("101", "501") - testify.EqualValues(t, []string{"101", "501"}, c.ParamValues()) + assert.EqualValues(t, []string{"101", "501"}, c.ParamValues()) // Param - testify.Equal(t, "501", c.Param("fid")) - testify.Equal(t, "", c.Param("undefined")) + assert.Equal(t, "501", c.Param("fid")) + assert.Equal(t, "", c.Param("undefined")) } func TestContextGetAndSetParam(t *testing.T) { @@ -507,23 +599,21 @@ func TestContextGetAndSetParam(t *testing.T) { // round-trip param values with modification paramVals := c.ParamValues() - testify.EqualValues(t, []string{""}, c.ParamValues()) + assert.EqualValues(t, []string{""}, c.ParamValues()) paramVals[0] = "bar" c.SetParamValues(paramVals...) - testify.EqualValues(t, []string{"bar"}, c.ParamValues()) + assert.EqualValues(t, []string{"bar"}, c.ParamValues()) // shouldn't explode during Reset() afterwards! - testify.NotPanics(t, func() { + assert.NotPanics(t, func() { c.Reset(nil, nil) }) } // Issue #1655 func TestContextSetParamNamesShouldUpdateEchoMaxParam(t *testing.T) { - assert := testify.New(t) - e := New() - assert.Equal(0, *e.maxParam) + assert.Equal(t, 0, *e.maxParam) expectedOneParam := []string{"one"} expectedTwoParams := []string{"one", "two"} @@ -533,23 +623,23 @@ func TestContextSetParamNamesShouldUpdateEchoMaxParam(t *testing.T) { c := e.NewContext(nil, nil) c.SetParamNames("1", "2") c.SetParamValues(expectedTwoParams...) - assert.Equal(2, *e.maxParam) - assert.EqualValues(expectedTwoParams, c.ParamValues()) + assert.Equal(t, 2, *e.maxParam) + assert.EqualValues(t, expectedTwoParams, c.ParamValues()) c.SetParamNames("1") - assert.Equal(2, *e.maxParam) + assert.Equal(t, 2, *e.maxParam) // Here for backward compatibility the ParamValues remains as they are - assert.EqualValues(expectedOneParam, c.ParamValues()) + assert.EqualValues(t, expectedOneParam, c.ParamValues()) c.SetParamNames("1", "2", "3") - assert.Equal(3, *e.maxParam) + assert.Equal(t, 3, *e.maxParam) // Here for backward compatibility the ParamValues remains as they are, but the len is extended to e.maxParam - assert.EqualValues(expectedThreeParams, c.ParamValues()) + assert.EqualValues(t, expectedThreeParams, c.ParamValues()) c.SetParamValues("A", "B", "C", "D") - assert.Equal(3, *e.maxParam) + assert.Equal(t, 3, *e.maxParam) // Here D shouldn't be returned - assert.EqualValues(expectedABCParams, c.ParamValues()) + assert.EqualValues(t, expectedABCParams, c.ParamValues()) } func TestContextFormValue(t *testing.T) { @@ -563,13 +653,13 @@ func TestContextFormValue(t *testing.T) { c := e.NewContext(req, nil) // FormValue - testify.Equal(t, "Jon Snow", c.FormValue("name")) - testify.Equal(t, "jon@labstack.com", c.FormValue("email")) + assert.Equal(t, "Jon Snow", c.FormValue("name")) + assert.Equal(t, "jon@labstack.com", c.FormValue("email")) // FormParams params, err := c.FormParams() - if testify.NoError(t, err) { - testify.Equal(t, url.Values{ + if assert.NoError(t, err) { + assert.Equal(t, url.Values{ "name": []string{"Jon Snow"}, "email": []string{"jon@labstack.com"}, }, params) @@ -580,8 +670,8 @@ func TestContextFormValue(t *testing.T) { req.Header.Add(HeaderContentType, MIMEMultipartForm) c = e.NewContext(req, nil) params, err = c.FormParams() - testify.Nil(t, params) - testify.Error(t, err) + assert.Nil(t, params) + assert.Error(t, err) } func TestContextQueryParam(t *testing.T) { @@ -593,11 +683,11 @@ func TestContextQueryParam(t *testing.T) { c := e.NewContext(req, nil) // QueryParam - testify.Equal(t, "Jon Snow", c.QueryParam("name")) - testify.Equal(t, "jon@labstack.com", c.QueryParam("email")) + assert.Equal(t, "Jon Snow", c.QueryParam("name")) + assert.Equal(t, "jon@labstack.com", c.QueryParam("email")) // QueryParams - testify.Equal(t, url.Values{ + assert.Equal(t, url.Values{ "name": []string{"Jon Snow"}, "email": []string{"jon@labstack.com"}, }, c.QueryParams()) @@ -608,7 +698,7 @@ func TestContextFormFile(t *testing.T) { buf := new(bytes.Buffer) mr := multipart.NewWriter(buf) w, err := mr.CreateFormFile("file", "test") - if testify.NoError(t, err) { + if assert.NoError(t, err) { w.Write([]byte("test")) } mr.Close() @@ -617,8 +707,8 @@ func TestContextFormFile(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.FormFile("file") - if testify.NoError(t, err) { - testify.Equal(t, "test", f.Filename) + if assert.NoError(t, err) { + assert.Equal(t, "test", f.Filename) } } @@ -633,8 +723,8 @@ func TestContextMultipartForm(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.MultipartForm() - if testify.NoError(t, err) { - testify.NotNil(t, f) + if assert.NoError(t, err) { + assert.NotNil(t, f) } } @@ -643,16 +733,16 @@ func TestContextRedirect(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - testify.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) - testify.Equal(t, http.StatusMovedPermanently, rec.Code) - testify.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) - testify.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) + assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) + assert.Equal(t, http.StatusMovedPermanently, rec.Code) + assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) + assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) } func TestContextStore(t *testing.T) { var c Context = new(context) c.Set("name", "Jon Snow") - testify.Equal(t, "Jon Snow", c.Get("name")) + assert.Equal(t, "Jon Snow", c.Get("name")) } func BenchmarkContext_Store(b *testing.B) { @@ -682,19 +772,19 @@ func TestContextHandler(t *testing.T) { c := e.NewContext(nil, nil) r.Find(http.MethodGet, "/handler", c) err := c.Handler()(c) - testify.Equal(t, "handler", b.String()) - testify.NoError(t, err) + assert.Equal(t, "handler", b.String()) + assert.NoError(t, err) } func TestContext_SetHandler(t *testing.T) { var c Context = new(context) - testify.Nil(t, c.Handler()) + assert.Nil(t, c.Handler()) c.SetHandler(func(c Context) error { return nil }) - testify.NotNil(t, c.Handler()) + assert.NotNil(t, c.Handler()) } func TestContext_Path(t *testing.T) { @@ -703,7 +793,7 @@ func TestContext_Path(t *testing.T) { var c Context = new(context) c.SetPath(path) - testify.Equal(t, path, c.Path()) + assert.Equal(t, path, c.Path()) } type validator struct{} @@ -716,10 +806,10 @@ func TestContext_Validate(t *testing.T) { e := New() c := e.NewContext(nil, nil) - testify.Error(t, c.Validate(struct{}{})) + assert.Error(t, c.Validate(struct{}{})) e.Validator = &validator{} - testify.NoError(t, c.Validate(struct{}{})) + assert.NoError(t, c.Validate(struct{}{})) } func TestContext_QueryString(t *testing.T) { @@ -730,18 +820,18 @@ func TestContext_QueryString(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil) c := e.NewContext(req, nil) - testify.Equal(t, queryString, c.QueryString()) + assert.Equal(t, queryString, c.QueryString()) } func TestContext_Request(t *testing.T) { var c Context = new(context) - testify.Nil(t, c.Request()) + assert.Nil(t, c.Request()) req := httptest.NewRequest(http.MethodGet, "/path", nil) c.SetRequest(req) - testify.Equal(t, req, c.Request()) + assert.Equal(t, req, c.Request()) } func TestContext_Scheme(t *testing.T) { @@ -798,14 +888,14 @@ func TestContext_Scheme(t *testing.T) { } for _, tt := range tests { - testify.Equal(t, tt.s, tt.c.Scheme()) + assert.Equal(t, tt.s, tt.c.Scheme()) } } func TestContext_IsWebSocket(t *testing.T) { tests := []struct { c Context - ws testify.BoolAssertionFunc + ws assert.BoolAssertionFunc }{ { &context{ @@ -813,7 +903,7 @@ func TestContext_IsWebSocket(t *testing.T) { Header: http.Header{HeaderUpgrade: []string{"websocket"}}, }, }, - testify.True, + assert.True, }, { &context{ @@ -821,13 +911,13 @@ func TestContext_IsWebSocket(t *testing.T) { Header: http.Header{HeaderUpgrade: []string{"Websocket"}}, }, }, - testify.True, + assert.True, }, { &context{ request: &http.Request{}, }, - testify.False, + assert.False, }, { &context{ @@ -835,7 +925,7 @@ func TestContext_IsWebSocket(t *testing.T) { Header: http.Header{HeaderUpgrade: []string{"other"}}, }, }, - testify.False, + assert.False, }, } @@ -854,8 +944,8 @@ func TestContext_Bind(t *testing.T) { req.Header.Add(HeaderContentType, MIMEApplicationJSON) err := c.Bind(u) - testify.NoError(t, err) - testify.Equal(t, &user{1, "Jon Snow"}, u) + assert.NoError(t, err) + assert.Equal(t, &user{1, "Jon Snow"}, u) } func TestContext_Logger(t *testing.T) { @@ -863,15 +953,15 @@ func TestContext_Logger(t *testing.T) { c := e.NewContext(nil, nil) log1 := c.Logger() - testify.NotNil(t, log1) + assert.NotNil(t, log1) log2 := log.New("echo2") c.SetLogger(log2) - testify.Equal(t, log2, c.Logger()) + assert.Equal(t, log2, c.Logger()) // Resetting the context returns the initial logger c.Reset(nil, nil) - testify.Equal(t, log1, c.Logger()) + assert.Equal(t, log1, c.Logger()) } func TestContext_RealIP(t *testing.T) { @@ -959,6 +1049,6 @@ func TestContext_RealIP(t *testing.T) { } for _, tt := range tests { - testify.Equal(t, tt.s, tt.c.RealIP()) + assert.Equal(t, tt.s, tt.c.RealIP()) } } diff --git a/json_test.go b/json_test.go index 27ee43e73..8fb9ebc96 100644 --- a/json_test.go +++ b/json_test.go @@ -1,7 +1,7 @@ package echo import ( - testify "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" "strings" @@ -16,16 +16,14 @@ func TestDefaultJSONCodec_Encode(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) - assert := testify.New(t) - // Echo - assert.Equal(e, c.Echo()) + assert.Equal(t, e, c.Echo()) // Request - assert.NotNil(c.Request()) + assert.NotNil(t, c.Request()) // Response - assert.NotNil(c.Response()) + assert.NotNil(t, c.Response()) //-------- // Default JSON encoder @@ -34,16 +32,16 @@ func TestDefaultJSONCodec_Encode(t *testing.T) { enc := new(DefaultJSONSerializer) err := enc.Serialize(c, user{1, "Jon Snow"}, "") - if assert.NoError(err) { - assert.Equal(userJSON+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, userJSON+"\n", rec.Body.String()) } req = httptest.NewRequest(http.MethodPost, "/", nil) rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = enc.Serialize(c, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(userJSONPretty+"\n", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } } @@ -55,16 +53,14 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec).(*context) - assert := testify.New(t) - // Echo - assert.Equal(e, c.Echo()) + assert.Equal(t, e, c.Echo()) // Request - assert.NotNil(c.Request()) + assert.NotNil(t, c.Request()) // Response - assert.NotNil(c.Response()) + assert.NotNil(t, c.Response()) //-------- // Default JSON encoder @@ -74,8 +70,8 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { var u = user{} err := enc.Deserialize(c, &u) - if assert.NoError(err) { - assert.Equal(u, user{ID: 1, Name: "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, u, user{ID: 1, Name: "Jon Snow"}) } var userUnmarshalSyntaxError = user{} @@ -83,8 +79,8 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = enc.Deserialize(c, &userUnmarshalSyntaxError) - assert.IsType(&HTTPError{}, err) - assert.EqualError(err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value") + assert.IsType(t, &HTTPError{}, err) + assert.EqualError(t, err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value") var userUnmarshalTypeError = struct { ID string `json:"id"` @@ -95,7 +91,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { rec = httptest.NewRecorder() c = e.NewContext(req, rec).(*context) err = enc.Deserialize(c, &userUnmarshalTypeError) - assert.IsType(&HTTPError{}, err) - assert.EqualError(err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string") + assert.IsType(t, &HTTPError{}, err) + assert.EqualError(t, err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string") } From 14daeb968049b71296a80b91abd3883afd02b4d1 Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 7 Nov 2023 14:10:06 +0200 Subject: [PATCH 318/446] Security: c.Attachment and c.Inline should escape name in `Content-Disposition` header to avoid 'Reflect File Download' vulnerability. (#2541) This is same as Go std does it https://github.com/golang/go/blob/9d836d41d0d9df3acabf7f9607d3b09188a9bfc6/src/mime/multipart/writer.go#L132 --- context.go | 4 ++- context_test.go | 82 +++++++++++++++++++++++++++++++++++++------------ 2 files changed, 65 insertions(+), 21 deletions(-) diff --git a/context.go b/context.go index 27da28a9c..6a1811685 100644 --- a/context.go +++ b/context.go @@ -584,8 +584,10 @@ func (c *context) Inline(file, name string) error { return c.contentDisposition(file, name, "inline") } +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + func (c *context) contentDisposition(file, name, dispositionType string) error { - c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf("%s; filename=%q", dispositionType, name)) + c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`, dispositionType, quoteEscaper.Replace(name))) return c.File(file) } diff --git a/context_test.go b/context_test.go index 85b221446..01a8784b8 100644 --- a/context_test.go +++ b/context_test.go @@ -414,30 +414,72 @@ func TestContextStream(t *testing.T) { } func TestContextAttachment(t *testing.T) { - e := New() - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) - c := e.NewContext(req, rec).(*context) - - err := c.Attachment("_fixture/images/walle.png", "walle.png") - if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(t, 219885, rec.Body.Len()) + var testCases = []struct { + name string + whenName string + expectHeader string + }{ + { + name: "ok", + whenName: "walle.png", + expectHeader: `attachment; filename="walle.png"`, + }, + { + name: "ok, escape quotes in malicious filename", + whenName: `malicious.sh"; \"; dummy=.txt`, + expectHeader: `attachment; filename="malicious.sh\"; \\\"; dummy=.txt"`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + err := c.Attachment("_fixture/images/walle.png", tc.whenName) + if assert.NoError(t, err) { + assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition)) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, 219885, rec.Body.Len()) + } + }) } } func TestContextInline(t *testing.T) { - e := New() - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) - c := e.NewContext(req, rec).(*context) - - err := c.Inline("_fixture/images/walle.png", "walle.png") - if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(t, 219885, rec.Body.Len()) + var testCases = []struct { + name string + whenName string + expectHeader string + }{ + { + name: "ok", + whenName: "walle.png", + expectHeader: `inline; filename="walle.png"`, + }, + { + name: "ok, escape quotes in malicious filename", + whenName: `malicious.sh"; \"; dummy=.txt`, + expectHeader: `inline; filename="malicious.sh\"; \\\"; dummy=.txt"`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec).(*context) + + err := c.Inline("_fixture/images/walle.png", tc.whenName) + if assert.NoError(t, err) { + assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition)) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, 219885, rec.Body.Len()) + } + }) } } From 4b26cde851bc7a51e624c04dcc5d37be1ce0c84f Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 7 Nov 2023 14:19:32 +0200 Subject: [PATCH 319/446] Changelog for v4.11.3 (#2542) --- CHANGELOG.md | 13 +++++++++++++ echo.go | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 40016c9ed..8490ab2c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## v4.11.3 - 2023-11-07 + +**Security** + +* 'c.Attachment' and 'c.Inline' should escape filename in 'Content-Disposition' header to avoid 'Reflect File Download' vulnerability. [#2541](https://github.com/labstack/echo/pull/2541) + +**Enhancements** + +* Tests: refactor context tests to be separate functions [#2540](https://github.com/labstack/echo/pull/2540) +* Proxy middleware: reuse echo request context [#2537](https://github.com/labstack/echo/pull/2537) +* Mark unmarshallable yaml struct tags as ignored [#2536](https://github.com/labstack/echo/pull/2536) + + ## v4.11.2 - 2023-10-11 **Security** diff --git a/echo.go b/echo.go index 8bdf97539..0ac644924 100644 --- a/echo.go +++ b/echo.go @@ -259,7 +259,7 @@ const ( const ( // Version of Echo - Version = "4.11.2" + Version = "4.11.3" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 584cb85a6b749846ac26a8cd151244ab281f2abc Mon Sep 17 00:00:00 2001 From: Martti T Date: Tue, 7 Nov 2023 15:09:43 +0200 Subject: [PATCH 320/446] request logger: add example for Slog https://pkg.go.dev/log/slog (#2543) --- middleware/request_logger.go | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/middleware/request_logger.go b/middleware/request_logger.go index ce76230c7..f82f6b622 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -8,6 +8,30 @@ import ( "github.com/labstack/echo/v4" ) +// Example for `slog` https://pkg.go.dev/log/slog +// logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) +// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ +// LogStatus: true, +// LogURI: true, +// LogError: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code +// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { +// if v.Error == nil { +// logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST", +// slog.String("uri", v.URI), +// slog.Int("status", v.Status), +// ) +// } else { +// logger.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR", +// slog.String("uri", v.URI), +// slog.Int("status", v.Status), +// slog.String("err", v.Error.Error()), +// ) +// } +// return nil +// }, +// })) +// // Example for `fmt.Printf` // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogStatus: true, From 287a82c228efce23fac50e84d37e8690896bf5a5 Mon Sep 17 00:00:00 2001 From: Nicu Maxian Date: Tue, 19 Dec 2023 18:07:23 +0200 Subject: [PATCH 321/446] Upgrade golang.org/x/crypto to v0.17.0 to fix vulnerability issue (#2562) Co-authored-by: Nicu MAXIAN --- go.mod | 6 +++--- go.sum | 10 +++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 367dcb8cb..e4944b016 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/labstack/gommon v0.4.0 github.com/stretchr/testify v1.8.4 github.com/valyala/fasttemplate v1.2.2 - golang.org/x/crypto v0.14.0 + golang.org/x/crypto v0.17.0 golang.org/x/net v0.17.0 golang.org/x/time v0.3.0 ) @@ -18,7 +18,7 @@ require ( github.com/mattn/go-isatty v0.0.19 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.13.0 // indirect - golang.org/x/text v0.13.0 // indirect + golang.org/x/sys v0.15.0 // indirect + golang.org/x/text v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 5b8ba6bcb..5664e0e6c 100644 --- a/go.sum +++ b/go.sum @@ -30,8 +30,9 @@ github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= +golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -56,20 +57,23 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= +golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= From 209c6a199af0d6443f640528351064ba31b5f864 Mon Sep 17 00:00:00 2001 From: Martti T Date: Wed, 20 Dec 2023 15:17:20 +0200 Subject: [PATCH 322/446] Update deps and mark Go version to 1.18 as this is what golang.org/x/* use. (#2563) --- go.mod | 10 ++++----- go.sum | 70 +++++++--------------------------------------------------- 2 files changed, 13 insertions(+), 67 deletions(-) diff --git a/go.mod b/go.mod index e4944b016..089ffb140 100644 --- a/go.mod +++ b/go.mod @@ -1,21 +1,21 @@ module github.com/labstack/echo/v4 -go 1.17 +go 1.18 require ( github.com/golang-jwt/jwt v3.2.2+incompatible - github.com/labstack/gommon v0.4.0 + github.com/labstack/gommon v0.4.2 github.com/stretchr/testify v1.8.4 github.com/valyala/fasttemplate v1.2.2 golang.org/x/crypto v0.17.0 - golang.org/x/net v0.17.0 - golang.org/x/time v0.3.0 + golang.org/x/net v0.19.0 + golang.org/x/time v0.5.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect golang.org/x/sys v0.15.0 // indirect diff --git a/go.sum b/go.sum index 5664e0e6c..0584b7e59 100644 --- a/go.sum +++ b/go.sum @@ -1,89 +1,35 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8= -github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= -github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= +github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= -golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= -golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 226e4f076a69de85b71cf059d8a3c0fa8feafcaf Mon Sep 17 00:00:00 2001 From: Martti T Date: Wed, 20 Dec 2023 15:24:30 +0200 Subject: [PATCH 323/446] Changelog for v4.11.4 (#2564) Changelog for v4.11.4 --- CHANGELOG.md | 12 ++++++++++++ echo.go | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8490ab2c8..cc17e28d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## v4.11.4 - 2023-12-20 + +**Security** + +* Upgrade golang.org/x/crypto to v0.17.0 to fix vulnerability [issue](https://pkg.go.dev/vuln/GO-2023-2402) [#2562](https://github.com/labstack/echo/pull/2562) + +**Enhancements** + +* Update deps and mark Go version to 1.18 as this is what golang.org/x/* use [#2563](https://github.com/labstack/echo/pull/2563) +* Request logger: add example for Slog https://pkg.go.dev/log/slog [#2543](https://github.com/labstack/echo/pull/2543) + + ## v4.11.3 - 2023-11-07 **Security** diff --git a/echo.go b/echo.go index 0ac644924..9924ac86d 100644 --- a/echo.go +++ b/echo.go @@ -259,7 +259,7 @@ const ( const ( // Version of Echo - Version = "4.11.3" + Version = "4.11.4" website = "https://echo.labstack.com" // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo banner = ` From 60fc2fb1b76f5613fc41aa9315cad6e8c96c6859 Mon Sep 17 00:00:00 2001 From: Martti T Date: Wed, 20 Dec 2023 15:32:51 +0200 Subject: [PATCH 324/446] binder: make binding to Map work better with string destinations (#2554) --- bind.go | 22 ++++++++++++++++++--- bind_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/bind.go b/bind.go index 374a2aec5..6f41ce587 100644 --- a/bind.go +++ b/bind.go @@ -131,10 +131,26 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri typ := reflect.TypeOf(destination).Elem() val := reflect.ValueOf(destination).Elem() - // Map - if typ.Kind() == reflect.Map { + // Support binding to limited Map destinations: + // - map[string][]string, + // - map[string]string <-- (binds first value from data slice) + // - map[string]interface{} + // You are better off binding to struct but there are user who want this map feature. Source of data for these cases are: + // params,query,header,form as these sources produce string values, most of the time slice of strings, actually. + if typ.Kind() == reflect.Map && typ.Key().Kind() == reflect.String { + k := typ.Elem().Kind() + isElemInterface := k == reflect.Interface + isElemString := k == reflect.String + isElemSliceOfStrings := k == reflect.Slice && typ.Elem().Elem().Kind() == reflect.String + if !(isElemSliceOfStrings || isElemString || isElemInterface) { + return nil + } for k, v := range data { - val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0])) + if isElemString { + val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0])) + } else { + val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v)) + } } return nil } diff --git a/bind_test.go b/bind_test.go index c35283dcf..c11723303 100644 --- a/bind_test.go +++ b/bind_test.go @@ -429,6 +429,62 @@ func TestBindUnsupportedMediaType(t *testing.T) { testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) } +func TestDefaultBinder_bindDataToMap(t *testing.T) { + exampleData := map[string][]string{ + "multiple": {"1", "2"}, + "single": {"3"}, + } + + t.Run("ok, bind to map[string]string", func(t *testing.T) { + dest := map[string]string{} + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.Equal(t, + map[string]string{ + "multiple": "1", + "single": "3", + }, + dest, + ) + }) + + t.Run("ok, bind to map[string][]string", func(t *testing.T) { + dest := map[string][]string{} + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.Equal(t, + map[string][]string{ + "multiple": {"1", "2"}, + "single": {"3"}, + }, + dest, + ) + }) + + t.Run("ok, bind to map[string]interface", func(t *testing.T) { + dest := map[string]interface{}{} + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.Equal(t, + map[string]interface{}{ + "multiple": []string{"1", "2"}, + "single": []string{"3"}, + }, + dest, + ) + }) + + t.Run("ok, bind to map[string]int skips", func(t *testing.T) { + dest := map[string]int{} + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.Equal(t, map[string]int{}, dest) + }) + + t.Run("ok, bind to map[string][]int skips", func(t *testing.T) { + dest := map[string][]int{} + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.Equal(t, map[string][]int{}, dest) + }) + +} + func TestBindbindData(t *testing.T) { ts := new(bindTestStruct) b := new(DefaultBinder) From d26212069089c65de25653d361932c0dd6f4d379 Mon Sep 17 00:00:00 2001 From: Marcus Kohlberg <78424526+marcuskohlberg@users.noreply.github.com> Date: Tue, 23 Jan 2024 04:26:05 +0100 Subject: [PATCH 325/446] README.md: add Encore as sponsor (#2579) There wasn't a sponsors section so I had to design one, hope you think it makes sense. --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index 18accea75..0a302072d 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,18 @@ For older versions, please use the latest v3 tag. - Automatic TLS via Let’s Encrypt - HTTP/2 support +## Sponsors + + +
+ +Click [here](https://github.com/sponsors/labstack) for more information on sponsorship. + ## Benchmarks Date: 2020/11/11
From b835498241989eea914fb63b774de801e6c16833 Mon Sep 17 00:00:00 2001 From: Martti T Date: Wed, 24 Jan 2024 17:45:40 +0200 Subject: [PATCH 326/446] Reorder paragraphs in README.md (#2581) --- README.md | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 0a302072d..351ba3c55 100644 --- a/README.md +++ b/README.md @@ -9,20 +9,18 @@ [![Twitter](https://img.shields.io/badge/twitter-@labstack-55acee.svg?style=flat-square)](https://twitter.com/labstack) [![License](http://img.shields.io/badge/license-mit-blue.svg?style=flat-square)](https://raw.githubusercontent.com/labstack/echo/master/LICENSE) -## Supported Go versions +## Echo -Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with -older versions. +High performance, extensible, minimalist Go web framework. -As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules). -Therefore a Go version capable of understanding /vN suffixed imports is required: +* [Official website](https://echo.labstack.com) +* [Quick start](https://echo.labstack.com/docs/quick-start) +* [Middlewares](https://echo.labstack.com/docs/category/middleware) -Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended -way of using Echo going forward. +Help and questions: [Github Discussions](https://github.com/labstack/echo/discussions) -For older versions, please use the latest v3 tag. -## Feature Overview +### Feature Overview - Optimized HTTP router which smartly prioritize routes - Build robust and scalable RESTful APIs @@ -69,6 +67,7 @@ The benchmarks above were run on an Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz // go get github.com/labstack/echo/{version} go get github.com/labstack/echo/v4 ``` +Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions. ### Example @@ -129,10 +128,6 @@ of middlewares in this list. Please send a PR to add your own library here. -## Help - -- [Forum](https://github.com/labstack/echo/discussions) - ## Contribute **Use issues for everything** From f12fdb09cd4d7afc749d132f12b60422753d8ecb Mon Sep 17 00:00:00 2001 From: Martti T Date: Sun, 28 Jan 2024 17:16:51 +0200 Subject: [PATCH 327/446] CI: upgrade actions/checkout to v4 and actions/setup-go to v5 (#2584) * CI: upgrade actions/checkout to v4 * CI: upgrade actions/setup-go to v5 --- .github/workflows/checks.yml | 4 ++-- .github/workflows/echo.yml | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 440f0ec52..fbd6d9571 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -21,10 +21,10 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: ${{ env.LATEST_GO_VERSION }} check-latest: true diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index c240dd0c5..5722dcbe9 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -30,10 +30,10 @@ jobs: runs-on: ${{ matrix.os }} steps: - name: Checkout Code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} @@ -53,18 +53,18 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Code (Previous) - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: ref: ${{ github.base_ref }} path: previous - name: Checkout Code (New) - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: path: new - name: Set up Go ${{ matrix.go }} - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: ${{ env.LATEST_GO_VERSION }} From 76994d17d59d25c53c4e333d2a2048410e0748e1 Mon Sep 17 00:00:00 2001 From: Suwon Chae Date: Tue, 6 Feb 2024 14:41:33 +0900 Subject: [PATCH 328/446] Remove default charset from 'application/json' Content-Type header (#2568) Fixes #2567 --- context.go | 4 ++-- context_test.go | 12 ++++++------ echo.go | 7 ++++++- middleware/decompress_test.go | 2 +- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/context.go b/context.go index 6a1811685..d4cba8447 100644 --- a/context.go +++ b/context.go @@ -489,7 +489,7 @@ func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error } func (c *context) json(code int, i interface{}, indent string) error { - c.writeContentType(MIMEApplicationJSONCharsetUTF8) + c.writeContentType(MIMEApplicationJSON) c.response.Status = code return c.echo.JSONSerializer.Serialize(c, i, indent) } @@ -507,7 +507,7 @@ func (c *context) JSONPretty(code int, i interface{}, indent string) (err error) } func (c *context) JSONBlob(code int, b []byte) (err error) { - return c.Blob(code, MIMEApplicationJSONCharsetUTF8, b) + return c.Blob(code, MIMEApplicationJSON, b) } func (c *context) JSONP(code int, callback string, i interface{}) (err error) { diff --git a/context_test.go b/context_test.go index 01a8784b8..4ca2cc84b 100644 --- a/context_test.go +++ b/context_test.go @@ -154,7 +154,7 @@ func TestContextJSON(t *testing.T) { err := c.JSON(http.StatusOK, user{1, "Jon Snow"}) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) assert.Equal(t, userJSON+"\n", rec.Body.String()) } } @@ -178,7 +178,7 @@ func TestContextJSONPrettyURL(t *testing.T) { err := c.JSON(http.StatusOK, user{1, "Jon Snow"}) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } } @@ -192,7 +192,7 @@ func TestContextJSONPretty(t *testing.T) { err := c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ") if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } } @@ -213,7 +213,7 @@ func TestContextJSONWithEmptyIntent(t *testing.T) { err := c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) assert.Equal(t, buf.String(), rec.Body.String()) } } @@ -244,7 +244,7 @@ func TestContextJSONBlob(t *testing.T) { err = c.JSONBlob(http.StatusOK, data) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) assert.Equal(t, userJSON, rec.Body.String()) } } @@ -533,7 +533,7 @@ func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { if assert.NoError(t, err) { assert.Equal(t, http.StatusCreated, rec.Code) - assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) assert.Equal(t, userJSON+"\n", rec.Body.String()) } } diff --git a/echo.go b/echo.go index 9924ac86d..7b6a0907d 100644 --- a/echo.go +++ b/echo.go @@ -169,7 +169,12 @@ const ( // MIME types const ( - MIMEApplicationJSON = "application/json" + // MIMEApplicationJSON JavaScript Object Notation (JSON) https://www.rfc-editor.org/rfc/rfc8259 + MIMEApplicationJSON = "application/json" + // Deprecated: Please use MIMEApplicationJSON instead. JSON should be encoded using UTF-8 by default. + // No "charset" parameter is defined for this registration. + // Adding one really has no effect on compliant recipients. + // See RFC 8259, section 8.1. https://datatracker.ietf.org/doc/html/rfc8259#section-8.1 MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8 MIMEApplicationJavaScript = "application/javascript" MIMEApplicationJavaScriptCharsetUTF8 = MIMEApplicationJavaScript + "; " + charsetUTF8 diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 2e73ba80e..351e0e708 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -131,7 +131,7 @@ func TestDecompressSkipper(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) e.ServeHTTP(rec, req) - assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSONCharsetUTF8) + assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSON) reqBody, err := io.ReadAll(c.Request().Body) assert.NoError(t, err) assert.Equal(t, body, string(reqBody)) From 51c54f473486a5c6e5a9117aca3a6425d24d2731 Mon Sep 17 00:00:00 2001 From: toim Date: Wed, 7 Feb 2024 07:23:31 +0200 Subject: [PATCH 329/446] CI: Use Go 1.22 --- .github/workflows/checks.yml | 2 +- .github/workflows/echo.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index fbd6d9571..9ae5dbd5a 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -14,7 +14,7 @@ permissions: env: # run static analysis only with the latest Go version - LATEST_GO_VERSION: "1.21" + LATEST_GO_VERSION: "1.22" jobs: check: diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index 5722dcbe9..cb3dc448b 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -14,7 +14,7 @@ permissions: env: # run coverage and benchmarks only with the latest Go version - LATEST_GO_VERSION: "1.21" + LATEST_GO_VERSION: "1.22" jobs: test: @@ -25,7 +25,7 @@ jobs: # Echo tests with last four major releases (unless there are pressing vulnerabilities) # As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when # we derive from last four major releases promise. - go: ["1.18", "1.19", "1.20", "1.21"] + go: ["1.19", "1.20", "1.21", "1.22"] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: From 29aab274b3810dfd4e1be172d5a569ac3b9efcd6 Mon Sep 17 00:00:00 2001 From: toim Date: Wed, 7 Feb 2024 07:37:19 +0200 Subject: [PATCH 330/446] In Go 1.22 finding name of function with reflection has changed. change tests to work with that. --- echo_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/echo_test.go b/echo_test.go index a352e4026..416479191 100644 --- a/echo_test.go +++ b/echo_test.go @@ -1572,7 +1572,7 @@ func TestEcho_OnAddRouteHandler(t *testing.T) { }) } - e.GET("/static", NotFoundHandler) + e.GET("/static", dummyHandler) e.Host("domain.site").GET("/static/*", dummyHandler, func(next HandlerFunc) HandlerFunc { return func(c Context) error { return next(c) @@ -1582,7 +1582,7 @@ func TestEcho_OnAddRouteHandler(t *testing.T) { assert.Len(t, added, 2) assert.Equal(t, "", added[0].host) - assert.Equal(t, Route{Method: http.MethodGet, Path: "/static", Name: "github.com/labstack/echo/v4.glob..func1"}, added[0].route) + assert.Equal(t, Route{Method: http.MethodGet, Path: "/static", Name: "github.com/labstack/echo/v4.TestEcho_OnAddRouteHandler.func1"}, added[0].route) assert.Len(t, added[0].middleware, 0) assert.Equal(t, "domain.site", added[1].host) From ea529bbab6602db8bd9fc0746405a3687ffbd885 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Georg=20M=C3=BCller?= Date: Tue, 6 Feb 2024 16:18:12 +0100 Subject: [PATCH 331/446] binder: allow binding to a nil map --- bind.go | 3 +++ bind_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/bind.go b/bind.go index 6f41ce587..51f4689e7 100644 --- a/bind.go +++ b/bind.go @@ -145,6 +145,9 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri if !(isElemSliceOfStrings || isElemString || isElemInterface) { return nil } + if val.IsNil() { + val.Set(reflect.MakeMap(typ)) + } for k, v := range data { if isElemString { val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0])) diff --git a/bind_test.go b/bind_test.go index c11723303..cffccfb35 100644 --- a/bind_test.go +++ b/bind_test.go @@ -447,6 +447,18 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { ) }) + t.Run("ok, bind to map[string]string with nil map", func(t *testing.T) { + var dest map[string]string + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.Equal(t, + map[string]string{ + "multiple": "1", + "single": "3", + }, + dest, + ) + }) + t.Run("ok, bind to map[string][]string", func(t *testing.T) { dest := map[string][]string{} assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) @@ -459,6 +471,18 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { ) }) + t.Run("ok, bind to map[string][]string with nil map", func(t *testing.T) { + var dest map[string][]string + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.Equal(t, + map[string][]string{ + "multiple": {"1", "2"}, + "single": {"3"}, + }, + dest, + ) + }) + t.Run("ok, bind to map[string]interface", func(t *testing.T) { dest := map[string]interface{}{} assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) @@ -471,18 +495,41 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { ) }) + t.Run("ok, bind to map[string]interface with nil map", func(t *testing.T) { + var dest map[string]interface{} + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.Equal(t, + map[string]interface{}{ + "multiple": []string{"1", "2"}, + "single": []string{"3"}, + }, + dest, + ) + }) + t.Run("ok, bind to map[string]int skips", func(t *testing.T) { dest := map[string]int{} assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) assert.Equal(t, map[string]int{}, dest) }) + t.Run("ok, bind to map[string]int skips with nil map", func(t *testing.T) { + var dest map[string]int + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.Equal(t, map[string]int(nil), dest) + }) + t.Run("ok, bind to map[string][]int skips", func(t *testing.T) { dest := map[string][]int{} assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) assert.Equal(t, map[string][]int{}, dest) }) + t.Run("ok, bind to map[string][]int skips with nil map", func(t *testing.T) { + var dest map[string][]int + assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param")) + assert.Equal(t, map[string][]int(nil), dest) + }) } func TestBindbindData(t *testing.T) { From fa70db801e3df89c7de8b8da161c3f41a1fe84d7 Mon Sep 17 00:00:00 2001 From: Ryo Kusnadi Date: Sun, 18 Feb 2024 20:47:13 +0700 Subject: [PATCH 332/446] Add Skipper Unit Test In BasicBasicAuthConfig and Add More Detail Explanation regarding BasicAuthValidator (#2461) * Add Skipper Unit Test In BasicBasicAuthConfig and Add More detail explanation regarding BasicAuthValidator * Simplify Skipper Unit Test --- middleware/basic_auth.go | 2 ++ middleware/basic_auth_test.go | 17 ++++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index f9e8caafe..07a5761b8 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -25,6 +25,8 @@ type ( } // BasicAuthValidator defines a function to validate BasicAuth credentials. + // The function should return a boolean indicating whether the credentials are valid, + // and an error if any error occurs during the validation process. BasicAuthValidator func(string, string, echo.Context) (bool, error) ) diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 20e769214..2e133e071 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -32,7 +32,6 @@ func TestBasicAuth(t *testing.T) { assert.NoError(t, h(c)) h = BasicAuthWithConfig(BasicAuthConfig{ - Skipper: nil, Validator: f, Realm: "someRealm", })(func(c echo.Context) error { @@ -72,4 +71,20 @@ func TestBasicAuth(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, auth) he = h(c).(*echo.HTTPError) assert.Equal(t, http.StatusUnauthorized, he.Code) + + h = BasicAuthWithConfig(BasicAuthConfig{ + Validator: f, + Realm: "someRealm", + Skipper: func(c echo.Context) bool { + return true + }, + })(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + // Skipped Request + auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")) + req.Header.Set(echo.HeaderAuthorization, auth) + assert.NoError(t, h(c)) + } From 34717b717df914b4c511610ef44ac0339316875f Mon Sep 17 00:00:00 2001 From: teslaedison <156734008+teslaedison@users.noreply.github.com> Date: Thu, 7 Mar 2024 03:43:59 +0800 Subject: [PATCH 333/446] fix some typos (#2603) Signed-off-by: teslaedison --- context.go | 2 +- echo.go | 2 +- ip.go | 2 +- middleware/request_logger_test.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/context.go b/context.go index d4cba8447..d917f3bc9 100644 --- a/context.go +++ b/context.go @@ -335,7 +335,7 @@ func (c *context) SetParamNames(names ...string) { if len(c.pvalues) < l { // Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them, - // probably those values will be overriden in a Context#SetParamValues + // probably those values will be overridden in a Context#SetParamValues newPvalues := make([]string, l) copy(newPvalues, c.pvalues) c.pvalues = newPvalues diff --git a/echo.go b/echo.go index 7b6a0907d..1599f5cb7 100644 --- a/echo.go +++ b/echo.go @@ -419,7 +419,7 @@ func (e *Echo) Routers() map[string]*Router { // // NOTE: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). // When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from -// handler. Then the error that global error handler received will be ignored because we have already "commited" the +// handler. Then the error that global error handler received will be ignored because we have already "committed" the // response and status code header has been sent to the client. func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { diff --git a/ip.go b/ip.go index 1bcd756ae..905268abf 100644 --- a/ip.go +++ b/ip.go @@ -64,7 +64,7 @@ XFF: "x" "x, a" "x, a, b" ``` In this case, use **first _untrustable_ IP reading from right**. Never use first one reading from left, as it is -configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructre". +configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructure". In above example, if `b` and `c` are trustable, the IP address of the client is `a` for both cases, never be `x`. In Echo, use `ExtractIPFromXFFHeader(...TrustOption)`. diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go index 51d617abb..f3c5f8425 100644 --- a/middleware/request_logger_test.go +++ b/middleware/request_logger_test.go @@ -194,7 +194,7 @@ func TestRequestLogger_LogValuesFuncError(t *testing.T) { e.ServeHTTP(rec, req) // NOTE: when global error handler received error returned from middleware the status has already - // been written to the client and response has been "commited" therefore global error handler does not do anything + // been written to the client and response has been "committed" therefore global error handler does not do anything // and error that bubbled up in middleware chain will not be reflected in response code. assert.Equal(t, http.StatusTeapot, rec.Code) assert.Equal(t, http.StatusTeapot, expect.Status) From 3e04e3e2f25cb37932e5fcb55574d42138a652ce Mon Sep 17 00:00:00 2001 From: pomadev <45284098+pomadev@users.noreply.github.com> Date: Thu, 7 Mar 2024 04:52:53 +0900 Subject: [PATCH 334/446] fix: some typos (#2596) --- echo.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/echo.go b/echo.go index 1599f5cb7..eb8a79f38 100644 --- a/echo.go +++ b/echo.go @@ -70,7 +70,7 @@ type ( filesystem common // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get - // listener address info (on which interface/port was listener binded) without having data races. + // listener address info (on which interface/port was listener bound) without having data races. startupMutex sync.RWMutex colorer *color.Color From bc1e1904f1f7b641b3c5eca11be634735a3688f9 Mon Sep 17 00:00:00 2001 From: Martti T Date: Sat, 9 Mar 2024 10:50:47 +0200 Subject: [PATCH 335/446] Allow ResponseWriters to unwrap writers when flushing/hijacking (#2595) * Allow ResponseWriters to unwrap writers when flushing/hijacking --- middleware/body_dump.go | 12 +++++-- middleware/body_dump_test.go | 50 +++++++++++++++++++++++++++ middleware/compress.go | 10 +++--- middleware/compress_test.go | 30 ++++++++++++++++ middleware/middleware_test.go | 46 ++++++++++++++++++++++++ middleware/responsecontroller_1.19.go | 41 ++++++++++++++++++++++ middleware/responsecontroller_1.20.go | 17 +++++++++ response.go | 8 +++-- response_test.go | 25 ++++++++++++++ responsecontroller_1.19.go | 41 ++++++++++++++++++++++ responsecontroller_1.20.go | 17 +++++++++ 11 files changed, 289 insertions(+), 8 deletions(-) create mode 100644 middleware/responsecontroller_1.19.go create mode 100644 middleware/responsecontroller_1.20.go create mode 100644 responsecontroller_1.19.go create mode 100644 responsecontroller_1.20.go diff --git a/middleware/body_dump.go b/middleware/body_dump.go index fa7891b16..946ffc58f 100644 --- a/middleware/body_dump.go +++ b/middleware/body_dump.go @@ -3,6 +3,7 @@ package middleware import ( "bufio" "bytes" + "errors" "io" "net" "net/http" @@ -98,9 +99,16 @@ func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) { } func (w *bodyDumpResponseWriter) Flush() { - w.ResponseWriter.(http.Flusher).Flush() + err := responseControllerFlush(w.ResponseWriter) + if err != nil && errors.Is(err, http.ErrNotSupported) { + panic(errors.New("response writer flushing is not supported")) + } } func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return w.ResponseWriter.(http.Hijacker).Hijack() + return responseControllerHijack(w.ResponseWriter) +} + +func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter } diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go index de1de3356..a68930b49 100644 --- a/middleware/body_dump_test.go +++ b/middleware/body_dump_test.go @@ -87,3 +87,53 @@ func TestBodyDumpFails(t *testing.T) { } }) } + +func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) { + bdrw := bodyDumpResponseWriter{ + ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush + } + + assert.PanicsWithError(t, "response writer flushing is not supported", func() { + bdrw.Flush() + }) +} + +func TestBodyDumpResponseWriter_CanFlush(t *testing.T) { + trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} + bdrw := bodyDumpResponseWriter{ + ResponseWriter: &trwu, + } + + bdrw.Flush() + assert.Equal(t, 1, trwu.unwrapCalled) +} + +func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) { + trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()} + bdrw := bodyDumpResponseWriter{ + ResponseWriter: trwu, + } + + result := bdrw.Unwrap() + assert.Equal(t, trwu, result) +} + +func TestBodyDumpResponseWriter_CanHijack(t *testing.T) { + trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} + bdrw := bodyDumpResponseWriter{ + ResponseWriter: &trwu, // this RW supports hijacking through unwrapping + } + + _, _, err := bdrw.Hijack() + assert.EqualError(t, err, "can hijack") +} + +func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) { + trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()} + bdrw := bodyDumpResponseWriter{ + ResponseWriter: &trwu, // this RW supports hijacking through unwrapping + } + + _, _, err := bdrw.Hijack() + assert.EqualError(t, err, "feature not supported") +} diff --git a/middleware/compress.go b/middleware/compress.go index 3e9bd3201..c77062d92 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -191,13 +191,15 @@ func (w *gzipResponseWriter) Flush() { } w.Writer.(*gzip.Writer).Flush() - if flusher, ok := w.ResponseWriter.(http.Flusher); ok { - flusher.Flush() - } + _ = responseControllerFlush(w.ResponseWriter) +} + +func (w *gzipResponseWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter } func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return w.ResponseWriter.(http.Hijacker).Hijack() + return responseControllerHijack(w.ResponseWriter) } func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { diff --git a/middleware/compress_test.go b/middleware/compress_test.go index 0ed16c813..6c5ce4123 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -311,6 +311,36 @@ func TestGzipWithStatic(t *testing.T) { } } +func TestGzipResponseWriter_CanUnwrap(t *testing.T) { + trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()} + bdrw := gzipResponseWriter{ + ResponseWriter: trwu, + } + + result := bdrw.Unwrap() + assert.Equal(t, trwu, result) +} + +func TestGzipResponseWriter_CanHijack(t *testing.T) { + trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} + bdrw := gzipResponseWriter{ + ResponseWriter: &trwu, // this RW supports hijacking through unwrapping + } + + _, _, err := bdrw.Hijack() + assert.EqualError(t, err, "can hijack") +} + +func TestGzipResponseWriter_CanNotHijack(t *testing.T) { + trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()} + bdrw := gzipResponseWriter{ + ResponseWriter: &trwu, // this RW supports hijacking through unwrapping + } + + _, _, err := bdrw.Hijack() + assert.EqualError(t, err, "feature not supported") +} + func BenchmarkGzip(b *testing.B) { e := echo.New() diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 44f44142c..990568d55 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -1,7 +1,10 @@ package middleware import ( + "bufio" + "errors" "github.com/stretchr/testify/assert" + "net" "net/http" "net/http/httptest" "regexp" @@ -90,3 +93,46 @@ func TestRewriteURL(t *testing.T) { }) } } + +type testResponseWriterNoFlushHijack struct { +} + +func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) { +} + +func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) { + return 0, nil +} + +func (w *testResponseWriterNoFlushHijack) Header() http.Header { + return nil +} + +type testResponseWriterUnwrapper struct { + unwrapCalled int + rw http.ResponseWriter +} + +func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) { +} + +func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) { + return 0, nil +} + +func (w *testResponseWriterUnwrapper) Header() http.Header { + return nil +} + +func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter { + w.unwrapCalled++ + return w.rw +} + +type testResponseWriterUnwrapperHijack struct { + testResponseWriterUnwrapper +} + +func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, errors.New("can hijack") +} diff --git a/middleware/responsecontroller_1.19.go b/middleware/responsecontroller_1.19.go new file mode 100644 index 000000000..104784fd0 --- /dev/null +++ b/middleware/responsecontroller_1.19.go @@ -0,0 +1,41 @@ +//go:build !go1.20 + +package middleware + +import ( + "bufio" + "fmt" + "net" + "net/http" +) + +// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore +func responseControllerFlush(rw http.ResponseWriter) error { + for { + switch t := rw.(type) { + case interface{ FlushError() error }: + return t.FlushError() + case http.Flusher: + t.Flush() + return nil + case interface{ Unwrap() http.ResponseWriter }: + rw = t.Unwrap() + default: + return fmt.Errorf("%w", http.ErrNotSupported) + } + } +} + +// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore +func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) { + for { + switch t := rw.(type) { + case http.Hijacker: + return t.Hijack() + case interface{ Unwrap() http.ResponseWriter }: + rw = t.Unwrap() + default: + return nil, nil, fmt.Errorf("%w", http.ErrNotSupported) + } + } +} diff --git a/middleware/responsecontroller_1.20.go b/middleware/responsecontroller_1.20.go new file mode 100644 index 000000000..02a0cb754 --- /dev/null +++ b/middleware/responsecontroller_1.20.go @@ -0,0 +1,17 @@ +//go:build go1.20 + +package middleware + +import ( + "bufio" + "net" + "net/http" +) + +func responseControllerFlush(rw http.ResponseWriter) error { + return http.NewResponseController(rw).Flush() +} + +func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) { + return http.NewResponseController(rw).Hijack() +} diff --git a/response.go b/response.go index d9c9aa6e0..117881cc6 100644 --- a/response.go +++ b/response.go @@ -2,6 +2,7 @@ package echo import ( "bufio" + "errors" "net" "net/http" ) @@ -84,14 +85,17 @@ func (r *Response) Write(b []byte) (n int, err error) { // buffered data to the client. // See [http.Flusher](https://golang.org/pkg/net/http/#Flusher) func (r *Response) Flush() { - r.Writer.(http.Flusher).Flush() + err := responseControllerFlush(r.Writer) + if err != nil && errors.Is(err, http.ErrNotSupported) { + panic(errors.New("response writer flushing is not supported")) + } } // Hijack implements the http.Hijacker interface to allow an HTTP handler to // take over the connection. // See [http.Hijacker](https://golang.org/pkg/net/http/#Hijacker) func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return r.Writer.(http.Hijacker).Hijack() + return responseControllerHijack(r.Writer) } // Unwrap returns the original http.ResponseWriter. diff --git a/response_test.go b/response_test.go index e4fd636d8..e457a0193 100644 --- a/response_test.go +++ b/response_test.go @@ -57,6 +57,31 @@ func TestResponse_Flush(t *testing.T) { assert.True(t, rec.Flushed) } +type testResponseWriter struct { +} + +func (w *testResponseWriter) WriteHeader(statusCode int) { +} + +func (w *testResponseWriter) Write([]byte) (int, error) { + return 0, nil +} + +func (w *testResponseWriter) Header() http.Header { + return nil +} + +func TestResponse_FlushPanics(t *testing.T) { + e := New() + rw := new(testResponseWriter) + res := &Response{echo: e, Writer: rw} + + // we test that we behave as before unwrapping flushers - flushing writer that does not support it causes panic + assert.PanicsWithError(t, "response writer flushing is not supported", func() { + res.Flush() + }) +} + func TestResponse_ChangeStatusCodeBeforeWrite(t *testing.T) { e := New() rec := httptest.NewRecorder() diff --git a/responsecontroller_1.19.go b/responsecontroller_1.19.go new file mode 100644 index 000000000..75c6e3e58 --- /dev/null +++ b/responsecontroller_1.19.go @@ -0,0 +1,41 @@ +//go:build !go1.20 + +package echo + +import ( + "bufio" + "fmt" + "net" + "net/http" +) + +// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore +func responseControllerFlush(rw http.ResponseWriter) error { + for { + switch t := rw.(type) { + case interface{ FlushError() error }: + return t.FlushError() + case http.Flusher: + t.Flush() + return nil + case interface{ Unwrap() http.ResponseWriter }: + rw = t.Unwrap() + default: + return fmt.Errorf("%w", http.ErrNotSupported) + } + } +} + +// TODO: remove when Go 1.23 is released and we do not support 1.19 anymore +func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) { + for { + switch t := rw.(type) { + case http.Hijacker: + return t.Hijack() + case interface{ Unwrap() http.ResponseWriter }: + rw = t.Unwrap() + default: + return nil, nil, fmt.Errorf("%w", http.ErrNotSupported) + } + } +} diff --git a/responsecontroller_1.20.go b/responsecontroller_1.20.go new file mode 100644 index 000000000..fa2fe8b3f --- /dev/null +++ b/responsecontroller_1.20.go @@ -0,0 +1,17 @@ +//go:build go1.20 + +package echo + +import ( + "bufio" + "net" + "net/http" +) + +func responseControllerFlush(rw http.ResponseWriter) error { + return http.NewResponseController(rw).Flush() +} + +func responseControllerHijack(rw http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) { + return http.NewResponseController(rw).Hijack() +} From a842444e8f8b81cfc72b50e16f8134ecf5eda645 Mon Sep 17 00:00:00 2001 From: Martti T Date: Sat, 9 Mar 2024 11:21:24 +0200 Subject: [PATCH 336/446] Add SPDX licence comments to files. See https://spdx.dev/learn/handling-license-info/ (#2604) --- bind.go | 3 +++ bind_test.go | 3 +++ binder.go | 3 +++ binder_external_test.go | 3 +++ binder_test.go | 3 +++ context.go | 3 +++ context_fs.go | 3 +++ context_fs_test.go | 3 +++ context_test.go | 3 +++ echo.go | 3 +++ echo_fs.go | 3 +++ echo_fs_test.go | 3 +++ echo_test.go | 3 +++ group.go | 3 +++ group_fs.go | 3 +++ group_fs_test.go | 3 +++ group_test.go | 3 +++ ip.go | 3 +++ ip_test.go | 3 +++ json.go | 3 +++ json_test.go | 3 +++ log.go | 3 +++ middleware/basic_auth.go | 3 +++ middleware/basic_auth_test.go | 3 +++ middleware/body_dump.go | 3 +++ middleware/body_dump_test.go | 3 +++ middleware/body_limit.go | 3 +++ middleware/body_limit_test.go | 3 +++ middleware/compress.go | 3 +++ middleware/compress_test.go | 3 +++ middleware/context_timeout.go | 3 +++ middleware/context_timeout_test.go | 3 +++ middleware/cors.go | 3 +++ middleware/cors_test.go | 3 +++ middleware/csrf.go | 3 +++ middleware/csrf_test.go | 3 +++ middleware/decompress.go | 3 +++ middleware/decompress_test.go | 3 +++ middleware/extractor.go | 3 +++ middleware/extractor_test.go | 3 +++ middleware/jwt.go | 3 +++ middleware/jwt_test.go | 3 +++ middleware/key_auth.go | 3 +++ middleware/key_auth_test.go | 3 +++ middleware/logger.go | 3 +++ middleware/logger_test.go | 3 +++ middleware/method_override.go | 3 +++ middleware/method_override_test.go | 3 +++ middleware/middleware.go | 3 +++ middleware/middleware_test.go | 3 +++ middleware/proxy.go | 3 +++ middleware/proxy_test.go | 3 +++ middleware/rate_limiter.go | 3 +++ middleware/rate_limiter_test.go | 3 +++ middleware/recover.go | 3 +++ middleware/recover_test.go | 3 +++ middleware/redirect.go | 3 +++ middleware/redirect_test.go | 3 +++ middleware/request_id.go | 3 +++ middleware/request_id_test.go | 3 +++ middleware/request_logger.go | 3 +++ middleware/request_logger_test.go | 3 +++ middleware/responsecontroller_1.19.go | 3 +++ middleware/responsecontroller_1.20.go | 3 +++ middleware/rewrite.go | 3 +++ middleware/rewrite_test.go | 3 +++ middleware/secure.go | 3 +++ middleware/secure_test.go | 3 +++ middleware/slash.go | 3 +++ middleware/slash_test.go | 3 +++ middleware/static.go | 3 +++ middleware/static_other.go | 3 +++ middleware/static_test.go | 3 +++ middleware/static_windows.go | 3 +++ middleware/timeout.go | 3 +++ middleware/timeout_test.go | 3 +++ middleware/util.go | 3 +++ middleware/util_test.go | 3 +++ response.go | 3 +++ response_test.go | 3 +++ responsecontroller_1.19.go | 3 +++ responsecontroller_1.20.go | 3 +++ router.go | 3 +++ router_test.go | 3 +++ 84 files changed, 252 insertions(+) diff --git a/bind.go b/bind.go index 51f4689e7..353c51325 100644 --- a/bind.go +++ b/bind.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/bind_test.go b/bind_test.go index cffccfb35..c0272e712 100644 --- a/bind_test.go +++ b/bind_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/binder.go b/binder.go index 8e7b81413..ebabeaf96 100644 --- a/binder.go +++ b/binder.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/binder_external_test.go b/binder_external_test.go index f1aecb52b..e44055a23 100644 --- a/binder_external_test.go +++ b/binder_external_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + // run tests as external package to get real feel for API package echo_test diff --git a/binder_test.go b/binder_test.go index 0b27cae64..d552b604d 100644 --- a/binder_test.go +++ b/binder_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/context.go b/context.go index d917f3bc9..2b4acae32 100644 --- a/context.go +++ b/context.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/context_fs.go b/context_fs.go index 1038f892e..1c25baf12 100644 --- a/context_fs.go +++ b/context_fs.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/context_fs_test.go b/context_fs_test.go index 51346c956..83232ea45 100644 --- a/context_fs_test.go +++ b/context_fs_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/context_test.go b/context_test.go index 4ca2cc84b..463e10a60 100644 --- a/context_test.go +++ b/context_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/echo.go b/echo.go index eb8a79f38..4d11af04a 100644 --- a/echo.go +++ b/echo.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + /* Package echo implements high performance, minimalist Go web framework. diff --git a/echo_fs.go b/echo_fs.go index 9f83a0351..a7b231f31 100644 --- a/echo_fs.go +++ b/echo_fs.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/echo_fs_test.go b/echo_fs_test.go index eb072a28d..e882a0682 100644 --- a/echo_fs_test.go +++ b/echo_fs_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/echo_test.go b/echo_test.go index 416479191..f09544127 100644 --- a/echo_test.go +++ b/echo_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/group.go b/group.go index 749a5caab..e69d80b7f 100644 --- a/group.go +++ b/group.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/group_fs.go b/group_fs.go index aedc4c6a9..c1b7ec2d3 100644 --- a/group_fs.go +++ b/group_fs.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/group_fs_test.go b/group_fs_test.go index 958d9efb1..8bcd547d1 100644 --- a/group_fs_test.go +++ b/group_fs_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/group_test.go b/group_test.go index d22f564b0..a97371418 100644 --- a/group_test.go +++ b/group_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/ip.go b/ip.go index 905268abf..5374dc018 100644 --- a/ip.go +++ b/ip.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/ip_test.go b/ip_test.go index 38c4a1cac..20e3127a8 100644 --- a/ip_test.go +++ b/ip_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/json.go b/json.go index 16b2d0577..6da0aaf97 100644 --- a/json.go +++ b/json.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/json_test.go b/json_test.go index 8fb9ebc96..0b15ed1a1 100644 --- a/json_test.go +++ b/json_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/log.go b/log.go index 3f8de5904..b9ec3d561 100644 --- a/log.go +++ b/log.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 07a5761b8..7e809f5f7 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 2e133e071..6e07065bf 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/body_dump.go b/middleware/body_dump.go index 946ffc58f..e7b20981c 100644 --- a/middleware/body_dump.go +++ b/middleware/body_dump.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go index a68930b49..e880af45b 100644 --- a/middleware/body_dump_test.go +++ b/middleware/body_dump_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/body_limit.go b/middleware/body_limit.go index 99e3ac547..81972304e 100644 --- a/middleware/body_limit.go +++ b/middleware/body_limit.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index 0fd66ee0f..d14c2b649 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/compress.go b/middleware/compress.go index c77062d92..681c0346f 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/compress_test.go b/middleware/compress_test.go index 6c5ce4123..4bbdfdbc2 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/context_timeout.go b/middleware/context_timeout.go index 1937693f1..e67173f21 100644 --- a/middleware/context_timeout.go +++ b/middleware/context_timeout.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/context_timeout_test.go b/middleware/context_timeout_test.go index 24c6203e7..e69bcd268 100644 --- a/middleware/context_timeout_test.go +++ b/middleware/context_timeout_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/cors.go b/middleware/cors.go index 7ace2f224..dd7030e56 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/cors_test.go b/middleware/cors_test.go index 797600c5c..64e5c6542 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/csrf.go b/middleware/csrf.go index adf12210b..015473d9f 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 6b20297ee..98e5d04f6 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/decompress.go b/middleware/decompress.go index a73c9738b..3dded53c5 100644 --- a/middleware/decompress.go +++ b/middleware/decompress.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 351e0e708..63b1a68f5 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/extractor.go b/middleware/extractor.go index 5d9cee6d0..3f2741407 100644 --- a/middleware/extractor.go +++ b/middleware/extractor.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go index 428c5563e..42cbcfeab 100644 --- a/middleware/extractor_test.go +++ b/middleware/extractor_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/jwt.go b/middleware/jwt.go index bc318c976..276bdfe39 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + //go:build go1.15 // +build go1.15 diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 90e8cad81..bbe4b8808 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + //go:build go1.15 // +build go1.15 diff --git a/middleware/key_auth.go b/middleware/key_auth.go index f6fcc5d69..f7ce8c18a 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go index ff8968c38..447f0bee8 100644 --- a/middleware/key_auth_test.go +++ b/middleware/key_auth_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/logger.go b/middleware/logger.go index 7958d873b..43fd59ffc 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/logger_test.go b/middleware/logger_test.go index 9f35a70bc..d5236e1ac 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/method_override.go b/middleware/method_override.go index 92b14d2ed..668a57a41 100644 --- a/middleware/method_override.go +++ b/middleware/method_override.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/method_override_test.go b/middleware/method_override_test.go index 5760b1581..0000d1d80 100644 --- a/middleware/method_override_test.go +++ b/middleware/method_override_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/middleware.go b/middleware/middleware.go index 664f71f45..8dfb8dda6 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 990568d55..7f3dc3866 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/proxy.go b/middleware/proxy.go index 16b00d645..ddf4b7f06 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 1c93ba031..e87229ab5 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index 1d24df52a..a58b16491 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go index f66961fe2..1de7b63e5 100644 --- a/middleware/rate_limiter_test.go +++ b/middleware/rate_limiter_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/recover.go b/middleware/recover.go index 0466cfe56..35f38e72c 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/recover_test.go b/middleware/recover_test.go index 3e0d35d79..8fa34fa5c 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/redirect.go b/middleware/redirect.go index 13877db38..b772ac131 100644 --- a/middleware/redirect.go +++ b/middleware/redirect.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go index 9d1b56205..88068ea2e 100644 --- a/middleware/redirect_test.go +++ b/middleware/redirect_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/request_id.go b/middleware/request_id.go index e29c8f50d..411737cb4 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go index 21b777826..4e68b126a 100644 --- a/middleware/request_id_test.go +++ b/middleware/request_id_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/request_logger.go b/middleware/request_logger.go index f82f6b622..7c18200b0 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go index f3c5f8425..c612f5c22 100644 --- a/middleware/request_logger_test.go +++ b/middleware/request_logger_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/responsecontroller_1.19.go b/middleware/responsecontroller_1.19.go index 104784fd0..ddf6b64c0 100644 --- a/middleware/responsecontroller_1.19.go +++ b/middleware/responsecontroller_1.19.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + //go:build !go1.20 package middleware diff --git a/middleware/responsecontroller_1.20.go b/middleware/responsecontroller_1.20.go index 02a0cb754..bc03059bc 100644 --- a/middleware/responsecontroller_1.20.go +++ b/middleware/responsecontroller_1.20.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + //go:build go1.20 package middleware diff --git a/middleware/rewrite.go b/middleware/rewrite.go index 2090eac04..260dbb1f5 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index 47d707c30..d137b2d13 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/secure.go b/middleware/secure.go index 6c4051723..b70854ddc 100644 --- a/middleware/secure.go +++ b/middleware/secure.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/secure_test.go b/middleware/secure_test.go index 79bd172ae..b579a6d21 100644 --- a/middleware/secure_test.go +++ b/middleware/secure_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/slash.go b/middleware/slash.go index a3bf807ec..774cc5582 100644 --- a/middleware/slash.go +++ b/middleware/slash.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/slash_test.go b/middleware/slash_test.go index ddb071045..1b365cfea 100644 --- a/middleware/slash_test.go +++ b/middleware/slash_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/static.go b/middleware/static.go index 24a5f59b9..15a838175 100644 --- a/middleware/static.go +++ b/middleware/static.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/static_other.go b/middleware/static_other.go index 0337b22af..35dbfb38e 100644 --- a/middleware/static_other.go +++ b/middleware/static_other.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + //go:build !windows package middleware diff --git a/middleware/static_test.go b/middleware/static_test.go index f26d97a95..a10ab8000 100644 --- a/middleware/static_test.go +++ b/middleware/static_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/static_windows.go b/middleware/static_windows.go index 0ab119859..e294020a1 100644 --- a/middleware/static_windows.go +++ b/middleware/static_windows.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/timeout.go b/middleware/timeout.go index 4e8836c85..a47bd4b3b 100644 --- a/middleware/timeout.go +++ b/middleware/timeout.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/timeout_test.go b/middleware/timeout_test.go index 98d96baef..e8415d636 100644 --- a/middleware/timeout_test.go +++ b/middleware/timeout_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/util.go b/middleware/util.go index 4d2d172fc..09428eb0b 100644 --- a/middleware/util.go +++ b/middleware/util.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/middleware/util_test.go b/middleware/util_test.go index d0f20bba6..b54f12627 100644 --- a/middleware/util_test.go +++ b/middleware/util_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( diff --git a/response.go b/response.go index 117881cc6..7ca522eb1 100644 --- a/response.go +++ b/response.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/response_test.go b/response_test.go index e457a0193..70cba9776 100644 --- a/response_test.go +++ b/response_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/responsecontroller_1.19.go b/responsecontroller_1.19.go index 75c6e3e58..782dab3a3 100644 --- a/responsecontroller_1.19.go +++ b/responsecontroller_1.19.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + //go:build !go1.20 package echo diff --git a/responsecontroller_1.20.go b/responsecontroller_1.20.go index fa2fe8b3f..6d77c07f8 100644 --- a/responsecontroller_1.20.go +++ b/responsecontroller_1.20.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + //go:build go1.20 package echo diff --git a/router.go b/router.go index ee6f3fa48..0a9b7d267 100644 --- a/router.go +++ b/router.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( diff --git a/router_test.go b/router_test.go index 619cce092..52d9a0abb 100644 --- a/router_test.go +++ b/router_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( From f0966790fb018524dc9ead2898a97e3ee532d135 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 9 Mar 2024 11:23:12 +0200 Subject: [PATCH 337/446] Upgrade deps --- go.mod | 6 +++--- go.sum | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 089ffb140..89a0e86a0 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,8 @@ require ( github.com/labstack/gommon v0.4.2 github.com/stretchr/testify v1.8.4 github.com/valyala/fasttemplate v1.2.2 - golang.org/x/crypto v0.17.0 - golang.org/x/net v0.19.0 + golang.org/x/crypto v0.21.0 + golang.org/x/net v0.22.0 golang.org/x/time v0.5.0 ) @@ -18,7 +18,7 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.15.0 // indirect + golang.org/x/sys v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 0584b7e59..397a22dc5 100644 --- a/go.sum +++ b/go.sum @@ -17,14 +17,14 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= From 5f7bedfb86e10bf0024236adfea544d0f5a82689 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Sat, 9 Mar 2024 11:23:55 +0200 Subject: [PATCH 338/446] update makefile --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 6aff6a89f..f9e5afb09 100644 --- a/Makefile +++ b/Makefile @@ -31,6 +31,6 @@ benchmark: ## Run benchmarks help: ## Display this help screen @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -goversion ?= "1.17" -test_version: ## Run tests inside Docker with given version (defaults to 1.17 oldest supported). Example: make test_version goversion=1.17 +goversion ?= "1.19" +test_version: ## Run tests inside Docker with given version (defaults to 1.19 oldest supported). Example: make test_version goversion=1.19 @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" From 3598f295f95f316bbeb252b7b332fe34e120815c Mon Sep 17 00:00:00 2001 From: Martti T Date: Sat, 9 Mar 2024 17:53:07 +0200 Subject: [PATCH 339/446] Change type definition blocks to single declarations. This helps copy/pasting Echo code in examples. (#2606) --- bind.go | 28 ++-- bind_test.go | 170 ++++++++++---------- context.go | 284 +++++++++++++++++----------------- context_test.go | 8 +- echo.go | 237 ++++++++++++++-------------- echo_test.go | 10 +- group.go | 22 ++- log.go | 67 ++++---- middleware/basic_auth.go | 48 +++--- middleware/body_dump.go | 42 +++-- middleware/body_limit.go | 42 +++-- middleware/compress.go | 80 +++++----- middleware/cors.go | 208 ++++++++++++------------- middleware/csrf.go | 138 ++++++++--------- middleware/decompress.go | 30 ++-- middleware/jwt.go | 244 ++++++++++++++--------------- middleware/key_auth.go | 106 ++++++------- middleware/logger.go | 136 ++++++++-------- middleware/method_override.go | 34 ++-- middleware/middleware.go | 14 +- middleware/proxy.go | 204 ++++++++++++------------ middleware/rate_limiter.go | 90 +++++------ middleware/recover.go | 92 ++++++----- middleware/request_id.go | 40 +++-- middleware/rewrite.go | 52 +++---- middleware/secure.go | 150 +++++++++--------- middleware/slash.go | 30 ++-- middleware/static.go | 78 +++++----- middleware/timeout.go | 14 +- response.go | 26 ++-- router.go | 102 ++++++------ 31 files changed, 1364 insertions(+), 1462 deletions(-) diff --git a/bind.go b/bind.go index 353c51325..5e29be8e5 100644 --- a/bind.go +++ b/bind.go @@ -14,23 +14,21 @@ import ( "strings" ) -type ( - // Binder is the interface that wraps the Bind method. - Binder interface { - Bind(i interface{}, c Context) error - } +// Binder is the interface that wraps the Bind method. +type Binder interface { + Bind(i interface{}, c Context) error +} - // DefaultBinder is the default implementation of the Binder interface. - DefaultBinder struct{} +// DefaultBinder is the default implementation of the Binder interface. +type DefaultBinder struct{} - // BindUnmarshaler is the interface used to wrap the UnmarshalParam method. - // Types that don't implement this, but do implement encoding.TextUnmarshaler - // will use that interface instead. - BindUnmarshaler interface { - // UnmarshalParam decodes and assigns a value from an form or query param. - UnmarshalParam(param string) error - } -) +// BindUnmarshaler is the interface used to wrap the UnmarshalParam method. +// Types that don't implement this, but do implement encoding.TextUnmarshaler +// will use that interface instead. +type BindUnmarshaler interface { + // UnmarshalParam decodes and assigns a value from an form or query param. + UnmarshalParam(param string) error +} // BindPathParams binds path params to bindable object func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { diff --git a/bind_test.go b/bind_test.go index c0272e712..05f8ef43c 100644 --- a/bind_test.go +++ b/bind_test.go @@ -22,91 +22,91 @@ import ( "github.com/stretchr/testify/assert" ) -type ( - bindTestStruct struct { - I int - PtrI *int - I8 int8 - PtrI8 *int8 - I16 int16 - PtrI16 *int16 - I32 int32 - PtrI32 *int32 - I64 int64 - PtrI64 *int64 - UI uint - PtrUI *uint - UI8 uint8 - PtrUI8 *uint8 - UI16 uint16 - PtrUI16 *uint16 - UI32 uint32 - PtrUI32 *uint32 - UI64 uint64 - PtrUI64 *uint64 - B bool - PtrB *bool - F32 float32 - PtrF32 *float32 - F64 float64 - PtrF64 *float64 - S string - PtrS *string - cantSet string - DoesntExist string - GoT time.Time - GoTptr *time.Time - T Timestamp - Tptr *Timestamp - SA StringArray - } - bindTestStructWithTags struct { - I int `json:"I" form:"I"` - PtrI *int `json:"PtrI" form:"PtrI"` - I8 int8 `json:"I8" form:"I8"` - PtrI8 *int8 `json:"PtrI8" form:"PtrI8"` - I16 int16 `json:"I16" form:"I16"` - PtrI16 *int16 `json:"PtrI16" form:"PtrI16"` - I32 int32 `json:"I32" form:"I32"` - PtrI32 *int32 `json:"PtrI32" form:"PtrI32"` - I64 int64 `json:"I64" form:"I64"` - PtrI64 *int64 `json:"PtrI64" form:"PtrI64"` - UI uint `json:"UI" form:"UI"` - PtrUI *uint `json:"PtrUI" form:"PtrUI"` - UI8 uint8 `json:"UI8" form:"UI8"` - PtrUI8 *uint8 `json:"PtrUI8" form:"PtrUI8"` - UI16 uint16 `json:"UI16" form:"UI16"` - PtrUI16 *uint16 `json:"PtrUI16" form:"PtrUI16"` - UI32 uint32 `json:"UI32" form:"UI32"` - PtrUI32 *uint32 `json:"PtrUI32" form:"PtrUI32"` - UI64 uint64 `json:"UI64" form:"UI64"` - PtrUI64 *uint64 `json:"PtrUI64" form:"PtrUI64"` - B bool `json:"B" form:"B"` - PtrB *bool `json:"PtrB" form:"PtrB"` - F32 float32 `json:"F32" form:"F32"` - PtrF32 *float32 `json:"PtrF32" form:"PtrF32"` - F64 float64 `json:"F64" form:"F64"` - PtrF64 *float64 `json:"PtrF64" form:"PtrF64"` - S string `json:"S" form:"S"` - PtrS *string `json:"PtrS" form:"PtrS"` - cantSet string - DoesntExist string `json:"DoesntExist" form:"DoesntExist"` - GoT time.Time `json:"GoT" form:"GoT"` - GoTptr *time.Time `json:"GoTptr" form:"GoTptr"` - T Timestamp `json:"T" form:"T"` - Tptr *Timestamp `json:"Tptr" form:"Tptr"` - SA StringArray `json:"SA" form:"SA"` - } - Timestamp time.Time - TA []Timestamp - StringArray []string - Struct struct { - Foo string - } - Bar struct { - Baz int `json:"baz" query:"baz"` - } -) +type bindTestStruct struct { + I int + PtrI *int + I8 int8 + PtrI8 *int8 + I16 int16 + PtrI16 *int16 + I32 int32 + PtrI32 *int32 + I64 int64 + PtrI64 *int64 + UI uint + PtrUI *uint + UI8 uint8 + PtrUI8 *uint8 + UI16 uint16 + PtrUI16 *uint16 + UI32 uint32 + PtrUI32 *uint32 + UI64 uint64 + PtrUI64 *uint64 + B bool + PtrB *bool + F32 float32 + PtrF32 *float32 + F64 float64 + PtrF64 *float64 + S string + PtrS *string + cantSet string + DoesntExist string + GoT time.Time + GoTptr *time.Time + T Timestamp + Tptr *Timestamp + SA StringArray +} + +type bindTestStructWithTags struct { + I int `json:"I" form:"I"` + PtrI *int `json:"PtrI" form:"PtrI"` + I8 int8 `json:"I8" form:"I8"` + PtrI8 *int8 `json:"PtrI8" form:"PtrI8"` + I16 int16 `json:"I16" form:"I16"` + PtrI16 *int16 `json:"PtrI16" form:"PtrI16"` + I32 int32 `json:"I32" form:"I32"` + PtrI32 *int32 `json:"PtrI32" form:"PtrI32"` + I64 int64 `json:"I64" form:"I64"` + PtrI64 *int64 `json:"PtrI64" form:"PtrI64"` + UI uint `json:"UI" form:"UI"` + PtrUI *uint `json:"PtrUI" form:"PtrUI"` + UI8 uint8 `json:"UI8" form:"UI8"` + PtrUI8 *uint8 `json:"PtrUI8" form:"PtrUI8"` + UI16 uint16 `json:"UI16" form:"UI16"` + PtrUI16 *uint16 `json:"PtrUI16" form:"PtrUI16"` + UI32 uint32 `json:"UI32" form:"UI32"` + PtrUI32 *uint32 `json:"PtrUI32" form:"PtrUI32"` + UI64 uint64 `json:"UI64" form:"UI64"` + PtrUI64 *uint64 `json:"PtrUI64" form:"PtrUI64"` + B bool `json:"B" form:"B"` + PtrB *bool `json:"PtrB" form:"PtrB"` + F32 float32 `json:"F32" form:"F32"` + PtrF32 *float32 `json:"PtrF32" form:"PtrF32"` + F64 float64 `json:"F64" form:"F64"` + PtrF64 *float64 `json:"PtrF64" form:"PtrF64"` + S string `json:"S" form:"S"` + PtrS *string `json:"PtrS" form:"PtrS"` + cantSet string + DoesntExist string `json:"DoesntExist" form:"DoesntExist"` + GoT time.Time `json:"GoT" form:"GoT"` + GoTptr *time.Time `json:"GoTptr" form:"GoTptr"` + T Timestamp `json:"T" form:"T"` + Tptr *Timestamp `json:"Tptr" form:"Tptr"` + SA StringArray `json:"SA" form:"SA"` +} + +type Timestamp time.Time +type TA []Timestamp +type StringArray []string +type Struct struct { + Foo string +} +type Bar struct { + Baz int `json:"baz" query:"baz"` +} func (t *Timestamp) UnmarshalParam(src string) error { ts, err := time.Parse(time.RFC3339, src) diff --git a/context.go b/context.go index 2b4acae32..a5177e884 100644 --- a/context.go +++ b/context.go @@ -16,204 +16,202 @@ import ( "sync" ) -type ( - // Context represents the context of the current HTTP request. It holds request and - // response objects, path, path parameters, data and registered handler. - Context interface { - // Request returns `*http.Request`. - Request() *http.Request +// Context represents the context of the current HTTP request. It holds request and +// response objects, path, path parameters, data and registered handler. +type Context interface { + // Request returns `*http.Request`. + Request() *http.Request - // SetRequest sets `*http.Request`. - SetRequest(r *http.Request) + // SetRequest sets `*http.Request`. + SetRequest(r *http.Request) - // SetResponse sets `*Response`. - SetResponse(r *Response) + // SetResponse sets `*Response`. + SetResponse(r *Response) - // Response returns `*Response`. - Response() *Response + // Response returns `*Response`. + Response() *Response - // IsTLS returns true if HTTP connection is TLS otherwise false. - IsTLS() bool + // IsTLS returns true if HTTP connection is TLS otherwise false. + IsTLS() bool - // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. - IsWebSocket() bool + // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. + IsWebSocket() bool - // Scheme returns the HTTP protocol scheme, `http` or `https`. - Scheme() string + // Scheme returns the HTTP protocol scheme, `http` or `https`. + Scheme() string - // RealIP returns the client's network address based on `X-Forwarded-For` - // or `X-Real-IP` request header. - // The behavior can be configured using `Echo#IPExtractor`. - RealIP() string + // RealIP returns the client's network address based on `X-Forwarded-For` + // or `X-Real-IP` request header. + // The behavior can be configured using `Echo#IPExtractor`. + RealIP() string - // Path returns the registered path for the handler. - Path() string + // Path returns the registered path for the handler. + Path() string - // SetPath sets the registered path for the handler. - SetPath(p string) + // SetPath sets the registered path for the handler. + SetPath(p string) - // Param returns path parameter by name. - Param(name string) string + // Param returns path parameter by name. + Param(name string) string - // ParamNames returns path parameter names. - ParamNames() []string + // ParamNames returns path parameter names. + ParamNames() []string - // SetParamNames sets path parameter names. - SetParamNames(names ...string) + // SetParamNames sets path parameter names. + SetParamNames(names ...string) - // ParamValues returns path parameter values. - ParamValues() []string + // ParamValues returns path parameter values. + ParamValues() []string - // SetParamValues sets path parameter values. - SetParamValues(values ...string) + // SetParamValues sets path parameter values. + SetParamValues(values ...string) - // QueryParam returns the query param for the provided name. - QueryParam(name string) string + // QueryParam returns the query param for the provided name. + QueryParam(name string) string - // QueryParams returns the query parameters as `url.Values`. - QueryParams() url.Values + // QueryParams returns the query parameters as `url.Values`. + QueryParams() url.Values - // QueryString returns the URL query string. - QueryString() string + // QueryString returns the URL query string. + QueryString() string - // FormValue returns the form field value for the provided name. - FormValue(name string) string + // FormValue returns the form field value for the provided name. + FormValue(name string) string - // FormParams returns the form parameters as `url.Values`. - FormParams() (url.Values, error) + // FormParams returns the form parameters as `url.Values`. + FormParams() (url.Values, error) - // FormFile returns the multipart form file for the provided name. - FormFile(name string) (*multipart.FileHeader, error) + // FormFile returns the multipart form file for the provided name. + FormFile(name string) (*multipart.FileHeader, error) - // MultipartForm returns the multipart form. - MultipartForm() (*multipart.Form, error) + // MultipartForm returns the multipart form. + MultipartForm() (*multipart.Form, error) - // Cookie returns the named cookie provided in the request. - Cookie(name string) (*http.Cookie, error) + // Cookie returns the named cookie provided in the request. + Cookie(name string) (*http.Cookie, error) - // SetCookie adds a `Set-Cookie` header in HTTP response. - SetCookie(cookie *http.Cookie) + // SetCookie adds a `Set-Cookie` header in HTTP response. + SetCookie(cookie *http.Cookie) - // Cookies returns the HTTP cookies sent with the request. - Cookies() []*http.Cookie + // Cookies returns the HTTP cookies sent with the request. + Cookies() []*http.Cookie - // Get retrieves data from the context. - Get(key string) interface{} + // Get retrieves data from the context. + Get(key string) interface{} - // Set saves data in the context. - Set(key string, val interface{}) + // Set saves data in the context. + Set(key string, val interface{}) - // Bind binds path params, query params and the request body into provided type `i`. The default binder - // binds body based on Content-Type header. - Bind(i interface{}) error + // Bind binds path params, query params and the request body into provided type `i`. The default binder + // binds body based on Content-Type header. + Bind(i interface{}) error - // Validate validates provided `i`. It is usually called after `Context#Bind()`. - // Validator must be registered using `Echo#Validator`. - Validate(i interface{}) error + // Validate validates provided `i`. It is usually called after `Context#Bind()`. + // Validator must be registered using `Echo#Validator`. + Validate(i interface{}) error - // Render renders a template with data and sends a text/html response with status - // code. Renderer must be registered using `Echo.Renderer`. - Render(code int, name string, data interface{}) error + // Render renders a template with data and sends a text/html response with status + // code. Renderer must be registered using `Echo.Renderer`. + Render(code int, name string, data interface{}) error - // HTML sends an HTTP response with status code. - HTML(code int, html string) error + // HTML sends an HTTP response with status code. + HTML(code int, html string) error - // HTMLBlob sends an HTTP blob response with status code. - HTMLBlob(code int, b []byte) error + // HTMLBlob sends an HTTP blob response with status code. + HTMLBlob(code int, b []byte) error - // String sends a string response with status code. - String(code int, s string) error + // String sends a string response with status code. + String(code int, s string) error - // JSON sends a JSON response with status code. - JSON(code int, i interface{}) error + // JSON sends a JSON response with status code. + JSON(code int, i interface{}) error - // JSONPretty sends a pretty-print JSON with status code. - JSONPretty(code int, i interface{}, indent string) error + // JSONPretty sends a pretty-print JSON with status code. + JSONPretty(code int, i interface{}, indent string) error - // JSONBlob sends a JSON blob response with status code. - JSONBlob(code int, b []byte) error + // JSONBlob sends a JSON blob response with status code. + JSONBlob(code int, b []byte) error - // JSONP sends a JSONP response with status code. It uses `callback` to construct - // the JSONP payload. - JSONP(code int, callback string, i interface{}) error + // JSONP sends a JSONP response with status code. It uses `callback` to construct + // the JSONP payload. + JSONP(code int, callback string, i interface{}) error - // JSONPBlob sends a JSONP blob response with status code. It uses `callback` - // to construct the JSONP payload. - JSONPBlob(code int, callback string, b []byte) error + // JSONPBlob sends a JSONP blob response with status code. It uses `callback` + // to construct the JSONP payload. + JSONPBlob(code int, callback string, b []byte) error - // XML sends an XML response with status code. - XML(code int, i interface{}) error + // XML sends an XML response with status code. + XML(code int, i interface{}) error - // XMLPretty sends a pretty-print XML with status code. - XMLPretty(code int, i interface{}, indent string) error + // XMLPretty sends a pretty-print XML with status code. + XMLPretty(code int, i interface{}, indent string) error - // XMLBlob sends an XML blob response with status code. - XMLBlob(code int, b []byte) error + // XMLBlob sends an XML blob response with status code. + XMLBlob(code int, b []byte) error - // Blob sends a blob response with status code and content type. - Blob(code int, contentType string, b []byte) error + // Blob sends a blob response with status code and content type. + Blob(code int, contentType string, b []byte) error - // Stream sends a streaming response with status code and content type. - Stream(code int, contentType string, r io.Reader) error + // Stream sends a streaming response with status code and content type. + Stream(code int, contentType string, r io.Reader) error - // File sends a response with the content of the file. - File(file string) error + // File sends a response with the content of the file. + File(file string) error - // Attachment sends a response as attachment, prompting client to save the - // file. - Attachment(file string, name string) error + // Attachment sends a response as attachment, prompting client to save the + // file. + Attachment(file string, name string) error - // Inline sends a response as inline, opening the file in the browser. - Inline(file string, name string) error + // Inline sends a response as inline, opening the file in the browser. + Inline(file string, name string) error - // NoContent sends a response with no body and a status code. - NoContent(code int) error + // NoContent sends a response with no body and a status code. + NoContent(code int) error - // Redirect redirects the request to a provided URL with status code. - Redirect(code int, url string) error + // Redirect redirects the request to a provided URL with status code. + Redirect(code int, url string) error - // Error invokes the registered global HTTP error handler. Generally used by middleware. - // A side-effect of calling global error handler is that now Response has been committed (sent to the client) and - // middlewares up in chain can not change Response status code or Response body anymore. - // - // Avoid using this method in handlers as no middleware will be able to effectively handle errors after that. - Error(err error) + // Error invokes the registered global HTTP error handler. Generally used by middleware. + // A side-effect of calling global error handler is that now Response has been committed (sent to the client) and + // middlewares up in chain can not change Response status code or Response body anymore. + // + // Avoid using this method in handlers as no middleware will be able to effectively handle errors after that. + Error(err error) - // Handler returns the matched handler by router. - Handler() HandlerFunc + // Handler returns the matched handler by router. + Handler() HandlerFunc - // SetHandler sets the matched handler by router. - SetHandler(h HandlerFunc) + // SetHandler sets the matched handler by router. + SetHandler(h HandlerFunc) - // Logger returns the `Logger` instance. - Logger() Logger + // Logger returns the `Logger` instance. + Logger() Logger - // SetLogger Set the logger - SetLogger(l Logger) + // SetLogger Set the logger + SetLogger(l Logger) - // Echo returns the `Echo` instance. - Echo() *Echo + // Echo returns the `Echo` instance. + Echo() *Echo - // Reset resets the context after request completes. It must be called along - // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. - // See `Echo#ServeHTTP()` - Reset(r *http.Request, w http.ResponseWriter) - } + // Reset resets the context after request completes. It must be called along + // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. + // See `Echo#ServeHTTP()` + Reset(r *http.Request, w http.ResponseWriter) +} - context struct { - request *http.Request - response *Response - path string - pnames []string - pvalues []string - query url.Values - handler HandlerFunc - store Map - echo *Echo - logger Logger - lock sync.RWMutex - } -) +type context struct { + request *http.Request + response *Response + path string + pnames []string + pvalues []string + query url.Values + handler HandlerFunc + store Map + echo *Echo + logger Logger + lock sync.RWMutex +} const ( // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain. diff --git a/context_test.go b/context_test.go index 463e10a60..e5c4a215a 100644 --- a/context_test.go +++ b/context_test.go @@ -25,11 +25,9 @@ import ( "github.com/stretchr/testify/assert" ) -type ( - Template struct { - templates *template.Template - } -) +type Template struct { + templates *template.Template +} var testUser = user{1, "Jon Snow"} diff --git a/echo.go b/echo.go index 4d11af04a..6e4ed9a8d 100644 --- a/echo.go +++ b/echo.go @@ -63,97 +63,95 @@ import ( "golang.org/x/net/http2/h2c" ) -type ( - // Echo is the top-level framework instance. - // - // Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these - // fields from handlers/middlewares and changing field values at the same time leads to data-races. - // Adding new routes after the server has been started is also not safe! - Echo struct { - filesystem - common - // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get - // listener address info (on which interface/port was listener bound) without having data races. - startupMutex sync.RWMutex - colorer *color.Color - - // premiddleware are middlewares that are run before routing is done. In case a pre-middleware returns - // an error the router is not executed and the request will end up in the global error handler. - premiddleware []MiddlewareFunc - middleware []MiddlewareFunc - maxParam *int - router *Router - routers map[string]*Router - pool sync.Pool - - StdLogger *stdLog.Logger - Server *http.Server - TLSServer *http.Server - Listener net.Listener - TLSListener net.Listener - AutoTLSManager autocert.Manager - DisableHTTP2 bool - Debug bool - HideBanner bool - HidePort bool - HTTPErrorHandler HTTPErrorHandler - Binder Binder - JSONSerializer JSONSerializer - Validator Validator - Renderer Renderer - Logger Logger - IPExtractor IPExtractor - ListenerNetwork string - - // OnAddRouteHandler is called when Echo adds new route to specific host router. - OnAddRouteHandler func(host string, route Route, handler HandlerFunc, middleware []MiddlewareFunc) - } - - // Route contains a handler and information for matching against requests. - Route struct { - Method string `json:"method"` - Path string `json:"path"` - Name string `json:"name"` - } - - // HTTPError represents an error that occurred while handling a request. - HTTPError struct { - Code int `json:"-"` - Message interface{} `json:"message"` - Internal error `json:"-"` // Stores the error returned by an external dependency - } - - // MiddlewareFunc defines a function to process middleware. - MiddlewareFunc func(next HandlerFunc) HandlerFunc - - // HandlerFunc defines a function to serve HTTP requests. - HandlerFunc func(c Context) error - - // HTTPErrorHandler is a centralized HTTP error handler. - HTTPErrorHandler func(err error, c Context) - - // Validator is the interface that wraps the Validate function. - Validator interface { - Validate(i interface{}) error - } - - // JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. - JSONSerializer interface { - Serialize(c Context, i interface{}, indent string) error - Deserialize(c Context, i interface{}) error - } - - // Renderer is the interface that wraps the Render function. - Renderer interface { - Render(io.Writer, string, interface{}, Context) error - } - - // Map defines a generic map of type `map[string]interface{}`. - Map map[string]interface{} - - // Common struct for Echo & Group. - common struct{} -) +// Echo is the top-level framework instance. +// +// Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these +// fields from handlers/middlewares and changing field values at the same time leads to data-races. +// Adding new routes after the server has been started is also not safe! +type Echo struct { + filesystem + common + // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get + // listener address info (on which interface/port was listener bound) without having data races. + startupMutex sync.RWMutex + colorer *color.Color + + // premiddleware are middlewares that are run before routing is done. In case a pre-middleware returns + // an error the router is not executed and the request will end up in the global error handler. + premiddleware []MiddlewareFunc + middleware []MiddlewareFunc + maxParam *int + router *Router + routers map[string]*Router + pool sync.Pool + + StdLogger *stdLog.Logger + Server *http.Server + TLSServer *http.Server + Listener net.Listener + TLSListener net.Listener + AutoTLSManager autocert.Manager + DisableHTTP2 bool + Debug bool + HideBanner bool + HidePort bool + HTTPErrorHandler HTTPErrorHandler + Binder Binder + JSONSerializer JSONSerializer + Validator Validator + Renderer Renderer + Logger Logger + IPExtractor IPExtractor + ListenerNetwork string + + // OnAddRouteHandler is called when Echo adds new route to specific host router. + OnAddRouteHandler func(host string, route Route, handler HandlerFunc, middleware []MiddlewareFunc) +} + +// Route contains a handler and information for matching against requests. +type Route struct { + Method string `json:"method"` + Path string `json:"path"` + Name string `json:"name"` +} + +// HTTPError represents an error that occurred while handling a request. +type HTTPError struct { + Code int `json:"-"` + Message interface{} `json:"message"` + Internal error `json:"-"` // Stores the error returned by an external dependency +} + +// MiddlewareFunc defines a function to process middleware. +type MiddlewareFunc func(next HandlerFunc) HandlerFunc + +// HandlerFunc defines a function to serve HTTP requests. +type HandlerFunc func(c Context) error + +// HTTPErrorHandler is a centralized HTTP error handler. +type HTTPErrorHandler func(err error, c Context) + +// Validator is the interface that wraps the Validate function. +type Validator interface { + Validate(i interface{}) error +} + +// JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. +type JSONSerializer interface { + Serialize(c Context, i interface{}, indent string) error + Deserialize(c Context, i interface{}) error +} + +// Renderer is the interface that wraps the Render function. +type Renderer interface { + Render(io.Writer, string, interface{}, Context) error +} + +// Map defines a generic map of type `map[string]interface{}`. +type Map map[string]interface{} + +// Common struct for Echo & Group. +type common struct{} // HTTP methods // NOTE: Deprecated, please use the stdlib constants directly instead. @@ -282,21 +280,19 @@ ____________________________________O/_______ ` ) -var ( - methods = [...]string{ - http.MethodConnect, - http.MethodDelete, - http.MethodGet, - http.MethodHead, - http.MethodOptions, - http.MethodPatch, - http.MethodPost, - PROPFIND, - http.MethodPut, - http.MethodTrace, - REPORT, - } -) +var methods = [...]string{ + http.MethodConnect, + http.MethodDelete, + http.MethodGet, + http.MethodHead, + http.MethodOptions, + http.MethodPatch, + http.MethodPost, + PROPFIND, + http.MethodPut, + http.MethodTrace, + REPORT, +} // Errors var ( @@ -349,22 +345,23 @@ var ( ErrInvalidListenerNetwork = errors.New("invalid listener network") ) -// Error handlers -var ( - NotFoundHandler = func(c Context) error { - return ErrNotFound - } +// NotFoundHandler is the handler that router uses in case there was no matching route found. Returns an error that results +// HTTP 404 status code. +var NotFoundHandler = func(c Context) error { + return ErrNotFound +} - MethodNotAllowedHandler = func(c Context) error { - // See RFC 7231 section 7.4.1: An origin server MUST generate an Allow field in a 405 (Method Not Allowed) - // response and MAY do so in any other response. For disabled resources an empty Allow header may be returned - routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string) - if ok && routerAllowMethods != "" { - c.Response().Header().Set(HeaderAllow, routerAllowMethods) - } - return ErrMethodNotAllowed +// MethodNotAllowedHandler is the handler thar router uses in case there was no matching route found but there was +// another matching routes for that requested URL. Returns an error that results HTTP 405 Method Not Allowed status code. +var MethodNotAllowedHandler = func(c Context) error { + // See RFC 7231 section 7.4.1: An origin server MUST generate an Allow field in a 405 (Method Not Allowed) + // response and MAY do so in any other response. For disabled resources an empty Allow header may be returned + routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string) + if ok && routerAllowMethods != "" { + c.Response().Header().Set(HeaderAllow, routerAllowMethods) } -) + return ErrMethodNotAllowed +} // New creates an instance of Echo. func New() (e *Echo) { diff --git a/echo_test.go b/echo_test.go index f09544127..57c257b17 100644 --- a/echo_test.go +++ b/echo_test.go @@ -25,12 +25,10 @@ import ( "golang.org/x/net/http2" ) -type ( - user struct { - ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"` - Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"` - } -) +type user struct { + ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"` + Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"` +} const ( userJSON = `{"id":1,"name":"Jon Snow"}` diff --git a/group.go b/group.go index e69d80b7f..eca25c947 100644 --- a/group.go +++ b/group.go @@ -7,18 +7,16 @@ import ( "net/http" ) -type ( - // Group is a set of sub-routes for a specified route. It can be used for inner - // routes that share a common middleware or functionality that should be separate - // from the parent echo instance while still inheriting from it. - Group struct { - common - host string - prefix string - middleware []MiddlewareFunc - echo *Echo - } -) +// Group is a set of sub-routes for a specified route. It can be used for inner +// routes that share a common middleware or functionality that should be separate +// from the parent echo instance while still inheriting from it. +type Group struct { + common + host string + prefix string + middleware []MiddlewareFunc + echo *Echo +} // Use implements `Echo#Use()` for sub-routes within the Group. func (g *Group) Use(middleware ...MiddlewareFunc) { diff --git a/log.go b/log.go index b9ec3d561..0acd9ff03 100644 --- a/log.go +++ b/log.go @@ -4,41 +4,38 @@ package echo import ( - "io" - "github.com/labstack/gommon/log" + "io" ) -type ( - // Logger defines the logging interface. - Logger interface { - Output() io.Writer - SetOutput(w io.Writer) - Prefix() string - SetPrefix(p string) - Level() log.Lvl - SetLevel(v log.Lvl) - SetHeader(h string) - Print(i ...interface{}) - Printf(format string, args ...interface{}) - Printj(j log.JSON) - Debug(i ...interface{}) - Debugf(format string, args ...interface{}) - Debugj(j log.JSON) - Info(i ...interface{}) - Infof(format string, args ...interface{}) - Infoj(j log.JSON) - Warn(i ...interface{}) - Warnf(format string, args ...interface{}) - Warnj(j log.JSON) - Error(i ...interface{}) - Errorf(format string, args ...interface{}) - Errorj(j log.JSON) - Fatal(i ...interface{}) - Fatalj(j log.JSON) - Fatalf(format string, args ...interface{}) - Panic(i ...interface{}) - Panicj(j log.JSON) - Panicf(format string, args ...interface{}) - } -) +// Logger defines the logging interface. +type Logger interface { + Output() io.Writer + SetOutput(w io.Writer) + Prefix() string + SetPrefix(p string) + Level() log.Lvl + SetLevel(v log.Lvl) + SetHeader(h string) + Print(i ...interface{}) + Printf(format string, args ...interface{}) + Printj(j log.JSON) + Debug(i ...interface{}) + Debugf(format string, args ...interface{}) + Debugj(j log.JSON) + Info(i ...interface{}) + Infof(format string, args ...interface{}) + Infoj(j log.JSON) + Warn(i ...interface{}) + Warnf(format string, args ...interface{}) + Warnj(j log.JSON) + Error(i ...interface{}) + Errorf(format string, args ...interface{}) + Errorj(j log.JSON) + Fatal(i ...interface{}) + Fatalj(j log.JSON) + Fatalf(format string, args ...interface{}) + Panic(i ...interface{}) + Panicj(j log.JSON) + Panicf(format string, args ...interface{}) +} diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 7e809f5f7..9285f29fd 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -12,39 +12,35 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // BasicAuthConfig defines the config for BasicAuth middleware. - BasicAuthConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Validator is a function to validate BasicAuth credentials. - // Required. - Validator BasicAuthValidator - - // Realm is a string to define realm attribute of BasicAuth. - // Default value "Restricted". - Realm string - } +// BasicAuthConfig defines the config for BasicAuth middleware. +type BasicAuthConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Validator is a function to validate BasicAuth credentials. + // Required. + Validator BasicAuthValidator + + // Realm is a string to define realm attribute of BasicAuth. + // Default value "Restricted". + Realm string +} - // BasicAuthValidator defines a function to validate BasicAuth credentials. - // The function should return a boolean indicating whether the credentials are valid, - // and an error if any error occurs during the validation process. - BasicAuthValidator func(string, string, echo.Context) (bool, error) -) +// BasicAuthValidator defines a function to validate BasicAuth credentials. +// The function should return a boolean indicating whether the credentials are valid, +// and an error if any error occurs during the validation process. +type BasicAuthValidator func(string, string, echo.Context) (bool, error) const ( basic = "basic" defaultRealm = "Restricted" ) -var ( - // DefaultBasicAuthConfig is the default BasicAuth middleware config. - DefaultBasicAuthConfig = BasicAuthConfig{ - Skipper: DefaultSkipper, - Realm: defaultRealm, - } -) +// DefaultBasicAuthConfig is the default BasicAuth middleware config. +var DefaultBasicAuthConfig = BasicAuthConfig{ + Skipper: DefaultSkipper, + Realm: defaultRealm, +} // BasicAuth returns an BasicAuth middleware. // diff --git a/middleware/body_dump.go b/middleware/body_dump.go index e7b20981c..b06f76202 100644 --- a/middleware/body_dump.go +++ b/middleware/body_dump.go @@ -14,32 +14,28 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // BodyDumpConfig defines the config for BodyDump middleware. - BodyDumpConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Handler receives request and response payload. - // Required. - Handler BodyDumpHandler - } +// BodyDumpConfig defines the config for BodyDump middleware. +type BodyDumpConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Handler receives request and response payload. + // Required. + Handler BodyDumpHandler +} - // BodyDumpHandler receives the request and response payload. - BodyDumpHandler func(echo.Context, []byte, []byte) +// BodyDumpHandler receives the request and response payload. +type BodyDumpHandler func(echo.Context, []byte, []byte) - bodyDumpResponseWriter struct { - io.Writer - http.ResponseWriter - } -) +type bodyDumpResponseWriter struct { + io.Writer + http.ResponseWriter +} -var ( - // DefaultBodyDumpConfig is the default BodyDump middleware config. - DefaultBodyDumpConfig = BodyDumpConfig{ - Skipper: DefaultSkipper, - } -) +// DefaultBodyDumpConfig is the default BodyDump middleware config. +var DefaultBodyDumpConfig = BodyDumpConfig{ + Skipper: DefaultSkipper, +} // BodyDump returns a BodyDump middleware. // diff --git a/middleware/body_limit.go b/middleware/body_limit.go index 81972304e..7d3c665f2 100644 --- a/middleware/body_limit.go +++ b/middleware/body_limit.go @@ -12,31 +12,27 @@ import ( "github.com/labstack/gommon/bytes" ) -type ( - // BodyLimitConfig defines the config for BodyLimit middleware. - BodyLimitConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Maximum allowed size for a request body, it can be specified - // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P. - Limit string `yaml:"limit"` - limit int64 - } +// BodyLimitConfig defines the config for BodyLimit middleware. +type BodyLimitConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Maximum allowed size for a request body, it can be specified + // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P. + Limit string `yaml:"limit"` + limit int64 +} - limitedReader struct { - BodyLimitConfig - reader io.ReadCloser - read int64 - } -) +type limitedReader struct { + BodyLimitConfig + reader io.ReadCloser + read int64 +} -var ( - // DefaultBodyLimitConfig is the default BodyLimit middleware config. - DefaultBodyLimitConfig = BodyLimitConfig{ - Skipper: DefaultSkipper, - } -) +// DefaultBodyLimitConfig is the default BodyLimit middleware config. +var DefaultBodyLimitConfig = BodyLimitConfig{ + Skipper: DefaultSkipper, +} // BodyLimit returns a BodyLimit middleware. // diff --git a/middleware/compress.go b/middleware/compress.go index 681c0346f..557bdc8e2 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -16,54 +16,50 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // GzipConfig defines the config for Gzip middleware. - GzipConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Gzip compression level. - // Optional. Default value -1. - Level int `yaml:"level"` - - // Length threshold before gzip compression is applied. - // Optional. Default value 0. - // - // Most of the time you will not need to change the default. Compressing - // a short response might increase the transmitted data because of the - // gzip format overhead. Compressing the response will also consume CPU - // and time on the server and the client (for decompressing). Depending on - // your use case such a threshold might be useful. - // - // See also: - // https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits - MinLength int - } +// GzipConfig defines the config for Gzip middleware. +type GzipConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Gzip compression level. + // Optional. Default value -1. + Level int `yaml:"level"` + + // Length threshold before gzip compression is applied. + // Optional. Default value 0. + // + // Most of the time you will not need to change the default. Compressing + // a short response might increase the transmitted data because of the + // gzip format overhead. Compressing the response will also consume CPU + // and time on the server and the client (for decompressing). Depending on + // your use case such a threshold might be useful. + // + // See also: + // https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits + MinLength int +} - gzipResponseWriter struct { - io.Writer - http.ResponseWriter - wroteHeader bool - wroteBody bool - minLength int - minLengthExceeded bool - buffer *bytes.Buffer - code int - } -) +type gzipResponseWriter struct { + io.Writer + http.ResponseWriter + wroteHeader bool + wroteBody bool + minLength int + minLengthExceeded bool + buffer *bytes.Buffer + code int +} const ( gzipScheme = "gzip" ) -var ( - // DefaultGzipConfig is the default Gzip middleware config. - DefaultGzipConfig = GzipConfig{ - Skipper: DefaultSkipper, - Level: -1, - MinLength: 0, - } -) +// DefaultGzipConfig is the default Gzip middleware config. +var DefaultGzipConfig = GzipConfig{ + Skipper: DefaultSkipper, + Level: -1, + MinLength: 0, +} // Gzip returns a middleware which compresses HTTP response using gzip compression // scheme. diff --git a/middleware/cors.go b/middleware/cors.go index dd7030e56..7af6a76f3 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -12,113 +12,109 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // CORSConfig defines the config for CORS middleware. - CORSConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // AllowOrigins determines the value of the Access-Control-Allow-Origin - // response header. This header defines a list of origins that may access the - // resource. The wildcard characters '*' and '?' are supported and are - // converted to regex fragments '.*' and '.' accordingly. - // - // Security: use extreme caution when handling the origin, and carefully - // validate any logic. Remember that attackers may register hostile domain names. - // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html - // - // Optional. Default value []string{"*"}. - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin - AllowOrigins []string `yaml:"allow_origins"` - - // AllowOriginFunc is a custom function to validate the origin. It takes the - // origin as an argument and returns true if allowed or false otherwise. If - // an error is returned, it is returned by the handler. If this option is - // set, AllowOrigins is ignored. - // - // Security: use extreme caution when handling the origin, and carefully - // validate any logic. Remember that attackers may register hostile domain names. - // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html - // - // Optional. - AllowOriginFunc func(origin string) (bool, error) `yaml:"-"` - - // AllowMethods determines the value of the Access-Control-Allow-Methods - // response header. This header specified the list of methods allowed when - // accessing the resource. This is used in response to a preflight request. - // - // Optional. Default value DefaultCORSConfig.AllowMethods. - // If `allowMethods` is left empty, this middleware will fill for preflight - // request `Access-Control-Allow-Methods` header value - // from `Allow` header that echo.Router set into context. - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods - AllowMethods []string `yaml:"allow_methods"` - - // AllowHeaders determines the value of the Access-Control-Allow-Headers - // response header. This header is used in response to a preflight request to - // indicate which HTTP headers can be used when making the actual request. - // - // Optional. Default value []string{}. - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers - AllowHeaders []string `yaml:"allow_headers"` - - // AllowCredentials determines the value of the - // Access-Control-Allow-Credentials response header. This header indicates - // whether or not the response to the request can be exposed when the - // credentials mode (Request.credentials) is true. When used as part of a - // response to a preflight request, this indicates whether or not the actual - // request can be made using credentials. See also - // [MDN: Access-Control-Allow-Credentials]. - // - // Optional. Default value false, in which case the header is not set. - // - // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. - // See "Exploiting CORS misconfigurations for Bitcoins and bounties", - // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials - AllowCredentials bool `yaml:"allow_credentials"` - - // UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials - // flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header. - // - // This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties) - // attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject. - // - // Optional. Default value is false. - UnsafeWildcardOriginWithAllowCredentials bool `yaml:"unsafe_wildcard_origin_with_allow_credentials"` - - // ExposeHeaders determines the value of Access-Control-Expose-Headers, which - // defines a list of headers that clients are allowed to access. - // - // Optional. Default value []string{}, in which case the header is not set. - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header - ExposeHeaders []string `yaml:"expose_headers"` - - // MaxAge determines the value of the Access-Control-Max-Age response header. - // This header indicates how long (in seconds) the results of a preflight - // request can be cached. - // The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response. - // - // Optional. Default value 0 - meaning header is not sent. - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age - MaxAge int `yaml:"max_age"` - } -) +// CORSConfig defines the config for CORS middleware. +type CORSConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // AllowOrigins determines the value of the Access-Control-Allow-Origin + // response header. This header defines a list of origins that may access the + // resource. The wildcard characters '*' and '?' are supported and are + // converted to regex fragments '.*' and '.' accordingly. + // + // Security: use extreme caution when handling the origin, and carefully + // validate any logic. Remember that attackers may register hostile domain names. + // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // + // Optional. Default value []string{"*"}. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin + AllowOrigins []string `yaml:"allow_origins"` + + // AllowOriginFunc is a custom function to validate the origin. It takes the + // origin as an argument and returns true if allowed or false otherwise. If + // an error is returned, it is returned by the handler. If this option is + // set, AllowOrigins is ignored. + // + // Security: use extreme caution when handling the origin, and carefully + // validate any logic. Remember that attackers may register hostile domain names. + // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // + // Optional. + AllowOriginFunc func(origin string) (bool, error) `yaml:"-"` + + // AllowMethods determines the value of the Access-Control-Allow-Methods + // response header. This header specified the list of methods allowed when + // accessing the resource. This is used in response to a preflight request. + // + // Optional. Default value DefaultCORSConfig.AllowMethods. + // If `allowMethods` is left empty, this middleware will fill for preflight + // request `Access-Control-Allow-Methods` header value + // from `Allow` header that echo.Router set into context. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + AllowMethods []string `yaml:"allow_methods"` + + // AllowHeaders determines the value of the Access-Control-Allow-Headers + // response header. This header is used in response to a preflight request to + // indicate which HTTP headers can be used when making the actual request. + // + // Optional. Default value []string{}. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + AllowHeaders []string `yaml:"allow_headers"` + + // AllowCredentials determines the value of the + // Access-Control-Allow-Credentials response header. This header indicates + // whether or not the response to the request can be exposed when the + // credentials mode (Request.credentials) is true. When used as part of a + // response to a preflight request, this indicates whether or not the actual + // request can be made using credentials. See also + // [MDN: Access-Control-Allow-Credentials]. + // + // Optional. Default value false, in which case the header is not set. + // + // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. + // See "Exploiting CORS misconfigurations for Bitcoins and bounties", + // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials + AllowCredentials bool `yaml:"allow_credentials"` + + // UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials + // flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header. + // + // This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties) + // attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject. + // + // Optional. Default value is false. + UnsafeWildcardOriginWithAllowCredentials bool `yaml:"unsafe_wildcard_origin_with_allow_credentials"` + + // ExposeHeaders determines the value of Access-Control-Expose-Headers, which + // defines a list of headers that clients are allowed to access. + // + // Optional. Default value []string{}, in which case the header is not set. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header + ExposeHeaders []string `yaml:"expose_headers"` + + // MaxAge determines the value of the Access-Control-Max-Age response header. + // This header indicates how long (in seconds) the results of a preflight + // request can be cached. + // The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response. + // + // Optional. Default value 0 - meaning header is not sent. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age + MaxAge int `yaml:"max_age"` +} -var ( - // DefaultCORSConfig is the default CORS middleware config. - DefaultCORSConfig = CORSConfig{ - Skipper: DefaultSkipper, - AllowOrigins: []string{"*"}, - AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, - } -) +// DefaultCORSConfig is the default CORS middleware config. +var DefaultCORSConfig = CORSConfig{ + Skipper: DefaultSkipper, + AllowOrigins: []string{"*"}, + AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, +} // CORS returns a Cross-Origin Resource Sharing (CORS) middleware. // See also [MDN: Cross-Origin Resource Sharing (CORS)]. diff --git a/middleware/csrf.go b/middleware/csrf.go index 015473d9f..92f4019dc 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -11,82 +11,78 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // CSRFConfig defines the config for CSRF middleware. - CSRFConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // TokenLength is the length of the generated token. - TokenLength uint8 `yaml:"token_length"` - // Optional. Default value 32. - - // TokenLookup is a string in the form of ":" or ":,:" that is used - // to extract token from the request. - // Optional. Default value "header:X-CSRF-Token". - // Possible values: - // - "header:" or "header::" - // - "query:" - // - "form:" - // Multiple sources example: - // - "header:X-CSRF-Token,query:csrf" - TokenLookup string `yaml:"token_lookup"` - - // Context key to store generated CSRF token into context. - // Optional. Default value "csrf". - ContextKey string `yaml:"context_key"` - - // Name of the CSRF cookie. This cookie will store CSRF token. - // Optional. Default value "csrf". - CookieName string `yaml:"cookie_name"` - - // Domain of the CSRF cookie. - // Optional. Default value none. - CookieDomain string `yaml:"cookie_domain"` - - // Path of the CSRF cookie. - // Optional. Default value none. - CookiePath string `yaml:"cookie_path"` - - // Max age (in seconds) of the CSRF cookie. - // Optional. Default value 86400 (24hr). - CookieMaxAge int `yaml:"cookie_max_age"` - - // Indicates if CSRF cookie is secure. - // Optional. Default value false. - CookieSecure bool `yaml:"cookie_secure"` - - // Indicates if CSRF cookie is HTTP only. - // Optional. Default value false. - CookieHTTPOnly bool `yaml:"cookie_http_only"` - - // Indicates SameSite mode of the CSRF cookie. - // Optional. Default value SameSiteDefaultMode. - CookieSameSite http.SameSite `yaml:"cookie_same_site"` - - // ErrorHandler defines a function which is executed for returning custom errors. - ErrorHandler CSRFErrorHandler - } +// CSRFConfig defines the config for CSRF middleware. +type CSRFConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // TokenLength is the length of the generated token. + TokenLength uint8 `yaml:"token_length"` + // Optional. Default value 32. + + // TokenLookup is a string in the form of ":" or ":,:" that is used + // to extract token from the request. + // Optional. Default value "header:X-CSRF-Token". + // Possible values: + // - "header:" or "header::" + // - "query:" + // - "form:" + // Multiple sources example: + // - "header:X-CSRF-Token,query:csrf" + TokenLookup string `yaml:"token_lookup"` + + // Context key to store generated CSRF token into context. + // Optional. Default value "csrf". + ContextKey string `yaml:"context_key"` + + // Name of the CSRF cookie. This cookie will store CSRF token. + // Optional. Default value "csrf". + CookieName string `yaml:"cookie_name"` + + // Domain of the CSRF cookie. + // Optional. Default value none. + CookieDomain string `yaml:"cookie_domain"` + + // Path of the CSRF cookie. + // Optional. Default value none. + CookiePath string `yaml:"cookie_path"` + + // Max age (in seconds) of the CSRF cookie. + // Optional. Default value 86400 (24hr). + CookieMaxAge int `yaml:"cookie_max_age"` + + // Indicates if CSRF cookie is secure. + // Optional. Default value false. + CookieSecure bool `yaml:"cookie_secure"` + + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value false. + CookieHTTPOnly bool `yaml:"cookie_http_only"` + + // Indicates SameSite mode of the CSRF cookie. + // Optional. Default value SameSiteDefaultMode. + CookieSameSite http.SameSite `yaml:"cookie_same_site"` + + // ErrorHandler defines a function which is executed for returning custom errors. + ErrorHandler CSRFErrorHandler +} - // CSRFErrorHandler is a function which is executed for creating custom errors. - CSRFErrorHandler func(err error, c echo.Context) error -) +// CSRFErrorHandler is a function which is executed for creating custom errors. +type CSRFErrorHandler func(err error, c echo.Context) error // ErrCSRFInvalid is returned when CSRF check fails var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") -var ( - // DefaultCSRFConfig is the default CSRF middleware config. - DefaultCSRFConfig = CSRFConfig{ - Skipper: DefaultSkipper, - TokenLength: 32, - TokenLookup: "header:" + echo.HeaderXCSRFToken, - ContextKey: "csrf", - CookieName: "_csrf", - CookieMaxAge: 86400, - CookieSameSite: http.SameSiteDefaultMode, - } -) +// DefaultCSRFConfig is the default CSRF middleware config. +var DefaultCSRFConfig = CSRFConfig{ + Skipper: DefaultSkipper, + TokenLength: 32, + TokenLookup: "header:" + echo.HeaderXCSRFToken, + ContextKey: "csrf", + CookieName: "_csrf", + CookieMaxAge: 86400, + CookieSameSite: http.SameSiteDefaultMode, +} // CSRF returns a Cross-Site Request Forgery (CSRF) middleware. // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery diff --git a/middleware/decompress.go b/middleware/decompress.go index 3dded53c5..0c56176ee 100644 --- a/middleware/decompress.go +++ b/middleware/decompress.go @@ -12,16 +12,14 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // DecompressConfig defines the config for Decompress middleware. - DecompressConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers - GzipDecompressPool Decompressor - } -) +// DecompressConfig defines the config for Decompress middleware. +type DecompressConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers + GzipDecompressPool Decompressor +} // GZIPEncoding content-encoding header if set to "gzip", decompress body contents. const GZIPEncoding string = "gzip" @@ -31,13 +29,11 @@ type Decompressor interface { gzipDecompressPool() sync.Pool } -var ( - //DefaultDecompressConfig defines the config for decompress middleware - DefaultDecompressConfig = DecompressConfig{ - Skipper: DefaultSkipper, - GzipDecompressPool: &DefaultGzipDecompressPool{}, - } -) +// DefaultDecompressConfig defines the config for decompress middleware +var DefaultDecompressConfig = DecompressConfig{ + Skipper: DefaultSkipper, + GzipDecompressPool: &DefaultGzipDecompressPool{}, +} // DefaultGzipDecompressPool is the default implementation of Decompressor interface type DefaultGzipDecompressPool struct { diff --git a/middleware/jwt.go b/middleware/jwt.go index 276bdfe39..a6bf16f95 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -15,139 +15,135 @@ import ( "reflect" ) -type ( - // JWTConfig defines the config for JWT middleware. - JWTConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // BeforeFunc defines a function which is executed just before the middleware. - BeforeFunc BeforeFunc - - // SuccessHandler defines a function which is executed for a valid token before middleware chain continues with next - // middleware or handler. - SuccessHandler JWTSuccessHandler - - // ErrorHandler defines a function which is executed for an invalid token. - // It may be used to define a custom JWT error. - ErrorHandler JWTErrorHandler - - // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. - ErrorHandlerWithContext JWTErrorHandlerWithContext - - // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to - // ignore the error (by returning `nil`). - // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. - // In that case you can use ErrorHandlerWithContext to set a default public JWT token value in the request context - // and continue. Some logic down the remaining execution chain needs to check that (public) token value then. - ContinueOnIgnoredError bool - - // Signing key to validate token. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither user-defined KeyFunc nor SigningKeys is provided. - SigningKey interface{} - - // Map of signing keys to validate token with kid field usage. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither user-defined KeyFunc nor SigningKey is provided. - SigningKeys map[string]interface{} - - // Signing method used to check the token's signing algorithm. - // Optional. Default value HS256. - SigningMethod string - - // Context key to store user information from the token into context. - // Optional. Default value "user". - ContextKey string - - // Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation. - // Not used if custom ParseTokenFunc is set. - // Optional. Default value jwt.MapClaims - Claims jwt.Claims - - // TokenLookup is a string in the form of ":" or ":,:" that is used - // to extract token from the request. - // Optional. Default value "header:Authorization". - // Possible values: - // - "header:" or "header::" - // `` is argument value to cut/trim prefix of the extracted value. This is useful if header - // value has static prefix like `Authorization: ` where part that we - // want to cut is ` ` note the space at the end. - // In case of JWT tokens `Authorization: Bearer ` prefix we cut is `Bearer `. - // If prefix is left empty the whole value is returned. - // - "query:" - // - "param:" - // - "cookie:" - // - "form:" - // Multiple sources example: - // - "header:Authorization,cookie:myowncookie" - TokenLookup string - - // TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context. - // This is one of the two options to provide a token extractor. - // The order of precedence is user-defined TokenLookupFuncs, and TokenLookup. - // You can also provide both if you want. - TokenLookupFuncs []ValuesExtractor - - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string - - // KeyFunc defines a user-defined function that supplies the public key for a token validation. - // The function shall take care of verifying the signing algorithm and selecting the proper key. - // A user-defined KeyFunc can be useful if tokens are issued by an external party. - // Used by default ParseTokenFunc implementation. - // - // When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored. - // This is one of the three options to provide a token validation key. - // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. - // Required if neither SigningKeys nor SigningKey is provided. - // Not used if custom ParseTokenFunc is set. - // Default to an internal implementation verifying the signing algorithm and selecting the proper key. - KeyFunc jwt.Keyfunc - - // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token - // parsing fails or parsed token is invalid. - // Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library - ParseTokenFunc func(auth string, c echo.Context) (interface{}, error) - } +// JWTConfig defines the config for JWT middleware. +type JWTConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // BeforeFunc defines a function which is executed just before the middleware. + BeforeFunc BeforeFunc + + // SuccessHandler defines a function which is executed for a valid token before middleware chain continues with next + // middleware or handler. + SuccessHandler JWTSuccessHandler + + // ErrorHandler defines a function which is executed for an invalid token. + // It may be used to define a custom JWT error. + ErrorHandler JWTErrorHandler + + // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. + ErrorHandlerWithContext JWTErrorHandlerWithContext + + // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to + // ignore the error (by returning `nil`). + // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. + // In that case you can use ErrorHandlerWithContext to set a default public JWT token value in the request context + // and continue. Some logic down the remaining execution chain needs to check that (public) token value then. + ContinueOnIgnoredError bool + + // Signing key to validate token. + // This is one of the three options to provide a token validation key. + // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. + // Required if neither user-defined KeyFunc nor SigningKeys is provided. + SigningKey interface{} + + // Map of signing keys to validate token with kid field usage. + // This is one of the three options to provide a token validation key. + // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. + // Required if neither user-defined KeyFunc nor SigningKey is provided. + SigningKeys map[string]interface{} + + // Signing method used to check the token's signing algorithm. + // Optional. Default value HS256. + SigningMethod string + + // Context key to store user information from the token into context. + // Optional. Default value "user". + ContextKey string + + // Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation. + // Not used if custom ParseTokenFunc is set. + // Optional. Default value jwt.MapClaims + Claims jwt.Claims + + // TokenLookup is a string in the form of ":" or ":,:" that is used + // to extract token from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" or "header::" + // `` is argument value to cut/trim prefix of the extracted value. This is useful if header + // value has static prefix like `Authorization: ` where part that we + // want to cut is ` ` note the space at the end. + // In case of JWT tokens `Authorization: Bearer ` prefix we cut is `Bearer `. + // If prefix is left empty the whole value is returned. + // - "query:" + // - "param:" + // - "cookie:" + // - "form:" + // Multiple sources example: + // - "header:Authorization,cookie:myowncookie" + TokenLookup string + + // TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context. + // This is one of the two options to provide a token extractor. + // The order of precedence is user-defined TokenLookupFuncs, and TokenLookup. + // You can also provide both if you want. + TokenLookupFuncs []ValuesExtractor + + // AuthScheme to be used in the Authorization header. + // Optional. Default value "Bearer". + AuthScheme string + + // KeyFunc defines a user-defined function that supplies the public key for a token validation. + // The function shall take care of verifying the signing algorithm and selecting the proper key. + // A user-defined KeyFunc can be useful if tokens are issued by an external party. + // Used by default ParseTokenFunc implementation. + // + // When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored. + // This is one of the three options to provide a token validation key. + // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. + // Required if neither SigningKeys nor SigningKey is provided. + // Not used if custom ParseTokenFunc is set. + // Default to an internal implementation verifying the signing algorithm and selecting the proper key. + KeyFunc jwt.Keyfunc + + // ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token + // parsing fails or parsed token is invalid. + // Defaults to implementation using `github.com/golang-jwt/jwt` as JWT implementation library + ParseTokenFunc func(auth string, c echo.Context) (interface{}, error) +} - // JWTSuccessHandler defines a function which is executed for a valid token. - JWTSuccessHandler func(c echo.Context) +// JWTSuccessHandler defines a function which is executed for a valid token. +type JWTSuccessHandler func(c echo.Context) - // JWTErrorHandler defines a function which is executed for an invalid token. - JWTErrorHandler func(err error) error +// JWTErrorHandler defines a function which is executed for an invalid token. +type JWTErrorHandler func(err error) error - // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. - JWTErrorHandlerWithContext func(err error, c echo.Context) error -) +// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. +type JWTErrorHandlerWithContext func(err error, c echo.Context) error // Algorithms const ( AlgorithmHS256 = "HS256" ) -// Errors -var ( - ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt") - ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") -) - -var ( - // DefaultJWTConfig is the default JWT auth middleware config. - DefaultJWTConfig = JWTConfig{ - Skipper: DefaultSkipper, - SigningMethod: AlgorithmHS256, - ContextKey: "user", - TokenLookup: "header:" + echo.HeaderAuthorization, - TokenLookupFuncs: nil, - AuthScheme: "Bearer", - Claims: jwt.MapClaims{}, - KeyFunc: nil, - } -) +// ErrJWTMissing is error that is returned when no JWToken was extracted from the request. +var ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt") + +// ErrJWTInvalid is error that is returned when middleware could not parse JWT correctly. +var ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") + +// DefaultJWTConfig is the default JWT auth middleware config. +var DefaultJWTConfig = JWTConfig{ + Skipper: DefaultSkipper, + SigningMethod: AlgorithmHS256, + ContextKey: "user", + TokenLookup: "header:" + echo.HeaderAuthorization, + TokenLookupFuncs: nil, + AuthScheme: "Bearer", + Claims: jwt.MapClaims{}, + KeyFunc: nil, +} // JWT returns a JSON Web Token (JWT) auth middleware. // diff --git a/middleware/key_auth.go b/middleware/key_auth.go index f7ce8c18a..79bee207c 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -9,69 +9,65 @@ import ( "net/http" ) -type ( - // KeyAuthConfig defines the config for KeyAuth middleware. - KeyAuthConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // KeyLookup is a string in the form of ":" or ":,:" that is used - // to extract key from the request. - // Optional. Default value "header:Authorization". - // Possible values: - // - "header:" or "header::" - // `` is argument value to cut/trim prefix of the extracted value. This is useful if header - // value has static prefix like `Authorization: ` where part that we - // want to cut is ` ` note the space at the end. - // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. - // - "query:" - // - "form:" - // - "cookie:" - // Multiple sources example: - // - "header:Authorization,header:X-Api-Key" - KeyLookup string - - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string - - // Validator is a function to validate key. - // Required. - Validator KeyAuthValidator - - // ErrorHandler defines a function which is executed for an invalid key. - // It may be used to define a custom error. - ErrorHandler KeyAuthErrorHandler - - // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to - // ignore the error (by returning `nil`). - // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. - // In that case you can use ErrorHandler to set a default public key auth value in the request context - // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. - ContinueOnIgnoredError bool - } - - // KeyAuthValidator defines a function to validate KeyAuth credentials. - KeyAuthValidator func(auth string, c echo.Context) (bool, error) +// KeyAuthConfig defines the config for KeyAuth middleware. +type KeyAuthConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // KeyLookup is a string in the form of ":" or ":,:" that is used + // to extract key from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" or "header::" + // `` is argument value to cut/trim prefix of the extracted value. This is useful if header + // value has static prefix like `Authorization: ` where part that we + // want to cut is ` ` note the space at the end. + // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. + // - "query:" + // - "form:" + // - "cookie:" + // Multiple sources example: + // - "header:Authorization,header:X-Api-Key" + KeyLookup string + + // AuthScheme to be used in the Authorization header. + // Optional. Default value "Bearer". + AuthScheme string + + // Validator is a function to validate key. + // Required. + Validator KeyAuthValidator + + // ErrorHandler defines a function which is executed for an invalid key. + // It may be used to define a custom error. + ErrorHandler KeyAuthErrorHandler + + // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to + // ignore the error (by returning `nil`). + // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. + // In that case you can use ErrorHandler to set a default public key auth value in the request context + // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. + ContinueOnIgnoredError bool +} - // KeyAuthErrorHandler defines a function which is executed for an invalid key. - KeyAuthErrorHandler func(err error, c echo.Context) error -) +// KeyAuthValidator defines a function to validate KeyAuth credentials. +type KeyAuthValidator func(auth string, c echo.Context) (bool, error) -var ( - // DefaultKeyAuthConfig is the default KeyAuth middleware config. - DefaultKeyAuthConfig = KeyAuthConfig{ - Skipper: DefaultSkipper, - KeyLookup: "header:" + echo.HeaderAuthorization, - AuthScheme: "Bearer", - } -) +// KeyAuthErrorHandler defines a function which is executed for an invalid key. +type KeyAuthErrorHandler func(err error, c echo.Context) error // ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups type ErrKeyAuthMissing struct { Err error } +// DefaultKeyAuthConfig is the default KeyAuth middleware config. +var DefaultKeyAuthConfig = KeyAuthConfig{ + Skipper: DefaultSkipper, + KeyLookup: "header:" + echo.HeaderAuthorization, + AuthScheme: "Bearer", +} + // Error returns errors text func (e *ErrKeyAuthMissing) Error() string { return e.Err.Error() diff --git a/middleware/logger.go b/middleware/logger.go index 43fd59ffc..910fce8cf 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -17,77 +17,73 @@ import ( "github.com/valyala/fasttemplate" ) -type ( - // LoggerConfig defines the config for Logger middleware. - LoggerConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Tags to construct the logger format. - // - // - time_unix - // - time_unix_milli - // - time_unix_micro - // - time_unix_nano - // - time_rfc3339 - // - time_rfc3339_nano - // - time_custom - // - id (Request ID) - // - remote_ip - // - uri - // - host - // - method - // - path - // - route - // - protocol - // - referer - // - user_agent - // - status - // - error - // - latency (In nanoseconds) - // - latency_human (Human readable) - // - bytes_in (Bytes received) - // - bytes_out (Bytes sent) - // - header: - // - query: - // - form: - // - custom (see CustomTagFunc field) - // - // Example "${remote_ip} ${status}" - // - // Optional. Default value DefaultLoggerConfig.Format. - Format string `yaml:"format"` - - // Optional. Default value DefaultLoggerConfig.CustomTimeFormat. - CustomTimeFormat string `yaml:"custom_time_format"` - - // CustomTagFunc is function called for `${custom}` tag to output user implemented text by writing it to buf. - // Make sure that outputted text creates valid JSON string with other logged tags. - // Optional. - CustomTagFunc func(c echo.Context, buf *bytes.Buffer) (int, error) - - // Output is a writer where logs in JSON format are written. - // Optional. Default value os.Stdout. - Output io.Writer - - template *fasttemplate.Template - colorer *color.Color - pool *sync.Pool - } -) +// LoggerConfig defines the config for Logger middleware. +type LoggerConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Tags to construct the logger format. + // + // - time_unix + // - time_unix_milli + // - time_unix_micro + // - time_unix_nano + // - time_rfc3339 + // - time_rfc3339_nano + // - time_custom + // - id (Request ID) + // - remote_ip + // - uri + // - host + // - method + // - path + // - route + // - protocol + // - referer + // - user_agent + // - status + // - error + // - latency (In nanoseconds) + // - latency_human (Human readable) + // - bytes_in (Bytes received) + // - bytes_out (Bytes sent) + // - header: + // - query: + // - form: + // - custom (see CustomTagFunc field) + // + // Example "${remote_ip} ${status}" + // + // Optional. Default value DefaultLoggerConfig.Format. + Format string `yaml:"format"` + + // Optional. Default value DefaultLoggerConfig.CustomTimeFormat. + CustomTimeFormat string `yaml:"custom_time_format"` + + // CustomTagFunc is function called for `${custom}` tag to output user implemented text by writing it to buf. + // Make sure that outputted text creates valid JSON string with other logged tags. + // Optional. + CustomTagFunc func(c echo.Context, buf *bytes.Buffer) (int, error) + + // Output is a writer where logs in JSON format are written. + // Optional. Default value os.Stdout. + Output io.Writer + + template *fasttemplate.Template + colorer *color.Color + pool *sync.Pool +} -var ( - // DefaultLoggerConfig is the default Logger middleware config. - DefaultLoggerConfig = LoggerConfig{ - Skipper: DefaultSkipper, - Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` + - `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + - `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + - `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", - CustomTimeFormat: "2006-01-02 15:04:05.00000", - colorer: color.New(), - } -) +// DefaultLoggerConfig is the default Logger middleware config. +var DefaultLoggerConfig = LoggerConfig{ + Skipper: DefaultSkipper, + Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` + + `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + + `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + + `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", + CustomTimeFormat: "2006-01-02 15:04:05.00000", + colorer: color.New(), +} // Logger returns a middleware that logs HTTP requests. func Logger() echo.MiddlewareFunc { diff --git a/middleware/method_override.go b/middleware/method_override.go index 668a57a41..3991e1029 100644 --- a/middleware/method_override.go +++ b/middleware/method_override.go @@ -9,28 +9,24 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // MethodOverrideConfig defines the config for MethodOverride middleware. - MethodOverrideConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// MethodOverrideConfig defines the config for MethodOverride middleware. +type MethodOverrideConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Getter is a function that gets overridden method from the request. - // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). - Getter MethodOverrideGetter - } + // Getter is a function that gets overridden method from the request. + // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). + Getter MethodOverrideGetter +} - // MethodOverrideGetter is a function that gets overridden method from the request - MethodOverrideGetter func(echo.Context) string -) +// MethodOverrideGetter is a function that gets overridden method from the request +type MethodOverrideGetter func(echo.Context) string -var ( - // DefaultMethodOverrideConfig is the default MethodOverride middleware config. - DefaultMethodOverrideConfig = MethodOverrideConfig{ - Skipper: DefaultSkipper, - Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), - } -) +// DefaultMethodOverrideConfig is the default MethodOverride middleware config. +var DefaultMethodOverrideConfig = MethodOverrideConfig{ + Skipper: DefaultSkipper, + Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), +} // MethodOverride returns a MethodOverride middleware. // MethodOverride middleware checks for the overridden method from the request and diff --git a/middleware/middleware.go b/middleware/middleware.go index 8dfb8dda6..6f33cc5c1 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -12,14 +12,12 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // Skipper defines a function to skip middleware. Returning true skips processing - // the middleware. - Skipper func(c echo.Context) bool +// Skipper defines a function to skip middleware. Returning true skips processing +// the middleware. +type Skipper func(c echo.Context) bool - // BeforeFunc defines a function which is executed just before the middleware. - BeforeFunc func(c echo.Context) -) +// BeforeFunc defines a function which is executed just before the middleware. +type BeforeFunc func(c echo.Context) func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { groups := pattern.FindAllStringSubmatch(input, -1) @@ -56,7 +54,7 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error return nil } - // Depending how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path. + // Depending on how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path. // We only want to use path part for rewriting and therefore trim prefix if it exists rawURI := req.RequestURI if rawURI != "" && rawURI[0] != '/' { diff --git a/middleware/proxy.go b/middleware/proxy.go index ddf4b7f06..f6b302af1 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -22,117 +22,113 @@ import ( // TODO: Handle TLS proxy -type ( - // ProxyConfig defines the config for Proxy middleware. - ProxyConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Balancer defines a load balancing technique. - // Required. - Balancer ProxyBalancer - - // RetryCount defines the number of times a failed proxied request should be retried - // using the next available ProxyTarget. Defaults to 0, meaning requests are never retried. - RetryCount int - - // RetryFilter defines a function used to determine if a failed request to a - // ProxyTarget should be retried. The RetryFilter will only be called when the number - // of previous retries is less than RetryCount. If the function returns true, the - // request will be retried. The provided error indicates the reason for the request - // failure. When the ProxyTarget is unavailable, the error will be an instance of - // echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error - // will indicate an internal error in the Proxy middleware. When a RetryFilter is not - // specified, all requests that fail with http.StatusBadGateway will be retried. A custom - // RetryFilter can be provided to only retry specific requests. Note that RetryFilter is - // only called when the request to the target fails, or an internal error in the Proxy - // middleware has occurred. Successful requests that return a non-200 response code cannot - // be retried. - RetryFilter func(c echo.Context, e error) bool - - // ErrorHandler defines a function which can be used to return custom errors from - // the Proxy middleware. ErrorHandler is only invoked when there has been - // either an internal error in the Proxy middleware or the ProxyTarget is - // unavailable. Due to the way requests are proxied, ErrorHandler is not invoked - // when a ProxyTarget returns a non-200 response. In these cases, the response - // is already written so errors cannot be modified. ErrorHandler is only - // invoked after all retry attempts have been exhausted. - ErrorHandler func(c echo.Context, err error) error - - // Rewrite defines URL path rewrite rules. The values captured in asterisk can be - // retrieved by index e.g. $1, $2 and so on. - // Examples: - // "/old": "/new", - // "/api/*": "/$1", - // "/js/*": "/public/javascripts/$1", - // "/users/*/orders/*": "/user/$1/order/$2", - Rewrite map[string]string - - // RegexRewrite defines rewrite rules using regexp.Rexexp with captures - // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. - // Example: - // "^/old/[0.9]+/": "/new", - // "^/api/.+?/(.*)": "/v2/$1", - RegexRewrite map[*regexp.Regexp]string - - // Context key to store selected ProxyTarget into context. - // Optional. Default value "target". - ContextKey string - - // To customize the transport to remote. - // Examples: If custom TLS certificates are required. - Transport http.RoundTripper - - // ModifyResponse defines function to modify response from ProxyTarget. - ModifyResponse func(*http.Response) error - } +// ProxyConfig defines the config for Proxy middleware. +type ProxyConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Balancer defines a load balancing technique. + // Required. + Balancer ProxyBalancer + + // RetryCount defines the number of times a failed proxied request should be retried + // using the next available ProxyTarget. Defaults to 0, meaning requests are never retried. + RetryCount int + + // RetryFilter defines a function used to determine if a failed request to a + // ProxyTarget should be retried. The RetryFilter will only be called when the number + // of previous retries is less than RetryCount. If the function returns true, the + // request will be retried. The provided error indicates the reason for the request + // failure. When the ProxyTarget is unavailable, the error will be an instance of + // echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error + // will indicate an internal error in the Proxy middleware. When a RetryFilter is not + // specified, all requests that fail with http.StatusBadGateway will be retried. A custom + // RetryFilter can be provided to only retry specific requests. Note that RetryFilter is + // only called when the request to the target fails, or an internal error in the Proxy + // middleware has occurred. Successful requests that return a non-200 response code cannot + // be retried. + RetryFilter func(c echo.Context, e error) bool + + // ErrorHandler defines a function which can be used to return custom errors from + // the Proxy middleware. ErrorHandler is only invoked when there has been + // either an internal error in the Proxy middleware or the ProxyTarget is + // unavailable. Due to the way requests are proxied, ErrorHandler is not invoked + // when a ProxyTarget returns a non-200 response. In these cases, the response + // is already written so errors cannot be modified. ErrorHandler is only + // invoked after all retry attempts have been exhausted. + ErrorHandler func(c echo.Context, err error) error + + // Rewrite defines URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Examples: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + Rewrite map[string]string + + // RegexRewrite defines rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRewrite map[*regexp.Regexp]string + + // Context key to store selected ProxyTarget into context. + // Optional. Default value "target". + ContextKey string + + // To customize the transport to remote. + // Examples: If custom TLS certificates are required. + Transport http.RoundTripper + + // ModifyResponse defines function to modify response from ProxyTarget. + ModifyResponse func(*http.Response) error +} - // ProxyTarget defines the upstream target. - ProxyTarget struct { - Name string - URL *url.URL - Meta echo.Map - } +// ProxyTarget defines the upstream target. +type ProxyTarget struct { + Name string + URL *url.URL + Meta echo.Map +} - // ProxyBalancer defines an interface to implement a load balancing technique. - ProxyBalancer interface { - AddTarget(*ProxyTarget) bool - RemoveTarget(string) bool - Next(echo.Context) *ProxyTarget - } +// ProxyBalancer defines an interface to implement a load balancing technique. +type ProxyBalancer interface { + AddTarget(*ProxyTarget) bool + RemoveTarget(string) bool + Next(echo.Context) *ProxyTarget +} - // TargetProvider defines an interface that gives the opportunity for balancer - // to return custom errors when selecting target. - TargetProvider interface { - NextTarget(echo.Context) (*ProxyTarget, error) - } +// TargetProvider defines an interface that gives the opportunity for balancer +// to return custom errors when selecting target. +type TargetProvider interface { + NextTarget(echo.Context) (*ProxyTarget, error) +} - commonBalancer struct { - targets []*ProxyTarget - mutex sync.Mutex - } +type commonBalancer struct { + targets []*ProxyTarget + mutex sync.Mutex +} - // RandomBalancer implements a random load balancing technique. - randomBalancer struct { - commonBalancer - random *rand.Rand - } +// RandomBalancer implements a random load balancing technique. +type randomBalancer struct { + commonBalancer + random *rand.Rand +} - // RoundRobinBalancer implements a round-robin load balancing technique. - roundRobinBalancer struct { - commonBalancer - // tracking the index on `targets` slice for the next `*ProxyTarget` to be used - i int - } -) +// RoundRobinBalancer implements a round-robin load balancing technique. +type roundRobinBalancer struct { + commonBalancer + // tracking the index on `targets` slice for the next `*ProxyTarget` to be used + i int +} -var ( - // DefaultProxyConfig is the default Proxy middleware config. - DefaultProxyConfig = ProxyConfig{ - Skipper: DefaultSkipper, - ContextKey: "target", - } -) +// DefaultProxyConfig is the default Proxy middleware config. +var DefaultProxyConfig = ProxyConfig{ + Skipper: DefaultSkipper, + ContextKey: "target", +} func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index a58b16491..d4724fd2a 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -12,39 +12,34 @@ import ( "golang.org/x/time/rate" ) -type ( - // RateLimiterStore is the interface to be implemented by custom stores. - RateLimiterStore interface { - // Stores for the rate limiter have to implement the Allow method - Allow(identifier string) (bool, error) - } -) +// RateLimiterStore is the interface to be implemented by custom stores. +type RateLimiterStore interface { + // Stores for the rate limiter have to implement the Allow method + Allow(identifier string) (bool, error) +} -type ( - // RateLimiterConfig defines the configuration for the rate limiter - RateLimiterConfig struct { - Skipper Skipper - BeforeFunc BeforeFunc - // IdentifierExtractor uses echo.Context to extract the identifier for a visitor - IdentifierExtractor Extractor - // Store defines a store for the rate limiter - Store RateLimiterStore - // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error - ErrorHandler func(context echo.Context, err error) error - // DenyHandler provides a handler to be called when RateLimiter denies access - DenyHandler func(context echo.Context, identifier string, err error) error - } - // Extractor is used to extract data from echo.Context - Extractor func(context echo.Context) (string, error) -) +// RateLimiterConfig defines the configuration for the rate limiter +type RateLimiterConfig struct { + Skipper Skipper + BeforeFunc BeforeFunc + // IdentifierExtractor uses echo.Context to extract the identifier for a visitor + IdentifierExtractor Extractor + // Store defines a store for the rate limiter + Store RateLimiterStore + // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error + ErrorHandler func(context echo.Context, err error) error + // DenyHandler provides a handler to be called when RateLimiter denies access + DenyHandler func(context echo.Context, identifier string, err error) error +} -// errors -var ( - // ErrRateLimitExceeded denotes an error raised when rate limit is exceeded - ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") - // ErrExtractorError denotes an error raised when extractor function is unsuccessful - ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") -) +// Extractor is used to extract data from echo.Context +type Extractor func(context echo.Context) (string, error) + +// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded +var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") + +// ErrExtractorError denotes an error raised when extractor function is unsuccessful +var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") // DefaultRateLimiterConfig defines default values for RateLimiterConfig var DefaultRateLimiterConfig = RateLimiterConfig{ @@ -153,25 +148,24 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { } } -type ( - // RateLimiterMemoryStore is the built-in store implementation for RateLimiter - RateLimiterMemoryStore struct { - visitors map[string]*Visitor - mutex sync.Mutex - rate rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. +// RateLimiterMemoryStore is the built-in store implementation for RateLimiter +type RateLimiterMemoryStore struct { + visitors map[string]*Visitor + mutex sync.Mutex + rate rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. - burst int - expiresIn time.Duration - lastCleanup time.Time + burst int + expiresIn time.Duration + lastCleanup time.Time - timeNow func() time.Time - } - // Visitor signifies a unique user's limiter details - Visitor struct { - *rate.Limiter - lastSeen time.Time - } -) + timeNow func() time.Time +} + +// Visitor signifies a unique user's limiter details +type Visitor struct { + *rate.Limiter + lastSeen time.Time +} /* NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with diff --git a/middleware/recover.go b/middleware/recover.go index 35f38e72c..e6a5940e4 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -12,56 +12,52 @@ import ( "github.com/labstack/gommon/log" ) -type ( +// LogErrorFunc defines a function for custom logging in the middleware. +type LogErrorFunc func(c echo.Context, err error, stack []byte) error + +// RecoverConfig defines the config for Recover middleware. +type RecoverConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Size of the stack to be printed. + // Optional. Default value 4KB. + StackSize int `yaml:"stack_size"` + + // DisableStackAll disables formatting stack traces of all other goroutines + // into buffer after the trace for the current goroutine. + // Optional. Default value false. + DisableStackAll bool `yaml:"disable_stack_all"` + + // DisablePrintStack disables printing stack trace. + // Optional. Default value as false. + DisablePrintStack bool `yaml:"disable_print_stack"` + + // LogLevel is log level to printing stack trace. + // Optional. Default value 0 (Print). + LogLevel log.Lvl + // LogErrorFunc defines a function for custom logging in the middleware. - LogErrorFunc func(c echo.Context, err error, stack []byte) error - - // RecoverConfig defines the config for Recover middleware. - RecoverConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Size of the stack to be printed. - // Optional. Default value 4KB. - StackSize int `yaml:"stack_size"` - - // DisableStackAll disables formatting stack traces of all other goroutines - // into buffer after the trace for the current goroutine. - // Optional. Default value false. - DisableStackAll bool `yaml:"disable_stack_all"` - - // DisablePrintStack disables printing stack trace. - // Optional. Default value as false. - DisablePrintStack bool `yaml:"disable_print_stack"` - - // LogLevel is log level to printing stack trace. - // Optional. Default value 0 (Print). - LogLevel log.Lvl - - // LogErrorFunc defines a function for custom logging in the middleware. - // If it's set you don't need to provide LogLevel for config. - // If this function returns nil, the centralized HTTPErrorHandler will not be called. - LogErrorFunc LogErrorFunc - - // DisableErrorHandler disables the call to centralized HTTPErrorHandler. - // The recovered error is then passed back to upstream middleware, instead of swallowing the error. - // Optional. Default value false. - DisableErrorHandler bool `yaml:"disable_error_handler"` - } -) + // If it's set you don't need to provide LogLevel for config. + // If this function returns nil, the centralized HTTPErrorHandler will not be called. + LogErrorFunc LogErrorFunc + + // DisableErrorHandler disables the call to centralized HTTPErrorHandler. + // The recovered error is then passed back to upstream middleware, instead of swallowing the error. + // Optional. Default value false. + DisableErrorHandler bool `yaml:"disable_error_handler"` +} -var ( - // DefaultRecoverConfig is the default Recover middleware config. - DefaultRecoverConfig = RecoverConfig{ - Skipper: DefaultSkipper, - StackSize: 4 << 10, // 4 KB - DisableStackAll: false, - DisablePrintStack: false, - LogLevel: 0, - LogErrorFunc: nil, - DisableErrorHandler: false, - } -) +// DefaultRecoverConfig is the default Recover middleware config. +var DefaultRecoverConfig = RecoverConfig{ + Skipper: DefaultSkipper, + StackSize: 4 << 10, // 4 KB + DisableStackAll: false, + DisablePrintStack: false, + LogLevel: 0, + LogErrorFunc: nil, + DisableErrorHandler: false, +} // Recover returns a middleware which recovers from panics anywhere in the chain // and handles the control to the centralized HTTPErrorHandler. diff --git a/middleware/request_id.go b/middleware/request_id.go index 411737cb4..14bd4fd15 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -7,32 +7,28 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // RequestIDConfig defines the config for RequestID middleware. - RequestIDConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// RequestIDConfig defines the config for RequestID middleware. +type RequestIDConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Generator defines a function to generate an ID. - // Optional. Defaults to generator for random string of length 32. - Generator func() string + // Generator defines a function to generate an ID. + // Optional. Defaults to generator for random string of length 32. + Generator func() string - // RequestIDHandler defines a function which is executed for a request id. - RequestIDHandler func(echo.Context, string) + // RequestIDHandler defines a function which is executed for a request id. + RequestIDHandler func(echo.Context, string) - // TargetHeader defines what header to look for to populate the id - TargetHeader string - } -) + // TargetHeader defines what header to look for to populate the id + TargetHeader string +} -var ( - // DefaultRequestIDConfig is the default RequestID middleware config. - DefaultRequestIDConfig = RequestIDConfig{ - Skipper: DefaultSkipper, - Generator: generator, - TargetHeader: echo.HeaderXRequestID, - } -) +// DefaultRequestIDConfig is the default RequestID middleware config. +var DefaultRequestIDConfig = RequestIDConfig{ + Skipper: DefaultSkipper, + Generator: generator, + TargetHeader: echo.HeaderXRequestID, +} // RequestID returns a X-Request-ID middleware. func RequestID() echo.MiddlewareFunc { diff --git a/middleware/rewrite.go b/middleware/rewrite.go index 260dbb1f5..4c19cc1cc 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -9,37 +9,33 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // RewriteConfig defines the config for Rewrite middleware. - RewriteConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// RewriteConfig defines the config for Rewrite middleware. +type RewriteConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Rules defines the URL path rewrite rules. The values captured in asterisk can be - // retrieved by index e.g. $1, $2 and so on. - // Example: - // "/old": "/new", - // "/api/*": "/$1", - // "/js/*": "/public/javascripts/$1", - // "/users/*/orders/*": "/user/$1/order/$2", - // Required. - Rules map[string]string `yaml:"rules"` + // Rules defines the URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Example: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + // Required. + Rules map[string]string `yaml:"rules"` - // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures - // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. - // Example: - // "^/old/[0.9]+/": "/new", - // "^/api/.+?/(.*)": "/v2/$1", - RegexRules map[*regexp.Regexp]string `yaml:"-"` - } -) + // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRules map[*regexp.Regexp]string `yaml:"-"` +} -var ( - // DefaultRewriteConfig is the default Rewrite middleware config. - DefaultRewriteConfig = RewriteConfig{ - Skipper: DefaultSkipper, - } -) +// DefaultRewriteConfig is the default Rewrite middleware config. +var DefaultRewriteConfig = RewriteConfig{ + Skipper: DefaultSkipper, +} // Rewrite returns a Rewrite middleware. // diff --git a/middleware/secure.go b/middleware/secure.go index b70854ddc..c904abf1a 100644 --- a/middleware/secure.go +++ b/middleware/secure.go @@ -9,84 +9,80 @@ import ( "github.com/labstack/echo/v4" ) -type ( - // SecureConfig defines the config for Secure middleware. - SecureConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // XSSProtection provides protection against cross-site scripting attack (XSS) - // by setting the `X-XSS-Protection` header. - // Optional. Default value "1; mode=block". - XSSProtection string `yaml:"xss_protection"` - - // ContentTypeNosniff provides protection against overriding Content-Type - // header by setting the `X-Content-Type-Options` header. - // Optional. Default value "nosniff". - ContentTypeNosniff string `yaml:"content_type_nosniff"` - - // XFrameOptions can be used to indicate whether or not a browser should - // be allowed to render a page in a ,