diff --git a/.gitattributes b/.gitattributes
index 49b63e526..28981b84a 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -13,8 +13,3 @@
*.js text eol=lf
*.json text eol=lf
LICENSE text eol=lf
-
-# Exclude `website` and `cookbook` from GitHub's language statistics
-# https://github.com/github/linguist#using-gitattributes
-cookbook/* linguist-documentation
-website/* linguist-documentation
diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml
new file mode 100644
index 000000000..af410716d
--- /dev/null
+++ b/.github/FUNDING.yml
@@ -0,0 +1,12 @@
+# These are supported funding model platforms
+
+github: [labstack]
+patreon: # Replace with a single Patreon username
+open_collective: # Replace with a single Open Collective username
+ko_fi: # Replace with a single Ko-fi username
+tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
+community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
+liberapay: # Replace with a single Liberapay username
+issuehunt: # Replace with a single IssueHunt username
+otechie: # Replace with a single Otechie username
+custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md
index ee6f33ef8..1a76adca7 100644
--- a/.github/ISSUE_TEMPLATE.md
+++ b/.github/ISSUE_TEMPLATE.md
@@ -1,23 +1,32 @@
### Issue Description
-### Checklist
+### Working code to debug
-- [ ] Dependencies installed
-- [ ] No typos
-- [ ] Searched existing issues and docs
+```go
+package main
-### Expected behaviour
+import (
+ "github.com/labstack/echo/v5"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
-### Actual behaviour
+func TestExample(t *testing.T) {
+ e := echo.New()
-### Steps to reproduce
+ e.GET("/", func(c *echo.Context) error {
+ return c.String(http.StatusOK, "Hello, World!")
+ })
-### Working code to debug
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
-```go
-package main
+ e.ServeHTTP(rec, req)
-func main() {
+ if rec.Code != http.StatusOK {
+ t.Errorf("got %d, want %d", rec.Code, http.StatusOK)
+ }
}
```
diff --git a/.github/stale.yml b/.github/stale.yml
index d9f656321..04dd169cd 100644
--- a/.github/stale.yml
+++ b/.github/stale.yml
@@ -1,17 +1,19 @@
# Number of days of inactivity before an issue becomes stale
daysUntilStale: 60
# Number of days of inactivity before a stale issue is closed
-daysUntilClose: 7
+daysUntilClose: 30
# Issues with these labels will never be considered stale
exemptLabels:
- pinned
- security
+ - bug
+ - enhancement
# Label to use when marking an issue as stale
-staleLabel: wontfix
+staleLabel: stale
# Comment to post when marking an issue as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
- recent activity. It will be closed if no further activity occurs. Thank you
- for your contributions.
+ recent activity. It will be closed within a month if no further activity occurs.
+ Thank you for your contributions.
# Comment to post when closing a stale issue. Set to `false` to disable
-closeComment: false
\ No newline at end of file
+closeComment: false
diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml
new file mode 100644
index 000000000..8f4eff96e
--- /dev/null
+++ b/.github/workflows/checks.yml
@@ -0,0 +1,47 @@
+name: Run checks
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ branches:
+ - master
+ workflow_dispatch:
+
+permissions:
+ contents: read # to fetch code (actions/checkout)
+
+env:
+ # run static analysis only with the latest Go version
+ LATEST_GO_VERSION: "1.26"
+
+jobs:
+ check:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout Code
+ uses: actions/checkout@v5
+
+ - name: Set up Go ${{ matrix.go }}
+ uses: actions/setup-go@v5
+ with:
+ go-version: ${{ env.LATEST_GO_VERSION }}
+ check-latest: true
+
+ - name: Run golint
+ run: |
+ go install golang.org/x/lint/golint@latest
+ golint -set_exit_status ./...
+
+ - name: Run staticcheck
+ run: |
+ go install honnef.co/go/tools/cmd/staticcheck@latest
+ staticcheck ./...
+
+ - name: Run govulncheck
+ run: |
+ go version
+ go install golang.org/x/vuln/cmd/govulncheck@latest
+ govulncheck ./...
+
diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml
new file mode 100644
index 000000000..b92c70c1b
--- /dev/null
+++ b/.github/workflows/echo.yml
@@ -0,0 +1,86 @@
+name: Run Tests
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ branches:
+ - master
+ workflow_dispatch:
+
+permissions:
+ contents: read # to fetch code (actions/checkout)
+
+env:
+ # run coverage and benchmarks only with the latest Go version
+ LATEST_GO_VERSION: "1.26"
+
+jobs:
+ test:
+ strategy:
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-latest]
+ # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy
+ # Echo tests with last four major releases (unless there are pressing vulnerabilities)
+ # As we depend on `golang.org/x/` libraries which only support the last 2 Go releases, we could have situations when
+ # we derive from the last four major releases promise.
+ go: ["1.25", "1.26"]
+ name: ${{ matrix.os }} @ Go ${{ matrix.go }}
+ runs-on: ${{ matrix.os }}
+ steps:
+ - name: Checkout Code
+ uses: actions/checkout@v5
+
+ - name: Set up Go ${{ matrix.go }}
+ uses: actions/setup-go@v5
+ with:
+ go-version: ${{ matrix.go }}
+
+ - name: Run Tests
+ run: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./...
+
+ - name: Upload coverage to Codecov
+ if: success() && matrix.go == env.LATEST_GO_VERSION && matrix.os == 'ubuntu-latest'
+ uses: codecov/codecov-action@v5
+ with:
+ token:
+ fail_ci_if_error: false
+
+ benchmark:
+ needs: test
+ name: Benchmark comparison
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout Code (Previous)
+ uses: actions/checkout@v5
+ with:
+ ref: ${{ github.base_ref }}
+ path: previous
+
+ - name: Checkout Code (New)
+ uses: actions/checkout@v5
+ with:
+ path: new
+
+ - name: Set up Go ${{ matrix.go }}
+ uses: actions/setup-go@v5
+ with:
+ go-version: ${{ env.LATEST_GO_VERSION }}
+
+ - name: Install Dependencies
+ run: go install golang.org/x/perf/cmd/benchstat@latest
+
+ - name: Run Benchmark (Previous)
+ run: |
+ cd previous
+ go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
+
+ - name: Run Benchmark (New)
+ run: |
+ cd new
+ go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
+
+ - name: Run Benchstat
+ run: |
+ benchstat previous/benchmark.txt new/benchmark.txt
diff --git a/.gitignore b/.gitignore
index dd74acca4..dbadf3bd0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,3 +5,4 @@ vendor
.idea
*.iml
*.out
+.vscode
diff --git a/.travis.yml b/.travis.yml
deleted file mode 100644
index a1fc87684..000000000
--- a/.travis.yml
+++ /dev/null
@@ -1,17 +0,0 @@
-language: go
-go:
- - 1.12.x
- - 1.13.x
- - tip
-env:
- - GO111MODULE=on
-install:
- - go get -v golang.org/x/lint/golint
-script:
- - golint -set_exit_status ./...
- - go test -race -coverprofile=coverage.txt -covermode=atomic ./...
-after_success:
- - bash <(curl -s https://codecov.io/bash)
-matrix:
- allow_failures:
- - go: tip
diff --git a/API_CHANGES_V5.md b/API_CHANGES_V5.md
new file mode 100644
index 000000000..d3ca81560
--- /dev/null
+++ b/API_CHANGES_V5.md
@@ -0,0 +1,1178 @@
+# Echo v5 Public API Changes
+
+**Comparison between `master` (v4.15.0) and `v5` (v5.0.0-alpha) branches**
+
+Generated: 2026-01-01
+
+---
+
+## Executive Summary (by authors)
+
+Echo `v5` is maintenance release with **major breaking changes**
+- `Context` is now struct instead of interface and we can add method to it in the future in minor versions.
+- Adds new `Router` interface for possible new routing implementations.
+- Drops old logging interface and uses moderm `log/slog` instead.
+- Rearranges alot of methods/function signatures to make them more consistent.
+
+## Executive Summary (by LLMs)
+
+Echo v5 represents a **major breaking release** with significant architectural changes focused on:
+- **Updated generic helpers** to take `*Context` and rename form helpers to `FormValue*`
+- **Simplified API surface** by moving Context from interface to concrete struct
+- **Modern Go patterns** including slog.Logger integration
+- **Enhanced routing** with explicit RouteInfo and Routes types
+- **Better error handling** with simplified HTTPError
+- **New test helpers** via the `echotest` package
+
+### Change Statistics
+
+- **Major Breaking Changes**: 15+
+- **New Functions Added**: 30+
+- **Type Signature Changes**: 20+
+- **Removed APIs**: 10+
+- **New Packages Added**: 1 (`echotest`)
+- **Version Change**: `4.15.0` → `5.0.0-alpha`
+
+---
+
+## Critical Breaking Changes
+
+### 1. **Context: Interface → Concrete Struct**
+
+**v4 (master):**
+```go
+type Context interface {
+ Request() *http.Request
+ // ... many methods
+}
+
+// Handler signature
+func handler(c echo.Context) error
+```
+
+**v5:**
+```go
+type Context struct {
+ // Has unexported fields
+}
+
+// Handler signature - NOW USES POINTER!
+func handler(c *echo.Context) error
+```
+
+**Impact:** 🔴 **CRITICAL BREAKING CHANGE**
+- ALL handlers must change from `echo.Context` to `*echo.Context`
+- Context is now a concrete struct, not an interface
+- This affects every single handler function in user code
+
+**Migration:**
+```go
+// Before (v4)
+func MyHandler(c echo.Context) error {
+ return c.JSON(200, map[string]string{"hello": "world"})
+}
+
+// After (v5)
+func MyHandler(c *echo.Context) error {
+ return c.JSON(200, map[string]string{"hello": "world"})
+}
+```
+
+---
+
+### 2. **Logger: Custom Interface → slog.Logger**
+
+**v4:**
+```go
+type Echo struct {
+ Logger Logger // Custom interface with Print, Debug, Info, etc.
+}
+
+type Logger interface {
+ Output() io.Writer
+ SetOutput(w io.Writer)
+ Prefix() string
+ // ... many custom methods
+}
+
+// Context returns Logger interface
+func (c Context) Logger() Logger
+```
+
+**v5:**
+```go
+type Echo struct {
+ Logger *slog.Logger // Standard library structured logger
+}
+
+// Context returns slog.Logger
+func (c *Context) Logger() *slog.Logger
+func (c *Context) SetLogger(logger *slog.Logger)
+```
+
+**Impact:** 🔴 **BREAKING CHANGE**
+- Must use Go's standard `log/slog` package
+- Logger interface completely removed
+- All logging code needs updating
+
+---
+
+### 3. **Router: From Router to DefaultRouter**
+
+**v4:**
+```go
+type Router struct { ... }
+
+func NewRouter(e *Echo) *Router
+func (e *Echo) Router() *Router
+```
+
+**v5:**
+```go
+type DefaultRouter struct { ... }
+
+func NewRouter(config RouterConfig) *DefaultRouter
+func (e *Echo) Router() Router // Returns interface
+```
+
+**Changes:**
+- New `Router` interface introduced
+- `DefaultRouter` is the concrete implementation
+- `NewRouter()` now takes `RouterConfig` instead of `*Echo`
+- Added `NewConcurrentRouter(r Router) Router` for thread-safe routing
+
+---
+
+### 4. **Route Return Types Changed**
+
+**v4:**
+```go
+func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route
+func (e *Echo) Any(path string, h HandlerFunc, m ...MiddlewareFunc) []*Route
+func (e *Echo) Routes() []*Route
+```
+
+**v5:**
+```go
+func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo
+func (e *Echo) Any(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo
+func (e *Echo) Match(...) Routes // Returns Routes type
+func (e *Echo) Router() Router // Returns interface
+```
+
+**New Types:**
+```go
+type RouteInfo struct {
+ Name string
+ Method string
+ Path string
+ Parameters []string
+}
+
+type Routes []RouteInfo // Collection with helper methods
+```
+
+**Impact:** 🔴 **BREAKING CHANGE**
+- Route registration methods return `RouteInfo` instead of `*Route`
+- New `Routes` collection type with filtering methods
+- `Route` struct still exists but used differently
+
+---
+
+### 5. **Response Type Changed**
+
+**v4:**
+```go
+func (c Context) Response() *Response
+type Response struct {
+ Writer http.ResponseWriter
+ Status int
+ Size int64
+ Committed bool
+}
+func NewResponse(w http.ResponseWriter, e *Echo) *Response
+```
+
+**v5:**
+```go
+func (c *Context) Response() http.ResponseWriter
+type Response struct {
+ http.ResponseWriter // Embedded
+ Status int
+ Size int64
+ Committed bool
+}
+func NewResponse(w http.ResponseWriter, logger *slog.Logger) *Response
+func UnwrapResponse(rw http.ResponseWriter) (*Response, error)
+```
+
+**Changes:**
+- Context.Response() returns `http.ResponseWriter` instead of `*Response`
+- Response now embeds `http.ResponseWriter`
+- NewResponse takes `*slog.Logger` instead of `*Echo`
+- New `UnwrapResponse()` helper function
+
+---
+
+### 6. **HTTPError Simplified**
+
+**v4:**
+```go
+type HTTPError struct {
+ Internal error
+ Message interface{} // Can be any type
+ Code int
+}
+
+func NewHTTPError(code int, message ...interface{}) *HTTPError
+```
+
+**v5:**
+```go
+type HTTPError struct {
+ Code int
+ Message string // Now string only
+ // Has unexported fields (Internal moved)
+}
+
+func NewHTTPError(code int, message string) *HTTPError
+func (he HTTPError) Wrap(err error) error // New method
+func (he *HTTPError) StatusCode() int // Implements HTTPStatusCoder
+```
+
+**Changes:**
+- `Message` field changed from `interface{}` to `string`
+- `NewHTTPError()` now takes `string` instead of `...interface{}`
+- Added `HTTPStatusCoder` interface and `StatusCode()` method
+- Added `Wrap(err error)` method for error wrapping
+
+---
+
+### 7. **HTTPErrorHandler Signature Changed**
+
+**v4:**
+```go
+type HTTPErrorHandler func(err error, c Context)
+
+func (e *Echo) DefaultHTTPErrorHandler(err error, c Context)
+```
+
+**v5:**
+```go
+type HTTPErrorHandler func(c *Context, err error) // Parameters swapped!
+
+func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler // Now a factory
+```
+
+**Impact:** 🔴 **BREAKING CHANGE**
+- Parameter order reversed: `(c *Context, err error)` instead of `(err error, c Context)`
+- DefaultHTTPErrorHandler is now a factory function that returns HTTPErrorHandler
+- Takes `exposeError` bool to control error message exposure
+
+---
+
+## Notable API Changes in v5
+
+### 1. **Generic Parameter Extraction Functions (Updated Signatures)**
+
+These helpers keep the same generic API but now accept `*Context`, and the
+form helpers are renamed from `FormParam*` to `FormValue*`:
+
+```go
+// Query Parameters
+func QueryParam[T any](c *Context, key string, opts ...any) (T, error)
+func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error)
+func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error)
+func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error)
+
+// Path Parameters
+func PathParam[T any](c *Context, paramName string, opts ...any) (T, error)
+func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...any) (T, error)
+
+// Form Values
+func FormValue[T any](c *Context, key string, opts ...any) (T, error)
+func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error)
+func FormValues[T any](c *Context, key string, opts ...any) ([]T, error)
+func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error)
+
+// Generic Parsing
+func ParseValue[T any](value string, opts ...any) (T, error)
+func ParseValueOr[T any](value string, defaultValue T, opts ...any) (T, error)
+func ParseValues[T any](values []string, opts ...any) ([]T, error)
+func ParseValuesOr[T any](values []string, defaultValue []T, opts ...any) ([]T, error)
+```
+
+`FormParam*` was renamed to `FormValue*`; the rest keep names but now take `*Context`.
+
+**Supported Types:**
+- bool, string
+- int, int8, int16, int32, int64
+- uint, uint8, uint16, uint32, uint64
+- float32, float64
+- time.Time, time.Duration
+- BindUnmarshaler, encoding.TextUnmarshaler, json.Unmarshaler
+
+**Example Usage:**
+```go
+// v5 - Type-safe parameter binding
+id, err := echo.PathParam[int](c, "id")
+page, err := echo.QueryParamOr[int](c, "page", 1)
+tags, err := echo.QueryParams[string](c, "tags")
+```
+
+---
+
+### 2. **Context Store Helpers Now Use `*Context`**
+
+```go
+// Type-safe context value retrieval
+func ContextGet[T any](c *Context, key string) (T, error)
+func ContextGetOr[T any](c *Context, key string, defaultValue T) (T, error)
+
+// Error types
+var ErrNonExistentKey = errors.New("non existent key")
+var ErrInvalidKeyType = errors.New("invalid key type")
+```
+
+These helpers existed in v4 with `Context` and now accept `*Context`.
+
+**Example:**
+```go
+// v5
+user, err := echo.ContextGet[*User](c, "user")
+count, err := echo.ContextGetOr[int](c, "count", 0)
+```
+
+---
+
+### 3. **PathValues Type**
+
+New structured path parameter handling:
+
+```go
+type PathValue struct {
+ Name string
+ Value string
+}
+
+type PathValues []PathValue
+
+func (p PathValues) Get(name string) (string, bool)
+func (p PathValues) GetOr(name string, defaultValue string) string
+
+// Context methods
+func (c *Context) PathValues() PathValues
+func (c *Context) SetPathValues(pathValues PathValues)
+```
+
+---
+
+### 4. **Time Parsing Options**
+
+```go
+type TimeLayout string
+
+const (
+ TimeLayoutUnixTime = TimeLayout("UnixTime")
+ TimeLayoutUnixTimeMilli = TimeLayout("UnixTimeMilli")
+ TimeLayoutUnixTimeNano = TimeLayout("UnixTimeNano")
+)
+
+type TimeOpts struct {
+ Layout TimeLayout
+ ParseInLocation *time.Location
+ ToInLocation *time.Location
+}
+```
+
+---
+
+### 5. **StartConfig for Server Configuration**
+
+```go
+type StartConfig struct {
+ Address string
+ HideBanner bool
+ HidePort bool
+ CertFilesystem fs.FS
+ TLSConfig *tls.Config
+ ListenerNetwork string
+ ListenerAddrFunc func(addr net.Addr)
+ GracefulTimeout time.Duration
+ OnShutdownError func(err error)
+ BeforeServeFunc func(s *http.Server) error
+}
+
+func (sc StartConfig) Start(ctx context.Context, h http.Handler) error
+func (sc StartConfig) StartTLS(ctx context.Context, h http.Handler, certFile, keyFile any) error
+```
+
+**Example:**
+```go
+// v5 - More control over server startup
+ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
+defer cancel()
+
+sc := echo.StartConfig{
+ Address: ":8080",
+ GracefulTimeout: 10 * time.Second,
+}
+if err := sc.Start(ctx, e); err != nil {
+ log.Fatal(err)
+}
+```
+
+---
+
+### 6. **Echo Config and Constructors**
+
+```go
+type Config struct {
+ // Configuration for Echo (logger, binder, renderer, etc.)
+}
+
+func NewWithConfig(config Config) *Echo
+```
+
+This adds a configuration struct for creating an `Echo` instance without
+mutating fields after `New()`.
+
+---
+
+### 7. **Enhanced Routing Features**
+
+```go
+// New route methods
+func (e *Echo) AddRoute(route Route) (RouteInfo, error)
+func (e *Echo) Middlewares() []MiddlewareFunc
+func (e *Echo) PreMiddlewares() []MiddlewareFunc
+type AddRouteError struct{ ... }
+
+// Routes collection with filters
+type Routes []RouteInfo
+
+func (r Routes) Clone() Routes
+func (r Routes) FilterByMethod(method string) (Routes, error)
+func (r Routes) FilterByName(name string) (Routes, error)
+func (r Routes) FilterByPath(path string) (Routes, error)
+func (r Routes) FindByMethodPath(method string, path string) (RouteInfo, error)
+func (r Routes) Reverse(routeName string, pathValues ...any) (string, error)
+
+// RouteInfo operations
+func (r RouteInfo) Clone() RouteInfo
+func (r RouteInfo) Reverse(pathValues ...any) string
+```
+
+---
+
+### 8. **Middleware Configuration Interface**
+
+```go
+type MiddlewareConfigurator interface {
+ ToMiddleware() (MiddlewareFunc, error)
+}
+```
+
+Allows middleware configs to be converted to middleware without panicking.
+
+---
+
+### 9. **New Context Methods**
+
+```go
+// v5 additions
+func (c *Context) FileFS(file string, filesystem fs.FS) error
+func (c *Context) FormValueOr(name, defaultValue string) string
+func (c *Context) InitializeRoute(ri *RouteInfo, pathValues *PathValues)
+func (c *Context) ParamOr(name, defaultValue string) string
+func (c *Context) QueryParamOr(name, defaultValue string) string
+func (c *Context) RouteInfo() RouteInfo
+```
+
+---
+
+### 10. **Virtual Host Support**
+
+```go
+func NewVirtualHostHandler(vhosts map[string]*Echo) *Echo
+```
+
+Creates an Echo instance that routes requests to different Echo instances based on host.
+
+---
+
+### 11. **New Binder Functions**
+
+```go
+func BindBody(c *Context, target any) error
+func BindHeaders(c *Context, target any) error
+func BindPathValues(c *Context, target any) error // Renamed from BindPathParams
+func BindQueryParams(c *Context, target any) error
+```
+
+Top-level binding functions that work with `*Context`.
+
+---
+
+### 12. **New echotest Package**
+
+```go
+package echotest // import "github.com/labstack/echo/v5/echotest"
+
+func LoadBytes(t *testing.T, name string, opts ...loadBytesOpts) []byte
+func TrimNewlineEnd(bytes []byte) []byte
+type ContextConfig struct{ ... }
+type MultipartForm struct{ ... }
+type MultipartFormFile struct{ ... }
+```
+
+Helpers for loading fixtures and constructing test contexts.
+
+---
+
+## Removed APIs in v5
+
+### Constants
+
+```go
+// v4 - Removed in v5
+const CONNECT = http.MethodConnect // Use http.MethodConnect directly
+```
+
+**Reason:** Deprecated in v4, use stdlib `http.Method*` constants instead.
+
+---
+
+### Constants Added in v5
+
+```go
+// v5 additions
+const (
+ NotFoundRouteName = "echo_route_not_found_name"
+)
+```
+
+---
+
+### Error Variable Changes
+
+**v4 exports:**
+```go
+ErrBadRequest
+ErrInvalidKeyType
+ErrNonExistentKey
+```
+
+**v5 exports:**
+```go
+ErrBadRequest // Now backed by unexported httpError type
+ErrValidatorNotRegistered // New
+ErrInvalidKeyType
+ErrNonExistentKey
+```
+
+**Reason:** v5 centralizes on `NewHTTPError(code, message)` rather than a broad set
+of predefined HTTP error variables.
+
+---
+
+### Functions Removed
+
+```go
+// v4 - Removed in v5
+func GetPath(r *http.Request) string // Use r.URL.Path or r.URL.RawPath
+```
+
+### Variables Removed
+
+```go
+// v4 - Removed in v5
+var MethodNotAllowedHandler = func(c Context) error { ... }
+var NotFoundHandler = func(c Context) error { ... }
+```
+
+### Functions Renamed
+
+```go
+// v4
+func FormParam[T any](c Context, key string, opts ...any) (T, error)
+func FormParamOr[T any](c Context, key string, defaultValue T, opts ...any) (T, error)
+func FormParams[T any](c Context, key string, opts ...any) ([]T, error)
+func FormParamsOr[T any](c Context, key string, defaultValue []T, opts ...any) ([]T, error)
+
+// v5
+func FormValue[T any](c *Context, key string, opts ...any) (T, error)
+func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error)
+func FormValues[T any](c *Context, key string, opts ...any) ([]T, error)
+func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error)
+```
+
+---
+
+### Type Methods Removed/Changed
+
+**Echo struct changes:**
+```go
+// v4 fields removed in v5
+type Echo struct {
+ StdLogger *stdLog.Logger // Removed
+ Server *http.Server // Removed (use StartConfig)
+ TLSServer *http.Server // Removed (use StartConfig)
+ Listener net.Listener // Removed (use StartConfig)
+ TLSListener net.Listener // Removed (use StartConfig)
+ AutoTLSManager autocert.Manager // Removed
+ ListenerNetwork string // Removed
+ OnAddRouteHandler func(...) // Changed to OnAddRoute
+ DisableHTTP2 bool // Removed (use StartConfig)
+ Debug bool // Removed
+ HideBanner bool // Removed (use StartConfig)
+ HidePort bool // Removed (use StartConfig)
+}
+
+// v5 Echo struct (simplified)
+type Echo struct {
+ Binder Binder
+ Filesystem fs.FS // NEW
+ Renderer Renderer
+ Validator Validator
+ JSONSerializer JSONSerializer
+ IPExtractor IPExtractor
+ OnAddRoute func(route Route) error // Simplified
+ HTTPErrorHandler HTTPErrorHandler
+ Logger *slog.Logger // Changed from Logger interface
+}
+```
+
+---
+
+**Context interface → struct:**
+```go
+// v4
+type Context interface {
+ // Had: SetResponse(*Response)
+ Response() *Response
+
+ // Had: ParamNames(), SetParamNames(), ParamValues(), SetParamValues()
+ // These are removed in v5 (use PathValues() instead)
+}
+
+// v5
+type Context struct {
+ // Concrete struct with unexported fields
+}
+
+func (c *Context) Response() http.ResponseWriter // Changed return type
+func (c *Context) PathValues() PathValues // Replaces ParamNames/Values
+```
+
+---
+
+**Types removed:**
+```go
+// v4
+type Map map[string]interface{}
+```
+
+**Group changes:**
+```go
+// v4
+func (g *Group) File(path, file string) // No return value
+func (g *Group) Static(pathPrefix, fsRoot string) // No return value
+func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) // No return value
+
+// v5
+func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo
+func (g *Group) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo
+func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo
+```
+
+Now return `RouteInfo` and accept middleware.
+
+---
+
+### Value Binder Factory Name Changes
+
+```go
+// v4
+func PathParamsBinder(c Context) *ValueBinder
+func QueryParamsBinder(c Context) *ValueBinder
+func FormFieldBinder(c Context) *ValueBinder
+
+// v5
+func PathValuesBinder(c *Context) *ValueBinder // Renamed
+func QueryParamsBinder(c *Context) *ValueBinder
+func FormFieldBinder(c *Context) *ValueBinder
+```
+
+---
+
+## Type Signature Changes
+
+### Binder Interface
+
+```go
+// v4
+type Binder interface {
+ Bind(i interface{}, c Context) error
+}
+
+// v5
+type Binder interface {
+ Bind(c *Context, target any) error // Parameters swapped!
+}
+```
+
+---
+
+### DefaultBinder Methods
+
+```go
+// v4
+func (b *DefaultBinder) Bind(i interface{}, c Context) error
+func (b *DefaultBinder) BindBody(c Context, i interface{}) error
+func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error
+
+// v5
+func (b *DefaultBinder) Bind(c *Context, target any) error // Swapped params
+// BindBody, BindPathParams, etc. are now top-level functions
+```
+
+---
+
+### JSONSerializer Interface
+
+```go
+// v4
+type JSONSerializer interface {
+ Serialize(c Context, i interface{}, indent string) error
+ Deserialize(c Context, i interface{}) error
+}
+
+// v5
+type JSONSerializer interface {
+ Serialize(c *Context, target any, indent string) error
+ Deserialize(c *Context, target any) error
+}
+```
+
+---
+
+### Renderer Interface
+
+```go
+// v4
+type Renderer interface {
+ Render(io.Writer, string, interface{}, Context) error
+}
+
+// v5
+type Renderer interface {
+ Render(c *Context, w io.Writer, templateName string, data any) error
+}
+```
+
+Parameters reordered with Context first.
+
+---
+
+### NewBindingError
+
+```go
+// v4
+func NewBindingError(sourceParam string, values []string, message interface{}, internalError error) error
+
+// v5
+func NewBindingError(sourceParam string, values []string, message string, err error) error
+```
+
+Message parameter changed from `interface{}` to `string`.
+
+---
+
+### HandlerName
+
+```go
+// v5 only
+func HandlerName(h HandlerFunc) string
+```
+
+New utility function to get handler function name.
+
+---
+
+## Middleware Package Changes
+
+### Signature and Type Updates
+
+```go
+// CORS now accepts optional allow-origins
+func CORS(allowOrigins ...string) echo.MiddlewareFunc
+
+// BodyLimit now accepts bytes
+func BodyLimit(limitBytes int64) echo.MiddlewareFunc
+
+// DefaultSkipper now uses *echo.Context
+func DefaultSkipper(c *echo.Context) bool
+
+// Trailing slash configs renamed/split
+func AddTrailingSlashWithConfig(config AddTrailingSlashConfig) echo.MiddlewareFunc
+func RemoveTrailingSlashWithConfig(config RemoveTrailingSlashConfig) echo.MiddlewareFunc
+type AddTrailingSlashConfig struct{ ... }
+type RemoveTrailingSlashConfig struct{ ... }
+
+// Auth + extractor signatures now use *echo.Context and add ExtractorSource
+type BasicAuthValidator func(c *echo.Context, user string, password string) (bool, error)
+type Extractor func(c *echo.Context) (string, error)
+type ExtractorSource string
+type KeyAuthValidator func(c *echo.Context, key string, source ExtractorSource) (bool, error)
+type KeyAuthErrorHandler func(c *echo.Context, err error) error
+
+// BodyDump handler now includes err
+type BodyDumpHandler func(c *echo.Context, reqBody []byte, resBody []byte, err error)
+
+// ValuesExtractor now returns extractor source and CreateExtractors takes a limit
+type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error)
+func CreateExtractors(lookups string, limit uint) ([]ValuesExtractor, error)
+type ValueExtractorError struct{ ... }
+
+// New constants
+const KB = 1024
+
+// Rate limiter store now takes a float64 limit
+func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore)
+```
+
+### Added Middleware Exports
+
+```go
+var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key")
+var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key")
+var RedirectHTTPSConfig = RedirectConfig{ ... }
+var RedirectHTTPSWWWConfig = RedirectConfig{ ... }
+var RedirectNonHTTPSWWWConfig = RedirectConfig{ ... }
+var RedirectNonWWWConfig = RedirectConfig{ ... }
+var RedirectWWWConfig = RedirectConfig{ ... }
+```
+
+### Removed/Consolidated Middleware Exports
+
+```go
+// Removed in v5
+func Logger() echo.MiddlewareFunc
+func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc
+func Timeout() echo.MiddlewareFunc
+func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc
+type ErrKeyAuthMissing struct{ ... }
+type CSRFErrorHandler func(err error, c echo.Context) error
+type LoggerConfig struct{ ... }
+type LogErrorFunc func(c echo.Context, err error, stack []byte) error
+type TargetProvider interface{ ... }
+type TrailingSlashConfig struct{ ... }
+type TimeoutConfig struct{ ... }
+```
+
+Also removed defaults: `DefaultBasicAuthConfig`, `DefaultBodyDumpConfig`, `DefaultBodyLimitConfig`,
+`DefaultCORSConfig`, `DefaultDecompressConfig`, `DefaultGzipConfig`, `DefaultLoggerConfig`,
+`DefaultRedirectConfig`, `DefaultRequestIDConfig`, `DefaultRewriteConfig`, `DefaultTimeoutConfig`,
+`DefaultTrailingSlashConfig`.
+
+---
+
+## Router Interface Changes
+
+### v4 Router (Concrete Struct)
+
+```go
+type Router struct { ... }
+
+func NewRouter(e *Echo) *Router
+func (r *Router) Add(method, path string, h HandlerFunc)
+func (r *Router) Find(method, path string, c Context)
+func (r *Router) Reverse(name string, params ...interface{}) string
+func (r *Router) Routes() []*Route
+```
+
+### v5 Router (Interface + DefaultRouter)
+
+```go
+type Router interface {
+ Add(routable Route) (RouteInfo, error)
+ Remove(method string, path string) error
+ Routes() Routes
+ Route(c *Context) HandlerFunc
+}
+
+type DefaultRouter struct { ... }
+
+func NewRouter(config RouterConfig) *DefaultRouter
+func NewConcurrentRouter(r Router) Router // NEW
+
+type RouterConfig struct {
+ NotFoundHandler HandlerFunc
+ MethodNotAllowedHandler HandlerFunc
+ OptionsMethodHandler HandlerFunc
+ AllowOverwritingRoute bool
+ UnescapePathParamValues bool
+ UseEscapedPathForMatching bool
+}
+```
+
+**Key Changes:**
+- Router is now an interface
+- DefaultRouter is the concrete implementation
+- Add() returns `(RouteInfo, error)` instead of being void
+- New `Remove()` method
+- New `Route()` method replaces `Find()`
+- Configuration through `RouterConfig`
+
+---
+
+## Echo Instance Method Changes
+
+### Route Registration
+
+```go
+// v4
+func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route
+
+// v5
+func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo
+func (e *Echo) AddRoute(route Route) (RouteInfo, error) // NEW
+```
+
+### Static File Serving
+
+```go
+// v4
+func (e *Echo) Static(pathPrefix, fsRoot string) *Route
+func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route
+func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route
+func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route
+
+// v5
+func (e *Echo) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo
+func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo
+func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo
+func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo
+```
+
+Return type changed from `*Route` to `RouteInfo`.
+
+### Server Management
+
+```go
+// v4
+func (e *Echo) Start(address string) error
+func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) error
+func (e *Echo) StartAutoTLS(address string) error
+func (e *Echo) StartH2CServer(address string, h2s *http2.Server) error
+func (e *Echo) StartServer(s *http.Server) error
+func (e *Echo) Shutdown(ctx context.Context) error
+func (e *Echo) Close() error
+func (e *Echo) ListenerAddr() net.Addr
+func (e *Echo) TLSListenerAddr() net.Addr
+func (e *Echo) DefaultHTTPErrorHandler(err error, c Context)
+
+// v5
+func (e *Echo) Start(address string) error // Simplified
+func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request)
+
+// Removed: StartTLS, StartAutoTLS, StartH2CServer, StartServer
+// Use StartConfig instead for advanced server configuration
+// Removed: Shutdown, Close, ListenerAddr, TLSListenerAddr
+// Removed: DefaultHTTPErrorHandler (now a top-level factory function)
+```
+
+**v5 provides** `StartConfig` type for all advanced server configuration.
+
+### Router Access
+
+```go
+// v4
+func (e *Echo) Router() *Router
+func (e *Echo) Routers() map[string]*Router // For multi-host
+func (e *Echo) Routes() []*Route
+func (e *Echo) Reverse(name string, params ...interface{}) string
+func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string
+func (e *Echo) URL(h HandlerFunc, params ...interface{}) string
+func (e *Echo) Host(name string, m ...MiddlewareFunc) *Group
+
+// v5
+func (e *Echo) Router() Router // Returns interface
+// Removed: Routers(), Reverse(), URI(), URL(), Host()
+// Use router.Routes() and Routes.Reverse() instead
+```
+
+---
+
+## NewContext Changes
+
+```go
+// v4
+func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context
+func NewResponse(w http.ResponseWriter, e *Echo) *Response
+
+// v5
+func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) *Context
+func NewContext(r *http.Request, w http.ResponseWriter, opts ...any) *Context // Standalone
+func NewResponse(w http.ResponseWriter, logger *slog.Logger) *Response
+```
+
+---
+
+## Migration Guide Summary
+
+If you are using Linux you can migrate easier parts like that:
+```bash
+find . -type f -name "*.go" -exec sed -i 's/ echo.Context/ *echo.Context/g' {} +
+find . -type f -name "*.go" -exec sed -i 's/echo\/v4/echo\/v5/g' {} +
+```
+or in your favorite IDE
+
+Replace all:
+1. ` echo.Context` -> ` *echo.Context`
+2. `echo/v4` -> `echo/v5`
+
+
+### 1. Update All Handler Signatures
+
+```go
+// Before
+func MyHandler(c echo.Context) error { ... }
+
+// After
+func MyHandler(c *echo.Context) error { ... }
+```
+
+### 2. Update Logger Usage
+
+```go
+// Before
+e.Logger.Info("Server started")
+c.Logger().Error("Something went wrong")
+
+// After
+e.Logger.Info("Server started")
+c.Logger().Error("Something went wrong") // Same API, different logger
+```
+
+### 3. Use Type-Safe Parameter Extraction
+
+```go
+// Before
+idStr := c.Param("id")
+id, err := strconv.Atoi(idStr)
+
+// After
+id, err := echo.PathParam[int](c, "id")
+```
+
+### 4. Update Error Handler
+
+```go
+// Before
+e.HTTPErrorHandler = func(err error, c echo.Context) {
+ // handle error
+}
+
+// After
+e.HTTPErrorHandler = func(c *echo.Context, err error) { // Swapped!
+ // handle error
+}
+
+// Or use factory
+e.HTTPErrorHandler = echo.DefaultHTTPErrorHandler(true) // exposeError=true
+```
+
+### 5. Update Server Startup
+
+```go
+// Before
+e.Start(":8080")
+e.StartTLS(":443", "cert.pem", "key.pem")
+
+// After
+// Simple
+e.Start(":8080")
+
+// Advanced with graceful shutdown
+ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
+defer cancel()
+sc := echo.StartConfig{Address: ":8080"}
+sc.Start(ctx, e)
+```
+
+### 6. Update Route Info Access
+
+```go
+// Before
+routes := e.Routes()
+for _, r := range routes {
+ fmt.Println(r.Method, r.Path)
+}
+
+// After
+routes := e.Router().Routes()
+for _, r := range routes {
+ fmt.Println(r.Method, r.Path)
+}
+```
+
+### 7. Update HTTPError Creation
+
+```go
+// Before
+return echo.NewHTTPError(400, "invalid request", someDetail)
+
+// After
+return echo.NewHTTPError(400, "invalid request")
+```
+
+### 8. Update Custom Binder
+
+```go
+// Before
+type MyBinder struct{}
+func (b *MyBinder) Bind(i interface{}, c echo.Context) error { ... }
+
+// After
+type MyBinder struct{}
+func (b *MyBinder) Bind(c *echo.Context, target any) error { ... } // Swapped!
+```
+
+### 9. Path Parameters
+
+```go
+// Before
+names := c.ParamNames()
+values := c.ParamValues()
+
+// After
+pathValues := c.PathValues()
+for _, pv := range pathValues {
+ fmt.Println(pv.Name, pv.Value)
+}
+```
+
+### 10. Response Access
+
+```go
+// Before
+resp := c.Response()
+resp.Header().Set("X-Custom", "value")
+
+// After
+c.Response().Header().Set("X-Custom", "value") // Returns http.ResponseWriter
+
+// To get *echo.Response
+resp, err := echo.UnwrapResponse(c.Response())
+```
+
+### Go Version Requirements
+
+- **v4**: Go 1.24.0 (per `go.mod`)
+- **v5**: Go 1.25.0 (per `go.mod`)
+
+---
+
+**Generated by comparing `go doc` output from master (v4.15.0) and v5 (v5.0.0-alpha) branches**
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 000000000..37d1adb66
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,838 @@
+# Changelog
+
+## v5.0.3 - 2026-02-06
+
+**Security**
+
+* Fix directory traversal vulnerability under Windows in Static middleware when default Echo filesystem is used. Reported by @shblue21.
+
+This applies to cases when:
+- Windows is used as OS
+- `middleware.StaticConfig.Filesystem` is `nil` (default)
+- `echo.Filesystem` is has not been set explicitly (default)
+
+Exposure is restricted to the active process working directory and its subfolders.
+
+
+## v5.0.2 - 2026-02-02
+
+**Security**
+
+* Fix Static middleware with `config.Browse=true` lists all files/subfolders from `config.Filesystem` root and not starting from `config.Root` in https://github.com/labstack/echo/pull/2887
+
+
+## v5.0.1 - 2026-01-28
+
+* Panic MW: will now return a custom PanicStackError with stack trace by @aldas in https://github.com/labstack/echo/pull/2871
+* Docs: add missing err parameter to DenyHandler example by @cgalibern in https://github.com/labstack/echo/pull/2878
+* improve: improve websocket checks in IsWebSocket() [per RFC 6455] by @raju-mechatronics in https://github.com/labstack/echo/pull/2875
+* fix: Context.Json() should not send status code before serialization is complete by @aldas in https://github.com/labstack/echo/pull/2877
+
+
+## v5.0.0 - 2026-01-18
+
+Echo `v5` is maintenance release with **major breaking changes**
+- `Context` is now struct instead of interface and we can add method to it in the future in minor versions.
+- Adds new `Router` interface for possible new routing implementations.
+- Drops old logging interface and uses moderm `log/slog` instead.
+- Rearranges alot of methods/function signatures to make them more consistent.
+
+Upgrade notes and `v4` support:
+- Echo `v4` is supported with **security*** updates and **bug** fixes until **2026-12-31**
+- If you are using Echo in a production environment, it is recommended to wait until after 2026-03-31 before upgrading.
+- Until 2026-03-31, any critical issues requiring breaking `v5` API changes will be addressed, even if this violates semantic versioning.
+
+See [API_CHANGES_V5.md](./API_CHANGES_V5.md) for public API changes between `v4` and `v5`, notes on **upgrading**.
+
+Upgrading TLDR:
+
+If you are using Linux you can migrate easier parts like that:
+```bash
+find . -type f -name "*.go" -exec sed -i 's/ echo.Context/ *echo.Context/g' {} +
+find . -type f -name "*.go" -exec sed -i 's/echo\/v4/echo\/v5/g' {} +
+```
+macOS
+```bash
+find . -type f -name "*.go" -exec sed -i '' 's/ echo.Context/ *echo.Context/g' {} +
+find . -type f -name "*.go" -exec sed -i '' 's/echo\/v4/echo\/v5/g' {} +
+```
+
+or in your favorite IDE
+
+Replace all:
+1. ` echo.Context` -> ` *echo.Context`
+2. `echo/v4` -> `echo/v5`
+
+This should solve most of the issues. Probably the hardest part is updating all the tests.
+
+
+## v4.15.0 - 2026-01-01
+
+
+**Security**
+
+NB: **If your application relies on cross-origin or same-site (same subdomain) requests do not blindly push this version to production**
+
+
+The CSRF middleware now supports the [**Sec-Fetch-Site**](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Sec-Fetch-Site) header as a modern, defense-in-depth approach to [CSRF
+protection](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers), implementing the OWASP-recommended Fetch Metadata API alongside the traditional token-based mechanism.
+
+**How it works:**
+
+Modern browsers automatically send the `Sec-Fetch-Site` header with all requests, indicating the relationship
+between the request origin and the target. The middleware uses this to make security decisions:
+
+- **`same-origin`** or **`none`**: Requests are allowed (exact origin match or direct user navigation)
+- **`same-site`**: Falls back to token validation (e.g., subdomain to main domain)
+- **`cross-site`**: Blocked by default with 403 error for unsafe methods (POST, PUT, DELETE, PATCH)
+
+For browsers that don't send this header (older browsers), the middleware seamlessly falls back to
+traditional token-based CSRF protection.
+
+**New Configuration Options:**
+- `TrustedOrigins []string`: Allowlist specific origins for cross-site requests (useful for OAuth callbacks, webhooks)
+- `AllowSecFetchSiteFunc func(echo.Context) (bool, error)`: Custom logic for same-site/cross-site request validation
+
+**Example:**
+ ```go
+ e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{
+ // Allow OAuth callbacks from trusted provider
+ TrustedOrigins: []string{"https://oauth-provider.com"},
+
+ // Custom validation for same-site requests
+ AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) {
+ // Your custom authorization logic here
+ return validateCustomAuth(c), nil
+ // return true, err // blocks request with error
+ // return true, nil // allows CSRF request through
+ // return false, nil // falls back to legacy token logic
+ },
+ }))
+ ```
+PR: https://github.com/labstack/echo/pull/2858
+
+**Type-Safe Generic Parameter Binding**
+
+* Added generic functions for type-safe parameter extraction and context access by @aldas in https://github.com/labstack/echo/pull/2856
+
+ Echo now provides generic functions for extracting path, query, and form parameters with automatic type conversion,
+ eliminating manual string parsing and type assertions.
+
+ **New Functions:**
+ - Path parameters: `PathParam[T]`, `PathParamOr[T]`
+ - Query parameters: `QueryParam[T]`, `QueryParamOr[T]`, `QueryParams[T]`, `QueryParamsOr[T]`
+ - Form values: `FormParam[T]`, `FormParamOr[T]`, `FormParams[T]`, `FormParamsOr[T]`
+ - Context store: `ContextGet[T]`, `ContextGetOr[T]`
+
+ **Supported Types:**
+ Primitives (`bool`, `string`, `int`/`uint` variants, `float32`/`float64`), `time.Duration`, `time.Time`
+ (with custom layouts and Unix timestamp support), and custom types implementing `BindUnmarshaler`,
+ `TextUnmarshaler`, or `JSONUnmarshaler`.
+
+ **Example:**
+ ```go
+ // Before: Manual parsing
+ idStr := c.Param("id")
+ id, err := strconv.Atoi(idStr)
+
+ // After: Type-safe with automatic parsing
+ id, err := echo.PathParam[int](c, "id")
+
+ // With default values
+ page, err := echo.QueryParamOr[int](c, "page", 1)
+ limit, err := echo.QueryParamOr[int](c, "limit", 20)
+
+ // Type-safe context access (no more panics from type assertions)
+ user, err := echo.ContextGet[*User](c, "user")
+ ```
+
+PR: https://github.com/labstack/echo/pull/2856
+
+
+
+**DEPRECATION NOTICE** Timeout Middleware Deprecated - Use ContextTimeout Instead
+
+The `middleware.Timeout` middleware has been **deprecated** due to fundamental architectural issues that cause
+data races. Use `middleware.ContextTimeout` or `middleware.ContextTimeoutWithConfig` instead.
+
+**Why is this being deprecated?**
+
+The Timeout middleware manipulates response writers across goroutine boundaries, which causes data races that
+cannot be reliably fixed without a complete architectural redesign. The middleware:
+
+- Swaps the response writer using `http.TimeoutHandler`
+- Must be the first middleware in the chain (fragile constraint)
+- Can cause races with other middleware (Logger, metrics, custom middleware)
+- Has been the source of multiple race condition fixes over the years
+
+**What should you use instead?**
+
+The `ContextTimeout` middleware (available since v4.12.0) provides timeout functionality using Go's standard
+context mechanism. It is:
+
+- Race-free by design
+- Can be placed anywhere in the middleware chain
+- Simpler and more maintainable
+- Compatible with all other middleware
+
+**Migration Guide:**
+
+```go
+// Before (deprecated):
+e.Use(middleware.Timeout())
+
+// After (recommended):
+e.Use(middleware.ContextTimeout(30 * time.Second))
+```
+
+**Important Behavioral Differences:**
+
+1. **Handler cooperation required**: With ContextTimeout, your handlers must check `context.Done()` for cooperative
+ cancellation. The old Timeout middleware would send a 503 response regardless of handler cooperation, but had
+ data race issues.
+
+2. **Error handling**: ContextTimeout returns errors through the standard error handling flow. Handlers that receive
+ `context.DeadlineExceeded` should handle it appropriately:
+
+```go
+e.GET("/long-task", func(c echo.Context) error {
+ ctx := c.Request().Context()
+
+ // Example: database query with context
+ result, err := db.QueryContext(ctx, "SELECT * FROM large_table")
+ if err != nil {
+ if errors.Is(err, context.DeadlineExceeded) {
+ // Handle timeout
+ return echo.NewHTTPError(http.StatusServiceUnavailable, "Request timeout")
+ }
+ return err
+ }
+
+ return c.JSON(http.StatusOK, result)
+})
+```
+
+3. **Background tasks**: For long-running background tasks, use goroutines with context:
+
+```go
+e.GET("/async-task", func(c echo.Context) error {
+ ctx := c.Request().Context()
+
+ resultCh := make(chan Result, 1)
+ errCh := make(chan error, 1)
+
+ go func() {
+ result, err := performLongTask(ctx)
+ if err != nil {
+ errCh <- err
+ return
+ }
+ resultCh <- result
+ }()
+
+ select {
+ case result := <-resultCh:
+ return c.JSON(http.StatusOK, result)
+ case err := <-errCh:
+ return err
+ case <-ctx.Done():
+ return echo.NewHTTPError(http.StatusServiceUnavailable, "Request timeout")
+ }
+})
+```
+
+**Enhancements**
+
+* Fixes by @aldas in https://github.com/labstack/echo/pull/2852
+* Generic functions by @aldas in https://github.com/labstack/echo/pull/2856
+* CRSF with Sec-Fetch-Site checks by @aldas in https://github.com/labstack/echo/pull/2858
+
+
+## v4.14.0 - 2025-12-11
+
+`middleware.Logger` has been deprecated. For request logging, use `middleware.RequestLogger` or
+`middleware.RequestLoggerWithConfig`.
+
+`middleware.RequestLogger` replaces `middleware.Logger`, offering comparable configuration while relying on the
+Go standard library’s new `slog` logger.
+
+The previous default output format was JSON. The new default follows the standard `slog` logger settings.
+To continue emitting request logs in JSON, configure `slog` accordingly:
+```go
+slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, nil)))
+e.Use(middleware.RequestLogger())
+```
+
+
+**Security**
+
+* Logger middleware json string escaping and deprecation by @aldas in https://github.com/labstack/echo/pull/2849
+
+
+
+**Enhancements**
+
+* Update deps by @aldas in https://github.com/labstack/echo/pull/2807
+* refactor to use reflect.TypeFor by @cuiweixie in https://github.com/labstack/echo/pull/2812
+* Use Go 1.25 in CI by @aldas in https://github.com/labstack/echo/pull/2810
+* Modernize context.go by replacing interface{} with any by @vishr in https://github.com/labstack/echo/pull/2822
+* Fix typo in SetParamValues comment by @vishr in https://github.com/labstack/echo/pull/2828
+* Fix typo in ContextTimeout middleware comment by @vishr in https://github.com/labstack/echo/pull/2827
+* Improve BasicAuth middleware: use strings.Cut and RFC compliance by @vishr in https://github.com/labstack/echo/pull/2825
+* Fix duplicate plus operator in router backtracking logic by @yuya-morimoto in https://github.com/labstack/echo/pull/2832
+* Replace custom private IP range check with built-in net.IP.IsPrivate by @kumapower17 in https://github.com/labstack/echo/pull/2835
+* Ensure proxy connection is closed in proxyRaw function(#2837) by @kumapower17 in https://github.com/labstack/echo/pull/2838
+* Update deps by @aldas in https://github.com/labstack/echo/pull/2843
+* Update golang.org/x/* deps by @aldas in https://github.com/labstack/echo/pull/2850
+
+
+
+## v4.13.4 - 2025-05-22
+
+**Enhancements**
+
+* chore: fix some typos in comment by @zhuhaicity in https://github.com/labstack/echo/pull/2735
+* CI: test with Go 1.24 by @aldas in https://github.com/labstack/echo/pull/2748
+* Add support for TLS WebSocket proxy by @t-ibayashi-safie in https://github.com/labstack/echo/pull/2762
+
+**Security**
+
+* Update dependencies for [GO-2025-3487](https://pkg.go.dev/vuln/GO-2025-3487), [GO-2025-3503](https://pkg.go.dev/vuln/GO-2025-3503) and [GO-2025-3595](https://pkg.go.dev/vuln/GO-2025-3595) in https://github.com/labstack/echo/pull/2780
+
+
+## v4.13.3 - 2024-12-19
+
+**Security**
+
+* Update golang.org/x/net dependency [GO-2024-3333](https://pkg.go.dev/vuln/GO-2024-3333) in https://github.com/labstack/echo/pull/2722
+
+
+## v4.13.2 - 2024-12-12
+
+**Security**
+
+* Update dependencies (dependabot reports [GO-2024-3321](https://pkg.go.dev/vuln/GO-2024-3321)) in https://github.com/labstack/echo/pull/2721
+
+
+## v4.13.1 - 2024-12-11
+
+**Fixes**
+
+* Fix BindBody ignoring `Transfer-Encoding: chunked` requests by @178inaba in https://github.com/labstack/echo/pull/2717
+
+
+
+## v4.13.0 - 2024-12-04
+
+**BREAKING CHANGE** JWT Middleware Removed from Core use [labstack/echo-jwt](https://github.com/labstack/echo-jwt) instead
+
+The JWT middleware has been **removed from Echo core** due to another security vulnerability, [CVE-2024-51744](https://nvd.nist.gov/vuln/detail/CVE-2024-51744). For more details, refer to issue [#2699](https://github.com/labstack/echo/issues/2699). A drop-in replacement is available in the [labstack/echo-jwt](https://github.com/labstack/echo-jwt) repository.
+
+**Important**: Direct assignments like `token := c.Get("user").(*jwt.Token)` will now cause a panic due to an invalid cast. Update your code accordingly. Replace the current imports from `"github.com/golang-jwt/jwt"` in your handlers to the new middleware version using `"github.com/golang-jwt/jwt/v5"`.
+
+
+Background:
+
+The version of `golang-jwt/jwt` (v3.2.2) previously used in Echo core has been in an unmaintained state for some time. This is not the first vulnerability affecting this library; earlier issues were addressed in [PR #1946](https://github.com/labstack/echo/pull/1946).
+JWT middleware was marked as deprecated in Echo core as of [v4.10.0](https://github.com/labstack/echo/releases/tag/v4.10.0) on 2022-12-27. If you did not notice that, consider leveraging tools like [Staticcheck](https://staticcheck.dev/) to catch such deprecations earlier in you dev/CI flow. For bonus points - check out [gosec](https://github.com/securego/gosec).
+
+We sincerely apologize for any inconvenience caused by this change. While we strive to maintain backward compatibility within Echo core, recurring security issues with third-party dependencies have forced this decision.
+
+**Enhancements**
+
+* remove jwt middleware by @stevenwhitehead in https://github.com/labstack/echo/pull/2701
+* optimization: struct alignment by @behnambm in https://github.com/labstack/echo/pull/2636
+* bind: Maintain backwards compatibility for map[string]interface{} binding by @thesaltree in https://github.com/labstack/echo/pull/2656
+* Add Go 1.23 to CI by @aldas in https://github.com/labstack/echo/pull/2675
+* improve `MultipartForm` test by @martinyonatann in https://github.com/labstack/echo/pull/2682
+* `bind` : add support of multipart multi files by @martinyonatann in https://github.com/labstack/echo/pull/2684
+* Add TemplateRenderer struct to ease creating renderers for `html/template` and `text/template` packages. by @aldas in https://github.com/labstack/echo/pull/2690
+* Refactor TestBasicAuth to utilize table-driven test format by @ErikOlson in https://github.com/labstack/echo/pull/2688
+* Remove broken header by @aldas in https://github.com/labstack/echo/pull/2705
+* fix(bind body): content-length can be -1 by @phamvinhdat in https://github.com/labstack/echo/pull/2710
+* CORS middleware should compile allowOrigin regexp at creation by @aldas in https://github.com/labstack/echo/pull/2709
+* Shorten Github issue template and add test example by @aldas in https://github.com/labstack/echo/pull/2711
+
+
+## v4.12.0 - 2024-04-15
+
+**Security**
+
+* Update golang.org/x/net dep because of [GO-2024-2687](https://pkg.go.dev/vuln/GO-2024-2687) by @aldas in https://github.com/labstack/echo/pull/2625
+
+
+**Enhancements**
+
+* binder: make binding to Map work better with string destinations by @aldas in https://github.com/labstack/echo/pull/2554
+* README.md: add Encore as sponsor by @marcuskohlberg in https://github.com/labstack/echo/pull/2579
+* Reorder paragraphs in README.md by @aldas in https://github.com/labstack/echo/pull/2581
+* CI: upgrade actions/checkout to v4 by @aldas in https://github.com/labstack/echo/pull/2584
+* Remove default charset from 'application/json' Content-Type header by @doortts in https://github.com/labstack/echo/pull/2568
+* CI: Use Go 1.22 by @aldas in https://github.com/labstack/echo/pull/2588
+* binder: allow binding to a nil map by @georgmu in https://github.com/labstack/echo/pull/2574
+* Add Skipper Unit Test In BasicBasicAuthConfig and Add More Detail Explanation regarding BasicAuthValidator by @RyoKusnadi in https://github.com/labstack/echo/pull/2461
+* fix some typos by @teslaedison in https://github.com/labstack/echo/pull/2603
+* fix: some typos by @pomadev in https://github.com/labstack/echo/pull/2596
+* Allow ResponseWriters to unwrap writers when flushing/hijacking by @aldas in https://github.com/labstack/echo/pull/2595
+* Add SPDX licence comments to files. by @aldas in https://github.com/labstack/echo/pull/2604
+* Upgrade deps by @aldas in https://github.com/labstack/echo/pull/2605
+* Change type definition blocks to single declarations. This helps copy… by @aldas in https://github.com/labstack/echo/pull/2606
+* Fix Real IP logic by @cl-bvl in https://github.com/labstack/echo/pull/2550
+* Default binder can use `UnmarshalParams(params []string) error` inter… by @aldas in https://github.com/labstack/echo/pull/2607
+* Default binder can bind pointer to slice as struct field. For example `*[]string` by @aldas in https://github.com/labstack/echo/pull/2608
+* Remove maxparam dependence from Context by @aldas in https://github.com/labstack/echo/pull/2611
+* When route is registered with empty path it is normalized to `/`. by @aldas in https://github.com/labstack/echo/pull/2616
+* proxy middleware should use httputil.ReverseProxy for SSE requests by @aldas in https://github.com/labstack/echo/pull/2624
+
+
+## v4.11.4 - 2023-12-20
+
+**Security**
+
+* Upgrade golang.org/x/crypto to v0.17.0 to fix vulnerability [issue](https://pkg.go.dev/vuln/GO-2023-2402) [#2562](https://github.com/labstack/echo/pull/2562)
+
+**Enhancements**
+
+* Update deps and mark Go version to 1.18 as this is what golang.org/x/* use [#2563](https://github.com/labstack/echo/pull/2563)
+* Request logger: add example for Slog https://pkg.go.dev/log/slog [#2543](https://github.com/labstack/echo/pull/2543)
+
+
+## v4.11.3 - 2023-11-07
+
+**Security**
+
+* 'c.Attachment' and 'c.Inline' should escape filename in 'Content-Disposition' header to avoid 'Reflect File Download' vulnerability. [#2541](https://github.com/labstack/echo/pull/2541)
+
+**Enhancements**
+
+* Tests: refactor context tests to be separate functions [#2540](https://github.com/labstack/echo/pull/2540)
+* Proxy middleware: reuse echo request context [#2537](https://github.com/labstack/echo/pull/2537)
+* Mark unmarshallable yaml struct tags as ignored [#2536](https://github.com/labstack/echo/pull/2536)
+
+
+## v4.11.2 - 2023-10-11
+
+**Security**
+
+* Bump golang.org/x/net to prevent CVE-2023-39325 / CVE-2023-44487 HTTP/2 Rapid Reset Attack [#2527](https://github.com/labstack/echo/pull/2527)
+* fix(sec): randomString bias introduced by #2490 [#2492](https://github.com/labstack/echo/pull/2492)
+* CSRF/RequestID mw: switch math/random usage to crypto/random [#2490](https://github.com/labstack/echo/pull/2490)
+
+**Enhancements**
+
+* Delete unused context in body_limit.go [#2483](https://github.com/labstack/echo/pull/2483)
+* Use Go 1.21 in CI [#2505](https://github.com/labstack/echo/pull/2505)
+* Fix some typos [#2511](https://github.com/labstack/echo/pull/2511)
+* Allow CORS middleware to send Access-Control-Max-Age: 0 [#2518](https://github.com/labstack/echo/pull/2518)
+* Bump dependancies [#2522](https://github.com/labstack/echo/pull/2522)
+
+## v4.11.1 - 2023-07-16
+
+**Fixes**
+
+* Fix `Gzip` middleware not sending response code for no content responses (404, 301/302 redirects etc) [#2481](https://github.com/labstack/echo/pull/2481)
+
+
+## v4.11.0 - 2023-07-14
+
+
+**Fixes**
+
+* Fixes the proxy middleware concurrency issue of calling the Next() proxy target on Round Robin Balancer [#2409](https://github.com/labstack/echo/pull/2409)
+* Fix `group.RouteNotFound` not working when group has attached middlewares [#2411](https://github.com/labstack/echo/pull/2411)
+* Fix global error handler return error message when message is an error [#2456](https://github.com/labstack/echo/pull/2456)
+* Do not use global timeNow variables [#2477](https://github.com/labstack/echo/pull/2477)
+
+
+**Enhancements**
+
+* Added a optional config variable to disable centralized error handler in recovery middleware [#2410](https://github.com/labstack/echo/pull/2410)
+* refactor: use `strings.ReplaceAll` directly [#2424](https://github.com/labstack/echo/pull/2424)
+* Add support for Go1.20 `http.rwUnwrapper` to Response struct [#2425](https://github.com/labstack/echo/pull/2425)
+* Check whether is nil before invoking centralized error handling [#2429](https://github.com/labstack/echo/pull/2429)
+* Proper colon support in `echo.Reverse` method [#2416](https://github.com/labstack/echo/pull/2416)
+* Fix misuses of a vs an in documentation comments [#2436](https://github.com/labstack/echo/pull/2436)
+* Add link to slog.Handler library for Echo logging into README.md [#2444](https://github.com/labstack/echo/pull/2444)
+* In proxy middleware Support retries of failed proxy requests [#2414](https://github.com/labstack/echo/pull/2414)
+* gofmt fixes to comments [#2452](https://github.com/labstack/echo/pull/2452)
+* gzip response only if it exceeds a minimal length [#2267](https://github.com/labstack/echo/pull/2267)
+* Upgrade packages [#2475](https://github.com/labstack/echo/pull/2475)
+
+
+## v4.10.2 - 2023-02-22
+
+**Security**
+
+* `filepath.Clean` behaviour has changed in Go 1.20 - adapt to it [#2406](https://github.com/labstack/echo/pull/2406)
+* Add `middleware.CORSConfig.UnsafeWildcardOriginWithAllowCredentials` to make UNSAFE usages of wildcard origin + allow cretentials less likely [#2405](https://github.com/labstack/echo/pull/2405)
+
+**Enhancements**
+
+* Add more HTTP error values [#2277](https://github.com/labstack/echo/pull/2277)
+
+
+## v4.10.1 - 2023-02-19
+
+**Security**
+
+* Upgrade deps due to the latest golang.org/x/net vulnerability [#2402](https://github.com/labstack/echo/pull/2402)
+
+
+**Enhancements**
+
+* Add new JWT repository to the README [#2377](https://github.com/labstack/echo/pull/2377)
+* Return an empty string for ctx.path if there is no registered path [#2385](https://github.com/labstack/echo/pull/2385)
+* Add context timeout middleware [#2380](https://github.com/labstack/echo/pull/2380)
+* Update link to jaegertracing [#2394](https://github.com/labstack/echo/pull/2394)
+
+
+## v4.10.0 - 2022-12-27
+
+**Security**
+
+* We are deprecating JWT middleware in this repository. Please use https://github.com/labstack/echo-jwt instead.
+
+ JWT middleware is moved to separate repository to allow us to bump/upgrade version of JWT implementation (`github.com/golang-jwt/jwt`) we are using
+which we can not do in Echo core because this would break backwards compatibility guarantees we try to maintain.
+
+* This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are
+ several vulnerabilities fixed in these libraries.
+
+ Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise.
+
+
+**Enhancements**
+
+* Bump x/text to 0.3.8 [#2305](https://github.com/labstack/echo/pull/2305)
+* Bump dependencies and add notes about Go releases we support [#2336](https://github.com/labstack/echo/pull/2336)
+* Add helper interface for ProxyBalancer interface [#2316](https://github.com/labstack/echo/pull/2316)
+* Expose `middleware.CreateExtractors` function so we can use it from echo-contrib repository [#2338](https://github.com/labstack/echo/pull/2338)
+* Refactor func(Context) error to HandlerFunc [#2315](https://github.com/labstack/echo/pull/2315)
+* Improve function comments [#2329](https://github.com/labstack/echo/pull/2329)
+* Add new method HTTPError.WithInternal [#2340](https://github.com/labstack/echo/pull/2340)
+* Replace io/ioutil package usages [#2342](https://github.com/labstack/echo/pull/2342)
+* Add staticcheck to CI flow [#2343](https://github.com/labstack/echo/pull/2343)
+* Replace relative path determination from proprietary to std [#2345](https://github.com/labstack/echo/pull/2345)
+* Remove square brackets from ipv6 addresses in XFF (X-Forwarded-For header) [#2182](https://github.com/labstack/echo/pull/2182)
+* Add testcases for some BodyLimit middleware configuration options [#2350](https://github.com/labstack/echo/pull/2350)
+* Additional configuration options for RequestLogger and Logger middleware [#2341](https://github.com/labstack/echo/pull/2341)
+* Add route to request log [#2162](https://github.com/labstack/echo/pull/2162)
+* GitHub Workflows security hardening [#2358](https://github.com/labstack/echo/pull/2358)
+* Add govulncheck to CI and bump dependencies [#2362](https://github.com/labstack/echo/pull/2362)
+* Fix rate limiter docs [#2366](https://github.com/labstack/echo/pull/2366)
+* Refactor how `e.Routes()` work and introduce `e.OnAddRouteHandler` callback [#2337](https://github.com/labstack/echo/pull/2337)
+
+
+## v4.9.1 - 2022-10-12
+
+**Fixes**
+
+* Fix logger panicing (when template is set to empty) by bumping dependency version [#2295](https://github.com/labstack/echo/issues/2295)
+
+**Enhancements**
+
+* Improve CORS documentation [#2272](https://github.com/labstack/echo/pull/2272)
+* Update readme about supported Go versions [#2291](https://github.com/labstack/echo/pull/2291)
+* Tests: improve error handling on closing body [#2254](https://github.com/labstack/echo/pull/2254)
+* Tests: refactor some of the assertions in tests [#2275](https://github.com/labstack/echo/pull/2275)
+* Tests: refactor assertions [#2301](https://github.com/labstack/echo/pull/2301)
+
+## v4.9.0 - 2022-09-04
+
+**Security**
+
+* Fix open redirect vulnerability in handlers serving static directories (e.Static, e.StaticFs, echo.StaticDirectoryHandler) [#2260](https://github.com/labstack/echo/pull/2260)
+
+**Enhancements**
+
+* Allow configuring ErrorHandler in CSRF middleware [#2257](https://github.com/labstack/echo/pull/2257)
+* Replace HTTP method constants in tests with stdlib constants [#2247](https://github.com/labstack/echo/pull/2247)
+
+
+## v4.8.0 - 2022-08-10
+
+**Most notable things**
+
+You can now add any arbitrary HTTP method type as a route [#2237](https://github.com/labstack/echo/pull/2237)
+```go
+e.Add("COPY", "/*", func(c echo.Context) error
+ return c.String(http.StatusOK, "OK COPY")
+})
+```
+
+You can add custom 404 handler for specific paths [#2217](https://github.com/labstack/echo/pull/2217)
+```go
+e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })
+
+g := e.Group("/images")
+g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })
+```
+
+**Enhancements**
+
+* Add new value binding methods (UnixTimeMilli,TextUnmarshaler,JSONUnmarshaler) to Valuebinder [#2127](https://github.com/labstack/echo/pull/2127)
+* Refactor: body_limit middleware unit test [#2145](https://github.com/labstack/echo/pull/2145)
+* Refactor: Timeout mw: rework how test waits for timeout. [#2187](https://github.com/labstack/echo/pull/2187)
+* BasicAuth middleware returns 500 InternalServerError on invalid base64 strings but should return 400 [#2191](https://github.com/labstack/echo/pull/2191)
+* Refactor: duplicated findStaticChild process at findChildWithLabel [#2176](https://github.com/labstack/echo/pull/2176)
+* Allow different param names in different methods with same path scheme [#2209](https://github.com/labstack/echo/pull/2209)
+* Add support for registering handlers for different 404 routes [#2217](https://github.com/labstack/echo/pull/2217)
+* Middlewares should use errors.As() instead of type assertion on HTTPError [#2227](https://github.com/labstack/echo/pull/2227)
+* Allow arbitrary HTTP method types to be added as routes [#2237](https://github.com/labstack/echo/pull/2237)
+
+## v4.7.2 - 2022-03-16
+
+**Fixes**
+
+* Fix nil pointer exception when calling Start again after address binding error [#2131](https://github.com/labstack/echo/pull/2131)
+* Fix CSRF middleware not being able to extract token from multipart/form-data form [#2136](https://github.com/labstack/echo/pull/2136)
+* Fix Timeout middleware write race [#2126](https://github.com/labstack/echo/pull/2126)
+
+**Enhancements**
+
+* Recover middleware should not log panic for aborted handler [#2134](https://github.com/labstack/echo/pull/2134)
+
+
+## v4.7.1 - 2022-03-13
+
+**Fixes**
+
+* Fix `e.Static`, `.File()`, `c.Attachment()` being picky with paths starting with `./`, `../` and `/` after 4.7.0 introduced echo.Filesystem support (Go1.16+) [#2123](https://github.com/labstack/echo/pull/2123)
+
+**Enhancements**
+
+* Remove some unused code [#2116](https://github.com/labstack/echo/pull/2116)
+
+
+## v4.7.0 - 2022-03-01
+
+**Enhancements**
+
+* Add JWT, KeyAuth, CSRF multivalue extractors [#2060](https://github.com/labstack/echo/pull/2060)
+* Add LogErrorFunc to recover middleware [#2072](https://github.com/labstack/echo/pull/2072)
+* Add support for HEAD method query params binding [#2027](https://github.com/labstack/echo/pull/2027)
+* Improve filesystem support with echo.FileFS, echo.StaticFS, group.FileFS, group.StaticFS [#2064](https://github.com/labstack/echo/pull/2064)
+
+**Fixes**
+
+* Fix X-Real-IP bug, improve tests [#2007](https://github.com/labstack/echo/pull/2007)
+* Minor syntax fixes [#1994](https://github.com/labstack/echo/pull/1994), [#2102](https://github.com/labstack/echo/pull/2102), [#2102](https://github.com/labstack/echo/pull/2102)
+
+**General**
+
+* Add cache-control and connection headers [#2103](https://github.com/labstack/echo/pull/2103)
+* Add Retry-After header constant [#2078](https://github.com/labstack/echo/pull/2078)
+* Upgrade `go` directive in `go.mod` to 1.17 [#2049](https://github.com/labstack/echo/pull/2049)
+* Add Pagoda [#2077](https://github.com/labstack/echo/pull/2077) and Souin [#2069](https://github.com/labstack/echo/pull/2069) to 3rd-party middlewares in README
+
+## v4.6.3 - 2022-01-10
+
+**Fixes**
+
+* Fixed Echo version number in greeting message which was not incremented to `4.6.2` [#2066](https://github.com/labstack/echo/issues/2066)
+
+
+## v4.6.2 - 2022-01-08
+
+**Fixes**
+
+* Fixed route containing escaped colon should be matchable but is not matched to request path [#2047](https://github.com/labstack/echo/pull/2047)
+* Fixed a problem that returned wrong content-encoding when the gzip compressed content was empty. [#1921](https://github.com/labstack/echo/pull/1921)
+* Update (test) dependencies [#2021](https://github.com/labstack/echo/pull/2021)
+
+
+**Enhancements**
+
+* Add support for configurable target header for the request_id middleware [#2040](https://github.com/labstack/echo/pull/2040)
+* Change decompress middleware to use stream decompression instead of buffering [#2018](https://github.com/labstack/echo/pull/2018)
+* Documentation updates
+
+
+## v4.6.1 - 2021-09-26
+
+**Enhancements**
+
+* Add start time to request logger middleware values [#1991](https://github.com/labstack/echo/pull/1991)
+
+## v4.6.0 - 2021-09-20
+
+Introduced a new [request logger](https://github.com/labstack/echo/blob/master/middleware/request_logger.go) middleware
+to help with cases when you want to use some other logging library in your application.
+
+**Fixes**
+
+* fix timeout middleware warning: superfluous response.WriteHeader [#1905](https://github.com/labstack/echo/issues/1905)
+
+**Enhancements**
+
+* Add Cookie to KeyAuth middleware's KeyLookup [#1929](https://github.com/labstack/echo/pull/1929)
+* JWT middleware should ignore case of auth scheme in request header [#1951](https://github.com/labstack/echo/pull/1951)
+* Refactor default error handler to return first if response is already committed [#1956](https://github.com/labstack/echo/pull/1956)
+* Added request logger middleware which helps to use custom logger library for logging requests. [#1980](https://github.com/labstack/echo/pull/1980)
+* Allow escaping of colon in route path so Google Cloud API "custom methods" could be implemented [#1988](https://github.com/labstack/echo/pull/1988)
+
+## v4.5.0 - 2021-08-01
+
+**Important notes**
+
+A **BREAKING CHANGE** is introduced for JWT middleware users.
+The JWT library used for the JWT middleware had to be changed from [github.com/dgrijalva/jwt-go](https://github.com/dgrijalva/jwt-go) to
+[github.com/golang-jwt/jwt](https://github.com/golang-jwt/jwt) due former library being unmaintained and affected by security
+issues.
+The [github.com/golang-jwt/jwt](https://github.com/golang-jwt/jwt) project is a drop-in replacement, but supports only the latest 2 Go versions.
+So for JWT middleware users Go 1.15+ is required. For detailed information please read [#1940](https://github.com/labstack/echo/discussions/)
+
+To change the library imports in all .go files in your project replace all occurrences of `dgrijalva/jwt-go` with `golang-jwt/jwt`.
+
+For Linux CLI you can use:
+```bash
+find -type f -name "*.go" -exec sed -i "s/dgrijalva\/jwt-go/golang-jwt\/jwt/g" {} \;
+go mod tidy
+```
+
+**Fixes**
+
+* Change JWT library to `github.com/golang-jwt/jwt` [#1946](https://github.com/labstack/echo/pull/1946)
+
+## v4.4.0 - 2021-07-12
+
+**Fixes**
+
+* Split HeaderXForwardedFor header only by comma [#1878](https://github.com/labstack/echo/pull/1878)
+* Fix Timeout middleware Context propagation [#1910](https://github.com/labstack/echo/pull/1910)
+
+**Enhancements**
+
+* Bind data using headers as source [#1866](https://github.com/labstack/echo/pull/1866)
+* Adds JWTConfig.ParseTokenFunc to JWT middleware to allow different libraries implementing JWT parsing. [#1887](https://github.com/labstack/echo/pull/1887)
+* Adding tests for Echo#Host [#1895](https://github.com/labstack/echo/pull/1895)
+* Adds RequestIDHandler function to RequestID middleware [#1898](https://github.com/labstack/echo/pull/1898)
+* Allow for custom JSON encoding implementations [#1880](https://github.com/labstack/echo/pull/1880)
+
+## v4.3.0 - 2021-05-08
+
+**Important notes**
+
+* Route matching has improvements for following cases:
+ 1. Correctly match routes with parameter part as last part of route (with trailing backslash)
+ 2. Considering handlers when resolving routes and search for matching http method handler
+* Echo minimal Go version is now 1.13.
+
+**Fixes**
+
+* When url ends with slash first param route is the match [#1804](https://github.com/labstack/echo/pull/1812)
+* Router should check if node is suitable as matching route by path+method and if not then continue search in tree [#1808](https://github.com/labstack/echo/issues/1808)
+* Fix timeout middleware not writing response correctly when handler panics [#1864](https://github.com/labstack/echo/pull/1864)
+* Fix binder not working with embedded pointer structs [#1861](https://github.com/labstack/echo/pull/1861)
+* Add Go 1.16 to CI and drop 1.12 specific code [#1850](https://github.com/labstack/echo/pull/1850)
+
+**Enhancements**
+
+* Make KeyFunc public in JWT middleware [#1756](https://github.com/labstack/echo/pull/1756)
+* Add support for optional filesystem to the static middleware [#1797](https://github.com/labstack/echo/pull/1797)
+* Add a custom error handler to key-auth middleware [#1847](https://github.com/labstack/echo/pull/1847)
+* Allow JWT token to be looked up from multiple sources [#1845](https://github.com/labstack/echo/pull/1845)
+
+## v4.2.2 - 2021-04-07
+
+**Fixes**
+
+* Allow proxy middleware to use query part in rewrite (#1802)
+* Fix timeout middleware not sending status code when handler returns an error (#1805)
+* Fix Bind() when target is array/slice and path/query params complains bind target not being struct (#1835)
+* Fix panic in redirect middleware on short host name (#1813)
+* Fix timeout middleware docs (#1836)
+
+## v4.2.1 - 2021-03-08
+
+**Important notes**
+
+Due to a datarace the config parameters for the newly added timeout middleware required a change.
+See the [docs](https://echo.labstack.com/middleware/timeout).
+A performance regression has been fixed, even bringing better performance than before for some routing scenarios.
+
+**Fixes**
+
+* Fix performance regression caused by path escaping (#1777, #1798, #1799, aldas)
+* Avoid context canceled errors (#1789, clwluvw)
+* Improve router to use on stack backtracking (#1791, aldas, stffabi)
+* Fix panic in timeout middleware not being not recovered and cause application crash (#1794, aldas)
+* Fix Echo.Serve() not serving on HTTP port correctly when TLSListener is used (#1785, #1793, aldas)
+* Apply go fmt (#1788, Le0tk0k)
+* Uses strings.Equalfold (#1790, rkilingr)
+* Improve code quality (#1792, withshubh)
+
+This release was made possible by our **contributors**:
+aldas, clwluvw, lammel, Le0tk0k, maciej-jezierski, rkilingr, stffabi, withshubh
+
+## v4.2.0 - 2021-02-11
+
+**Important notes**
+
+The behaviour for binding data has been reworked for compatibility with echo before v4.1.11 by
+enforcing `explicit tagging` for processing parameters. This **may break** your code if you
+expect combined handling of query/path/form params.
+Please see the updated documentation for [request](https://echo.labstack.com/guide/request) and [binding](https://echo.labstack.com/guide/request)
+
+The handling for rewrite rules has been slightly adjusted to expand `*` to a non-greedy `(.*?)` capture group. This is only relevant if multiple asterisks are used in your rules.
+Please see [rewrite](https://echo.labstack.com/middleware/rewrite) and [proxy](https://echo.labstack.com/middleware/proxy) for details.
+
+**Security**
+
+* Fix directory traversal vulnerability for Windows (#1718, little-cui)
+* Fix open redirect vulnerability with trailing slash (#1771,#1775 aldas,GeoffreyFrogeye)
+
+**Enhancements**
+
+* Add Echo#ListenerNetwork as configuration (#1667, pafuent)
+* Add ability to change the status code using response beforeFuncs (#1706, RashadAnsari)
+* Echo server startup to allow data race free access to listener address
+* Binder: Restore pre v4.1.11 behaviour for c.Bind() to use query params only for GET or DELETE methods (#1727, aldas)
+* Binder: Add separate methods to bind only query params, path params or request body (#1681, aldas)
+* Binder: New fluent binder for query/path/form parameter binding (#1717, #1736, aldas)
+* Router: Performance improvements for missed routes (#1689, pafuent)
+* Router: Improve performance for Real-IP detection using IndexByte instead of Split (#1640, imxyb)
+* Middleware: Support real regex rules for rewrite and proxy middleware (#1767)
+* Middleware: New rate limiting middleware (#1724, iambenkay)
+* Middleware: New timeout middleware implementation for go1.13+ (#1743, )
+* Middleware: Allow regex pattern for CORS middleware (#1623, KlotzAndrew)
+* Middleware: Add IgnoreBase parameter to static middleware (#1701, lnenad, iambenkay)
+* Middleware: Add an optional custom function to CORS middleware to validate origin (#1651, curvegrid)
+* Middleware: Support form fields in JWT middleware (#1704, rkfg)
+* Middleware: Use sync.Pool for (de)compress middleware to improve performance (#1699, #1672, pafuent)
+* Middleware: Add decompress middleware to support gzip compressed requests (#1687, arun0009)
+* Middleware: Add ErrJWTInvalid for JWT middleware (#1627, juanbelieni)
+* Middleware: Add SameSite mode for CSRF cookies to support iframes (#1524, pr0head)
+
+**Fixes**
+
+* Fix handling of special trailing slash case for partial prefix (#1741, stffabi)
+* Fix handling of static routes with trailing slash (#1747)
+* Fix Static files route not working (#1671, pwli0755, lammel)
+* Fix use of caret(^) in regex for rewrite middleware (#1588, chotow)
+* Fix Echo#Reverse for Any type routes (#1695, pafuent)
+* Fix Router#Find panic with infinite loop (#1661, pafuent)
+* Fix Router#Find panic fails on Param paths (#1659, pafuent)
+* Fix DefaultHTTPErrorHandler with Debug=true (#1477, lammel)
+* Fix incorrect CORS headers (#1669, ulasakdeniz)
+* Fix proxy middleware rewritePath to use url with updated tests (#1630, arun0009)
+* Fix rewritePath for proxy middleware to use escaped path in (#1628, arun0009)
+* Remove unless defer (#1656, imxyb)
+
+**General**
+
+* New maintainers for Echo: Roland Lammel (@lammel) and Pablo Andres Fuente (@pafuent)
+* Add GitHub action to compare benchmarks (#1702, pafuent)
+* Binding query/path params and form fields to struct only works for explicit tags (#1729,#1734, aldas)
+* Add support for Go 1.15 in CI (#1683, asahasrabuddhe)
+* Add test for request id to remain unchanged if provided (#1719, iambenkay)
+* Refactor echo instance listener access and startup to speed up testing (#1735, aldas)
+* Refactor and improve various tests for binding and routing
+* Run test workflow only for relevant changes (#1637, #1636, pofl)
+* Update .travis.yml (#1662, santosh653)
+* Update README.md with an recents framework benchmark (#1679, pafuent)
+
+This release was made possible by **over 100 commits** from more than **20 contributors**:
+asahasrabuddhe, aldas, AndrewKlotz, arun0009, chotow, curvegrid, iambenkay, imxyb,
+juanbelieni, lammel, little-cui, lnenad, pafuent, pofl, pr0head, pwli, RashadAnsari,
+rkfg, santosh653, segfiner, stffabi, ulasakdeniz
diff --git a/CLAUDE.md b/CLAUDE.md
new file mode 100644
index 000000000..decbf0792
--- /dev/null
+++ b/CLAUDE.md
@@ -0,0 +1,99 @@
+# CLAUDE.md
+
+This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
+
+## About This Project
+
+Echo is a high performance, minimalist Go web framework. This is the main repository for Echo v4, which is available as a Go module at `github.com/labstack/echo/v4`.
+
+## Development Commands
+
+The project uses a Makefile for common development tasks:
+
+- `make check` - Run linting, vetting, and race condition tests (default target)
+- `make init` - Install required linting tools (golint, staticcheck)
+- `make lint` - Run staticcheck and golint
+- `make vet` - Run go vet
+- `make test` - Run short tests
+- `make race` - Run tests with race detector
+- `make benchmark` - Run benchmarks
+
+Example commands for development:
+```bash
+# Setup development environment
+make init
+
+# Run all checks (lint, vet, race)
+make check
+
+# Run specific tests
+go test ./middleware/...
+go test -race ./...
+
+# Run benchmarks
+make benchmark
+```
+
+## Code Architecture
+
+### Core Components
+
+**Echo Instance (`echo.go`)**
+- The `Echo` struct is the top-level framework instance
+- Contains router, middleware stacks, and server configuration
+- Not goroutine-safe for mutations after server start
+
+**Context (`context.go`)**
+- The `Context` interface represents HTTP request/response context
+- Provides methods for request/response handling, path parameters, data binding
+- Core abstraction for request processing
+
+**Router (`router.go`)**
+- Radix tree-based HTTP router with smart route prioritization
+- Supports static routes, parameterized routes (`/users/:id`), and wildcard routes (`/static/*`)
+- Each HTTP method has its own routing tree
+
+**Middleware (`middleware/`)**
+- Extensive middleware system with 50+ built-in middlewares
+- Middleware can be applied at Echo, Group, or individual route level
+- Common middleware: Logger, Recover, CORS, JWT, Rate Limiting, etc.
+
+### Key Patterns
+
+**Middleware Chain**
+- Pre-middleware runs before routing
+- Regular middleware runs after routing but before handlers
+- Middleware functions have signature `func(next echo.HandlerFunc) echo.HandlerFunc`
+
+**Route Groups**
+- Routes can be grouped with common prefixes and middleware
+- Groups support nested sub-groups
+- Defined in `group.go`
+
+**Data Binding**
+- Automatic binding of request data (JSON, XML, form) to Go structs
+- Implemented in `binder.go` with support for custom binders
+
+**Error Handling**
+- Centralized error handling via `HTTPErrorHandler`
+- Automatic panic recovery with stack traces
+
+## File Organization
+
+- Root directory: Core Echo functionality (echo.go, context.go, router.go, etc.)
+- `middleware/`: All built-in middleware implementations
+- `_test/`: Test fixtures and utilities
+- `_fixture/`: Test data files
+
+## Code Style
+
+- Go code uses tabs for indentation (per .editorconfig)
+- Follows standard Go conventions and formatting
+- Uses gofmt, golint, and staticcheck for code quality
+
+## Testing
+
+- Standard Go testing with `testing` package
+- Tests include unit tests, integration tests, and benchmarks
+- Race condition testing is required (`make race`)
+- Test files follow `*_test.go` naming convention
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
index b5b006b4e..2f18411bd 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,6 +1,6 @@
The MIT License (MIT)
-Copyright (c) 2017 LabStack
+Copyright (c) 2022 LabStack
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
diff --git a/Makefile b/Makefile
index dfcb6c02b..bd075bbae 100644
--- a/Makefile
+++ b/Makefile
@@ -1,3 +1,32 @@
-tag:
- @git tag `grep -P '^\tversion = ' echo.go|cut -f2 -d'"'`
- @git tag|grep -v ^v
+PKG := "github.com/labstack/echo"
+PKG_LIST := $(shell go list ${PKG}/...)
+
+.DEFAULT_GOAL := check
+check: lint vet race ## Check project
+
+init:
+ @go install golang.org/x/lint/golint@latest
+ @go install honnef.co/go/tools/cmd/staticcheck@latest
+
+lint: ## Lint the files
+ @staticcheck ${PKG_LIST}
+ @golint -set_exit_status ${PKG_LIST}
+
+vet: ## Vet the files
+ @go vet ${PKG_LIST}
+
+test: ## Run tests
+ @go test -short ${PKG_LIST}
+
+race: ## Run tests with data race detector
+ @go test -race ${PKG_LIST}
+
+benchmark: ## Run benchmarks
+ @go test -run="-" -benchmem -bench=".*" ${PKG_LIST}
+
+help: ## Display this help screen
+ @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
+
+goversion ?= "1.25"
+test_version: ## Run tests inside Docker with given version (defaults to 1.25 oldest supported). Example: make test_version goversion=1.25
+ @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check"
diff --git a/README.md b/README.md
index 0da031225..ca6dfbf5d 100644
--- a/README.md
+++ b/README.md
@@ -1,30 +1,24 @@
-
-
[](https://sourcegraph.com/github.com/labstack/echo?badge)
-[](http://godoc.org/github.com/labstack/echo)
+[](https://pkg.go.dev/github.com/labstack/echo/v4)
[](https://goreportcard.com/report/github.com/labstack/echo)
-[](https://travis-ci.org/labstack/echo)
+[](https://github.com/labstack/echo/actions)
[](https://codecov.io/gh/labstack/echo)
-[](https://gitter.im/labstack/echo)
-[](https://forum.labstack.com)
+[](https://github.com/labstack/echo/discussions)
[](https://twitter.com/labstack)
[](https://raw.githubusercontent.com/labstack/echo/master/LICENSE)
-## Supported Go versions
+## Echo
-As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules).
-Therefore a Go version capable of understanding /vN suffixed imports is required:
+High performance, extensible, minimalist Go web framework.
-- 1.9.7+
-- 1.10.3+
-- 1.11+
+* [Official website](https://echo.labstack.com)
+* [Quick start](https://echo.labstack.com/docs/quick-start)
+* [Middlewares](https://echo.labstack.com/docs/category/middleware)
-Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended
-way of using Echo going forward.
+Help and questions: [Github Discussions](https://github.com/labstack/echo/discussions)
-For older versions, please use the latest v3 tag.
-## Feature Overview
+### Feature Overview
- Optimized HTTP router which smartly prioritize routes
- Build robust and scalable RESTful APIs
@@ -40,25 +34,48 @@ For older versions, please use the latest v3 tag.
- Automatic TLS via Let’s Encrypt
- HTTP/2 support
-## Benchmarks
+## Sponsors
-Date: 2018/03/15
-Source: https://github.com/vishr/web-framework-benchmark
-Lower is better!
+
+
-
+Click [here](https://github.com/sponsors/labstack) for more information on sponsorship.
## [Guide](https://echo.labstack.com/guide)
+### Supported Echo versions
+
+- Latest major version of Echo is `v5` as of 2026-01-18.
+ - Until 2026-03-31, any critical issues requiring breaking API changes will be addressed, even if this violates semantic versioning.
+ - See [API_CHANGES_V5.md](./API_CHANGES_V5.md) for public API changes between `v4` and `v5`, notes on upgrading.
+ - If you are using Echo in a production environment, it is recommended to wait until after 2026-03-31 before upgrading.
+- Echo `v4` is supported with **security*** updates and **bug** fixes until **2026-12-31**
+
+
+### Installation
+
+```sh
+// go get github.com/labstack/echo/{version}
+go get github.com/labstack/echo/v5
+```
+Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions.
+
+
### Example
```go
package main
import (
+ "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v5/middleware"
+ "log/slog"
"net/http"
- "github.com/labstack/echo/v4"
- "github.com/labstack/echo/v4/middleware"
)
func main() {
@@ -66,26 +83,50 @@ func main() {
e := echo.New()
// Middleware
- e.Use(middleware.Logger())
- e.Use(middleware.Recover())
+ e.Use(middleware.RequestLogger()) // use the RequestLogger middleware with slog logger
+ e.Use(middleware.Recover()) // recover panics as errors for proper error handling
// Routes
e.GET("/", hello)
// Start server
- e.Logger.Fatal(e.Start(":1323"))
+ if err := e.Start(":8080"); err != nil {
+ slog.Error("failed to start server", "error", err)
+ }
}
// Handler
-func hello(c echo.Context) error {
+func hello(c *echo.Context) error {
return c.String(http.StatusOK, "Hello, World!")
}
```
-## Help
+# Official middleware repositories
+
+Following list of middleware is maintained by Echo team.
+
+| Repository | Description |
+|------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware |
+| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](https://github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares |
+
+# Third-party middleware repositories
+
+Be careful when adding 3rd party middleware. Echo teams does not have time or manpower to guarantee safety and quality
+of middlewares in this list.
+
+| Repository | Description |
+|------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [oapi-codegen/oapi-codegen](https://github.com/oapi-codegen/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator |
+| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. |
+| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. |
+| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. |
+| [github.com/samber/slog-echo](https://github.com/samber/slog-echo) | Go [slog](https://pkg.go.dev/golang.org/x/exp/slog) logging library wrapper for Echo logger interface. |
+| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. |
+| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. |
+| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code |
-- [Forum](https://forum.labstack.com)
-- [Chat](https://gitter.im/labstack/echo)
+Please send a PR to add your own library here.
## Contribute
@@ -104,8 +145,11 @@ func hello(c echo.Context) error {
## Credits
-- [Vishal Rana](https://github.com/vishr) - Author
-- [Nitin Rana](https://github.com/nr17) - Consultant
+- [Vishal Rana](https://github.com/vishr) (Author)
+- [Nitin Rana](https://github.com/nr17) (Consultant)
+- [Roland Lammel](https://github.com/lammel) (Maintainer)
+- [Martti T.](https://github.com/aldas) (Maintainer)
+- [Pablo Andres Fuente](https://github.com/pafuent) (Maintainer)
- [Contributors](https://github.com/labstack/echo/graphs/contributors)
## License
diff --git a/SECURITY.md b/SECURITY.md
new file mode 100644
index 000000000..efb618697
--- /dev/null
+++ b/SECURITY.md
@@ -0,0 +1,15 @@
+# Security Policy
+
+## Supported Versions
+
+| Version | Supported |
+|-----------|-------------------------------------|
+| 5.x.x | :white_check_mark: |
+| >= 4.15.x | :white_check_mark: until 2026.12.31 |
+| < 4.15 | :x: |
+
+## Reporting a Vulnerability
+
+https://github.com/labstack/echo/security/advisories/new
+
+or look for maintainers email(s) in commits and email them.
diff --git a/_fixture/_fixture/README.md b/_fixture/_fixture/README.md
new file mode 100644
index 000000000..21a785851
--- /dev/null
+++ b/_fixture/_fixture/README.md
@@ -0,0 +1 @@
+This directory is used for the static middleware test
\ No newline at end of file
diff --git a/_fixture/certs/README.md b/_fixture/certs/README.md
new file mode 100644
index 000000000..e27d4b139
--- /dev/null
+++ b/_fixture/certs/README.md
@@ -0,0 +1,13 @@
+To generate a valid certificate and private key use the following command:
+
+```bash
+# In OpenSSL ≥ 1.1.1
+openssl req -x509 -newkey rsa:4096 -sha256 -days 9999 -nodes \
+ -keyout key.pem -out cert.pem -subj "/CN=localhost" \
+ -addext "subjectAltName=DNS:localhost,IP:127.0.0.1,IP:::1"
+```
+
+To check a certificate use the following command:
+```bash
+openssl x509 -in cert.pem -text
+```
diff --git a/_fixture/certs/cert.pem b/_fixture/certs/cert.pem
index c58f13fa6..d88cf3fec 100644
--- a/_fixture/certs/cert.pem
+++ b/_fixture/certs/cert.pem
@@ -1,18 +1,30 @@
-----BEGIN CERTIFICATE-----
-MIIC+TCCAeGgAwIBAgIQe/dw9alKTWAPhsHoLdkn+TANBgkqhkiG9w0BAQsFADAS
-MRAwDgYDVQQKEwdBY21lIENvMB4XDTE2MDkyNTAwNDcxN1oXDTE3MDkyNTAwNDcx
-N1owEjEQMA4GA1UEChMHQWNtZSBDbzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC
-AQoCggEBAL8WwhLGbK8HkiEDKV0JbjtWp3/EWKhKFW3YtKtPfPOgoZejdNn9VE0B
-IlQ4rwa1wmsM9NDKC0m60oiNOYeyugx9PoFI3RXzuKVX2x7E5LTW0sv0LC9PCggZ
-MZTih1AiYtwJIZl+aK6s4dTb/PUOLDdcRTZTF2egkdAicbUlQT4Kn+A3jHiE+ATC
-h3MlV2BHarhAhWb0FrOg2bEtFrMyFDaLbHI7xbj+vB9CkGB9L5tObP2M9lQCxH8d
-ElWkJjxg7vdkhJ5+sWNaY80utNipUdVO845tIERwRXRRviFYpOcuNfnJYC9kwRjv
-CRanh3epWhG0cFQVV5d45sHf6t5F+jsCAwEAAaNLMEkwDgYDVR0PAQH/BAQDAgWg
-MBMGA1UdJQQMMAoGCCsGAQUFBwMBMAwGA1UdEwEB/wQCMAAwFAYDVR0RBA0wC4IJ
-bG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IBAQAdd3ZW6R4cImmxIzfoz7Ttq862
-oOiyzFnisCxgNdA78epit49zg0CgF7q9guTEArXJLI+/qnjPPObPOlTlsEyomb2F
-UOS+2hn/ZyU5/tUxhkeOBYqdEaryk6zF6vPLUJ5IphJgOg00uIQGL0UvupBLEyIG
-Rsa/lKEtW5Z9PbIi9GeVn51U+9VMCYft/T7SDziKl7OcE/qoVh1G0/tTRkAqOqpZ
-bzc8ssEhJVNZ/DO+uYHNYf/waB6NjfXQuTegU/SyxnawvQ4oBHIzyuWplGCcTlfT
-IXsOQdJo2xuu8807d+rO1FpN8yWi5OF/0sif0RrocSskLAIL/PI1qfWuuPck
+MIIFODCCAyCgAwIBAgIUaTvDluaMf+VJgYHQ0HFTS3yuCHYwDQYJKoZIhvcNAQEL
+BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIxMDIyNzIxMzQ0MVoXDTQ4MDcx
+NDIxMzQ0MVowFDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEF
+AAOCAg8AMIICCgKCAgEAnqyyAAnWFH2TH7Epj5yfZxYrBvizydZe1Wo/1WpGR2IK
+QT+qIul5sEKX/ERqEOXsawSrL3fw9cuSM8Z2vD/57ZZdoSR7XIdVaMDEQenJ968a
+HObu4D27uBQwIwrM5ELgnd+fC4gis64nIu+2GSfHumZXi7lLW7DbNm8oWkMqI6tY
+2s2wx2hwGYNVJrwSn4WGnkzhQ5U5mkcsLELMx7GR0Qnv6P7sNGZVeqMU7awkcSpR
+crKR1OUP7XCJkEq83WLHSx50+QZv7LiyDmGnujHevRbdSHlcFfHZtaufYat+qICe
+S3XADwRQe/0VSsmja6u3DAHy7VmL8PNisAdkopQZrhiI9OvGrpGZffs9zn+s/jeX
+N1bqVDihCMiEjqXMlHx2oj3AXrZTFxb7y7Ap9C07nf70lpxQWW9SjMYRF98JBiHF
+eJbQkNVkmz6T8ielQbX0l46F2SGK98oyFCGNIAZBUdj5CcS1E6w/lk4t58/em0k7
+3wFC5qg0g0wfIbNSmxljBNxnaBYUqyaaAJJhpaEoOebm4RYV58hQ0FbMfpnLnSh4
+dYStsk6i1PumWoa7D45DTtxF3kH7TB3YOB5aWaNGAPQC1m4Qcd23YB5Rd/ABirSp
+ux6/cFGosjSfJ/G+G0RhNUpmcbDJvFSOhD2WCuieVhCTAzp+VPIA9bSqD+InlT0C
+AwEAAaOBgTB/MB0GA1UdDgQWBBQZyM//SvzYKokQZI/0MVGb6PkH+zAfBgNVHSME
+GDAWgBQZyM//SvzYKokQZI/0MVGb6PkH+zAPBgNVHRMBAf8EBTADAQH/MCwGA1Ud
+EQQlMCOCCWxvY2FsaG9zdIcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG
+9w0BAQsFAAOCAgEAKGAJQmQ/KLw8iMb5QsyxxAonVjJ1eDAhNM3GWdHpM0/GFamO
+vVtATLQQldwDiZJvrsCQPEc8ctZ2Utvg/StLQ3+rZpsvt0+gcUlLJK61qguwYqb2
++T7VK5s7V/OyI/tsuboOW50Pka9vQHV+Z0aM06Yu+HNDAq/UTpEOb/3MQvZd6Ooy
+PTpZtFb/+5jIQa1dIsfFWmpBxF0+wUd9GEkX3j7nekwoZfJ8Ze4GWYERZbOFpDAQ
+rIHdthH5VJztnpQJmaKqzgIOF+Rurwlp5ecSC33xNNjDaYtuf/fiWnoKGhHVSBhT
+61+0yxn3rTgh/Dsm95xY00rSX6lmcvI+kRNTUc8GGPz0ajBH6xyY7bNhfMjmnSW/
+C/XTEDbTAhT7ndWC5vvzp7ZU0TvN+WY6A0f2kxSnnrEk6QRUvRtKkjAkmAFz8exi
+ttBBW0I3E5HNIC5CYRimq/9z+3clM/P1KbNblwuC65bL+PZ+nzFnn5hFaK9eLPol
+OwZQXv7IvAw8GfgLTrEUT7eBCQwe1IqesA7NTxF1BVwmNUb2XamvQZ7ly67QybRw
+0uJq80XjpVjBWYTTQy1dsnC2OTKdqGsV9TVIDR+UGfIG9cxL70pEbiSH2AX+IDCy
+i3kNIvpXgBliAyOjW6Hj1fv6dNfAat/hqEfnquWkfvcs3HNrG/InwpwNAUs=
-----END CERTIFICATE-----
diff --git a/_fixture/certs/key.pem b/_fixture/certs/key.pem
index 9c75e7ca8..0276c224e 100644
--- a/_fixture/certs/key.pem
+++ b/_fixture/certs/key.pem
@@ -1,27 +1,52 @@
------BEGIN RSA PRIVATE KEY-----
-MIIEpAIBAAKCAQEAvxbCEsZsrweSIQMpXQluO1anf8RYqEoVbdi0q09886Chl6N0
-2f1UTQEiVDivBrXCawz00MoLSbrSiI05h7K6DH0+gUjdFfO4pVfbHsTktNbSy/Qs
-L08KCBkxlOKHUCJi3AkhmX5orqzh1Nv89Q4sN1xFNlMXZ6CR0CJxtSVBPgqf4DeM
-eIT4BMKHcyVXYEdquECFZvQWs6DZsS0WszIUNotscjvFuP68H0KQYH0vm05s/Yz2
-VALEfx0SVaQmPGDu92SEnn6xY1pjzS602KlR1U7zjm0gRHBFdFG+IVik5y41+clg
-L2TBGO8JFqeHd6laEbRwVBVXl3jmwd/q3kX6OwIDAQABAoIBAQCR69EcAUZxinh+
-mSl3EIKK8atLGCcTrC8dCQU+ZJ7odFuxrnLHHHrJqvoKEpclqprioKw63G8uSGoJ
-OL8b7tHAQ8v9ciTSZKE2Mhb0MirsJbgnYzhykAr7EDIanbny6a9Qk/CChFNwQDjc
-EXnjsIT3aZC44U7YJXfz1rm6OM7Pjn6z8H4vYGRDOsYkhXvPfnPW8C2LFJVr9nvE
-0gIAOVoGejEJrsJVK3Uj/nPcqSQYXmwEmtjtzOw7u6yp1b2VZEK7tR47HwJt6ltG
-Z9zhpwhpvdOuXNMqMOYRf9bLBWnSqIlTHOO0UlAnyRCY1HxluZB7ZSg9VnoJDrD7
-w+JqAGnBAoGBAO5qyIzjldwR004YjepmZfuX3PnGLZhzhmTTC7Pl9gqv1TvxfxvD
-6yBFL2GrN1IcnrX9Qk2xncUAbpM989MF+EC7I4++1t1I6akUKFEDkfvQwQjCXfPS
-Jv2rkwIVSkt8F0X/tOb13OeIiHuFVI/Bb9VoJSP/k4DfPV+/HnwBxvzLAoGBAM0u
-b/rYfm5rb20/PKClUs154s0eKSokVogqiJkf+5qLsV+TD50JVZBVw8s4XM79iwQI
-PyGY9nI1AvqG7yIzxSy5/Qk1+ZVdVYpmWIO5PnJ8TVraDVhCQ3fVz1uWtcyaqPVr
-3QzdyvsEgFUGFItmRdhSvA8RGrpVCHTBzrDj3jpRAoGBAKNaSLS3jkstb3D3w+yR
-YliisYX1cfIdXTyhmUgWTKD/3oLmsSdt8iC3JoKt1AaPk3Kv5ojjJG0BIcIC1ZeF
-ZJW9Yt0vbXpKZcYyCHmRj6lQW6JLwiG3oH133A62VaQojq2oSONiG4wL8S9oqAqj
-B6PZanEiwIaw7hU3FoTylstHAoGAFYvE0pCdZjb98njrgusZcN5VxLhgFj7On2no
-AjxrjWUR8TleMF1kkM2Qy+xVQp85U+kRyBNp/cA3WduFjQ/mqrW1LpxuYxL0Ap6Q
-uPRg7GDFNr8jG5uJvjHDnpiK6rtq9qqnAczgnc9xMnx699B7kSXO/b4MEnkPdENN
-0yF6mqECgYA88UELxbhqMSdG24DX0zHXvkXLIml2JNVb54glFByIIem+acff9oG9
-X5GajlBroPoKk7FgA9ouqcQMH66UnFi6qh07l0J2xb0aXP8yzLAGauVGTTNIQCR4
-VpqyDpjlc1ZqfZWOrvwSrUH1mEkxbeVvQsOUja2Jvu+lc3Zo099ILw==
------END RSA PRIVATE KEY-----
+-----BEGIN PRIVATE KEY-----
+MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQCerLIACdYUfZMf
+sSmPnJ9nFisG+LPJ1l7Vaj/VakZHYgpBP6oi6XmwQpf8RGoQ5exrBKsvd/D1y5Iz
+xna8P/ntll2hJHtch1VowMRB6cn3rxoc5u7gPbu4FDAjCszkQuCd358LiCKzrici
+77YZJ8e6ZleLuUtbsNs2byhaQyojq1jazbDHaHAZg1UmvBKfhYaeTOFDlTmaRyws
+QszHsZHRCe/o/uw0ZlV6oxTtrCRxKlFyspHU5Q/tcImQSrzdYsdLHnT5Bm/suLIO
+Yae6Md69Ft1IeVwV8dm1q59hq36ogJ5LdcAPBFB7/RVKyaNrq7cMAfLtWYvw82Kw
+B2SilBmuGIj068aukZl9+z3Of6z+N5c3VupUOKEIyISOpcyUfHaiPcBetlMXFvvL
+sCn0LTud/vSWnFBZb1KMxhEX3wkGIcV4ltCQ1WSbPpPyJ6VBtfSXjoXZIYr3yjIU
+IY0gBkFR2PkJxLUTrD+WTi3nz96bSTvfAULmqDSDTB8hs1KbGWME3GdoFhSrJpoA
+kmGloSg55ubhFhXnyFDQVsx+mcudKHh1hK2yTqLU+6ZahrsPjkNO3EXeQftMHdg4
+HlpZo0YA9ALWbhBx3bdgHlF38AGKtKm7Hr9wUaiyNJ8n8b4bRGE1SmZxsMm8VI6E
+PZYK6J5WEJMDOn5U8gD1tKoP4ieVPQIDAQABAoICAEHF2CsH6MOpofi7GT08cR7s
+I33KTcxWngzc9ATk/qjMTO/rEf1Sxmx3zkR1n3nNtQhPcR5GG43nin0HwWQbKOCB
+OeJ4GuKp/o9jiHbCEEQpQyvD1jUBofSV+bYs3e2ogy8t6OGA1tGgWPy0XMlkoff0
+QEnczw3864FO5m0z9h2/Ax//r02ZTw5kUEG0KAwT709jEuVO0AfRhM/8CKKmSola
+EyaDtSmrWbdyLlSuzJRUNFrVBno3UTjdM0iqkks6jN3ojBhFwNNhY/1uIXafAXNk
+LOnD1JYMIHCb6X809VWnqvYgozIWWb5rlA3iM2mITmId1LLqMYX5fWj2R5LUzSek
+H+XG+F9FIouTaL1ACoXr0zyeY5N5YJdyXYa1tThdW+axX9ZrnPgeiQrmxzKPIyb7
+LLlVtNBQUg/t5tX80KyYjkNUu4j3oq/uBYPi0m//ovwMyi9bSbbyPT+cDXuXX5Bc
+oY7wyn3evXX0c1R7vdJLZLkLu+ctVex/9hvMjeW/mMasDjLnqY7pF3Skct1SX5N2
+U8YVU9bGvFpLEwM9lmi/T7bcv+zbmGPlfTsZiFrCsixPLn7sX7y5M4L8au8O0jh0
+nHm/8rWVg1Qw0Hobg3tA8FjeMa8Sr2fYmkNLVKFzhuJLxknTJLaUbX5CymNqWP4H
+OctvfSY0nSZ1eQpBkQaJAoIBAQDTb/NhYCfaJBLXHVMy/VYd7kWGZ+I87artcE/l
+8u0pJ8XOP4kp0otFIumpHUFodysAeP6HrI79MuJB40fy91HzWZC+NrPufFFFuZ0z
+Ld1o3Y5nAeoZmMlf1F12Oe3OQZy7nm9eNNkfeoVtKqDv4FhAqk+aoMor86HscKsR
+C6HlZFdGc7kX0ylrQAXPq9KLhcvUU9oAUpbqTbhYK83IebRJgFDG45HkVo9SUHpF
+dmCFSb91eZpRGpdfNLCuLiSu52TebayaUCnceeAt8SyeiChJ/TwWmRRDJS0QUv6h
+s3Wdp+cx9ANoujA4XzAs8Fld5IZ4bcG5jjwD62/tJyWrCC5DAoIBAQDAHfHjrYCK
+GHBrMj+MA7cK7fCJUn/iJLSLGgo2ANYF5oq9gaCwHCtKIyB9DN/KiY0JpJ6PWg+Q
+9Difq23YXiJjNEBS5EFTu9UwWAr1RhSAegrfHxm0sDbcAx31NtDYvBsADCWQYmzc
+KPfBshf5K4g/VCIj2VzC2CE6kNtdhqLU6AV2Pi1Tl1S82xWoAjHy91tDmlFQNWCj
+B2ZnZ7tY9zuwDfeBBOVCPHICgl5Q4PrY1KEWEXiNxgbtkNmOPAsY9WSqgOsP9pWK
+J924gdCCvovINzZtgRisxKth6Fkhra+VCsheg9SWvgR09Deo6CCoSwYxOSb0cjh2
+oyX5Rb1kJ7Z/AoIBAQCX2iNVoBV/GcFeNXV3fXLH9ESCj0FwuNC1zp/TanDhyerK
+gd8k5k2Xzcc66gP73vpHUJ6dGlVni4/r+ivGV9HHkF/f/LGlaiuEhBZel2YY1mZb
+nIhg8dZOuNqW+mvMYlsKdHNPmW0GqpwBF0iWfu1jI+4gA7Kvdj6o7RIvH8eaVEJK
+GvqoHcP1fvmteJ2yDtmhGMfMy4QPqtnmmS8l+CJ/V2SsMuyorXIpkBsAoFAZ6ilT
+WY53CT4F5nWt4v39j7pl9SatfT1TV0SmOjvtb6Rf3zu0jyR6RMzkmHa/839ZRylI
+OxPntzDCi7qxy7yjLmlVPJ6RgZGgzwqHrEHlX+65AoIBAQCEzu6d3x5B2N02LZli
+eFr8MjqbI64GLiulEY5HgNJzZ8k3cjocJI0Ehj36VIEMaYRXSzbVkIO8SCgwsPiR
+n5mUDNX+t441jV62Odbxcc3Qdw226rABieOSupDmKEu92GOt57e8FV5939BOVYhf
+FunsJYQoViXbCEAIVYVgJSfBmNfVwuvgonfQyn8xErtm4/pyRGa71PqGGSKAj2Qi
+/16CuVUFGtZFsLV76JW8wZqHdI4bTF6TW3cEmaLbwcRGL7W0bMSS13rO8/pBh3QW
+PhUxhoGYt6rQHHEBkPa04nXDyZ10QRwgTSGVnBIyMK4KyTpxorm8OI2x7dzdcomX
+iCCPAoIBAETwfr2JKPb/AzrKhhbZgU+sLVn3WH/nb68VheNEmGOzsqXaSHCR2NOq
+/ow7bawjc8yUIhBRzokR4F/7jGolOmfdq0MYFb6/YokssKfv1ugxBhmvOxpZ6F6E
+cERJ8Ex/ffQU053gLR/0ammddVuS1GR5I/jEdP0lJVh0xapoZNUlT5dWYCgo20hY
+ZAmKpU+veyUn+5Li0pmm959vnLK5LJzEA5mpz3w1QPPtVwQs05dwmEV3CRAcCeeh
+8sXp49WNCSW4I3BxuTZzRV845SGIFhZwgVV42PTp2LPKl2p6E7Bk8xpUCCvBpALp
+QmA5yIMx+u2Jpr7fUsXEXEPTEhvjff0=
+-----END PRIVATE KEY-----
diff --git a/_fixture/dist/private.txt b/_fixture/dist/private.txt
new file mode 100644
index 000000000..0f9d2435b
--- /dev/null
+++ b/_fixture/dist/private.txt
@@ -0,0 +1 @@
+private file
diff --git a/_fixture/dist/public/assets/readme.md b/_fixture/dist/public/assets/readme.md
new file mode 100644
index 000000000..50590f554
--- /dev/null
+++ b/_fixture/dist/public/assets/readme.md
@@ -0,0 +1 @@
+readme in assets
diff --git a/_fixture/dist/public/assets/subfolder/subfolder.md b/_fixture/dist/public/assets/subfolder/subfolder.md
new file mode 100644
index 000000000..74c928b2f
--- /dev/null
+++ b/_fixture/dist/public/assets/subfolder/subfolder.md
@@ -0,0 +1 @@
+file inside subfolder
diff --git a/_fixture/dist/public/index.html b/_fixture/dist/public/index.html
new file mode 100644
index 000000000..df6d9015a
--- /dev/null
+++ b/_fixture/dist/public/index.html
@@ -0,0 +1 @@
+Hello from index
diff --git a/_fixture/dist/public/test.txt b/_fixture/dist/public/test.txt
new file mode 100644
index 000000000..dd937160d
--- /dev/null
+++ b/_fixture/dist/public/test.txt
@@ -0,0 +1 @@
+test.txt contents
diff --git a/bind.go b/bind.go
index c8c88bb20..050e8973b 100644
--- a/bind.go
+++ b/bind.go
@@ -1,137 +1,237 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package echo
import (
"encoding"
- "encoding/json"
"encoding/xml"
"errors"
- "fmt"
+ "mime/multipart"
"net/http"
"reflect"
"strconv"
"strings"
+ "time"
)
-type (
- // Binder is the interface that wraps the Bind method.
- Binder interface {
- Bind(i interface{}, c Context) error
- }
+// Binder is the interface that wraps the Bind method.
+type Binder interface {
+ Bind(c *Context, target any) error
+}
- // DefaultBinder is the default implementation of the Binder interface.
- DefaultBinder struct{}
+// DefaultBinder is the default implementation of the Binder interface.
+type DefaultBinder struct{}
- // BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
- // Types that don't implement this, but do implement encoding.TextUnmarshaler
- // will use that interface instead.
- BindUnmarshaler interface {
- // UnmarshalParam decodes and assigns a value from an form or query param.
- UnmarshalParam(param string) error
- }
-)
+// BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
+// Types that don't implement this, but do implement encoding.TextUnmarshaler
+// will use that interface instead.
+type BindUnmarshaler interface {
+ // UnmarshalParam decodes and assigns a value from an form or query param.
+ UnmarshalParam(param string) error
+}
-// Bind implements the `Binder#Bind` function.
-func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) {
- req := c.Request()
+// bindMultipleUnmarshaler is used by binder to unmarshal multiple values from request at once to
+// type implementing this interface. For example request could have multiple query fields `?a=1&a=2&b=test` in that case
+// for `a` following slice `["1", "2"] will be passed to unmarshaller.
+type bindMultipleUnmarshaler interface {
+ UnmarshalParams(params []string) error
+}
- names := c.ParamNames()
- values := c.ParamValues()
+// BindPathValues binds path parameter values to bindable object
+func BindPathValues(c *Context, target any) error {
params := map[string][]string{}
- for i, name := range names {
- params[name] = []string{values[i]}
+ for _, param := range c.PathValues() {
+ params[param.Name] = []string{param.Value}
}
- if err := b.bindData(i, params, "param"); err != nil {
- return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
+ if err := bindData(target, params, "param", nil); err != nil {
+ return ErrBadRequest.Wrap(err)
}
- if err = b.bindData(i, c.QueryParams(), "query"); err != nil {
- return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
+ return nil
+}
+
+// BindQueryParams binds query params to bindable object
+func BindQueryParams(c *Context, target any) error {
+ if err := bindData(target, c.QueryParams(), "query", nil); err != nil {
+ return ErrBadRequest.Wrap(err)
}
+ return nil
+}
+
+// BindBody binds request body contents to bindable object
+// NB: then binding forms take note that this implementation uses standard library form parsing
+// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm
+// See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm
+// See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm
+func BindBody(c *Context, target any) (err error) {
+ req := c.Request()
if req.ContentLength == 0 {
return
}
- ctype := req.Header.Get(HeaderContentType)
- switch {
- case strings.HasPrefix(ctype, MIMEApplicationJSON):
- if err = json.NewDecoder(req.Body).Decode(i); err != nil {
- if ute, ok := err.(*json.UnmarshalTypeError); ok {
- return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err)
- } else if se, ok := err.(*json.SyntaxError); ok {
- return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err)
+
+ // mediatype is found like `mime.ParseMediaType()` does it
+ base, _, _ := strings.Cut(req.Header.Get(HeaderContentType), ";")
+ mediatype := strings.TrimSpace(base)
+
+ switch mediatype {
+ case MIMEApplicationJSON:
+ if err = c.Echo().JSONSerializer.Deserialize(c, target); err != nil {
+ var hErr *HTTPError
+ if errors.As(err, &hErr) {
+ return err
}
- return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
+ return ErrBadRequest.Wrap(err)
}
- case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML):
- if err = xml.NewDecoder(req.Body).Decode(i); err != nil {
- if ute, ok := err.(*xml.UnsupportedTypeError); ok {
- return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err)
- } else if se, ok := err.(*xml.SyntaxError); ok {
- return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err)
- }
- return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
+ case MIMEApplicationXML, MIMETextXML:
+ if err = xml.NewDecoder(req.Body).Decode(target); err != nil {
+ return ErrBadRequest.Wrap(err)
+ }
+ case MIMEApplicationForm:
+ params, err := c.FormValues()
+ if err != nil {
+ return ErrBadRequest.Wrap(err)
+ }
+ if err = bindData(target, params, "form", nil); err != nil {
+ return ErrBadRequest.Wrap(err)
}
- case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
- params, err := c.FormParams()
+ case MIMEMultipartForm:
+ params, err := c.MultipartForm()
if err != nil {
- return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
+ return ErrBadRequest.Wrap(err)
}
- if err = b.bindData(i, params, "form"); err != nil {
- return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
+ if err = bindData(target, params.Value, "form", params.File); err != nil {
+ return ErrBadRequest.Wrap(err)
}
default:
- return ErrUnsupportedMediaType
+ return &HTTPError{Code: http.StatusUnsupportedMediaType}
}
- return
+ return nil
+}
+
+// BindHeaders binds HTTP headers to a bindable object
+func BindHeaders(c *Context, target any) error {
+ if err := bindData(target, c.Request().Header, "header", nil); err != nil {
+ return ErrBadRequest.Wrap(err)
+ }
+ return nil
+}
+
+// Bind implements the `Binder#Bind` function.
+// Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous
+// step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathValues.
+func (b *DefaultBinder) Bind(c *Context, target any) error {
+ if err := BindPathValues(c, target); err != nil {
+ return err
+ }
+ // Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body.
+ // For example a request URL `&id=1&lang=en` with body `{"id":100,"lang":"de"}` would lead to precedence issues.
+ // The HTTP method check restores pre-v4.1.11 behavior to avoid these problems (see issue #1670)
+ method := c.Request().Method
+ if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead {
+ if err := BindQueryParams(c, target); err != nil {
+ return err
+ }
+ }
+ return BindBody(c, target)
}
-func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag string) error {
- if ptr == nil || len(data) == 0 {
+// bindData will bind data ONLY fields in destination struct that have EXPLICIT tag
+func bindData(destination any, data map[string][]string, tag string, dataFiles map[string][]*multipart.FileHeader) error {
+ if destination == nil || (len(data) == 0 && len(dataFiles) == 0) {
return nil
}
- typ := reflect.TypeOf(ptr).Elem()
- val := reflect.ValueOf(ptr).Elem()
+ hasFiles := len(dataFiles) > 0
+ typ := reflect.TypeOf(destination).Elem()
+ val := reflect.ValueOf(destination).Elem()
- // Map
- if typ.Kind() == reflect.Map {
+ // Support binding to limited Map destinations:
+ // - map[string][]string,
+ // - map[string]string <-- (binds first value from data slice)
+ // - map[string]any
+ // You are better off binding to struct but there are user who want this map feature. Source of data for these cases are:
+ // params,query,header,form as these sources produce string values, most of the time slice of strings, actually.
+ if typ.Kind() == reflect.Map && typ.Key().Kind() == reflect.String {
+ k := typ.Elem().Kind()
+ isElemInterface := k == reflect.Interface
+ isElemString := k == reflect.String
+ isElemSliceOfStrings := k == reflect.Slice && typ.Elem().Elem().Kind() == reflect.String
+ if !(isElemSliceOfStrings || isElemString || isElemInterface) {
+ return nil
+ }
+ if val.IsNil() {
+ val.Set(reflect.MakeMap(typ))
+ }
for k, v := range data {
- val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0]))
+ if isElemString {
+ val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0]))
+ } else if isElemInterface {
+ // To maintain backward compatibility, we always bind to the first string value
+ // and not the slice of strings when dealing with map[string]any{}
+ val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0]))
+ } else {
+ val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v))
+ }
}
return nil
}
// !struct
if typ.Kind() != reflect.Struct {
+ if tag == "param" || tag == "query" || tag == "header" {
+ // incompatible type, data is probably to be found in the body
+ return nil
+ }
return errors.New("binding element must be a struct")
}
- for i := 0; i < typ.NumField(); i++ {
+ for i := 0; i < typ.NumField(); i++ { // iterate over all destination fields
typeField := typ.Field(i)
structField := val.Field(i)
+ if typeField.Anonymous {
+ if structField.Kind() == reflect.Ptr {
+ structField = structField.Elem()
+ }
+ }
if !structField.CanSet() {
continue
}
structFieldKind := structField.Kind()
inputFieldName := typeField.Tag.Get(tag)
+ if typeField.Anonymous && structFieldKind == reflect.Struct && inputFieldName != "" {
+ // if anonymous struct with query/param/form tags, report an error
+ return errors.New("query/param/form tags are not allowed with anonymous struct field")
+ }
if inputFieldName == "" {
- inputFieldName = typeField.Name
- // If tag is nil, we inspect if the field is a struct.
- if _, ok := bindUnmarshaler(structField); !ok && structFieldKind == reflect.Struct {
- if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil {
+ // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contain fields with tags).
+ // structs that implement BindUnmarshaler are bound only when they have explicit tag
+ if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct {
+ if err := bindData(structField.Addr().Interface(), data, tag, dataFiles); err != nil {
return err
}
- continue
+ }
+ // does not have explicit tag and is not an ordinary struct - so move to next field
+ continue
+ }
+
+ if hasFiles {
+ if ok, err := isFieldMultipartFile(structField.Type()); err != nil {
+ return err
+ } else if ok {
+ if ok := setMultipartFileHeaderTypes(structField, inputFieldName, dataFiles); ok {
+ continue
+ }
}
}
inputValue, exists := data[inputFieldName]
if !exists {
- // Go json.Unmarshal supports case insensitive binding. However the
- // url params are bound case sensitive which is inconsistent. To
+ // Go json.Unmarshal supports case-insensitive binding. However the
+ // url params are bound case-sensitive which is inconsistent. To
// fix this we must check all of the map values in a
// case-insensitive search.
- inputFieldName = strings.ToLower(inputFieldName)
for k, v := range data {
- if strings.ToLower(k) == inputFieldName {
+ if strings.EqualFold(k, inputFieldName) {
inputValue = v
exists = true
break
@@ -143,27 +243,47 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag
continue
}
- // Call this first, in case we're dealing with an alias to an array type
- if ok, err := unmarshalField(typeField.Type.Kind(), inputValue[0], structField); ok {
+ // NOTE: algorithm here is not particularly sophisticated. It probably does not work with absurd types like `**[]*int`
+ // but it is smart enough to handle niche cases like `*int`,`*[]string`,`[]*int` .
+
+ // try unmarshalling first, in case we're dealing with an alias to an array type
+ if ok, err := unmarshalInputsToField(typeField.Type.Kind(), inputValue, structField); ok {
+ if err != nil {
+ return err
+ }
+ continue
+ }
+
+ formatTag := typeField.Tag.Get("format")
+ if ok, err := unmarshalInputToField(typeField.Type.Kind(), inputValue[0], structField, formatTag); ok {
if err != nil {
return err
}
continue
}
- numElems := len(inputValue)
- if structFieldKind == reflect.Slice && numElems > 0 {
+ // we could be dealing with pointer to slice `*[]string` so dereference it. There are weird OpenAPI generators
+ // that could create struct fields like that.
+ if structFieldKind == reflect.Pointer {
+ structFieldKind = structField.Elem().Kind()
+ structField = structField.Elem()
+ }
+
+ if structFieldKind == reflect.Slice {
sliceOf := structField.Type().Elem().Kind()
+ numElems := len(inputValue)
slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
for j := 0; j < numElems; j++ {
if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil {
return err
}
}
- val.Field(i).Set(slice)
- } else if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil {
- return err
+ structField.Set(slice)
+ continue
+ }
+ if err := setWithProperType(structFieldKind, inputValue[0], structField); err != nil {
+ return err
}
}
return nil
@@ -171,7 +291,8 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag
func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
// But also call it here, in case we're dealing with an array of BindUnmarshalers
- if ok, err := unmarshalField(valueKind, val, structField); ok {
+ // Note: format tag not available in this context, so empty string is passed
+ if ok, err := unmarshalInputToField(valueKind, val, structField, ""); ok {
return err
}
@@ -212,62 +333,53 @@ func setWithProperType(valueKind reflect.Kind, val string, structField reflect.V
return nil
}
-func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) {
- switch valueKind {
- case reflect.Ptr:
- return unmarshalFieldPtr(val, field)
- default:
- return unmarshalFieldNonPtr(val, field)
+func unmarshalInputsToField(valueKind reflect.Kind, values []string, field reflect.Value) (bool, error) {
+ if valueKind == reflect.Ptr {
+ if field.IsNil() {
+ field.Set(reflect.New(field.Type().Elem()))
+ }
+ field = field.Elem()
}
-}
-// bindUnmarshaler attempts to unmarshal a reflect.Value into a BindUnmarshaler
-func bindUnmarshaler(field reflect.Value) (BindUnmarshaler, bool) {
- ptr := reflect.New(field.Type())
- if ptr.CanInterface() {
- iface := ptr.Interface()
- if unmarshaler, ok := iface.(BindUnmarshaler); ok {
- return unmarshaler, ok
- }
+ fieldIValue := field.Addr().Interface()
+ unmarshaler, ok := fieldIValue.(bindMultipleUnmarshaler)
+ if !ok {
+ return false, nil
}
- return nil, false
+ return true, unmarshaler.UnmarshalParams(values)
}
-// textUnmarshaler attempts to unmarshal a reflect.Value into a TextUnmarshaler
-func textUnmarshaler(field reflect.Value) (encoding.TextUnmarshaler, bool) {
- ptr := reflect.New(field.Type())
- if ptr.CanInterface() {
- iface := ptr.Interface()
- if unmarshaler, ok := iface.(encoding.TextUnmarshaler); ok {
- return unmarshaler, ok
+func unmarshalInputToField(valueKind reflect.Kind, val string, field reflect.Value, formatTag string) (bool, error) {
+ if valueKind == reflect.Ptr {
+ if field.IsNil() {
+ field.Set(reflect.New(field.Type().Elem()))
}
+ field = field.Elem()
}
- return nil, false
-}
-func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) {
- if unmarshaler, ok := bindUnmarshaler(field); ok {
- err := unmarshaler.UnmarshalParam(value)
- field.Set(reflect.ValueOf(unmarshaler).Elem())
- return true, err
+ fieldIValue := field.Addr().Interface()
+ // Handle time.Time with custom format tag
+ if formatTag != "" {
+ if _, isTime := fieldIValue.(*time.Time); isTime {
+ t, err := time.Parse(formatTag, val)
+ if err != nil {
+ return true, err
+ }
+ field.Set(reflect.ValueOf(t))
+ return true, nil
+ }
}
- if unmarshaler, ok := textUnmarshaler(field); ok {
- err := unmarshaler.UnmarshalText([]byte(value))
- field.Set(reflect.ValueOf(unmarshaler).Elem())
- return true, err
+
+ switch unmarshaler := fieldIValue.(type) {
+ case BindUnmarshaler:
+ return true, unmarshaler.UnmarshalParam(val)
+ case encoding.TextUnmarshaler:
+ return true, unmarshaler.UnmarshalText([]byte(val))
}
return false, nil
}
-func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) {
- if field.IsNil() {
- // Initialize the pointer to a nil value
- field.Set(reflect.New(field.Type().Elem()))
- }
- return unmarshalFieldNonPtr(value, field.Elem())
-}
-
func setIntField(value string, bitSize int, field reflect.Value) error {
if value == "" {
value = "0"
@@ -311,3 +423,50 @@ func setFloatField(value string, bitSize int, field reflect.Value) error {
}
return err
}
+
+var (
+ // NOT supported by bind as you can NOT check easily empty struct being actual file or not
+ multipartFileHeaderType = reflect.TypeFor[multipart.FileHeader]()
+ // supported by bind as you can check by nil value if file existed or not
+ multipartFileHeaderPointerType = reflect.TypeFor[*multipart.FileHeader]()
+ multipartFileHeaderSliceType = reflect.TypeFor[[]multipart.FileHeader]()
+ multipartFileHeaderPointerSliceType = reflect.TypeFor[[]*multipart.FileHeader]()
+)
+
+func isFieldMultipartFile(field reflect.Type) (bool, error) {
+ switch field {
+ case multipartFileHeaderPointerType,
+ multipartFileHeaderSliceType,
+ multipartFileHeaderPointerSliceType:
+ return true, nil
+ case multipartFileHeaderType:
+ return true, errors.New("binding to multipart.FileHeader struct is not supported, use pointer to struct")
+ default:
+ return false, nil
+ }
+}
+
+func setMultipartFileHeaderTypes(structField reflect.Value, inputFieldName string, files map[string][]*multipart.FileHeader) bool {
+ fileHeaders := files[inputFieldName]
+ if len(fileHeaders) == 0 {
+ return false
+ }
+
+ result := true
+ switch structField.Type() {
+ case multipartFileHeaderPointerSliceType:
+ structField.Set(reflect.ValueOf(fileHeaders))
+ case multipartFileHeaderSliceType:
+ headers := make([]multipart.FileHeader, len(fileHeaders))
+ for i, fileHeader := range fileHeaders {
+ headers[i] = *fileHeader
+ }
+ structField.Set(reflect.ValueOf(headers))
+ case multipartFileHeaderPointerType:
+ structField.Set(reflect.ValueOf(fileHeaders[0]))
+ default:
+ result = false
+ }
+
+ return result
+}
diff --git a/bind_test.go b/bind_test.go
index 84ac8710e..1d5f8ca41 100644
--- a/bind_test.go
+++ b/bind_test.go
@@ -1,3 +1,6 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package echo
import (
@@ -5,10 +8,13 @@ import (
"encoding/json"
"encoding/xml"
"errors"
+ "fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
+ "net/http/httputil"
+ "net/url"
"reflect"
"strconv"
"strings"
@@ -18,51 +24,91 @@ import (
"github.com/stretchr/testify/assert"
)
-type (
- bindTestStruct struct {
- I int
- PtrI *int
- I8 int8
- PtrI8 *int8
- I16 int16
- PtrI16 *int16
- I32 int32
- PtrI32 *int32
- I64 int64
- PtrI64 *int64
- UI uint
- PtrUI *uint
- UI8 uint8
- PtrUI8 *uint8
- UI16 uint16
- PtrUI16 *uint16
- UI32 uint32
- PtrUI32 *uint32
- UI64 uint64
- PtrUI64 *uint64
- B bool
- PtrB *bool
- F32 float32
- PtrF32 *float32
- F64 float64
- PtrF64 *float64
- S string
- PtrS *string
- cantSet string
- DoesntExist string
- GoT time.Time
- GoTptr *time.Time
- T Timestamp
- Tptr *Timestamp
- SA StringArray
- }
- Timestamp time.Time
- TA []Timestamp
- StringArray []string
- Struct struct {
- Foo string
- }
-)
+type bindTestStruct struct {
+ T Timestamp
+ GoT time.Time
+ PtrI16 *int16
+ PtrUI *uint
+ Tptr *Timestamp
+ PtrF32 *float32
+ PtrB *bool
+ PtrI32 *int32
+ GoTptr *time.Time
+ PtrI64 *int64
+ PtrI *int
+ PtrI8 *int8
+ PtrF64 *float64
+ PtrUI8 *uint8
+ PtrUI64 *uint64
+ PtrUI16 *uint16
+ PtrS *string
+ PtrUI32 *uint32
+ S string
+ cantSet string
+ DoesntExist string
+ SA StringArray
+ F64 float64
+ I int
+ UI64 uint64
+ UI uint
+ I64 int64
+ F32 float32
+ UI32 uint32
+ I32 int32
+ UI16 uint16
+ I16 int16
+ B bool
+ UI8 uint8
+ I8 int8
+}
+
+type bindTestStructWithTags struct {
+ T Timestamp `json:"T" form:"T"`
+ GoT time.Time `json:"GoT" form:"GoT"`
+ PtrI16 *int16 `json:"PtrI16" form:"PtrI16"`
+ PtrUI *uint `json:"PtrUI" form:"PtrUI"`
+ Tptr *Timestamp `json:"Tptr" form:"Tptr"`
+ PtrF32 *float32 `json:"PtrF32" form:"PtrF32"`
+ PtrB *bool `json:"PtrB" form:"PtrB"`
+ PtrI32 *int32 `json:"PtrI32" form:"PtrI32"`
+ GoTptr *time.Time `json:"GoTptr" form:"GoTptr"`
+ PtrI64 *int64 `json:"PtrI64" form:"PtrI64"`
+ PtrI *int `json:"PtrI" form:"PtrI"`
+ PtrI8 *int8 `json:"PtrI8" form:"PtrI8"`
+ PtrF64 *float64 `json:"PtrF64" form:"PtrF64"`
+ PtrUI8 *uint8 `json:"PtrUI8" form:"PtrUI8"`
+ PtrUI64 *uint64 `json:"PtrUI64" form:"PtrUI64"`
+ PtrUI16 *uint16 `json:"PtrUI16" form:"PtrUI16"`
+ PtrS *string `json:"PtrS" form:"PtrS"`
+ PtrUI32 *uint32 `json:"PtrUI32" form:"PtrUI32"`
+ S string `json:"S" form:"S"`
+ cantSet string
+ DoesntExist string `json:"DoesntExist" form:"DoesntExist"`
+ SA StringArray `json:"SA" form:"SA"`
+ F64 float64 `json:"F64" form:"F64"`
+ I int `json:"I" form:"I"`
+ UI64 uint64 `json:"UI64" form:"UI64"`
+ UI uint `json:"UI" form:"UI"`
+ I64 int64 `json:"I64" form:"I64"`
+ F32 float32 `json:"F32" form:"F32"`
+ UI32 uint32 `json:"UI32" form:"UI32"`
+ I32 int32 `json:"I32" form:"I32"`
+ UI16 uint16 `json:"UI16" form:"UI16"`
+ I16 int16 `json:"I16" form:"I16"`
+ B bool `json:"B" form:"B"`
+ UI8 uint8 `json:"UI8" form:"UI8"`
+ I8 int8 `json:"I8" form:"I8"`
+}
+
+type Timestamp time.Time
+type TA []Timestamp
+type StringArray []string
+type Struct struct {
+ Foo string
+}
+type Bar struct {
+ Baz int `json:"baz" query:"baz"`
+}
func (t *Timestamp) UnmarshalParam(src string) error {
ts, err := time.Parse(time.RFC3339, src)
@@ -123,37 +169,71 @@ var values = map[string][]string{
"ST": {"bar"},
}
+// ptr return pointer to value. This is useful as `v := []*int8{&int8(1)}` will not compile
+func ptr[T any](value T) *T {
+ return &value
+}
+
+func TestToMultipleFields(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/?id=1&ID=2", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ type Root struct {
+ ID int64 `query:"id"`
+ Child2 struct {
+ ID int64
+ }
+ Child1 struct {
+ ID int64 `query:"id"`
+ }
+ }
+
+ u := new(Root)
+ err := c.Bind(u)
+ if assert.NoError(t, err) {
+ assert.Equal(t, int64(1), u.ID) // perfectly reasonable
+ assert.Equal(t, int64(1), u.Child1.ID) // untagged struct containing tagged field gets filled (by tag)
+ assert.Equal(t, int64(0), u.Child2.ID) // untagged struct containing untagged field should not be bind
+ }
+}
+
func TestBindJSON(t *testing.T) {
- assert := assert.New(t)
- testBindOkay(assert, strings.NewReader(userJSON), MIMEApplicationJSON)
- testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{})
- testBindError(assert, strings.NewReader(userJSONInvalidType), MIMEApplicationJSON, &json.UnmarshalTypeError{})
+ testBindOkay(t, strings.NewReader(userJSON), nil, MIMEApplicationJSON)
+ testBindOkay(t, strings.NewReader(userJSON), dummyQuery, MIMEApplicationJSON)
+ testBindArrayOkay(t, strings.NewReader(usersJSON), nil, MIMEApplicationJSON)
+ testBindArrayOkay(t, strings.NewReader(usersJSON), dummyQuery, MIMEApplicationJSON)
+ testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{})
+ testBindError(t, strings.NewReader(userJSONInvalidType), MIMEApplicationJSON, &json.UnmarshalTypeError{})
}
func TestBindXML(t *testing.T) {
- assert := assert.New(t)
-
- testBindOkay(assert, strings.NewReader(userXML), MIMEApplicationXML)
- testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationXML, errors.New(""))
- testBindError(assert, strings.NewReader(userXMLConvertNumberError), MIMEApplicationXML, &strconv.NumError{})
- testBindError(assert, strings.NewReader(userXMLUnsupportedTypeError), MIMEApplicationXML, &xml.SyntaxError{})
- testBindOkay(assert, strings.NewReader(userXML), MIMETextXML)
- testBindError(assert, strings.NewReader(invalidContent), MIMETextXML, errors.New(""))
- testBindError(assert, strings.NewReader(userXMLConvertNumberError), MIMETextXML, &strconv.NumError{})
- testBindError(assert, strings.NewReader(userXMLUnsupportedTypeError), MIMETextXML, &xml.SyntaxError{})
+ testBindOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML)
+ testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML)
+ testBindArrayOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML)
+ testBindArrayOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML)
+ testBindError(t, strings.NewReader(invalidContent), MIMEApplicationXML, errors.New(""))
+ testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMEApplicationXML, &strconv.NumError{})
+ testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMEApplicationXML, &xml.SyntaxError{})
+ testBindOkay(t, strings.NewReader(userXML), nil, MIMETextXML)
+ testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMETextXML)
+ testBindError(t, strings.NewReader(invalidContent), MIMETextXML, errors.New(""))
+ testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMETextXML, &strconv.NumError{})
+ testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMETextXML, &xml.SyntaxError{})
}
func TestBindForm(t *testing.T) {
- assert := assert.New(t)
- testBindOkay(assert, strings.NewReader(userForm), MIMEApplicationForm)
+ testBindOkay(t, strings.NewReader(userForm), nil, MIMEApplicationForm)
+ testBindOkay(t, strings.NewReader(userForm), dummyQuery, MIMEApplicationForm)
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userForm))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(HeaderContentType, MIMEApplicationForm)
err := c.Bind(&[]struct{ Field string }{})
- assert.Error(err)
+ assert.Error(t, err)
}
func TestBindQueryParams(t *testing.T) {
@@ -195,40 +275,74 @@ func TestBindQueryParamsCaseSensitivePrioritized(t *testing.T) {
}
}
+func TestBindHeaderParam(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set("Name", "Jon Doe")
+ req.Header.Set("Id", "2")
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ u := new(user)
+ err := BindHeaders(c, u)
+ if assert.NoError(t, err) {
+ assert.Equal(t, 2, u.ID)
+ assert.Equal(t, "Jon Doe", u.Name)
+ }
+}
+
+func TestBindHeaderParamBadType(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set("Id", "salamander")
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ u := new(user)
+ err := BindHeaders(c, u)
+ assert.Error(t, err)
+
+ httpErr, ok := err.(*HTTPError)
+ if assert.True(t, ok) {
+ assert.Equal(t, http.StatusBadRequest, httpErr.Code)
+ }
+}
+
func TestBindUnmarshalParam(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
- T Timestamp `query:"ts"`
+ T Timestamp `query:"ts"`
+ ST Struct
+ StWithTag struct {
+ Foo string `query:"st"`
+ }
TA []Timestamp `query:"ta"`
SA StringArray `query:"sa"`
- ST Struct
}{}
err := c.Bind(&result)
ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC))
- assert := assert.New(t)
- if assert.NoError(err) {
+ if assert.NoError(t, err) {
// assert.Equal( Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T)
- assert.Equal(ts, result.T)
- assert.Equal(StringArray([]string{"one", "two", "three"}), result.SA)
- assert.Equal([]Timestamp{ts, ts}, result.TA)
- assert.Equal(Struct{"baz"}, result.ST)
+ assert.Equal(t, ts, result.T)
+ assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA)
+ assert.Equal(t, []Timestamp{ts, ts}, result.TA)
+ assert.Equal(t, Struct{""}, result.ST) // child struct does not have a field with matching tag
+ assert.Equal(t, "baz", result.StWithTag.Foo) // child struct has field with matching tag
}
}
func TestBindUnmarshalText(t *testing.T) {
e := New()
- req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
+ req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
- T time.Time `query:"ts"`
+ T time.Time `query:"ts"`
+ ST Struct
TA []time.Time `query:"ta"`
SA StringArray `query:"sa"`
- ST Struct
}{}
err := c.Bind(&result)
ts := time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)
@@ -237,7 +351,7 @@ func TestBindUnmarshalText(t *testing.T) {
assert.Equal(t, ts, result.T)
assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA)
assert.Equal(t, []time.Time{ts, ts}, result.TA)
- assert.Equal(t, Struct{"baz"}, result.ST)
+ assert.Equal(t, Struct{""}, result.ST) // field in child struct does not have tag
}
}
@@ -255,9 +369,49 @@ func TestBindUnmarshalParamPtr(t *testing.T) {
}
}
+func TestBindUnmarshalParamAnonymousFieldPtr(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/?baz=1", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ result := struct {
+ *Bar
+ }{&Bar{}}
+ err := c.Bind(&result)
+ if assert.NoError(t, err) {
+ assert.Equal(t, 1, result.Baz)
+ }
+}
+
+func TestBindUnmarshalParamAnonymousFieldPtrNil(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/?baz=1", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ result := struct {
+ *Bar
+ }{}
+ err := c.Bind(&result)
+ if assert.NoError(t, err) {
+ assert.Nil(t, result.Bar)
+ }
+}
+
+func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, `/?bar={"baz":100}&baz=1`, nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ result := struct {
+ *Bar `json:"bar" query:"bar"`
+ }{&Bar{}}
+ err := c.Bind(&result)
+ assert.Contains(t, err.Error(), "query/param/form tags are not allowed with anonymous struct field")
+}
+
func TestBindUnmarshalTextPtr(t *testing.T) {
e := New()
- req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil)
+ req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
@@ -270,37 +424,158 @@ func TestBindUnmarshalTextPtr(t *testing.T) {
}
func TestBindMultipartForm(t *testing.T) {
- body := new(bytes.Buffer)
- mw := multipart.NewWriter(body)
+ bodyBuffer := new(bytes.Buffer)
+ mw := multipart.NewWriter(bodyBuffer)
mw.WriteField("id", "1")
mw.WriteField("name", "Jon Snow")
mw.Close()
+ body := bodyBuffer.Bytes()
- assert := assert.New(t)
- testBindOkay(assert, body, mw.FormDataContentType())
+ testBindOkay(t, bytes.NewReader(body), nil, mw.FormDataContentType())
+ testBindOkay(t, bytes.NewReader(body), dummyQuery, mw.FormDataContentType())
}
func TestBindUnsupportedMediaType(t *testing.T) {
- assert := assert.New(t)
- testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{})
+ testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{})
+}
+
+func TestDefaultBinder_bindDataToMap(t *testing.T) {
+ exampleData := map[string][]string{
+ "multiple": {"1", "2"},
+ "single": {"3"},
+ }
+
+ t.Run("ok, bind to map[string]string", func(t *testing.T) {
+ dest := map[string]string{}
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t,
+ map[string]string{
+ "multiple": "1",
+ "single": "3",
+ },
+ dest,
+ )
+ })
+
+ t.Run("ok, bind to map[string]string with nil map", func(t *testing.T) {
+ var dest map[string]string
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t,
+ map[string]string{
+ "multiple": "1",
+ "single": "3",
+ },
+ dest,
+ )
+ })
+
+ t.Run("ok, bind to map[string][]string", func(t *testing.T) {
+ dest := map[string][]string{}
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t,
+ map[string][]string{
+ "multiple": {"1", "2"},
+ "single": {"3"},
+ },
+ dest,
+ )
+ })
+
+ t.Run("ok, bind to map[string][]string with nil map", func(t *testing.T) {
+ var dest map[string][]string
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t,
+ map[string][]string{
+ "multiple": {"1", "2"},
+ "single": {"3"},
+ },
+ dest,
+ )
+ })
+
+ t.Run("ok, bind to map[string]interface", func(t *testing.T) {
+ dest := map[string]any{}
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t,
+ map[string]any{
+ "multiple": "1",
+ "single": "3",
+ },
+ dest,
+ )
+ })
+
+ t.Run("ok, bind to map[string]interface with nil map", func(t *testing.T) {
+ var dest map[string]any
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t,
+ map[string]any{
+ "multiple": "1",
+ "single": "3",
+ },
+ dest,
+ )
+ })
+
+ t.Run("ok, bind to map[string]int skips", func(t *testing.T) {
+ dest := map[string]int{}
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t, map[string]int{}, dest)
+ })
+
+ t.Run("ok, bind to map[string]int skips with nil map", func(t *testing.T) {
+ var dest map[string]int
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t, map[string]int(nil), dest)
+ })
+
+ t.Run("ok, bind to map[string][]int skips", func(t *testing.T) {
+ dest := map[string][]int{}
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t, map[string][]int{}, dest)
+ })
+
+ t.Run("ok, bind to map[string][]int skips with nil map", func(t *testing.T) {
+ var dest map[string][]int
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t, map[string][]int(nil), dest)
+ })
}
func TestBindbindData(t *testing.T) {
- assert := assert.New(t)
ts := new(bindTestStruct)
- b := new(DefaultBinder)
- b.bindData(ts, values, "form")
- assertBindTestStruct(assert, ts)
+ err := bindData(ts, values, "form", nil)
+ assert.NoError(t, err)
+
+ assert.Equal(t, 0, ts.I)
+ assert.Equal(t, int8(0), ts.I8)
+ assert.Equal(t, int16(0), ts.I16)
+ assert.Equal(t, int32(0), ts.I32)
+ assert.Equal(t, int64(0), ts.I64)
+ assert.Equal(t, uint(0), ts.UI)
+ assert.Equal(t, uint8(0), ts.UI8)
+ assert.Equal(t, uint16(0), ts.UI16)
+ assert.Equal(t, uint32(0), ts.UI32)
+ assert.Equal(t, uint64(0), ts.UI64)
+ assert.Equal(t, false, ts.B)
+ assert.Equal(t, float32(0), ts.F32)
+ assert.Equal(t, float64(0), ts.F64)
+ assert.Equal(t, "", ts.S)
+ assert.Equal(t, "", ts.cantSet)
}
func TestBindParam(t *testing.T) {
e := New()
- req := httptest.NewRequest(GET, "/", nil)
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- c.SetPath("/users/:id/:name")
- c.SetParamNames("id", "name")
- c.SetParamValues("1", "Jon Snow")
+ c.InitializeRoute(
+ &RouteInfo{Path: "/users/:id/:name"},
+ &PathValues{
+ {Name: "id", Value: "1"},
+ {Name: "name", Value: "Jon Snow"},
+ },
+ )
u := new(user)
err := c.Bind(u)
@@ -311,9 +586,12 @@ func TestBindParam(t *testing.T) {
// Second test for the absence of a param
c2 := e.NewContext(req, rec)
- c2.SetPath("/users/:id")
- c2.SetParamNames("id")
- c2.SetParamValues("1")
+ c2.InitializeRoute(
+ &RouteInfo{Path: "/users/:id"},
+ &PathValues{
+ {Name: "id", Value: "1"},
+ },
+ )
u = new(user)
err = c2.Bind(u)
@@ -325,15 +603,18 @@ func TestBindParam(t *testing.T) {
// Bind something with param and post data payload
body := bytes.NewBufferString(`{ "name": "Jon Snow" }`)
e2 := New()
- req2 := httptest.NewRequest(POST, "/", body)
+ req2 := httptest.NewRequest(http.MethodPost, "/", body)
req2.Header.Set(HeaderContentType, MIMEApplicationJSON)
rec2 := httptest.NewRecorder()
c3 := e2.NewContext(req2, rec2)
- c3.SetPath("/users/:id")
- c3.SetParamNames("id")
- c3.SetParamValues("1")
+ c3.InitializeRoute(
+ &RouteInfo{Path: "/users/:id"},
+ &PathValues{
+ {Name: "id", Value: "1"},
+ },
+ )
u = new(user)
err = c3.Bind(u)
@@ -355,13 +636,10 @@ func TestBindUnmarshalTypeError(t *testing.T) {
err := c.Bind(u)
- he := &HTTPError{Code: http.StatusBadRequest, Message: "Unmarshal type error: expected=int, got=string, field=id, offset=14", Internal: err.(*HTTPError).Internal}
-
- assert.Equal(t, he, err)
+ assert.EqualError(t, err, `code=400, message=Bad Request, err=json: cannot unmarshal string into Go struct field user.id of type int`)
}
func TestBindSetWithProperType(t *testing.T) {
- assert := assert.New(t)
ts := new(bindTestStruct)
typ := reflect.TypeOf(ts).Elem()
val := reflect.ValueOf(ts).Elem()
@@ -376,9 +654,9 @@ func TestBindSetWithProperType(t *testing.T) {
}
val := values[typeField.Name][0]
err := setWithProperType(typeField.Type.Kind(), val, structField)
- assert.NoError(err)
+ assert.NoError(t, err)
}
- assertBindTestStruct(assert, ts)
+ assertBindTestStruct(t, ts)
type foo struct {
Bar bytes.Buffer
@@ -386,86 +664,77 @@ func TestBindSetWithProperType(t *testing.T) {
v := &foo{}
typ = reflect.TypeOf(v).Elem()
val = reflect.ValueOf(v).Elem()
- assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0)))
+ assert.Error(t, setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0)))
}
-func TestBindSetFields(t *testing.T) {
- assert := assert.New(t)
-
- ts := new(bindTestStruct)
- val := reflect.ValueOf(ts).Elem()
- // Int
- if assert.NoError(setIntField("5", 0, val.FieldByName("I"))) {
- assert.Equal(5, ts.I)
- }
- if assert.NoError(setIntField("", 0, val.FieldByName("I"))) {
- assert.Equal(0, ts.I)
- }
-
- // Uint
- if assert.NoError(setUintField("10", 0, val.FieldByName("UI"))) {
- assert.Equal(uint(10), ts.UI)
- }
- if assert.NoError(setUintField("", 0, val.FieldByName("UI"))) {
- assert.Equal(uint(0), ts.UI)
- }
-
- // Float
- if assert.NoError(setFloatField("15.5", 0, val.FieldByName("F32"))) {
- assert.Equal(float32(15.5), ts.F32)
- }
- if assert.NoError(setFloatField("", 0, val.FieldByName("F32"))) {
- assert.Equal(float32(0.0), ts.F32)
- }
-
- // Bool
- if assert.NoError(setBoolField("true", val.FieldByName("B"))) {
- assert.Equal(true, ts.B)
- }
- if assert.NoError(setBoolField("", val.FieldByName("B"))) {
- assert.Equal(false, ts.B)
- }
-
- ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T"))
- if assert.NoError(err) {
- assert.Equal(ok, true)
- assert.Equal(Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T)
+func BenchmarkBindbindDataWithTags(b *testing.B) {
+ b.ReportAllocs()
+ ts := new(bindTestStructWithTags)
+ var err error
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ err = bindData(ts, values, "form", nil)
}
+ assert.NoError(b, err)
+ assertBindTestStruct(b, (*bindTestStruct)(ts))
}
-func assertBindTestStruct(a *assert.Assertions, ts *bindTestStruct) {
- a.Equal(0, ts.I)
- a.Equal(int8(8), ts.I8)
- a.Equal(int16(16), ts.I16)
- a.Equal(int32(32), ts.I32)
- a.Equal(int64(64), ts.I64)
- a.Equal(uint(0), ts.UI)
- a.Equal(uint8(8), ts.UI8)
- a.Equal(uint16(16), ts.UI16)
- a.Equal(uint32(32), ts.UI32)
- a.Equal(uint64(64), ts.UI64)
- a.Equal(true, ts.B)
- a.Equal(float32(32.5), ts.F32)
- a.Equal(float64(64.5), ts.F64)
- a.Equal("test", ts.S)
- a.Equal("", ts.GetCantSet())
+func assertBindTestStruct(tb testing.TB, ts *bindTestStruct) {
+ assert.Equal(tb, 0, ts.I)
+ assert.Equal(tb, int8(8), ts.I8)
+ assert.Equal(tb, int16(16), ts.I16)
+ assert.Equal(tb, int32(32), ts.I32)
+ assert.Equal(tb, int64(64), ts.I64)
+ assert.Equal(tb, uint(0), ts.UI)
+ assert.Equal(tb, uint8(8), ts.UI8)
+ assert.Equal(tb, uint16(16), ts.UI16)
+ assert.Equal(tb, uint32(32), ts.UI32)
+ assert.Equal(tb, uint64(64), ts.UI64)
+ assert.Equal(tb, true, ts.B)
+ assert.Equal(tb, float32(32.5), ts.F32)
+ assert.Equal(tb, float64(64.5), ts.F64)
+ assert.Equal(tb, "test", ts.S)
+ assert.Equal(tb, "", ts.GetCantSet())
}
-func testBindOkay(assert *assert.Assertions, r io.Reader, ctype string) {
+func testBindOkay(t *testing.T, r io.Reader, query url.Values, ctype string) {
e := New()
- req := httptest.NewRequest(http.MethodPost, "/", r)
+ path := "/"
+ if len(query) > 0 {
+ path += "?" + query.Encode()
+ }
+ req := httptest.NewRequest(http.MethodPost, path, r)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(HeaderContentType, ctype)
u := new(user)
err := c.Bind(u)
- if assert.NoError(err) {
- assert.Equal(1, u.ID)
- assert.Equal("Jon Snow", u.Name)
+ if assert.Equal(t, nil, err) {
+ assert.Equal(t, 1, u.ID)
+ assert.Equal(t, "Jon Snow", u.Name)
+ }
+}
+
+func testBindArrayOkay(t *testing.T, r io.Reader, query url.Values, ctype string) {
+ e := New()
+ path := "/"
+ if len(query) > 0 {
+ path += "?" + query.Encode()
+ }
+ req := httptest.NewRequest(http.MethodPost, path, r)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ req.Header.Set(HeaderContentType, ctype)
+ u := []user{}
+ err := c.Bind(&u)
+ if assert.NoError(t, err) {
+ assert.Equal(t, 1, len(u))
+ assert.Equal(t, 1, u[0].ID)
+ assert.Equal(t, "Jon Snow", u[0].Name)
}
}
-func testBindError(assert *assert.Assertions, r io.Reader, ctype string, expectedInternal error) {
+func testBindError(t *testing.T, r io.Reader, ctype string, expectedInternal error) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", r)
rec := httptest.NewRecorder()
@@ -477,14 +746,948 @@ func testBindError(assert *assert.Assertions, r io.Reader, ctype string, expecte
switch {
case strings.HasPrefix(ctype, MIMEApplicationJSON), strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML),
strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
- if assert.IsType(new(HTTPError), err) {
- assert.Equal(http.StatusBadRequest, err.(*HTTPError).Code)
- assert.IsType(expectedInternal, err.(*HTTPError).Internal)
+ if assert.IsType(t, new(HTTPError), err) {
+ assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).Code)
+ assert.IsType(t, expectedInternal, err.(*HTTPError).Unwrap())
}
default:
- if assert.IsType(new(HTTPError), err) {
- assert.Equal(ErrUnsupportedMediaType, err)
- assert.IsType(expectedInternal, err.(*HTTPError).Internal)
+ if assert.IsType(t, new(HTTPError), err) {
+ assert.Equal(t, ErrUnsupportedMediaType, err)
+ assert.IsType(t, expectedInternal, err.(*HTTPError).Unwrap())
}
}
}
+
+func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
+ // tests to check binding behaviour when multiple sources (path params, query params and request body) are in use
+ // binding is done in steps and one source could overwrite previous source bound data
+ // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed
+
+ type Opts struct {
+ Node string `json:"node" form:"node" query:"node" param:"node"`
+ Lang string
+ ID int `json:"id" form:"id" query:"id"`
+ }
+
+ var testCases = []struct {
+ givenContent io.Reader
+ whenBindTarget any
+ expect any
+ name string
+ givenURL string
+ givenMethod string
+ expectError string
+ whenNoPathValues bool
+ }{
+ {
+ name: "ok, POST bind to struct with: path param + query param + body",
+ givenMethod: http.MethodPost,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`{"id": 1}`),
+ expect: &Opts{ID: 1, Node: "node_from_path"}, // query params are not used, node is filled from path
+ },
+ {
+ name: "ok, PUT bind to struct with: path param + query param + body",
+ givenMethod: http.MethodPut,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`{"id": 1}`),
+ expect: &Opts{ID: 1, Node: "node_from_path"}, // query params are not used
+ },
+ {
+ name: "ok, GET bind to struct with: path param + query param + body",
+ givenMethod: http.MethodGet,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`{"id": 1}`),
+ expect: &Opts{ID: 1, Node: "xxx"}, // query overwrites previous path value
+ },
+ {
+ name: "ok, GET bind to struct with: path param + query param + body",
+ givenMethod: http.MethodGet,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
+ expect: &Opts{ID: 1, Node: "zzz"}, // body is bound last and overwrites previous (path,query) values
+ },
+ {
+ name: "ok, DELETE bind to struct with: path param + query param + body",
+ givenMethod: http.MethodDelete,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
+ expect: &Opts{ID: 1, Node: "zzz"}, // for DELETE body is bound after query params
+ },
+ {
+ name: "ok, POST bind to struct with: path param + body",
+ givenMethod: http.MethodPost,
+ givenURL: "/api/real_node/endpoint",
+ givenContent: strings.NewReader(`{"id": 1}`),
+ expect: &Opts{ID: 1, Node: "node_from_path"},
+ },
+ {
+ name: "ok, POST bind to struct with path + query + body = body has priority",
+ givenMethod: http.MethodPost,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
+ expect: &Opts{ID: 1, Node: "zzz"}, // field value from content has higher priority
+ },
+ {
+ name: "nok, POST body bind failure",
+ givenMethod: http.MethodPost,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`{`),
+ expect: &Opts{ID: 0, Node: "node_from_path"}, // query binding has already modified bind target
+ expectError: "code=400, message=Bad Request, err=unexpected EOF",
+ },
+ {
+ name: "nok, GET with body bind failure when types are not convertible",
+ givenMethod: http.MethodGet,
+ givenURL: "/api/real_node/endpoint?id=nope",
+ givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
+ expect: &Opts{ID: 0, Node: "node_from_path"}, // path params binding has already modified bind target
+ expectError: `code=400, message=Bad Request, err=strconv.ParseInt: parsing "nope": invalid syntax`,
+ },
+ {
+ name: "nok, GET body bind failure - trying to bind json array to struct",
+ givenMethod: http.MethodGet,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`[{"id": 1}]`),
+ expect: &Opts{ID: 0, Node: "xxx"}, // query binding has already modified bind target
+ expectError: `code=400, message=Bad Request, err=json: cannot unmarshal array into Go value of type echo.Opts`,
+ },
+ { // query param is ignored as we do not know where exactly to bind it in slice
+ name: "ok, GET bind to struct slice, ignore query param",
+ givenMethod: http.MethodGet,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`[{"id": 1}]`),
+ whenNoPathValues: true,
+ whenBindTarget: &[]Opts{},
+ expect: &[]Opts{
+ {ID: 1, Node: ""},
+ },
+ },
+ { // binding query params interferes with body. b.BindBody() should be used to bind only body to slice
+ name: "ok, POST binding to slice should not be affected query params types",
+ givenMethod: http.MethodPost,
+ givenURL: "/api/real_node/endpoint?id=nope&node=xxx",
+ givenContent: strings.NewReader(`[{"id": 1}]`),
+ whenNoPathValues: true,
+ whenBindTarget: &[]Opts{},
+ expect: &[]Opts{{ID: 1}},
+ expectError: "",
+ },
+ { // path param is ignored as we do not know where exactly to bind it in slice
+ name: "ok, GET bind to struct slice, ignore path param",
+ givenMethod: http.MethodGet,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`[{"id": 1}]`),
+ whenBindTarget: &[]Opts{},
+ expect: &[]Opts{
+ {ID: 1, Node: ""},
+ },
+ },
+ {
+ name: "ok, GET body bind json array to slice",
+ givenMethod: http.MethodGet,
+ givenURL: "/api/real_node/endpoint",
+ givenContent: strings.NewReader(`[{"id": 1}]`),
+ whenNoPathValues: true,
+ whenBindTarget: &[]Opts{},
+ expect: &[]Opts{{ID: 1, Node: ""}},
+ expectError: "",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ // assume route we are testing is "/api/:node/endpoint?some_query_params=here"
+ req := httptest.NewRequest(tc.givenMethod, tc.givenURL, tc.givenContent)
+ req.Header.Set(HeaderContentType, MIMEApplicationJSON)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ if !tc.whenNoPathValues {
+ c.SetPathValues(PathValues{
+ {Name: "node", Value: "node_from_path"},
+ })
+ }
+
+ var bindTarget any
+ if tc.whenBindTarget != nil {
+ bindTarget = tc.whenBindTarget
+ } else {
+ bindTarget = &Opts{}
+ }
+ b := new(DefaultBinder)
+
+ err := b.Bind(c, bindTarget)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, bindTarget)
+ })
+ }
+}
+
+func TestDefaultBinder_BindBody(t *testing.T) {
+ // tests to check binding behaviour when multiple sources (path params, query params and request body) are in use
+ // generally when binding from request body - URL and path params are ignored - unless form is being bound.
+ // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed
+
+ type Node struct {
+ Node string `json:"node" xml:"node" form:"node" query:"node" param:"node"`
+ ID int `json:"id" xml:"id" form:"id" query:"id"`
+ }
+ type Nodes struct {
+ Nodes []Node `xml:"node" form:"node"`
+ }
+
+ var testCases = []struct {
+ givenContent io.Reader
+ whenBindTarget any
+ expect any
+ name string
+ givenURL string
+ givenMethod string
+ givenContentType string
+ expectError string
+ whenNoPathValues bool
+ whenChunkedBody bool
+ }{
+ {
+ name: "ok, JSON POST bind to struct with: path + query + empty field in body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationJSON,
+ givenContent: strings.NewReader(`{"id": 1}`),
+ expect: &Node{ID: 1, Node: ""}, // path params or query params should not interfere with body
+ },
+ {
+ name: "ok, JSON POST bind to struct with: path + query + body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationJSON,
+ givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
+ expect: &Node{ID: 1, Node: "zzz"}, // field value from content has higher priority
+ },
+ {
+ name: "ok, JSON POST body bind json array to slice (has matching path/query params)",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationJSON,
+ givenContent: strings.NewReader(`[{"id": 1}]`),
+ whenNoPathValues: true,
+ whenBindTarget: &[]Node{},
+ expect: &[]Node{{ID: 1, Node: ""}},
+ expectError: "",
+ },
+ { // rare case as GET is not usually used to send request body
+ name: "ok, JSON GET bind to struct with: path + query + empty field in body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodGet,
+ givenContentType: MIMEApplicationJSON,
+ givenContent: strings.NewReader(`{"id": 1}`),
+ expect: &Node{ID: 1, Node: ""}, // path params or query params should not interfere with body
+ },
+ { // rare case as GET is not usually used to send request body
+ name: "ok, JSON GET bind to struct with: path + query + body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodGet,
+ givenContentType: MIMEApplicationJSON,
+ givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
+ expect: &Node{ID: 1, Node: "zzz"}, // field value from content has higher priority
+ },
+ {
+ name: "nok, JSON POST body bind failure",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationJSON,
+ givenContent: strings.NewReader(`{`),
+ expect: &Node{ID: 0, Node: ""},
+ expectError: "code=400, message=Bad Request, err=unexpected EOF",
+ },
+ {
+ name: "ok, XML POST bind to struct with: path + query + empty body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationXML,
+ givenContent: strings.NewReader(`1yyy`),
+ expect: &Node{ID: 1, Node: "yyy"},
+ },
+ {
+ name: "ok, XML POST bind array to slice with: path + query + body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationXML,
+ givenContent: strings.NewReader(`1yyy`),
+ whenBindTarget: &Nodes{},
+ expect: &Nodes{Nodes: []Node{{ID: 1, Node: "yyy"}}},
+ },
+ {
+ name: "nok, XML POST bind failure",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationXML,
+ givenContent: strings.NewReader(`<`),
+ expect: &Node{ID: 0, Node: ""},
+ expectError: "code=400, message=Bad Request, err=XML syntax error on line 1: unexpected EOF",
+ },
+ {
+ name: "ok, FORM POST bind to struct with: path + query + body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationForm,
+ givenContent: strings.NewReader(`id=1&node=yyy`),
+ expect: &Node{ID: 1, Node: "yyy"},
+ },
+ {
+ // NB: form values are taken from BOTH body and query for POST/PUT/PATCH by standard library implementation
+ // See: https://golang.org/pkg/net/http/#Request.ParseForm
+ name: "ok, FORM POST bind to struct with: path + query + empty field in body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationForm,
+ givenContent: strings.NewReader(`id=1`),
+ expect: &Node{ID: 1, Node: "xxx"},
+ },
+ {
+ // NB: form values are taken from query by standard library implementation
+ // See: https://golang.org/pkg/net/http/#Request.ParseForm
+ name: "ok, FORM GET bind to struct with: path + query + empty field in body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodGet,
+ givenContentType: MIMEApplicationForm,
+ givenContent: strings.NewReader(`id=1`),
+ expect: &Node{ID: 0, Node: "xxx"}, // 'xxx' is taken from URL and body is not used with GET by implementation
+ },
+ {
+ name: "nok, unsupported content type",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMETextPlain,
+ givenContent: strings.NewReader(``),
+ expect: &Node{ID: 0, Node: ""},
+ expectError: "code=415, message=Unsupported Media Type",
+ },
+ // FIXME: REASON in Go 1.24 and earlier http.NoBody would result ContentLength=-1
+ // but as of Go 1.25 http.NoBody would result ContentLength=0
+ // I am too lazy to bother documenting this as 2 version specific tests.
+ //{
+ // name: "nok, JSON POST with http.NoBody",
+ // givenURL: "/api/real_node/endpoint?node=xxx",
+ // givenMethod: http.MethodPost,
+ // givenContentType: MIMEApplicationJSON,
+ // givenContent: http.NoBody,
+ // expect: &Node{ID: 0, Node: ""},
+ // expectError: "code=400, message=EOF, internal=EOF",
+ //},
+ {
+ name: "ok, JSON POST with empty body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationJSON,
+ givenContent: strings.NewReader(""),
+ expect: &Node{ID: 0, Node: ""},
+ },
+ {
+ name: "ok, JSON POST bind to struct with: path + query + chunked body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationJSON,
+ givenContent: httputil.NewChunkedReader(strings.NewReader("18\r\n" + `{"id": 1, "node": "zzz"}` + "\r\n0\r\n")),
+ whenChunkedBody: true,
+ expect: &Node{ID: 1, Node: "zzz"},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ // assume route we are testing is "/api/:node/endpoint?some_query_params=here"
+ req := httptest.NewRequest(tc.givenMethod, tc.givenURL, tc.givenContent)
+ switch tc.givenContentType {
+ case MIMEApplicationXML:
+ req.Header.Set(HeaderContentType, MIMEApplicationXML)
+ case MIMEApplicationForm:
+ req.Header.Set(HeaderContentType, MIMEApplicationForm)
+ case MIMEApplicationJSON:
+ req.Header.Set(HeaderContentType, MIMEApplicationJSON)
+ }
+ if tc.whenChunkedBody {
+ req.ContentLength = -1
+ req.TransferEncoding = append(req.TransferEncoding, "chunked")
+ }
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ if !tc.whenNoPathValues {
+ c.SetPathValues(PathValues{
+ {Name: "node", Value: "real_node"},
+ })
+ }
+
+ var bindTarget any
+ if tc.whenBindTarget != nil {
+ bindTarget = tc.whenBindTarget
+ } else {
+ bindTarget = &Node{}
+ }
+
+ err := BindBody(c, bindTarget)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, bindTarget)
+ })
+ }
+}
+
+func testBindURL(queryString string, target any) error {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, queryString, nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ return c.Bind(target)
+}
+
+type unixTimestamp struct {
+ Time time.Time
+}
+
+func (t *unixTimestamp) UnmarshalParam(param string) error {
+ n, err := strconv.ParseInt(param, 10, 64)
+ if err != nil {
+ return fmt.Errorf("'%s' is not an integer", param)
+ }
+ *t = unixTimestamp{Time: time.Unix(n, 0)}
+ return err
+}
+
+type IntArrayA []int
+
+// UnmarshalParam converts value to *Int64Slice. This allows the API to accept
+// a comma-separated list of integers as a query parameter.
+func (i *IntArrayA) UnmarshalParam(value string) error {
+ var values = strings.Split(value, ",")
+ var numbers = make([]int, 0, len(values))
+
+ for _, v := range values {
+ n, err := strconv.ParseInt(v, 10, 64)
+ if err != nil {
+ return fmt.Errorf("'%s' is not an integer", v)
+ }
+
+ numbers = append(numbers, int(n))
+ }
+
+ *i = append(*i, numbers...)
+ return nil
+}
+
+func TestBindUnmarshalParamExtras(t *testing.T) {
+ // this test documents how bind handles `BindUnmarshaler` interface:
+ // NOTE: BindUnmarshaler chooses first input value to be bound.
+
+ t.Run("nok, unmarshalling fails", func(t *testing.T) {
+ result := struct {
+ V unixTimestamp `query:"t"`
+ }{}
+ err := testBindURL("/?t=xxxx", &result)
+
+ assert.EqualError(t, err, `code=400, message=Bad Request, err='xxxx' is not an integer`)
+ })
+
+ t.Run("ok, target is struct", func(t *testing.T) {
+ result := struct {
+ V unixTimestamp `query:"t"`
+ }{}
+ err := testBindURL("/?t=1710095540&t=1710095541", &result)
+
+ assert.NoError(t, err)
+ expect := unixTimestamp{
+ Time: time.Unix(1710095540, 0),
+ }
+ assert.Equal(t, expect, result.V)
+ })
+
+ t.Run("ok, target is an alias to slice and is nil, append only values from first", func(t *testing.T) {
+ result := struct {
+ V IntArrayA `query:"a"`
+ }{}
+ err := testBindURL("/?a=1,2,3&a=4,5,6", &result)
+
+ assert.NoError(t, err)
+ assert.Equal(t, IntArrayA([]int{1, 2, 3}), result.V)
+ })
+
+ t.Run("ok, target is an alias to slice and is nil, single input", func(t *testing.T) {
+ result := struct {
+ V IntArrayA `query:"a"`
+ }{}
+ err := testBindURL("/?a=1,2", &result)
+
+ assert.NoError(t, err)
+ assert.Equal(t, IntArrayA([]int{1, 2}), result.V)
+ })
+
+ t.Run("ok, target is pointer an alias to slice and is nil", func(t *testing.T) {
+ result := struct {
+ V *IntArrayA `query:"a"`
+ }{}
+ err := testBindURL("/?a=1&a=4,5,6", &result)
+
+ assert.NoError(t, err)
+ var expected = IntArrayA([]int{1})
+ assert.Equal(t, &expected, result.V)
+ })
+
+ t.Run("ok, target is pointer an alias to slice and is NOT nil", func(t *testing.T) {
+ result := struct {
+ V *IntArrayA `query:"a"`
+ }{}
+ result.V = new(IntArrayA) // NOT nil
+
+ err := testBindURL("/?a=1&a=4,5,6", &result)
+
+ assert.NoError(t, err)
+ var expected = IntArrayA([]int{1})
+ assert.Equal(t, &expected, result.V)
+ })
+}
+
+type unixTimestampLast struct {
+ Time time.Time
+}
+
+// this is silly example for `bindMultipleUnmarshaler` for type that uses last input value for unmarshalling
+func (t *unixTimestampLast) UnmarshalParams(params []string) error {
+ lastInput := params[len(params)-1]
+ n, err := strconv.ParseInt(lastInput, 10, 64)
+ if err != nil {
+ return fmt.Errorf("'%s' is not an integer", lastInput)
+ }
+ *t = unixTimestampLast{Time: time.Unix(n, 0)}
+ return err
+}
+
+type IntArrayB []int
+
+func (i *IntArrayB) UnmarshalParams(params []string) error {
+ var numbers = make([]int, 0, len(params))
+
+ for _, param := range params {
+ var values = strings.Split(param, ",")
+ for _, v := range values {
+ n, err := strconv.ParseInt(v, 10, 64)
+ if err != nil {
+ return fmt.Errorf("'%s' is not an integer", v)
+ }
+ numbers = append(numbers, int(n))
+ }
+ }
+
+ *i = append(*i, numbers...)
+ return nil
+}
+
+func TestBindUnmarshalParams(t *testing.T) {
+ // this test documents how bind handles `bindMultipleUnmarshaler` interface:
+
+ t.Run("nok, unmarshalling fails", func(t *testing.T) {
+ result := struct {
+ V unixTimestampLast `query:"t"`
+ }{}
+ err := testBindURL("/?t=xxxx", &result)
+
+ assert.EqualError(t, err, "code=400, message=Bad Request, err='xxxx' is not an integer")
+ })
+
+ t.Run("ok, target is struct", func(t *testing.T) {
+ result := struct {
+ V unixTimestampLast `query:"t"`
+ }{}
+ err := testBindURL("/?t=1710095540&t=1710095541", &result)
+
+ assert.NoError(t, err)
+ expect := unixTimestampLast{
+ Time: time.Unix(1710095541, 0),
+ }
+ assert.Equal(t, expect, result.V)
+ })
+
+ t.Run("ok, target is an alias to slice and is nil, append multiple inputs", func(t *testing.T) {
+ result := struct {
+ V IntArrayB `query:"a"`
+ }{}
+ err := testBindURL("/?a=1,2,3&a=4,5,6", &result)
+
+ assert.NoError(t, err)
+ assert.Equal(t, IntArrayB([]int{1, 2, 3, 4, 5, 6}), result.V)
+ })
+
+ t.Run("ok, target is an alias to slice and is nil, single input", func(t *testing.T) {
+ result := struct {
+ V IntArrayB `query:"a"`
+ }{}
+ err := testBindURL("/?a=1,2", &result)
+
+ assert.NoError(t, err)
+ assert.Equal(t, IntArrayB([]int{1, 2}), result.V)
+ })
+
+ t.Run("ok, target is pointer an alias to slice and is nil", func(t *testing.T) {
+ result := struct {
+ V *IntArrayB `query:"a"`
+ }{}
+ err := testBindURL("/?a=1&a=4,5,6", &result)
+
+ assert.NoError(t, err)
+ var expected = IntArrayB([]int{1, 4, 5, 6})
+ assert.Equal(t, &expected, result.V)
+ })
+
+ t.Run("ok, target is pointer an alias to slice and is NOT nil", func(t *testing.T) {
+ result := struct {
+ V *IntArrayB `query:"a"`
+ }{}
+ result.V = new(IntArrayB) // NOT nil
+
+ err := testBindURL("/?a=1&a=4,5,6", &result)
+ assert.NoError(t, err)
+ var expected = IntArrayB([]int{1, 4, 5, 6})
+ assert.Equal(t, &expected, result.V)
+ })
+}
+
+func TestBindInt8(t *testing.T) {
+ t.Run("nok, binding fails", func(t *testing.T) {
+ type target struct {
+ V int8 `query:"v"`
+ }
+ p := target{}
+ err := testBindURL("/?v=x&v=2", &p)
+ assert.EqualError(t, err, `code=400, message=Bad Request, err=strconv.ParseInt: parsing "x": invalid syntax`)
+ })
+
+ t.Run("nok, int8 embedded in struct", func(t *testing.T) {
+ type target struct {
+ int8 `query:"v"` // embedded field is `Anonymous`. We can only set public fields
+ }
+ p := target{}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+ assert.Equal(t, target{0}, p)
+ })
+
+ t.Run("nok, pointer to int8 embedded in struct", func(t *testing.T) {
+ type target struct {
+ *int8 `query:"v"` // embedded field is `Anonymous`. We can only set public fields
+ }
+ p := target{}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+
+ assert.Equal(t, target{int8: nil}, p)
+ })
+
+ t.Run("ok, bind int8 as struct field", func(t *testing.T) {
+ type target struct {
+ V int8 `query:"v"`
+ }
+ p := target{V: 127}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+ assert.Equal(t, target{V: 1}, p)
+ })
+
+ t.Run("ok, bind pointer to int8 as struct field, value is nil", func(t *testing.T) {
+ type target struct {
+ V *int8 `query:"v"`
+ }
+ p := target{}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+ assert.Equal(t, target{V: ptr(int8(1))}, p)
+ })
+
+ t.Run("ok, bind pointer to int8 as struct field, value is set", func(t *testing.T) {
+ type target struct {
+ V *int8 `query:"v"`
+ }
+ p := target{V: ptr(int8(127))}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+ assert.Equal(t, target{V: ptr(int8(1))}, p)
+ })
+
+ t.Run("ok, bind int8 slice as struct field, value is nil", func(t *testing.T) {
+ type target struct {
+ V []int8 `query:"v"`
+ }
+ p := target{}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+ assert.Equal(t, target{V: []int8{1, 2}}, p)
+ })
+
+ t.Run("ok, bind slice of int8 as struct field, value is set", func(t *testing.T) {
+ type target struct {
+ V []int8 `query:"v"`
+ }
+ p := target{V: []int8{111}}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+ assert.Equal(t, target{V: []int8{1, 2}}, p)
+ })
+
+ t.Run("ok, bind slice of pointer to int8 as struct field, value is set", func(t *testing.T) {
+ type target struct {
+ V []*int8 `query:"v"`
+ }
+ p := target{V: []*int8{ptr(int8(127))}}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+ assert.Equal(t, target{V: []*int8{ptr(int8(1)), ptr(int8(2))}}, p)
+ })
+
+ t.Run("ok, bind pointer to slice of int8 as struct field, value is set", func(t *testing.T) {
+ type target struct {
+ V *[]int8 `query:"v"`
+ }
+ p := target{V: &[]int8{111}}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+ assert.Equal(t, target{V: &[]int8{1, 2}}, p)
+ })
+}
+
+func TestBindMultipartFormFiles(t *testing.T) {
+ file1 := createTestFormFile("file", "file1.txt")
+ file11 := createTestFormFile("file", "file11.txt")
+ file2 := createTestFormFile("file2", "file2.txt")
+ filesA := createTestFormFile("files", "filesA.txt")
+ filesB := createTestFormFile("files", "filesB.txt")
+
+ t.Run("nok, can not bind to multipart file struct", func(t *testing.T) {
+ var target struct {
+ File multipart.FileHeader `form:"file"`
+ }
+ err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored
+
+ assert.EqualError(t, err, `code=400, message=Bad Request, err=binding to multipart.FileHeader struct is not supported, use pointer to struct`)
+ })
+
+ t.Run("ok, bind single multipart file to pointer to multipart file", func(t *testing.T) {
+ var target struct {
+ File *multipart.FileHeader `form:"file"`
+ }
+ err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored
+
+ assert.NoError(t, err)
+ assertMultipartFileHeader(t, target.File, file1)
+ })
+
+ t.Run("ok, bind multiple multipart files to pointer to multipart file", func(t *testing.T) {
+ var target struct {
+ File *multipart.FileHeader `form:"file"`
+ }
+ err := bindMultipartFiles(t, &target, file1, file11)
+
+ assert.NoError(t, err)
+ assertMultipartFileHeader(t, target.File, file1) // should choose first one
+ })
+
+ t.Run("ok, bind multiple multipart files to slice of multipart file", func(t *testing.T) {
+ var target struct {
+ Files []multipart.FileHeader `form:"files"`
+ }
+ err := bindMultipartFiles(t, &target, filesA, filesB, file1)
+
+ assert.NoError(t, err)
+
+ assert.Len(t, target.Files, 2)
+ assertMultipartFileHeader(t, &target.Files[0], filesA)
+ assertMultipartFileHeader(t, &target.Files[1], filesB)
+ })
+
+ t.Run("ok, bind multiple multipart files to slice of pointer to multipart file", func(t *testing.T) {
+ var target struct {
+ Files []*multipart.FileHeader `form:"files"`
+ }
+ err := bindMultipartFiles(t, &target, filesA, filesB, file1)
+
+ assert.NoError(t, err)
+
+ assert.Len(t, target.Files, 2)
+ assertMultipartFileHeader(t, target.Files[0], filesA)
+ assertMultipartFileHeader(t, target.Files[1], filesB)
+ })
+}
+
+type testFormFile struct {
+ Fieldname string
+ Filename string
+ Content []byte
+}
+
+func createTestFormFile(formFieldName string, filename string) testFormFile {
+ return testFormFile{
+ Fieldname: formFieldName,
+ Filename: filename,
+ Content: []byte(strings.Repeat(filename, 10)),
+ }
+}
+
+func bindMultipartFiles(t *testing.T, target any, files ...testFormFile) error {
+ var body bytes.Buffer
+ mw := multipart.NewWriter(&body)
+
+ for _, file := range files {
+ fw, err := mw.CreateFormFile(file.Fieldname, file.Filename)
+ assert.NoError(t, err)
+
+ n, err := fw.Write(file.Content)
+ assert.NoError(t, err)
+ assert.Equal(t, len(file.Content), n)
+ }
+
+ err := mw.Close()
+ assert.NoError(t, err)
+
+ req, err := http.NewRequest(http.MethodPost, "/", &body)
+ assert.NoError(t, err)
+ req.Header.Set("Content-Type", mw.FormDataContentType())
+
+ rec := httptest.NewRecorder()
+
+ e := New()
+ c := e.NewContext(req, rec)
+ return c.Bind(target)
+}
+
+func assertMultipartFileHeader(t *testing.T, fh *multipart.FileHeader, file testFormFile) {
+ assert.Equal(t, file.Filename, fh.Filename)
+ assert.Equal(t, int64(len(file.Content)), fh.Size)
+ fl, err := fh.Open()
+ assert.NoError(t, err)
+ body, err := io.ReadAll(fl)
+ assert.NoError(t, err)
+ assert.Equal(t, string(file.Content), string(body))
+ err = fl.Close()
+ assert.NoError(t, err)
+}
+
+func TestTimeFormatBinding(t *testing.T) {
+ type TestStruct struct {
+ DateTimeLocal time.Time `form:"datetime_local" format:"2006-01-02T15:04"`
+ Date time.Time `query:"date" format:"2006-01-02"`
+ CustomFormat time.Time `form:"custom" format:"01/02/2006 15:04:05"`
+ DefaultTime time.Time `form:"default_time"` // No format tag - should use default parsing
+ PtrTime *time.Time `query:"ptr_time" format:"2006-01-02"`
+ }
+
+ testCases := []struct {
+ name string
+ contentType string
+ data string
+ queryParams string
+ expect TestStruct
+ expectError bool
+ }{
+ {
+ name: "ok, datetime-local format binding",
+ contentType: MIMEApplicationForm,
+ data: "datetime_local=2023-12-25T14:30&default_time=2023-12-25T14:30:45Z",
+ expect: TestStruct{
+ DateTimeLocal: time.Date(2023, 12, 25, 14, 30, 0, 0, time.UTC),
+ DefaultTime: time.Date(2023, 12, 25, 14, 30, 45, 0, time.UTC),
+ },
+ },
+ {
+ name: "ok, date format binding via query params",
+ queryParams: "?date=2023-01-15&ptr_time=2023-02-20",
+ expect: TestStruct{
+ Date: time.Date(2023, 1, 15, 0, 0, 0, 0, time.UTC),
+ PtrTime: &time.Time{},
+ },
+ },
+ {
+ name: "ok, custom format via form data",
+ contentType: MIMEApplicationForm,
+ data: "custom=12/25/2023 14:30:45",
+ expect: TestStruct{
+ CustomFormat: time.Date(2023, 12, 25, 14, 30, 45, 0, time.UTC),
+ },
+ },
+ {
+ name: "nok, invalid format should fail",
+ contentType: MIMEApplicationForm,
+ data: "datetime_local=invalid-date",
+ expectError: true,
+ },
+ {
+ name: "nok, wrong format should fail",
+ contentType: MIMEApplicationForm,
+ data: "datetime_local=2023-12-25", // Missing time part
+ expectError: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ var req *http.Request
+
+ if tc.contentType == MIMEApplicationJSON {
+ req = httptest.NewRequest(http.MethodPost, "/"+tc.queryParams, strings.NewReader(tc.data))
+ req.Header.Set(HeaderContentType, tc.contentType)
+ } else if tc.contentType == MIMEApplicationForm {
+ req = httptest.NewRequest(http.MethodPost, "/"+tc.queryParams, strings.NewReader(tc.data))
+ req.Header.Set(HeaderContentType, tc.contentType)
+ } else {
+ req = httptest.NewRequest(http.MethodGet, "/"+tc.queryParams, nil)
+ }
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ var result TestStruct
+ err := c.Bind(&result)
+
+ if tc.expectError {
+ assert.Error(t, err)
+ return
+ }
+
+ assert.NoError(t, err)
+
+ // Check individual fields since time comparison can be tricky
+ if !tc.expect.DateTimeLocal.IsZero() {
+ assert.True(t, tc.expect.DateTimeLocal.Equal(result.DateTimeLocal),
+ "DateTimeLocal: expected %v, got %v", tc.expect.DateTimeLocal, result.DateTimeLocal)
+ }
+ if !tc.expect.Date.IsZero() {
+ assert.True(t, tc.expect.Date.Equal(result.Date),
+ "Date: expected %v, got %v", tc.expect.Date, result.Date)
+ }
+ if !tc.expect.CustomFormat.IsZero() {
+ assert.True(t, tc.expect.CustomFormat.Equal(result.CustomFormat),
+ "CustomFormat: expected %v, got %v", tc.expect.CustomFormat, result.CustomFormat)
+ }
+ if !tc.expect.DefaultTime.IsZero() {
+ assert.True(t, tc.expect.DefaultTime.Equal(result.DefaultTime),
+ "DefaultTime: expected %v, got %v", tc.expect.DefaultTime, result.DefaultTime)
+ }
+ if tc.expect.PtrTime != nil {
+ assert.NotNil(t, result.PtrTime)
+ if result.PtrTime != nil {
+ expectedPtr := time.Date(2023, 2, 20, 0, 0, 0, 0, time.UTC)
+ assert.True(t, expectedPtr.Equal(*result.PtrTime),
+ "PtrTime: expected %v, got %v", expectedPtr, *result.PtrTime)
+ }
+ }
+ })
+ }
+}
diff --git a/binder.go b/binder.go
new file mode 100644
index 000000000..32029ec0f
--- /dev/null
+++ b/binder.go
@@ -0,0 +1,1329 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "encoding"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+)
+
+/**
+ Following functions provide handful of methods for binding to Go native types from request query or path parameters.
+ * QueryParamsBinder(c) - binds query parameters (source URL)
+ * PathValuesBinder(c) - binds path parameters (source URL)
+ * FormFieldBinder(c) - binds form fields (source URL + body)
+
+ Example:
+ ```go
+ var length int64
+ err := echo.QueryParamsBinder(c).Int64("length", &length).BindError()
+ ```
+
+ For every supported type there are following methods:
+ * ("param", &destination) - if parameter value exists then binds it to given destination of that type i.e Int64(...).
+ * Must("param", &destination) - parameter value is required to exist, binds it to given destination of that type i.e MustInt64(...).
+ * s("param", &destination) - (for slices) if parameter values exists then binds it to given destination of that type i.e Int64s(...).
+ * Musts("param", &destination) - (for slices) parameter value is required to exist, binds it to given destination of that type i.e MustInt64s(...).
+
+ for some slice types `BindWithDelimiter("param", &dest, ",")` supports splitting parameter values before type conversion is done
+ i.e. URL `/api/search?id=1,2,3&id=1` can be bind to `[]int64{1,2,3,1}`
+
+ `FailFast` flags binder to stop binding after first bind error during binder call chain. Enabled by default.
+ `BindError()` returns first bind error from binder and resets errors in binder. Useful along with `FailFast()` method
+ to do binding and returns on first problem
+ `BindErrors()` returns all bind errors from binder and resets errors in binder.
+
+ Types that are supported:
+ * bool
+ * float32
+ * float64
+ * int
+ * int8
+ * int16
+ * int32
+ * int64
+ * uint
+ * uint8/byte (does not support `bytes()`. Use BindUnmarshaler/CustomFunc to convert value from base64 etc to []byte{})
+ * uint16
+ * uint32
+ * uint64
+ * string
+ * time
+ * duration
+ * BindUnmarshaler() interface
+ * TextUnmarshaler() interface
+ * JSONUnmarshaler() interface
+ * UnixTime() - converts unix time (integer) to time.Time
+ * UnixTimeMilli() - converts unix time with millisecond precision (integer) to time.Time
+ * UnixTimeNano() - converts unix time with nanosecond precision (integer) to time.Time
+ * CustomFunc() - callback function for your custom conversion logic. Signature `func(values []string) []error`
+*/
+
+// BindingError represents an error that occurred while binding request data.
+type BindingError struct {
+ // Field is the field name where value binding failed
+ Field string `json:"field"`
+ *HTTPError
+ // Values of parameter that failed to bind.
+ Values []string `json:"-"`
+}
+
+// NewBindingError creates new instance of binding error
+func NewBindingError(sourceParam string, values []string, message string, err error) error {
+ return &BindingError{
+ Field: sourceParam,
+ Values: values,
+ HTTPError: &HTTPError{Code: http.StatusBadRequest, Message: message, err: err},
+ }
+}
+
+// Error returns error message
+func (be *BindingError) Error() string {
+ return fmt.Sprintf("%s, field=%s", be.HTTPError.Error(), be.Field)
+}
+
+// ValueBinder provides utility methods for binding query or path parameter to various Go built-in types
+type ValueBinder struct {
+ // ValueFunc is used to get single parameter (first) value from request
+ ValueFunc func(sourceParam string) string
+ // ValuesFunc is used to get all values for parameter from request. i.e. `/api/search?ids=1&ids=2`
+ ValuesFunc func(sourceParam string) []string
+ // ErrorFunc is used to create errors. Allows you to use your own error type, that for example marshals to your specific json response
+ ErrorFunc func(sourceParam string, values []string, message string, internalError error) error
+ errors []error
+ // failFast is flag for binding methods to return without attempting to bind when previous binding already failed
+ failFast bool
+}
+
+// QueryParamsBinder creates query parameter value binder
+func QueryParamsBinder(c *Context) *ValueBinder {
+ return &ValueBinder{
+ failFast: true,
+ ValueFunc: c.QueryParam,
+ ValuesFunc: func(sourceParam string) []string {
+ values, ok := c.QueryParams()[sourceParam]
+ if !ok {
+ return nil
+ }
+ return values
+ },
+ ErrorFunc: NewBindingError,
+ }
+}
+
+// PathValuesBinder creates path parameter value binder
+func PathValuesBinder(c *Context) *ValueBinder {
+ return &ValueBinder{
+ failFast: true,
+ ValueFunc: c.Param,
+ ValuesFunc: func(sourceParam string) []string {
+ // path parameter should not have multiple values so getting values does not make sense but lets not error out here
+ value := c.Param(sourceParam)
+ if value == "" {
+ return nil
+ }
+ return []string{value}
+ },
+ ErrorFunc: NewBindingError,
+ }
+}
+
+// FormFieldBinder creates form field value binder
+// For all requests, FormFieldBinder parses the raw query from the URL and uses query params as form fields
+//
+// For POST, PUT, and PATCH requests, it also reads the request body, parses it
+// as a form and uses query params as form fields. Request body parameters take precedence over URL query
+// string values in r.Form.
+//
+// NB: when binding forms take note that this implementation uses standard library form parsing
+// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm
+// See https://golang.org/pkg/net/http/#Request.ParseForm
+func FormFieldBinder(c *Context) *ValueBinder {
+ vb := &ValueBinder{
+ failFast: true,
+ ValueFunc: func(sourceParam string) string {
+ return c.Request().FormValue(sourceParam)
+ },
+ ErrorFunc: NewBindingError,
+ }
+ vb.ValuesFunc = func(sourceParam string) []string {
+ if c.Request().Form == nil {
+ // this is same as `Request().FormValue()` does internally
+ _, _ = c.MultipartForm() // we want to trigger c.request.ParseMultipartForm(c.formParseMaxMemory)
+ }
+ values, ok := c.Request().Form[sourceParam]
+ if !ok {
+ return nil
+ }
+ return values
+ }
+
+ return vb
+}
+
+// FailFast set internal flag to indicate if binding methods will return early (without binding) when previous bind failed
+// NB: call this method before any other binding methods as it modifies binding methods behaviour
+func (b *ValueBinder) FailFast(value bool) *ValueBinder {
+ b.failFast = value
+ return b
+}
+
+func (b *ValueBinder) setError(err error) {
+ if b.errors == nil {
+ b.errors = []error{err}
+ return
+ }
+ b.errors = append(b.errors, err)
+}
+
+// BindError returns first seen bind error and resets/empties binder errors for further calls
+func (b *ValueBinder) BindError() error {
+ if b.errors == nil {
+ return nil
+ }
+ err := b.errors[0]
+ b.errors = nil // reset errors so next chain will start from zero
+ return err
+}
+
+// BindErrors returns all bind errors and resets/empties binder errors for further calls
+func (b *ValueBinder) BindErrors() []error {
+ if b.errors == nil {
+ return nil
+ }
+ errors := b.errors
+ b.errors = nil // reset errors so next chain will start from zero
+ return errors
+}
+
+// CustomFunc binds parameter values with Func. Func is called only when parameter values exist.
+func (b *ValueBinder) CustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder {
+ return b.customFunc(sourceParam, customFunc, false)
+}
+
+// MustCustomFunc requires parameter values to exist to bind with Func. Returns error when value does not exist.
+func (b *ValueBinder) MustCustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder {
+ return b.customFunc(sourceParam, customFunc, true)
+}
+
+func (b *ValueBinder) customFunc(sourceParam string, customFunc func(values []string) []error, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ values := b.ValuesFunc(sourceParam)
+ if len(values) == 0 {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+ if errs := customFunc(values); errs != nil {
+ b.errors = append(b.errors, errs...)
+ }
+ return b
+}
+
+// String binds parameter to string variable
+func (b *ValueBinder) String(sourceParam string, dest *string) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ return b
+ }
+ *dest = value
+ return b
+}
+
+// MustString requires parameter value to exist to bind to string variable. Returns error when value does not exist
+func (b *ValueBinder) MustString(sourceParam string, dest *string) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil))
+ return b
+ }
+ *dest = value
+ return b
+}
+
+// Strings binds parameter values to slice of string
+func (b *ValueBinder) Strings(sourceParam string, dest *[]string) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValuesFunc(sourceParam)
+ if value == nil {
+ return b
+ }
+ *dest = value
+ return b
+}
+
+// MustStrings requires parameter values to exist to bind to slice of string variables. Returns error when value does not exist
+func (b *ValueBinder) MustStrings(sourceParam string, dest *[]string) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValuesFunc(sourceParam)
+ if value == nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ return b
+ }
+ *dest = value
+ return b
+}
+
+// BindUnmarshaler binds parameter to destination implementing BindUnmarshaler interface
+func (b *ValueBinder) BindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ tmp := b.ValueFunc(sourceParam)
+ if tmp == "" {
+ return b
+ }
+
+ if err := dest.UnmarshalParam(tmp); err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to BindUnmarshaler interface", err))
+ }
+ return b
+}
+
+// MustBindUnmarshaler requires parameter value to exist to bind to destination implementing BindUnmarshaler interface.
+// Returns error when value does not exist
+func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil))
+ return b
+ }
+
+ if err := dest.UnmarshalParam(value); err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to BindUnmarshaler interface", err))
+ }
+ return b
+}
+
+// JSONUnmarshaler binds parameter to destination implementing json.Unmarshaler interface
+func (b *ValueBinder) JSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ tmp := b.ValueFunc(sourceParam)
+ if tmp == "" {
+ return b
+ }
+
+ if err := dest.UnmarshalJSON([]byte(tmp)); err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err))
+ }
+ return b
+}
+
+// MustJSONUnmarshaler requires parameter value to exist to bind to destination implementing json.Unmarshaler interface.
+// Returns error when value does not exist
+func (b *ValueBinder) MustJSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ tmp := b.ValueFunc(sourceParam)
+ if tmp == "" {
+ b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil))
+ return b
+ }
+
+ if err := dest.UnmarshalJSON([]byte(tmp)); err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err))
+ }
+ return b
+}
+
+// TextUnmarshaler binds parameter to destination implementing encoding.TextUnmarshaler interface
+func (b *ValueBinder) TextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ tmp := b.ValueFunc(sourceParam)
+ if tmp == "" {
+ return b
+ }
+
+ if err := dest.UnmarshalText([]byte(tmp)); err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err))
+ }
+ return b
+}
+
+// MustTextUnmarshaler requires parameter value to exist to bind to destination implementing encoding.TextUnmarshaler interface.
+// Returns error when value does not exist
+func (b *ValueBinder) MustTextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ tmp := b.ValueFunc(sourceParam)
+ if tmp == "" {
+ b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil))
+ return b
+ }
+
+ if err := dest.UnmarshalText([]byte(tmp)); err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err))
+ }
+ return b
+}
+
+// BindWithDelimiter binds parameter to destination by suitable conversion function.
+// Delimiter is used before conversion to split parameter value to separate values
+func (b *ValueBinder) BindWithDelimiter(sourceParam string, dest any, delimiter string) *ValueBinder {
+ return b.bindWithDelimiter(sourceParam, dest, delimiter, false)
+}
+
+// MustBindWithDelimiter requires parameter value to exist to bind destination by suitable conversion function.
+// Delimiter is used before conversion to split parameter value to separate values
+func (b *ValueBinder) MustBindWithDelimiter(sourceParam string, dest any, delimiter string) *ValueBinder {
+ return b.bindWithDelimiter(sourceParam, dest, delimiter, true)
+}
+
+func (b *ValueBinder) bindWithDelimiter(sourceParam string, dest any, delimiter string, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ values := b.ValuesFunc(sourceParam)
+ if len(values) == 0 {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+ tmpValues := make([]string, 0, len(values))
+ for _, v := range values {
+ tmpValues = append(tmpValues, strings.Split(v, delimiter)...)
+ }
+
+ switch d := dest.(type) {
+ case *[]string:
+ *d = tmpValues
+ return b
+ case *[]bool:
+ return b.bools(sourceParam, tmpValues, d)
+ case *[]int64, *[]int32, *[]int16, *[]int8, *[]int:
+ return b.ints(sourceParam, tmpValues, d)
+ case *[]uint64, *[]uint32, *[]uint16, *[]uint8, *[]uint: // *[]byte is same as *[]uint8
+ return b.uints(sourceParam, tmpValues, d)
+ case *[]float64, *[]float32:
+ return b.floats(sourceParam, tmpValues, d)
+ case *[]time.Duration:
+ return b.durations(sourceParam, tmpValues, d)
+ default:
+ // support only cases when destination is slice
+ // does not support time.Time as it needs argument (layout) for parsing or BindUnmarshaler
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "unsupported bind type", nil))
+ return b
+ }
+}
+
+// Int64 binds parameter to int64 variable
+func (b *ValueBinder) Int64(sourceParam string, dest *int64) *ValueBinder {
+ return b.intValue(sourceParam, dest, 64, false)
+}
+
+// MustInt64 requires parameter value to exist to bind to int64 variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt64(sourceParam string, dest *int64) *ValueBinder {
+ return b.intValue(sourceParam, dest, 64, true)
+}
+
+// Int32 binds parameter to int32 variable
+func (b *ValueBinder) Int32(sourceParam string, dest *int32) *ValueBinder {
+ return b.intValue(sourceParam, dest, 32, false)
+}
+
+// MustInt32 requires parameter value to exist to bind to int32 variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt32(sourceParam string, dest *int32) *ValueBinder {
+ return b.intValue(sourceParam, dest, 32, true)
+}
+
+// Int16 binds parameter to int16 variable
+func (b *ValueBinder) Int16(sourceParam string, dest *int16) *ValueBinder {
+ return b.intValue(sourceParam, dest, 16, false)
+}
+
+// MustInt16 requires parameter value to exist to bind to int16 variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt16(sourceParam string, dest *int16) *ValueBinder {
+ return b.intValue(sourceParam, dest, 16, true)
+}
+
+// Int8 binds parameter to int8 variable
+func (b *ValueBinder) Int8(sourceParam string, dest *int8) *ValueBinder {
+ return b.intValue(sourceParam, dest, 8, false)
+}
+
+// MustInt8 requires parameter value to exist to bind to int8 variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt8(sourceParam string, dest *int8) *ValueBinder {
+ return b.intValue(sourceParam, dest, 8, true)
+}
+
+// Int binds parameter to int variable
+func (b *ValueBinder) Int(sourceParam string, dest *int) *ValueBinder {
+ return b.intValue(sourceParam, dest, 0, false)
+}
+
+// MustInt requires parameter value to exist to bind to int variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt(sourceParam string, dest *int) *ValueBinder {
+ return b.intValue(sourceParam, dest, 0, true)
+}
+
+func (b *ValueBinder) intValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+
+ return b.int(sourceParam, value, dest, bitSize)
+}
+
+func (b *ValueBinder) int(sourceParam string, value string, dest any, bitSize int) *ValueBinder {
+ n, err := strconv.ParseInt(value, 10, bitSize)
+ if err != nil {
+ if bitSize == 0 {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to int", err))
+ } else {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to int%v", bitSize), err))
+ }
+ return b
+ }
+
+ switch d := dest.(type) {
+ case *int64:
+ *d = n
+ case *int32:
+ *d = int32(n) // #nosec G115
+ case *int16:
+ *d = int16(n) // #nosec G115
+ case *int8:
+ *d = int8(n) // #nosec G115
+ case *int:
+ *d = int(n)
+ }
+ return b
+}
+
+func (b *ValueBinder) intsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ values := b.ValuesFunc(sourceParam)
+ if len(values) == 0 {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, values, "required field value is empty", nil))
+ }
+ return b
+ }
+ return b.ints(sourceParam, values, dest)
+}
+
+func (b *ValueBinder) ints(sourceParam string, values []string, dest any) *ValueBinder {
+ switch d := dest.(type) {
+ case *[]int64:
+ tmp := make([]int64, len(values))
+ for i, v := range values {
+ b.int(sourceParam, v, &tmp[i], 64)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]int32:
+ tmp := make([]int32, len(values))
+ for i, v := range values {
+ b.int(sourceParam, v, &tmp[i], 32)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]int16:
+ tmp := make([]int16, len(values))
+ for i, v := range values {
+ b.int(sourceParam, v, &tmp[i], 16)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]int8:
+ tmp := make([]int8, len(values))
+ for i, v := range values {
+ b.int(sourceParam, v, &tmp[i], 8)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]int:
+ tmp := make([]int, len(values))
+ for i, v := range values {
+ b.int(sourceParam, v, &tmp[i], 0)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ }
+ return b
+}
+
+// Int64s binds parameter to slice of int64
+func (b *ValueBinder) Int64s(sourceParam string, dest *[]int64) *ValueBinder {
+ return b.intsValue(sourceParam, dest, false)
+}
+
+// MustInt64s requires parameter value to exist to bind to int64 slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt64s(sourceParam string, dest *[]int64) *ValueBinder {
+ return b.intsValue(sourceParam, dest, true)
+}
+
+// Int32s binds parameter to slice of int32
+func (b *ValueBinder) Int32s(sourceParam string, dest *[]int32) *ValueBinder {
+ return b.intsValue(sourceParam, dest, false)
+}
+
+// MustInt32s requires parameter value to exist to bind to int32 slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt32s(sourceParam string, dest *[]int32) *ValueBinder {
+ return b.intsValue(sourceParam, dest, true)
+}
+
+// Int16s binds parameter to slice of int16
+func (b *ValueBinder) Int16s(sourceParam string, dest *[]int16) *ValueBinder {
+ return b.intsValue(sourceParam, dest, false)
+}
+
+// MustInt16s requires parameter value to exist to bind to int16 slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt16s(sourceParam string, dest *[]int16) *ValueBinder {
+ return b.intsValue(sourceParam, dest, true)
+}
+
+// Int8s binds parameter to slice of int8
+func (b *ValueBinder) Int8s(sourceParam string, dest *[]int8) *ValueBinder {
+ return b.intsValue(sourceParam, dest, false)
+}
+
+// MustInt8s requires parameter value to exist to bind to int8 slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt8s(sourceParam string, dest *[]int8) *ValueBinder {
+ return b.intsValue(sourceParam, dest, true)
+}
+
+// Ints binds parameter to slice of int
+func (b *ValueBinder) Ints(sourceParam string, dest *[]int) *ValueBinder {
+ return b.intsValue(sourceParam, dest, false)
+}
+
+// MustInts requires parameter value to exist to bind to int slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustInts(sourceParam string, dest *[]int) *ValueBinder {
+ return b.intsValue(sourceParam, dest, true)
+}
+
+// Uint64 binds parameter to uint64 variable
+func (b *ValueBinder) Uint64(sourceParam string, dest *uint64) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 64, false)
+}
+
+// MustUint64 requires parameter value to exist to bind to uint64 variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint64(sourceParam string, dest *uint64) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 64, true)
+}
+
+// Uint32 binds parameter to uint32 variable
+func (b *ValueBinder) Uint32(sourceParam string, dest *uint32) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 32, false)
+}
+
+// MustUint32 requires parameter value to exist to bind to uint32 variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint32(sourceParam string, dest *uint32) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 32, true)
+}
+
+// Uint16 binds parameter to uint16 variable
+func (b *ValueBinder) Uint16(sourceParam string, dest *uint16) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 16, false)
+}
+
+// MustUint16 requires parameter value to exist to bind to uint16 variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint16(sourceParam string, dest *uint16) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 16, true)
+}
+
+// Uint8 binds parameter to uint8 variable
+func (b *ValueBinder) Uint8(sourceParam string, dest *uint8) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 8, false)
+}
+
+// MustUint8 requires parameter value to exist to bind to uint8 variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint8(sourceParam string, dest *uint8) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 8, true)
+}
+
+// Byte binds parameter to byte variable
+func (b *ValueBinder) Byte(sourceParam string, dest *byte) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 8, false)
+}
+
+// MustByte requires parameter value to exist to bind to byte variable. Returns error when value does not exist
+func (b *ValueBinder) MustByte(sourceParam string, dest *byte) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 8, true)
+}
+
+// Uint binds parameter to uint variable
+func (b *ValueBinder) Uint(sourceParam string, dest *uint) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 0, false)
+}
+
+// MustUint requires parameter value to exist to bind to uint variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint(sourceParam string, dest *uint) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 0, true)
+}
+
+func (b *ValueBinder) uintValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+
+ return b.uint(sourceParam, value, dest, bitSize)
+}
+
+func (b *ValueBinder) uint(sourceParam string, value string, dest any, bitSize int) *ValueBinder {
+ n, err := strconv.ParseUint(value, 10, bitSize)
+ if err != nil {
+ if bitSize == 0 {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to uint", err))
+ } else {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to uint%v", bitSize), err))
+ }
+ return b
+ }
+
+ switch d := dest.(type) {
+ case *uint64:
+ *d = n
+ case *uint32:
+ *d = uint32(n) // #nosec G115
+ case *uint16:
+ *d = uint16(n) // #nosec G115
+ case *uint8: // byte is alias to uint8
+ *d = uint8(n) // #nosec G115
+ case *uint:
+ *d = uint(n) // #nosec G115
+ }
+ return b
+}
+
+func (b *ValueBinder) uintsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ values := b.ValuesFunc(sourceParam)
+ if len(values) == 0 {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, values, "required field value is empty", nil))
+ }
+ return b
+ }
+ return b.uints(sourceParam, values, dest)
+}
+
+func (b *ValueBinder) uints(sourceParam string, values []string, dest any) *ValueBinder {
+ switch d := dest.(type) {
+ case *[]uint64:
+ tmp := make([]uint64, len(values))
+ for i, v := range values {
+ b.uint(sourceParam, v, &tmp[i], 64)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]uint32:
+ tmp := make([]uint32, len(values))
+ for i, v := range values {
+ b.uint(sourceParam, v, &tmp[i], 32)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]uint16:
+ tmp := make([]uint16, len(values))
+ for i, v := range values {
+ b.uint(sourceParam, v, &tmp[i], 16)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]uint8: // byte is alias to uint8
+ tmp := make([]uint8, len(values))
+ for i, v := range values {
+ b.uint(sourceParam, v, &tmp[i], 8)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]uint:
+ tmp := make([]uint, len(values))
+ for i, v := range values {
+ b.uint(sourceParam, v, &tmp[i], 0)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ }
+ return b
+}
+
+// Uint64s binds parameter to slice of uint64
+func (b *ValueBinder) Uint64s(sourceParam string, dest *[]uint64) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, false)
+}
+
+// MustUint64s requires parameter value to exist to bind to uint64 slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint64s(sourceParam string, dest *[]uint64) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, true)
+}
+
+// Uint32s binds parameter to slice of uint32
+func (b *ValueBinder) Uint32s(sourceParam string, dest *[]uint32) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, false)
+}
+
+// MustUint32s requires parameter value to exist to bind to uint32 slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint32s(sourceParam string, dest *[]uint32) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, true)
+}
+
+// Uint16s binds parameter to slice of uint16
+func (b *ValueBinder) Uint16s(sourceParam string, dest *[]uint16) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, false)
+}
+
+// MustUint16s requires parameter value to exist to bind to uint16 slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint16s(sourceParam string, dest *[]uint16) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, true)
+}
+
+// Uint8s binds parameter to slice of uint8
+func (b *ValueBinder) Uint8s(sourceParam string, dest *[]uint8) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, false)
+}
+
+// MustUint8s requires parameter value to exist to bind to uint8 slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint8s(sourceParam string, dest *[]uint8) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, true)
+}
+
+// Uints binds parameter to slice of uint
+func (b *ValueBinder) Uints(sourceParam string, dest *[]uint) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, false)
+}
+
+// MustUints requires parameter value to exist to bind to uint slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustUints(sourceParam string, dest *[]uint) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, true)
+}
+
+// Bool binds parameter to bool variable
+func (b *ValueBinder) Bool(sourceParam string, dest *bool) *ValueBinder {
+ return b.boolValue(sourceParam, dest, false)
+}
+
+// MustBool requires parameter value to exist to bind to bool variable. Returns error when value does not exist
+func (b *ValueBinder) MustBool(sourceParam string, dest *bool) *ValueBinder {
+ return b.boolValue(sourceParam, dest, true)
+}
+
+func (b *ValueBinder) boolValue(sourceParam string, dest *bool, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+ return b.bool(sourceParam, value, dest)
+}
+
+func (b *ValueBinder) bool(sourceParam string, value string, dest *bool) *ValueBinder {
+ n, err := strconv.ParseBool(value)
+ if err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to bool", err))
+ return b
+ }
+
+ *dest = n
+ return b
+}
+
+func (b *ValueBinder) boolsValue(sourceParam string, dest *[]bool, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ values := b.ValuesFunc(sourceParam)
+ if len(values) == 0 {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+ return b.bools(sourceParam, values, dest)
+}
+
+func (b *ValueBinder) bools(sourceParam string, values []string, dest *[]bool) *ValueBinder {
+ tmp := make([]bool, len(values))
+ for i, v := range values {
+ b.bool(sourceParam, v, &tmp[i])
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *dest = tmp
+ }
+ return b
+}
+
+// Bools binds parameter values to slice of bool variables
+func (b *ValueBinder) Bools(sourceParam string, dest *[]bool) *ValueBinder {
+ return b.boolsValue(sourceParam, dest, false)
+}
+
+// MustBools requires parameter values to exist to bind to slice of bool variables. Returns error when values does not exist
+func (b *ValueBinder) MustBools(sourceParam string, dest *[]bool) *ValueBinder {
+ return b.boolsValue(sourceParam, dest, true)
+}
+
+// Float64 binds parameter to float64 variable
+func (b *ValueBinder) Float64(sourceParam string, dest *float64) *ValueBinder {
+ return b.floatValue(sourceParam, dest, 64, false)
+}
+
+// MustFloat64 requires parameter value to exist to bind to float64 variable. Returns error when value does not exist
+func (b *ValueBinder) MustFloat64(sourceParam string, dest *float64) *ValueBinder {
+ return b.floatValue(sourceParam, dest, 64, true)
+}
+
+// Float32 binds parameter to float32 variable
+func (b *ValueBinder) Float32(sourceParam string, dest *float32) *ValueBinder {
+ return b.floatValue(sourceParam, dest, 32, false)
+}
+
+// MustFloat32 requires parameter value to exist to bind to float32 variable. Returns error when value does not exist
+func (b *ValueBinder) MustFloat32(sourceParam string, dest *float32) *ValueBinder {
+ return b.floatValue(sourceParam, dest, 32, true)
+}
+
+func (b *ValueBinder) floatValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+
+ return b.float(sourceParam, value, dest, bitSize)
+}
+
+func (b *ValueBinder) float(sourceParam string, value string, dest any, bitSize int) *ValueBinder {
+ n, err := strconv.ParseFloat(value, bitSize)
+ if err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to float%v", bitSize), err))
+ return b
+ }
+
+ switch d := dest.(type) {
+ case *float64:
+ *d = n
+ case *float32:
+ *d = float32(n)
+ }
+ return b
+}
+
+func (b *ValueBinder) floatsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ values := b.ValuesFunc(sourceParam)
+ if len(values) == 0 {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+ return b.floats(sourceParam, values, dest)
+}
+
+func (b *ValueBinder) floats(sourceParam string, values []string, dest any) *ValueBinder {
+ switch d := dest.(type) {
+ case *[]float64:
+ tmp := make([]float64, len(values))
+ for i, v := range values {
+ b.float(sourceParam, v, &tmp[i], 64)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]float32:
+ tmp := make([]float32, len(values))
+ for i, v := range values {
+ b.float(sourceParam, v, &tmp[i], 32)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ }
+ return b
+}
+
+// Float64s binds parameter values to slice of float64 variables
+func (b *ValueBinder) Float64s(sourceParam string, dest *[]float64) *ValueBinder {
+ return b.floatsValue(sourceParam, dest, false)
+}
+
+// MustFloat64s requires parameter values to exist to bind to slice of float64 variables. Returns error when values does not exist
+func (b *ValueBinder) MustFloat64s(sourceParam string, dest *[]float64) *ValueBinder {
+ return b.floatsValue(sourceParam, dest, true)
+}
+
+// Float32s binds parameter values to slice of float32 variables
+func (b *ValueBinder) Float32s(sourceParam string, dest *[]float32) *ValueBinder {
+ return b.floatsValue(sourceParam, dest, false)
+}
+
+// MustFloat32s requires parameter values to exist to bind to slice of float32 variables. Returns error when values does not exist
+func (b *ValueBinder) MustFloat32s(sourceParam string, dest *[]float32) *ValueBinder {
+ return b.floatsValue(sourceParam, dest, true)
+}
+
+// Time binds parameter to time.Time variable
+func (b *ValueBinder) Time(sourceParam string, dest *time.Time, layout string) *ValueBinder {
+ return b.time(sourceParam, dest, layout, false)
+}
+
+// MustTime requires parameter value to exist to bind to time.Time variable. Returns error when value does not exist
+func (b *ValueBinder) MustTime(sourceParam string, dest *time.Time, layout string) *ValueBinder {
+ return b.time(sourceParam, dest, layout, true)
+}
+
+func (b *ValueBinder) time(sourceParam string, dest *time.Time, layout string, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil))
+ }
+ return b
+ }
+ t, err := time.Parse(layout, value)
+ if err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Time", err))
+ return b
+ }
+ *dest = t
+ return b
+}
+
+// Times binds parameter values to slice of time.Time variables
+func (b *ValueBinder) Times(sourceParam string, dest *[]time.Time, layout string) *ValueBinder {
+ return b.times(sourceParam, dest, layout, false)
+}
+
+// MustTimes requires parameter values to exist to bind to slice of time.Time variables. Returns error when values does not exist
+func (b *ValueBinder) MustTimes(sourceParam string, dest *[]time.Time, layout string) *ValueBinder {
+ return b.times(sourceParam, dest, layout, true)
+}
+
+func (b *ValueBinder) times(sourceParam string, dest *[]time.Time, layout string, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ values := b.ValuesFunc(sourceParam)
+ if len(values) == 0 {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+
+ tmp := make([]time.Time, len(values))
+ for i, v := range values {
+ t, err := time.Parse(layout, v)
+ if err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{v}, "failed to bind field value to Time", err))
+ if b.failFast {
+ return b
+ }
+ continue
+ }
+ tmp[i] = t
+ }
+ if b.errors == nil {
+ *dest = tmp
+ }
+ return b
+}
+
+// Duration binds parameter to time.Duration variable
+func (b *ValueBinder) Duration(sourceParam string, dest *time.Duration) *ValueBinder {
+ return b.duration(sourceParam, dest, false)
+}
+
+// MustDuration requires parameter value to exist to bind to time.Duration variable. Returns error when value does not exist
+func (b *ValueBinder) MustDuration(sourceParam string, dest *time.Duration) *ValueBinder {
+ return b.duration(sourceParam, dest, true)
+}
+
+func (b *ValueBinder) duration(sourceParam string, dest *time.Duration, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil))
+ }
+ return b
+ }
+ t, err := time.ParseDuration(value)
+ if err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Duration", err))
+ return b
+ }
+ *dest = t
+ return b
+}
+
+// Durations binds parameter values to slice of time.Duration variables
+func (b *ValueBinder) Durations(sourceParam string, dest *[]time.Duration) *ValueBinder {
+ return b.durationsValue(sourceParam, dest, false)
+}
+
+// MustDurations requires parameter values to exist to bind to slice of time.Duration variables. Returns error when values does not exist
+func (b *ValueBinder) MustDurations(sourceParam string, dest *[]time.Duration) *ValueBinder {
+ return b.durationsValue(sourceParam, dest, true)
+}
+
+func (b *ValueBinder) durationsValue(sourceParam string, dest *[]time.Duration, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ values := b.ValuesFunc(sourceParam)
+ if len(values) == 0 {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+ return b.durations(sourceParam, values, dest)
+}
+
+func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]time.Duration) *ValueBinder {
+ tmp := make([]time.Duration, len(values))
+ for i, v := range values {
+ t, err := time.ParseDuration(v)
+ if err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{v}, "failed to bind field value to Duration", err))
+ if b.failFast {
+ return b
+ }
+ continue
+ }
+ tmp[i] = t
+ }
+ if b.errors == nil {
+ *dest = tmp
+ }
+ return b
+}
+
+// UnixTime binds parameter to time.Time variable (in local Time corresponding to the given Unix time).
+//
+// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
+//
+// Note:
+// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
+func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder {
+ return b.unixTime(sourceParam, dest, false, time.Second)
+}
+
+// MustUnixTime requires parameter value to exist to bind to time.Duration variable (in local time corresponding
+// to the given Unix time). Returns error when value does not exist.
+//
+// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
+//
+// Note:
+// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
+func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder {
+ return b.unixTime(sourceParam, dest, true, time.Second)
+}
+
+// UnixTimeMilli binds parameter to time.Time variable (in local time corresponding to the given Unix time in millisecond precision).
+//
+// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
+//
+// Note:
+// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
+func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
+ return b.unixTime(sourceParam, dest, false, time.Millisecond)
+}
+
+// MustUnixTimeMilli requires parameter value to exist to bind to time.Duration variable (in local time corresponding
+// to the given Unix time in millisecond precision). Returns error when value does not exist.
+//
+// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
+//
+// Note:
+// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
+func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
+ return b.unixTime(sourceParam, dest, true, time.Millisecond)
+}
+
+// UnixTimeNano binds parameter to time.Time variable (in local time corresponding to the given Unix time in nanosecond precision).
+//
+// Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00
+// Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00
+// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00
+//
+// Note:
+// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
+// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
+func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
+ return b.unixTime(sourceParam, dest, false, time.Nanosecond)
+}
+
+// MustUnixTimeNano requires parameter value to exist to bind to time.Duration variable (in local Time corresponding
+// to the given Unix time value in nano second precision). Returns error when value does not exist.
+//
+// Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00
+// Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00
+// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00
+//
+// Note:
+// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
+// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
+func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
+ return b.unixTime(sourceParam, dest, true, time.Nanosecond)
+}
+
+func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, precision time.Duration) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil))
+ }
+ return b
+ }
+
+ n, err := strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Time", err))
+ return b
+ }
+
+ switch precision {
+ case time.Second:
+ *dest = time.Unix(n, 0)
+ case time.Millisecond:
+ *dest = time.UnixMilli(n)
+ case time.Nanosecond:
+ *dest = time.Unix(0, n)
+ }
+ return b
+}
diff --git a/binder_external_test.go b/binder_external_test.go
new file mode 100644
index 000000000..d83c891b3
--- /dev/null
+++ b/binder_external_test.go
@@ -0,0 +1,134 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+// run tests as external package to get real feel for API
+package echo_test
+
+import (
+ "encoding/base64"
+ "fmt"
+ "log"
+ "net/http"
+ "net/http/httptest"
+
+ "github.com/labstack/echo/v5"
+)
+
+func ExampleValueBinder_BindErrors() {
+ // example route function that binds query params to different destinations and returns all bind errors in one go
+ routeFunc := func(c *echo.Context) error {
+ var opts struct {
+ IDs []int64
+ Active bool
+ }
+ length := int64(50) // default length is 50
+
+ b := echo.QueryParamsBinder(c)
+
+ errs := b.Int64("length", &length).
+ Int64s("ids", &opts.IDs).
+ Bool("active", &opts.Active).
+ BindErrors() // returns all errors
+ if errs != nil {
+ for _, err := range errs {
+ bErr := err.(*echo.BindingError)
+ log.Printf("in case you want to access what field: %s values: %v failed", bErr.Field, bErr.Values)
+ }
+ return fmt.Errorf("%v fields failed to bind", len(errs))
+ }
+ fmt.Printf("active = %v, length = %v, ids = %v", opts.Active, length, opts.IDs)
+
+ return c.JSON(http.StatusOK, opts)
+ }
+
+ e := echo.New()
+ c := e.NewContext(
+ httptest.NewRequest(http.MethodGet, "/api/endpoint?active=true&length=25&ids=1&ids=2&ids=3", nil),
+ httptest.NewRecorder(),
+ )
+
+ _ = routeFunc(c)
+
+ // Output: active = true, length = 25, ids = [1 2 3]
+}
+
+func ExampleValueBinder_BindError() {
+ // example route function that binds query params to different destinations and stops binding on first bind error
+ failFastRouteFunc := func(c *echo.Context) error {
+ var opts struct {
+ IDs []int64
+ Active bool
+ }
+ length := int64(50) // default length is 50
+
+ // create binder that stops binding at first error
+ b := echo.QueryParamsBinder(c)
+
+ err := b.Int64("length", &length).
+ Int64s("ids", &opts.IDs).
+ Bool("active", &opts.Active).
+ BindError() // returns first binding error
+ if err != nil {
+ bErr := err.(*echo.BindingError)
+ return fmt.Errorf("my own custom error for field: %s values: %v", bErr.Field, bErr.Values)
+ }
+ fmt.Printf("active = %v, length = %v, ids = %v\n", opts.Active, length, opts.IDs)
+
+ return c.JSON(http.StatusOK, opts)
+ }
+
+ e := echo.New()
+ c := e.NewContext(
+ httptest.NewRequest(http.MethodGet, "/api/endpoint?active=true&length=25&ids=1&ids=2&ids=3", nil),
+ httptest.NewRecorder(),
+ )
+
+ _ = failFastRouteFunc(c)
+
+ // Output: active = true, length = 25, ids = [1 2 3]
+}
+
+func ExampleValueBinder_CustomFunc() {
+ // example route function that binds query params using custom function closure
+ routeFunc := func(c *echo.Context) error {
+ length := int64(50) // default length is 50
+ var binary []byte
+
+ b := echo.QueryParamsBinder(c)
+ errs := b.Int64("length", &length).
+ CustomFunc("base64", func(values []string) []error {
+ if len(values) == 0 {
+ return nil
+ }
+ decoded, err := base64.URLEncoding.DecodeString(values[0])
+ if err != nil {
+ // in this example we use only first param value but url could contain multiple params in reality and
+ // therefore in theory produce multiple binding errors
+ return []error{echo.NewBindingError("base64", values[0:1], "failed to decode base64", err)}
+ }
+ binary = decoded
+ return nil
+ }).
+ BindErrors() // returns all errors
+
+ if errs != nil {
+ for _, err := range errs {
+ bErr := err.(*echo.BindingError)
+ log.Printf("in case you want to access what field: %s values: %v failed", bErr.Field, bErr.Values)
+ }
+ return fmt.Errorf("%v fields failed to bind", len(errs))
+ }
+ fmt.Printf("length = %v, base64 = %s", length, binary)
+
+ return c.JSON(http.StatusOK, "ok")
+ }
+
+ e := echo.New()
+ c := e.NewContext(
+ httptest.NewRequest(http.MethodGet, "/api/endpoint?length=25&base64=SGVsbG8gV29ybGQ%3D", nil),
+ httptest.NewRecorder(),
+ )
+ _ = routeFunc(c)
+
+ // Output: length = 25, base64 = Hello World
+}
diff --git a/binder_generic.go b/binder_generic.go
new file mode 100644
index 000000000..0c0eb9089
--- /dev/null
+++ b/binder_generic.go
@@ -0,0 +1,563 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "encoding"
+ "encoding/json"
+ "fmt"
+ "strconv"
+ "time"
+)
+
+// TimeLayout specifies the format for parsing time values in request parameters.
+// It can be a standard Go time layout string or one of the special Unix time layouts.
+type TimeLayout string
+
+// TimeOpts is options for parsing time.Time values
+type TimeOpts struct {
+ // Layout specifies the format for parsing time values in request parameters.
+ // It can be a standard Go time layout string or one of the special Unix time layouts.
+ //
+ // Parsing layout defaults to: echo.TimeLayout(time.RFC3339Nano)
+ // - To convert to custom layout use `echo.TimeLayout("2006-01-02")`
+ // - To convert unix timestamp (integer) to time.Time use `echo.TimeLayoutUnixTime`
+ // - To convert unix timestamp in milliseconds to time.Time use `echo.TimeLayoutUnixTimeMilli`
+ // - To convert unix timestamp in nanoseconds to time.Time use `echo.TimeLayoutUnixTimeNano`
+ Layout TimeLayout
+
+ // ParseInLocation is location used with time.ParseInLocation for layout that do not contain
+ // timezone information to set output time in given location.
+ // Defaults to time.UTC
+ ParseInLocation *time.Location
+
+ // ToInLocation is location to which parsed time is converted to after parsing.
+ // The parsed time will be converted using time.In(ToInLocation).
+ // Defaults to time.UTC
+ ToInLocation *time.Location
+}
+
+// TimeLayout constants for parsing Unix timestamps in different precisions.
+const (
+ TimeLayoutUnixTime = TimeLayout("UnixTime") // Unix timestamp in seconds
+ TimeLayoutUnixTimeMilli = TimeLayout("UnixTimeMilli") // Unix timestamp in milliseconds
+ TimeLayoutUnixTimeNano = TimeLayout("UnixTimeNano") // Unix timestamp in nanoseconds
+)
+
+// PathParam extracts and parses a path parameter from the context by name.
+// It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found.
+//
+// Empty String Handling:
+// If the parameter exists but has an empty value, the zero value of type T is returned
+// with no error. For example, a path parameter with value "" returns (0, nil) for int types.
+// This differs from standard library behavior where parsing empty strings returns errors.
+// To treat empty values as errors, validate the result separately or check the raw value.
+//
+// See ParseValue for supported types and options
+func PathParam[T any](c *Context, paramName string, opts ...any) (T, error) {
+ for _, pv := range c.PathValues() {
+ if pv.Name == paramName {
+ v, err := ParseValue[T](pv.Value, opts...)
+ if err != nil {
+ return v, NewBindingError(paramName, []string{pv.Value}, "path value", err)
+ }
+ return v, nil
+ }
+ }
+ var zero T
+ return zero, ErrNonExistentKey
+}
+
+// PathParamOr extracts and parses a path parameter from the context by name.
+// Returns defaultValue if the parameter is not found or has an empty value.
+// Returns an error only if parsing fails (e.g., "abc" for int type).
+//
+// Example:
+// id, err := echo.PathParamOr[int](c, "id", 0)
+// // If "id" is missing: returns (0, nil)
+// // If "id" is "123": returns (123, nil)
+// // If "id" is "abc": returns (0, BindingError)
+//
+// See ParseValue for supported types and options
+func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...any) (T, error) {
+ for _, pv := range c.PathValues() {
+ if pv.Name == paramName {
+ v, err := ParseValueOr[T](pv.Value, defaultValue, opts...)
+ if err != nil {
+ return v, NewBindingError(paramName, []string{pv.Value}, "path value", err)
+ }
+ return v, nil
+ }
+ }
+ return defaultValue, nil
+}
+
+// QueryParam extracts and parses a single query parameter from the request by key.
+// It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found.
+//
+// Empty String Handling:
+// If the parameter exists but has an empty value (?key=), the zero value of type T is returned
+// with no error. For example, "?count=" returns (0, nil) for int types.
+// This differs from standard library behavior where parsing empty strings returns errors.
+// To treat empty values as errors, validate the result separately or check the raw value.
+//
+// Behavior Summary:
+// - Missing key (?other=value): returns (zero, ErrNonExistentKey)
+// - Empty value (?key=): returns (zero, nil)
+// - Invalid value (?key=abc for int): returns (zero, BindingError)
+//
+// See ParseValue for supported types and options
+func QueryParam[T any](c *Context, key string, opts ...any) (T, error) {
+ values, ok := c.QueryParams()[key]
+ if !ok {
+ var zero T
+ return zero, ErrNonExistentKey
+ }
+ if len(values) == 0 {
+ var zero T
+ return zero, nil
+ }
+ value := values[0]
+ v, err := ParseValue[T](value, opts...)
+ if err != nil {
+ return v, NewBindingError(key, []string{value}, "query param", err)
+ }
+ return v, nil
+}
+
+// QueryParamOr extracts and parses a single query parameter from the request by key.
+// Returns defaultValue if the parameter is not found or has an empty value.
+// Returns an error only if parsing fails (e.g., "abc" for int type).
+//
+// Example:
+// page, err := echo.QueryParamOr[int](c, "page", 1)
+// // If "page" is missing: returns (1, nil)
+// // If "page" is "5": returns (5, nil)
+// // If "page" is "abc": returns (1, BindingError)
+//
+// See ParseValue for supported types and options
+func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) {
+ values, ok := c.QueryParams()[key]
+ if !ok {
+ return defaultValue, nil
+ }
+ if len(values) == 0 {
+ return defaultValue, nil
+ }
+ value := values[0]
+ v, err := ParseValueOr[T](value, defaultValue, opts...)
+ if err != nil {
+ return v, NewBindingError(key, []string{value}, "query param", err)
+ }
+ return v, nil
+}
+
+// QueryParams extracts and parses all values for a query parameter key as a slice.
+// It returns the typed slice and an error if binding any value fails. Returns ErrNonExistentKey if parameter not found.
+//
+// See ParseValues for supported types and options
+func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error) {
+ values, ok := c.QueryParams()[key]
+ if !ok {
+ return nil, ErrNonExistentKey
+ }
+
+ result, err := ParseValues[T](values, opts...)
+ if err != nil {
+ return nil, NewBindingError(key, values, "query params", err)
+ }
+ return result, nil
+}
+
+// QueryParamsOr extracts and parses all values for a query parameter key as a slice.
+// Returns defaultValue if the parameter is not found.
+// Returns an error only if parsing any value fails.
+//
+// Example:
+// ids, err := echo.QueryParamsOr[int](c, "ids", []int{})
+// // If "ids" is missing: returns ([], nil)
+// // If "ids" is "1&ids=2": returns ([1, 2], nil)
+// // If "ids" contains "abc": returns ([], BindingError)
+//
+// See ParseValues for supported types and options
+func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) {
+ values, ok := c.QueryParams()[key]
+ if !ok {
+ return defaultValue, nil
+ }
+
+ result, err := ParseValuesOr[T](values, defaultValue, opts...)
+ if err != nil {
+ return nil, NewBindingError(key, values, "query params", err)
+ }
+ return result, nil
+}
+
+// FormValue extracts and parses a single form value from the request by key.
+// It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found.
+//
+// Empty String Handling:
+// If the form field exists but has an empty value, the zero value of type T is returned
+// with no error. For example, an empty form field returns (0, nil) for int types.
+// This differs from standard library behavior where parsing empty strings returns errors.
+// To treat empty values as errors, validate the result separately or check the raw value.
+//
+// See ParseValue for supported types and options
+func FormValue[T any](c *Context, key string, opts ...any) (T, error) {
+ formValues, err := c.FormValues()
+ if err != nil {
+ var zero T
+ return zero, fmt.Errorf("failed to parse form value, key: %s, err: %w", key, err)
+ }
+ values, ok := formValues[key]
+ if !ok {
+ var zero T
+ return zero, ErrNonExistentKey
+ }
+ if len(values) == 0 {
+ var zero T
+ return zero, nil
+ }
+ value := values[0]
+ v, err := ParseValue[T](value, opts...)
+ if err != nil {
+ return v, NewBindingError(key, []string{value}, "form value", err)
+ }
+ return v, nil
+}
+
+// FormValueOr extracts and parses a single form value from the request by key.
+// Returns defaultValue if the parameter is not found or has an empty value.
+// Returns an error only if parsing fails or form parsing errors occur.
+//
+// Example:
+// limit, err := echo.FormValueOr[int](c, "limit", 100)
+// // If "limit" is missing: returns (100, nil)
+// // If "limit" is "50": returns (50, nil)
+// // If "limit" is "abc": returns (100, BindingError)
+//
+// See ParseValue for supported types and options
+func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) {
+ formValues, err := c.FormValues()
+ if err != nil {
+ var zero T
+ return zero, fmt.Errorf("failed to parse form value, key: %s, err: %w", key, err)
+ }
+ values, ok := formValues[key]
+ if !ok {
+ return defaultValue, nil
+ }
+ if len(values) == 0 {
+ return defaultValue, nil
+ }
+ value := values[0]
+ v, err := ParseValueOr[T](value, defaultValue, opts...)
+ if err != nil {
+ return v, NewBindingError(key, []string{value}, "form value", err)
+ }
+ return v, nil
+}
+
+// FormValues extracts and parses all values for a form values key as a slice.
+// It returns the typed slice and an error if binding any value fails. Returns ErrNonExistentKey if parameter not found.
+//
+// See ParseValues for supported types and options
+func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) {
+ formValues, err := c.FormValues()
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse form values, key: %s, err: %w", key, err)
+ }
+ values, ok := formValues[key]
+ if !ok {
+ return nil, ErrNonExistentKey
+ }
+ result, err := ParseValues[T](values, opts...)
+ if err != nil {
+ return nil, NewBindingError(key, values, "form values", err)
+ }
+ return result, nil
+}
+
+// FormValuesOr extracts and parses all values for a form values key as a slice.
+// Returns defaultValue if the parameter is not found.
+// Returns an error only if parsing any value fails or form parsing errors occur.
+//
+// Example:
+// tags, err := echo.FormValuesOr[string](c, "tags", []string{})
+// // If "tags" is missing: returns ([], nil)
+// // If form parsing fails: returns (nil, error)
+//
+// See ParseValues for supported types and options
+func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) {
+ formValues, err := c.FormValues()
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse form values, key: %s, err: %w", key, err)
+ }
+ values, ok := formValues[key]
+ if !ok {
+ return defaultValue, nil
+ }
+ result, err := ParseValuesOr[T](values, defaultValue, opts...)
+ if err != nil {
+ return nil, NewBindingError(key, values, "form values", err)
+ }
+ return result, nil
+}
+
+// ParseValues parses value to generic type slice. Same types are supported as ParseValue
+// function but the result type is slice instead of scalar value.
+//
+// See ParseValue for supported types and options
+func ParseValues[T any](values []string, opts ...any) ([]T, error) {
+ var zero []T
+ return ParseValuesOr(values, zero, opts...)
+}
+
+// ParseValuesOr parses value to generic type slice, when value is empty defaultValue is returned.
+// Same types are supported as ParseValue function but the result type is slice instead of scalar value.
+//
+// See ParseValue for supported types and options
+func ParseValuesOr[T any](values []string, defaultValue []T, opts ...any) ([]T, error) {
+ if len(values) == 0 {
+ return defaultValue, nil
+ }
+ result := make([]T, 0, len(values))
+ for _, v := range values {
+ tmp, err := ParseValue[T](v, opts...)
+ if err != nil {
+ return nil, err
+ }
+ result = append(result, tmp)
+ }
+ return result, nil
+}
+
+// ParseValue parses value to generic type
+//
+// Types that are supported:
+// - bool
+// - float32
+// - float64
+// - int
+// - int8
+// - int16
+// - int32
+// - int64
+// - uint
+// - uint8/byte
+// - uint16
+// - uint32
+// - uint64
+// - string
+// - echo.BindUnmarshaler interface
+// - encoding.TextUnmarshaler interface
+// - json.Unmarshaler interface
+// - time.Duration
+// - time.Time use echo.TimeOpts or echo.TimeLayout to set time parsing configuration
+func ParseValue[T any](value string, opts ...any) (T, error) {
+ var zero T
+ return ParseValueOr(value, zero, opts...)
+}
+
+// ParseValueOr parses value to generic type, when value is empty defaultValue is returned.
+//
+// Types that are supported:
+// - bool
+// - float32
+// - float64
+// - int
+// - int8
+// - int16
+// - int32
+// - int64
+// - uint
+// - uint8/byte
+// - uint16
+// - uint32
+// - uint64
+// - string
+// - echo.BindUnmarshaler interface
+// - encoding.TextUnmarshaler interface
+// - json.Unmarshaler interface
+// - time.Duration
+// - time.Time use echo.TimeOpts or echo.TimeLayout to set time parsing configuration
+func ParseValueOr[T any](value string, defaultValue T, opts ...any) (T, error) {
+ if len(value) == 0 {
+ return defaultValue, nil
+ }
+ var tmp T
+ if err := bindValue(value, &tmp, opts...); err != nil {
+ var zero T
+ return zero, fmt.Errorf("failed to parse value, err: %w", err)
+ }
+ return tmp, nil
+}
+
+func bindValue(value string, dest any, opts ...any) error {
+ // NOTE: if this function is ever made public the dest should be checked for nil
+ // values when dealing with interfaces
+ if len(opts) > 0 {
+ if _, isTime := dest.(*time.Time); !isTime {
+ return fmt.Errorf("options are only supported for time.Time, got %T", dest)
+ }
+ }
+
+ switch d := dest.(type) {
+ case *bool:
+ n, err := strconv.ParseBool(value)
+ if err != nil {
+ return err
+ }
+ *d = n
+ case *float32:
+ n, err := strconv.ParseFloat(value, 32)
+ if err != nil {
+ return err
+ }
+ *d = float32(n)
+ case *float64:
+ n, err := strconv.ParseFloat(value, 64)
+ if err != nil {
+ return err
+ }
+ *d = n
+ case *int:
+ n, err := strconv.ParseInt(value, 10, 0)
+ if err != nil {
+ return err
+ }
+ *d = int(n)
+ case *int8:
+ n, err := strconv.ParseInt(value, 10, 8)
+ if err != nil {
+ return err
+ }
+ *d = int8(n)
+ case *int16:
+ n, err := strconv.ParseInt(value, 10, 16)
+ if err != nil {
+ return err
+ }
+ *d = int16(n)
+ case *int32:
+ n, err := strconv.ParseInt(value, 10, 32)
+ if err != nil {
+ return err
+ }
+ *d = int32(n)
+ case *int64:
+ n, err := strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ return err
+ }
+ *d = n
+ case *uint:
+ n, err := strconv.ParseUint(value, 10, 0)
+ if err != nil {
+ return err
+ }
+ *d = uint(n)
+ case *uint8:
+ n, err := strconv.ParseUint(value, 10, 8)
+ if err != nil {
+ return err
+ }
+ *d = uint8(n)
+ case *uint16:
+ n, err := strconv.ParseUint(value, 10, 16)
+ if err != nil {
+ return err
+ }
+ *d = uint16(n)
+ case *uint32:
+ n, err := strconv.ParseUint(value, 10, 32)
+ if err != nil {
+ return err
+ }
+ *d = uint32(n)
+ case *uint64:
+ n, err := strconv.ParseUint(value, 10, 64)
+ if err != nil {
+ return err
+ }
+ *d = n
+ case *string:
+ *d = value
+ case *time.Duration:
+ t, err := time.ParseDuration(value)
+ if err != nil {
+ return err
+ }
+ *d = t
+ case *time.Time:
+ to := TimeOpts{
+ Layout: TimeLayout(time.RFC3339Nano),
+ ParseInLocation: time.UTC,
+ ToInLocation: time.UTC,
+ }
+ for _, o := range opts {
+ switch v := o.(type) {
+ case TimeOpts:
+ if v.Layout != "" {
+ to.Layout = v.Layout
+ }
+ if v.ParseInLocation != nil {
+ to.ParseInLocation = v.ParseInLocation
+ }
+ if v.ToInLocation != nil {
+ to.ToInLocation = v.ToInLocation
+ }
+ case TimeLayout:
+ to.Layout = v
+ }
+ }
+ var t time.Time
+ var err error
+ switch to.Layout {
+ case TimeLayoutUnixTime:
+ n, err := strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ return err
+ }
+ t = time.Unix(n, 0)
+ case TimeLayoutUnixTimeMilli:
+ n, err := strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ return err
+ }
+ t = time.UnixMilli(n)
+ case TimeLayoutUnixTimeNano:
+ n, err := strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ return err
+ }
+ t = time.Unix(0, n)
+ default:
+ if to.ParseInLocation != nil {
+ t, err = time.ParseInLocation(string(to.Layout), value, to.ParseInLocation)
+ } else {
+ t, err = time.Parse(string(to.Layout), value)
+ }
+ if err != nil {
+ return err
+ }
+ }
+ *d = t.In(to.ToInLocation)
+ case BindUnmarshaler:
+ if err := d.UnmarshalParam(value); err != nil {
+ return err
+ }
+ case encoding.TextUnmarshaler:
+ if err := d.UnmarshalText([]byte(value)); err != nil {
+ return err
+ }
+ case json.Unmarshaler:
+ if err := d.UnmarshalJSON([]byte(value)); err != nil {
+ return err
+ }
+ default:
+ return fmt.Errorf("unsupported value type: %T", dest)
+ }
+ return nil
+}
diff --git a/binder_generic_test.go b/binder_generic_test.go
new file mode 100644
index 000000000..849d75962
--- /dev/null
+++ b/binder_generic_test.go
@@ -0,0 +1,1616 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "cmp"
+ "encoding/json"
+ "fmt"
+ "math"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+// TextUnmarshalerType implements encoding.TextUnmarshaler but NOT BindUnmarshaler
+type TextUnmarshalerType struct {
+ Value string
+}
+
+func (t *TextUnmarshalerType) UnmarshalText(data []byte) error {
+ s := string(data)
+ if s == "invalid" {
+ return fmt.Errorf("invalid value: %s", s)
+ }
+ t.Value = strings.ToUpper(s)
+ return nil
+}
+
+// JSONUnmarshalerType implements json.Unmarshaler but NOT BindUnmarshaler or TextUnmarshaler
+type JSONUnmarshalerType struct {
+ Value string
+}
+
+func (j *JSONUnmarshalerType) UnmarshalJSON(data []byte) error {
+ return json.Unmarshal(data, &j.Value)
+}
+
+func TestPathParam(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenKey string
+ givenValue string
+ expect bool
+ expectErr string
+ }{
+ {
+ name: "ok",
+ givenValue: "true",
+ expect: true,
+ },
+ {
+ name: "nok, non existent key",
+ givenKey: "missing",
+ givenValue: "true",
+ expect: false,
+ expectErr: ErrNonExistentKey.Error(),
+ },
+ {
+ name: "nok, invalid value",
+ givenValue: "can_parse_me",
+ expect: false,
+ expectErr: `code=400, message=path value, err=failed to parse value, err: strconv.ParseBool: parsing "can_parse_me": invalid syntax, field=key`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := NewContext(nil, nil)
+ c.SetPathValues(PathValues{{
+ Name: cmp.Or(tc.givenKey, "key"),
+ Value: tc.givenValue,
+ }})
+
+ v, err := PathParam[bool](c, "key")
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestPathParam_UnsupportedType(t *testing.T) {
+ c := NewContext(nil, nil)
+ c.SetPathValues(PathValues{{Name: "key", Value: "true"}})
+
+ v, err := PathParam[[]bool](c, "key")
+
+ expectErr := "code=400, message=path value, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
+ assert.EqualError(t, err, expectErr)
+ assert.Equal(t, []bool(nil), v)
+}
+
+func TestQueryParam(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ expect bool
+ expectErr string
+ }{
+ {
+ name: "ok",
+ givenURL: "/?key=true",
+ expect: true,
+ },
+ {
+ name: "nok, non existent key",
+ givenURL: "/?different=true",
+ expect: false,
+ expectErr: ErrNonExistentKey.Error(),
+ },
+ {
+ name: "nok, invalid value",
+ givenURL: "/?key=invalidbool",
+ expect: false,
+ expectErr: `code=400, message=query param, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
+ c := NewContext(req, nil)
+
+ v, err := QueryParam[bool](c, "key")
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestQueryParam_UnsupportedType(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil)
+ c := NewContext(req, nil)
+
+ v, err := QueryParam[[]bool](c, "key")
+
+ expectErr := "code=400, message=query param, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
+ assert.EqualError(t, err, expectErr)
+ assert.Equal(t, []bool(nil), v)
+}
+
+func TestQueryParams(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ expect []bool
+ expectErr string
+ }{
+ {
+ name: "ok",
+ givenURL: "/?key=true&key=false",
+ expect: []bool{true, false},
+ },
+ {
+ name: "nok, non existent key",
+ givenURL: "/?different=true",
+ expect: []bool(nil),
+ expectErr: ErrNonExistentKey.Error(),
+ },
+ {
+ name: "nok, invalid value",
+ givenURL: "/?key=true&key=invalidbool",
+ expect: []bool(nil),
+ expectErr: `code=400, message=query params, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
+ c := NewContext(req, nil)
+
+ v, err := QueryParams[bool](c, "key")
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestQueryParams_UnsupportedType(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil)
+ c := NewContext(req, nil)
+
+ v, err := QueryParams[[]bool](c, "key")
+
+ expectErr := "code=400, message=query params, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
+ assert.EqualError(t, err, expectErr)
+ assert.Equal(t, [][]bool(nil), v)
+}
+
+func TestFormValue(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ expect bool
+ expectErr string
+ }{
+ {
+ name: "ok",
+ givenURL: "/?key=true",
+ expect: true,
+ },
+ {
+ name: "nok, non existent key",
+ givenURL: "/?different=true",
+ expect: false,
+ expectErr: ErrNonExistentKey.Error(),
+ },
+ {
+ name: "nok, invalid value",
+ givenURL: "/?key=invalidbool",
+ expect: false,
+ expectErr: `code=400, message=form value, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
+ c := NewContext(req, nil)
+
+ v, err := FormValue[bool](c, "key")
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestFormValue_UnsupportedType(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil)
+ c := NewContext(req, nil)
+
+ v, err := FormValue[[]bool](c, "key")
+
+ expectErr := "code=400, message=form value, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
+ assert.EqualError(t, err, expectErr)
+ assert.Equal(t, []bool(nil), v)
+}
+
+func TestFormValues(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ expect []bool
+ expectErr string
+ }{
+ {
+ name: "ok",
+ givenURL: "/?key=true&key=false",
+ expect: []bool{true, false},
+ },
+ {
+ name: "nok, non existent key",
+ givenURL: "/?different=true",
+ expect: []bool(nil),
+ expectErr: ErrNonExistentKey.Error(),
+ },
+ {
+ name: "nok, invalid value",
+ givenURL: "/?key=true&key=invalidbool",
+ expect: []bool(nil),
+ expectErr: `code=400, message=form values, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
+ c := NewContext(req, nil)
+
+ v, err := FormValues[bool](c, "key")
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestFormValues_UnsupportedType(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil)
+ c := NewContext(req, nil)
+
+ v, err := FormValues[[]bool](c, "key")
+
+ expectErr := "code=400, message=form values, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
+ assert.EqualError(t, err, expectErr)
+ assert.Equal(t, [][]bool(nil), v)
+}
+
+func TestParseValue_bool(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect bool
+ expectErr error
+ }{
+ {
+ name: "ok, true",
+ when: "true",
+ expect: true,
+ },
+ {
+ name: "ok, false",
+ when: "false",
+ expect: false,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: true,
+ },
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: false,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[bool](tc.when)
+ if tc.expectErr != nil {
+ assert.ErrorIs(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_float32(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect float32
+ expectErr string
+ }{
+ {
+ name: "ok, 123.345",
+ when: "123.345",
+ expect: 123.345,
+ },
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, Inf",
+ when: "+Inf",
+ expect: float32(math.Inf(1)),
+ },
+ {
+ name: "ok, Inf",
+ when: "-Inf",
+ expect: float32(math.Inf(-1)),
+ },
+ {
+ name: "ok, NaN",
+ when: "NaN",
+ expect: float32(math.NaN()),
+ },
+ {
+ name: "ok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseFloat: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[float32](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ if math.IsNaN(float64(tc.expect)) {
+ if !math.IsNaN(float64(v)) {
+ t.Fatal("expected NaN but got non NaN")
+ }
+ } else {
+ assert.Equal(t, tc.expect, v)
+ }
+ })
+ }
+}
+
+func TestParseValue_float64(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect float64
+ expectErr string
+ }{
+ {
+ name: "ok, 123.345",
+ when: "123.345",
+ expect: 123.345,
+ },
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, Inf",
+ when: "+Inf",
+ expect: math.Inf(1),
+ },
+ {
+ name: "ok, Inf",
+ when: "-Inf",
+ expect: math.Inf(-1),
+ },
+ {
+ name: "ok, NaN",
+ when: "NaN",
+ expect: math.NaN(),
+ },
+ {
+ name: "ok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseFloat: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[float64](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ if math.IsNaN(tc.expect) {
+ if !math.IsNaN(v) {
+ t.Fatal("expected NaN but got non NaN")
+ }
+ } else {
+ assert.Equal(t, tc.expect, v)
+ }
+ })
+ }
+}
+
+func TestParseValue_int(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect int
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, -1",
+ when: "-1",
+ expect: -1,
+ },
+ {
+ name: "ok, max int (64bit)",
+ when: "9223372036854775807",
+ expect: 9223372036854775807,
+ },
+ {
+ name: "ok, min int (64bit)",
+ when: "-9223372036854775808",
+ expect: -9223372036854775808,
+ },
+ {
+ name: "ok, overflow max int (64bit)",
+ when: "9223372036854775808",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "9223372036854775808": value out of range`,
+ },
+ {
+ name: "ok, underflow min int (64bit)",
+ when: "-9223372036854775809",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-9223372036854775809": value out of range`,
+ },
+ {
+ name: "ok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[int](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_uint(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect uint
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, max uint (64bit)",
+ when: "18446744073709551615",
+ expect: 18446744073709551615,
+ },
+ {
+ name: "nok, overflow max uint (64bit)",
+ when: "18446744073709551616",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "18446744073709551616": value out of range`,
+ },
+ {
+ name: "nok, negative value",
+ when: "-1",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[uint](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_int8(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect int8
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, -1",
+ when: "-1",
+ expect: -1,
+ },
+ {
+ name: "ok, max int8",
+ when: "127",
+ expect: 127,
+ },
+ {
+ name: "ok, min int8",
+ when: "-128",
+ expect: -128,
+ },
+ {
+ name: "nok, overflow max int8",
+ when: "128",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "128": value out of range`,
+ },
+ {
+ name: "nok, underflow min int8",
+ when: "-129",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-129": value out of range`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[int8](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_int16(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect int16
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, -1",
+ when: "-1",
+ expect: -1,
+ },
+ {
+ name: "ok, max int16",
+ when: "32767",
+ expect: 32767,
+ },
+ {
+ name: "ok, min int16",
+ when: "-32768",
+ expect: -32768,
+ },
+ {
+ name: "nok, overflow max int16",
+ when: "32768",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "32768": value out of range`,
+ },
+ {
+ name: "nok, underflow min int16",
+ when: "-32769",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-32769": value out of range`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[int16](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_int32(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect int32
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, -1",
+ when: "-1",
+ expect: -1,
+ },
+ {
+ name: "ok, max int32",
+ when: "2147483647",
+ expect: 2147483647,
+ },
+ {
+ name: "ok, min int32",
+ when: "-2147483648",
+ expect: -2147483648,
+ },
+ {
+ name: "nok, overflow max int32",
+ when: "2147483648",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "2147483648": value out of range`,
+ },
+ {
+ name: "nok, underflow min int32",
+ when: "-2147483649",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-2147483649": value out of range`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[int32](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_int64(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect int64
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, -1",
+ when: "-1",
+ expect: -1,
+ },
+ {
+ name: "ok, max int64",
+ when: "9223372036854775807",
+ expect: 9223372036854775807,
+ },
+ {
+ name: "ok, min int64",
+ when: "-9223372036854775808",
+ expect: -9223372036854775808,
+ },
+ {
+ name: "nok, overflow max int64",
+ when: "9223372036854775808",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "9223372036854775808": value out of range`,
+ },
+ {
+ name: "nok, underflow min int64",
+ when: "-9223372036854775809",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-9223372036854775809": value out of range`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[int64](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_uint8(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect uint8
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, max uint8",
+ when: "255",
+ expect: 255,
+ },
+ {
+ name: "nok, overflow max uint8",
+ when: "256",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "256": value out of range`,
+ },
+ {
+ name: "nok, negative value",
+ when: "-1",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[uint8](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_uint16(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect uint16
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, max uint16",
+ when: "65535",
+ expect: 65535,
+ },
+ {
+ name: "nok, overflow max uint16",
+ when: "65536",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "65536": value out of range`,
+ },
+ {
+ name: "nok, negative value",
+ when: "-1",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[uint16](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_uint32(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect uint32
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, max uint32",
+ when: "4294967295",
+ expect: 4294967295,
+ },
+ {
+ name: "nok, overflow max uint32",
+ when: "4294967296",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "4294967296": value out of range`,
+ },
+ {
+ name: "nok, negative value",
+ when: "-1",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[uint32](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_uint64(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect uint64
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, max uint64",
+ when: "18446744073709551615",
+ expect: 18446744073709551615,
+ },
+ {
+ name: "nok, overflow max uint64",
+ when: "18446744073709551616",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "18446744073709551616": value out of range`,
+ },
+ {
+ name: "nok, negative value",
+ when: "-1",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[uint64](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_string(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect string
+ expectErr string
+ }{
+ {
+ name: "ok, my",
+ when: "my",
+ expect: "my",
+ },
+ {
+ name: "ok, empty",
+ when: "",
+ expect: "",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[string](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_Duration(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect time.Duration
+ expectErr string
+ }{
+ {
+ name: "ok, 10h11m01s",
+ when: "10h11m01s",
+ expect: 10*time.Hour + 11*time.Minute + 1*time.Second,
+ },
+ {
+ name: "ok, empty",
+ when: "",
+ expect: 0,
+ },
+ {
+ name: "ok, invalid",
+ when: "0x0",
+ expect: 0,
+ expectErr: `failed to parse value, err: time: unknown unit "x" in duration "0x0"`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[time.Duration](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_Time(t *testing.T) {
+ tallinn, err := time.LoadLocation("Europe/Tallinn")
+ if err != nil {
+ t.Fatal(err)
+ }
+ berlin, err := time.LoadLocation("Europe/Berlin")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ parse := func(t *testing.T, layout string, s string) time.Time {
+ result, err := time.Parse(layout, s)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return result
+ }
+
+ parseInLoc := func(t *testing.T, layout string, s string, loc *time.Location) time.Time {
+ result, err := time.ParseInLocation(layout, s, loc)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return result
+ }
+
+ var testCases = []struct {
+ name string
+ when string
+ whenLayout TimeLayout
+ whenTimeOpts *TimeOpts
+ expect time.Time
+ expectErr string
+ }{
+ {
+ name: "ok, defaults to RFC3339Nano",
+ when: "2006-01-02T15:04:05.999999999Z",
+ expect: parse(t, time.RFC3339Nano, "2006-01-02T15:04:05.999999999Z"),
+ },
+ {
+ name: "ok, custom TimeOpt",
+ when: "2006-01-02",
+ whenTimeOpts: &TimeOpts{
+ Layout: time.DateOnly,
+ ParseInLocation: tallinn,
+ ToInLocation: berlin,
+ },
+ expect: parseInLoc(t, time.DateTime, "2006-01-01 23:00:00", berlin),
+ },
+ {
+ name: "ok, custom layout",
+ when: "2006-01-02",
+ whenLayout: TimeLayout(time.DateOnly),
+ expect: parse(t, time.DateOnly, "2006-01-02"),
+ },
+ {
+ name: "ok, TimeLayoutUnixTime",
+ when: "1766604665",
+ whenLayout: TimeLayoutUnixTime,
+ expect: parse(t, time.RFC3339Nano, "2025-12-24T19:31:05Z"),
+ },
+ {
+ name: "nok, TimeLayoutUnixTime, invalid value",
+ when: "176x6604665",
+ whenLayout: TimeLayoutUnixTime,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "176x6604665": invalid syntax`,
+ },
+ {
+ name: "ok, TimeLayoutUnixTimeMilli",
+ when: "1766604665123",
+ whenLayout: TimeLayoutUnixTimeMilli,
+ expect: parse(t, time.RFC3339Nano, "2025-12-24T19:31:05.123Z"),
+ },
+ {
+ name: "nok, TimeLayoutUnixTimeMilli, invalid value",
+ when: "1x766604665123",
+ whenLayout: TimeLayoutUnixTimeMilli,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "1x766604665123": invalid syntax`,
+ },
+ {
+ name: "ok, TimeLayoutUnixTimeMilli",
+ when: "1766604665999999999",
+ whenLayout: TimeLayoutUnixTimeNano,
+ expect: parse(t, time.RFC3339Nano, "2025-12-24T19:31:05.999999999Z"),
+ },
+ {
+ name: "nok, TimeLayoutUnixTimeMilli, invalid value",
+ when: "1x766604665999999999",
+ whenLayout: TimeLayoutUnixTimeNano,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "1x766604665999999999": invalid syntax`,
+ },
+ {
+ name: "ok, invalid",
+ when: "xx",
+ expect: time.Time{},
+ expectErr: `failed to parse value, err: parsing time "xx" as "2006-01-02T15:04:05.999999999Z07:00": cannot parse "xx" as "2006"`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ var opts []any
+ if tc.whenLayout != "" {
+ opts = append(opts, tc.whenLayout)
+ }
+ if tc.whenTimeOpts != nil {
+ opts = append(opts, *tc.whenTimeOpts)
+ }
+ v, err := ParseValue[time.Time](tc.when, opts...)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_OptionsOnlyForTime(t *testing.T) {
+ _, err := ParseValue[int]("test", TimeLayoutUnixTime)
+ assert.EqualError(t, err, `failed to parse value, err: options are only supported for time.Time, got *int`)
+}
+
+func TestParseValue_BindUnmarshaler(t *testing.T) {
+ exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00")
+
+ var testCases = []struct {
+ name string
+ when string
+ expect Timestamp
+ expectErr string
+ }{
+ {
+ name: "ok",
+ when: "2020-12-23T09:45:31+02:00",
+ expect: Timestamp(exampleTime),
+ },
+ {
+ name: "nok, invalid value",
+ when: "2020-12-23T09:45:3102:00",
+ expect: Timestamp{},
+ expectErr: `failed to parse value, err: parsing time "2020-12-23T09:45:3102:00" as "2006-01-02T15:04:05Z07:00": cannot parse "02:00" as "Z07:00"`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[Timestamp](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_TextUnmarshaler(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect TextUnmarshalerType
+ expectErr string
+ }{
+ {
+ name: "ok, converts to uppercase",
+ when: "hello",
+ expect: TextUnmarshalerType{Value: "HELLO"},
+ },
+ {
+ name: "ok, empty string",
+ when: "",
+ expect: TextUnmarshalerType{Value: ""},
+ },
+ {
+ name: "nok, invalid value",
+ when: "invalid",
+ expect: TextUnmarshalerType{},
+ expectErr: "failed to parse value, err: invalid value: invalid",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[TextUnmarshalerType](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_JSONUnmarshaler(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect JSONUnmarshalerType
+ expectErr string
+ }{
+ {
+ name: "ok, valid JSON string",
+ when: `"hello"`,
+ expect: JSONUnmarshalerType{Value: "hello"},
+ },
+ {
+ name: "ok, empty JSON string",
+ when: `""`,
+ expect: JSONUnmarshalerType{Value: ""},
+ },
+ {
+ name: "nok, invalid JSON",
+ when: "not-json",
+ expect: JSONUnmarshalerType{},
+ expectErr: "failed to parse value, err: invalid character 'o' in literal null (expecting 'u')",
+ },
+ {
+ name: "nok, unquoted string",
+ when: "hello",
+ expect: JSONUnmarshalerType{},
+ expectErr: "failed to parse value, err: invalid character 'h' looking for beginning of value",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[JSONUnmarshalerType](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValues_bools(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when []string
+ expect []bool
+ expectErr string
+ }{
+ {
+ name: "ok",
+ when: []string{"true", "0", "false", "1"},
+ expect: []bool{true, false, false, true},
+ },
+ {
+ name: "nok",
+ when: []string{"true", "10"},
+ expect: nil,
+ expectErr: `failed to parse value, err: strconv.ParseBool: parsing "10": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValues[bool](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestPathParamOr(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenKey string
+ givenValue string
+ defaultValue int
+ expect int
+ expectErr string
+ }{
+ {
+ name: "ok, param exists",
+ givenKey: "id",
+ givenValue: "123",
+ defaultValue: 999,
+ expect: 123,
+ },
+ {
+ name: "ok, param missing - returns default",
+ givenKey: "other",
+ givenValue: "123",
+ defaultValue: 999,
+ expect: 999,
+ },
+ {
+ name: "ok, param exists but empty - returns default",
+ givenKey: "id",
+ givenValue: "",
+ defaultValue: 999,
+ expect: 999,
+ },
+ {
+ name: "nok, invalid value",
+ givenKey: "id",
+ givenValue: "invalid",
+ defaultValue: 999,
+ expectErr: "code=400, message=path value, err=failed to parse value",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := NewContext(nil, nil)
+ c.SetPathValues(PathValues{{Name: tc.givenKey, Value: tc.givenValue}})
+
+ v, err := PathParamOr[int](c, "id", tc.defaultValue)
+ if tc.expectErr != "" {
+ assert.ErrorContains(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestQueryParamOr(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ defaultValue int
+ expect int
+ expectErr string
+ }{
+ {
+ name: "ok, param exists",
+ givenURL: "/?key=42",
+ defaultValue: 999,
+ expect: 42,
+ },
+ {
+ name: "ok, param missing - returns default",
+ givenURL: "/?other=42",
+ defaultValue: 999,
+ expect: 999,
+ },
+ {
+ name: "ok, param exists but empty - returns default",
+ givenURL: "/?key=",
+ defaultValue: 999,
+ expect: 999,
+ },
+ {
+ name: "nok, invalid value",
+ givenURL: "/?key=invalid",
+ defaultValue: 999,
+ expectErr: "code=400, message=query param",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
+ c := NewContext(req, nil)
+
+ v, err := QueryParamOr[int](c, "key", tc.defaultValue)
+ if tc.expectErr != "" {
+ assert.ErrorContains(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestQueryParamsOr(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ defaultValue []int
+ expect []int
+ expectErr string
+ }{
+ {
+ name: "ok, params exist",
+ givenURL: "/?key=1&key=2&key=3",
+ defaultValue: []int{999},
+ expect: []int{1, 2, 3},
+ },
+ {
+ name: "ok, params missing - returns default",
+ givenURL: "/?other=1",
+ defaultValue: []int{7, 8, 9},
+ expect: []int{7, 8, 9},
+ },
+ {
+ name: "nok, invalid value",
+ givenURL: "/?key=1&key=invalid",
+ defaultValue: []int{999},
+ expectErr: "code=400, message=query params",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
+ c := NewContext(req, nil)
+
+ v, err := QueryParamsOr[int](c, "key", tc.defaultValue)
+ if tc.expectErr != "" {
+ assert.ErrorContains(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestFormValueOr(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ defaultValue string
+ expect string
+ expectErr string
+ }{
+ {
+ name: "ok, value exists",
+ givenURL: "/?name=john",
+ defaultValue: "default",
+ expect: "john",
+ },
+ {
+ name: "ok, value missing - returns default",
+ givenURL: "/?other=john",
+ defaultValue: "default",
+ expect: "default",
+ },
+ {
+ name: "ok, value exists but empty - returns default",
+ givenURL: "/?name=",
+ defaultValue: "default",
+ expect: "default",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
+ c := NewContext(req, nil)
+
+ v, err := FormValueOr[string](c, "name", tc.defaultValue)
+ if tc.expectErr != "" {
+ assert.ErrorContains(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestFormValuesOr(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ defaultValue []string
+ expect []string
+ expectErr string
+ }{
+ {
+ name: "ok, values exist",
+ givenURL: "/?tags=go&tags=rust&tags=python",
+ defaultValue: []string{"default"},
+ expect: []string{"go", "rust", "python"},
+ },
+ {
+ name: "ok, values missing - returns default",
+ givenURL: "/?other=value",
+ defaultValue: []string{"a", "b"},
+ expect: []string{"a", "b"},
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
+ c := NewContext(req, nil)
+
+ v, err := FormValuesOr[string](c, "tags", tc.defaultValue)
+ if tc.expectErr != "" {
+ assert.ErrorContains(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
diff --git a/binder_test.go b/binder_test.go
new file mode 100644
index 000000000..8eced8208
--- /dev/null
+++ b/binder_test.go
@@ -0,0 +1,3252 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "github.com/stretchr/testify/assert"
+ "io"
+ "math/big"
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+)
+
+func createTestContext(URL string, body io.Reader, pathValues map[string]string) *Context {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, URL, body)
+ if body != nil {
+ req.Header.Set(HeaderContentType, MIMEApplicationJSON)
+ }
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ if len(pathValues) > 0 {
+ params := make(PathValues, 0)
+ for name, value := range pathValues {
+ params = append(params, PathValue{
+ Name: name,
+ Value: value,
+ })
+ }
+ c.SetPathValues(params)
+ }
+
+ return c
+}
+
+func TestBindingError_Error(t *testing.T) {
+ err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error"))
+ assert.EqualError(t, err, `code=400, message=bind failed, err=internal error, field=id`)
+
+ bErr := err.(*BindingError)
+ assert.Equal(t, 400, bErr.Code)
+ assert.Equal(t, "bind failed", bErr.Message)
+ assert.Equal(t, errors.New("internal error"), bErr.err)
+
+ assert.Equal(t, "id", bErr.Field)
+ assert.Equal(t, []string{"1", "nope"}, bErr.Values)
+}
+
+func TestBindingError_ErrorJSON(t *testing.T) {
+ err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error"))
+
+ resp, _ := json.Marshal(err)
+
+ assert.Equal(t, `{"field":"id","message":"bind failed"}`, string(resp))
+}
+
+func TestPathValuesBinder(t *testing.T) {
+ c := createTestContext("/api/user/999", nil, map[string]string{
+ "id": "1",
+ "nr": "2",
+ "slice": "3",
+ })
+ b := PathValuesBinder(c)
+
+ id := int64(99)
+ nr := int64(88)
+ var slice = make([]int64, 0)
+ var notExisting = make([]int64, 0)
+ err := b.Int64("id", &id).
+ Int64("nr", &nr).
+ Int64s("slice", &slice).
+ Int64s("not_existing", ¬Existing).
+ BindError()
+
+ assert.NoError(t, err)
+ assert.Equal(t, int64(1), id)
+ assert.Equal(t, int64(2), nr)
+ assert.Equal(t, []int64{3}, slice) // binding params to slice does not make sense but it should not panic either
+ assert.Equal(t, []int64{}, notExisting) // binding params to slice does not make sense but it should not panic either
+}
+
+func TestQueryParamsBinder_FailFast(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError []string
+ givenFailFast bool
+ }{
+ {
+ name: "ok, FailFast=true stops at first error",
+ whenURL: "/api/user/999?nr=en&id=nope",
+ givenFailFast: true,
+ expectError: []string{
+ `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "nope": invalid syntax, field=id`,
+ },
+ },
+ {
+ name: "ok, FailFast=false encounters all errors",
+ whenURL: "/api/user/999?nr=en&id=nope",
+ givenFailFast: false,
+ expectError: []string{
+ `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "nope": invalid syntax, field=id`,
+ `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "en": invalid syntax, field=nr`,
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, map[string]string{"id": "999"})
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ id := int64(99)
+ nr := int64(88)
+ errs := b.Int64("id", &id).
+ Int64("nr", &nr).
+ BindErrors()
+
+ assert.Len(t, errs, len(tc.expectError))
+ for _, err := range errs {
+ assert.Contains(t, tc.expectError, err.Error())
+ }
+ })
+ }
+}
+
+func TestFormFieldBinder(t *testing.T) {
+ e := New()
+ body := `texta=foo&slice=5`
+ req := httptest.NewRequest(http.MethodPost, "/api/search?id=1&nr=2&slice=3&slice=4", strings.NewReader(body))
+ req.Header.Set(HeaderContentLength, strconv.Itoa(len(body)))
+ req.Header.Set(HeaderContentType, MIMEApplicationForm)
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ b := FormFieldBinder(c)
+
+ var texta string
+ id := int64(99)
+ nr := int64(88)
+ var slice = make([]int64, 0)
+ var notExisting = make([]int64, 0)
+ err := b.
+ Int64s("slice", &slice).
+ Int64("id", &id).
+ Int64("nr", &nr).
+ String("texta", &texta).
+ Int64s("notExisting", ¬Existing).
+ BindError()
+
+ assert.NoError(t, err)
+ assert.Equal(t, "foo", texta)
+ assert.Equal(t, int64(1), id)
+ assert.Equal(t, int64(2), nr)
+ assert.Equal(t, []int64{5, 3, 4}, slice)
+ assert.Equal(t, []int64{}, notExisting)
+}
+
+func TestValueBinder_errorStopsBinding(t *testing.T) {
+ // this test documents "feature" that binding multiple params can change destination if it was bound before
+ // failing parameter binding
+
+ c := createTestContext("/api/user/999?id=1&nr=nope", nil, nil)
+ b := QueryParamsBinder(c)
+
+ id := int64(99) // will be changed before nr binding fails
+ nr := int64(88) // will not be changed
+ err := b.Int64("id", &id).
+ Int64("nr", &nr).
+ BindError()
+
+ assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=nr")
+ assert.Equal(t, int64(1), id)
+ assert.Equal(t, int64(88), nr)
+}
+
+func TestValueBinder_BindError(t *testing.T) {
+ c := createTestContext("/api/user/999?nr=en&id=nope", nil, nil)
+ b := QueryParamsBinder(c)
+
+ id := int64(99)
+ nr := int64(88)
+ err := b.Int64("id", &id).
+ Int64("nr", &nr).
+ BindError()
+
+ assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=id")
+ assert.Nil(t, b.errors)
+ assert.Nil(t, b.BindError())
+}
+
+func TestValueBinder_GetValues(t *testing.T) {
+ var testCases = []struct {
+ whenValuesFunc func(sourceParam string) []string
+ name string
+ expectError string
+ expect []int64
+ }{
+ {
+ name: "ok, default implementation",
+ expect: []int64{1, 101},
+ },
+ {
+ name: "ok, values returns nil",
+ whenValuesFunc: func(sourceParam string) []string {
+ return nil
+ },
+ expect: []int64(nil),
+ },
+ {
+ name: "ok, values returns empty slice",
+ whenValuesFunc: func(sourceParam string) []string {
+ return []string{}
+ },
+ expect: []int64(nil),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext("/search?nr=en&id=1&id=101", nil, nil)
+ b := QueryParamsBinder(c)
+ if tc.whenValuesFunc != nil {
+ b.ValuesFunc = tc.whenValuesFunc
+ }
+
+ var IDs []int64
+ err := b.Int64s("id", &IDs).BindError()
+
+ assert.Equal(t, tc.expect, IDs)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_CustomFuncWithError(t *testing.T) {
+ c := createTestContext("/search?nr=en&id=1&id=101", nil, nil)
+ b := QueryParamsBinder(c)
+
+ id := int64(99)
+ givenCustomFunc := func(values []string) []error {
+ assert.Equal(t, []string{"1", "101"}, values)
+
+ return []error{
+ errors.New("first error"),
+ errors.New("second error"),
+ }
+ }
+ err := b.CustomFunc("id", givenCustomFunc).BindError()
+
+ assert.Equal(t, int64(99), id)
+ assert.EqualError(t, err, "first error")
+}
+
+func TestValueBinder_CustomFunc(t *testing.T) {
+ var testCases = []struct {
+ expectValue any
+ name string
+ whenURL string
+ givenFuncErrors []error
+ expectParamValues []string
+ expectErrors []string
+ givenFailFast bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectParamValues: []string{"1", "100"},
+ expectValue: int64(1000),
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nr=en",
+ expectParamValues: []string{},
+ expectValue: int64(99),
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectParamValues: []string{"1", "100"},
+ expectValue: int64(99),
+ expectErrors: []string{"previous error"},
+ },
+ {
+ name: "nok, func returns errors",
+ givenFuncErrors: []error{
+ errors.New("first error"),
+ errors.New("second error"),
+ },
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectParamValues: []string{"1", "100"},
+ expectValue: int64(99),
+ expectErrors: []string{"first error", "second error"},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ id := int64(99)
+ givenCustomFunc := func(values []string) []error {
+ assert.Equal(t, tc.expectParamValues, values)
+ if tc.givenFuncErrors == nil {
+ id = 1000 // emulated conversion and setting value
+ return nil
+ }
+ return tc.givenFuncErrors
+ }
+ errs := b.CustomFunc("id", givenCustomFunc).BindErrors()
+
+ assert.Equal(t, tc.expectValue, id)
+ if tc.expectErrors != nil {
+ assert.Len(t, errs, len(tc.expectErrors))
+ for _, err := range errs {
+ assert.Contains(t, tc.expectErrors, err.Error())
+ }
+ } else {
+ assert.Nil(t, errs)
+ }
+ })
+ }
+}
+
+func TestValueBinder_MustCustomFunc(t *testing.T) {
+ var testCases = []struct {
+ expectValue any
+ name string
+ whenURL string
+ givenFuncErrors []error
+ expectParamValues []string
+ expectErrors []string
+ givenFailFast bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectParamValues: []string{"1", "100"},
+ expectValue: int64(1000),
+ },
+ {
+ name: "nok, params values empty, returns error, value is not changed",
+ whenURL: "/search?nr=en",
+ expectParamValues: []string{},
+ expectValue: int64(99),
+ expectErrors: []string{"code=400, message=required field value is empty, field=id"},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectParamValues: []string{"1", "100"},
+ expectValue: int64(99),
+ expectErrors: []string{"previous error"},
+ },
+ {
+ name: "nok, func returns errors",
+ givenFuncErrors: []error{
+ errors.New("first error"),
+ errors.New("second error"),
+ },
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectParamValues: []string{"1", "100"},
+ expectValue: int64(99),
+ expectErrors: []string{"first error", "second error"},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ id := int64(99)
+ givenCustomFunc := func(values []string) []error {
+ assert.Equal(t, tc.expectParamValues, values)
+ if tc.givenFuncErrors == nil {
+ id = 1000 // emulated conversion and setting value
+ return nil
+ }
+ return tc.givenFuncErrors
+ }
+ errs := b.MustCustomFunc("id", givenCustomFunc).BindErrors()
+
+ assert.Equal(t, tc.expectValue, id)
+ if tc.expectErrors != nil {
+ assert.Len(t, errs, len(tc.expectErrors))
+ for _, err := range errs {
+ assert.Contains(t, tc.expectErrors, err.Error())
+ }
+ } else {
+ assert.Nil(t, errs)
+ }
+ })
+ }
+}
+
+func TestValueBinder_String(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectValue string
+ expectError string
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=en¶m=de",
+ expectValue: "en",
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nr=en",
+ expectValue: "default",
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectValue: "default",
+ expectError: "previous error",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=en¶m=de",
+ expectValue: "en",
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nr=en",
+ expectValue: "default",
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectValue: "default",
+ expectError: "previous error",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := "default"
+ var err error
+ if tc.whenMust {
+ err = b.MustString("param", &dest).BindError()
+ } else {
+ err = b.String("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Strings(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []string
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=en¶m=de",
+ expectValue: []string{"en", "de"},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nr=en",
+ expectValue: []string{"default"},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectValue: []string{"default"},
+ expectError: "previous error",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=en¶m=de",
+ expectValue: []string{"en", "de"},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nr=en",
+ expectValue: []string{"default"},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectValue: []string{"default"},
+ expectError: "previous error",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := []string{"default"}
+ var err error
+ if tc.whenMust {
+ err = b.MustStrings("param", &dest).BindError()
+ } else {
+ err = b.Strings("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Int64_intValue(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue int64
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 1,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: 99,
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 99,
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 99,
+ expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 1,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: 99,
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 99,
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 99,
+ expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := int64(99)
+ var err error
+ if tc.whenMust {
+ err = b.MustInt64("param", &dest).BindError()
+ } else {
+ err = b.Int64("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Int_errorMessage(t *testing.T) {
+ // int/uint (without byte size) has a little bit different error message so test these separately
+ c := createTestContext("/search?param=nope", nil, nil)
+ b := QueryParamsBinder(c).FailFast(false)
+
+ destInt := 99
+ destUint := uint(98)
+ errs := b.Int("param", &destInt).Uint("param", &destUint).BindErrors()
+
+ assert.Equal(t, 99, destInt)
+ assert.Equal(t, uint(98), destUint)
+ assert.EqualError(t, errs[0], `code=400, message=failed to bind field value to int, err=strconv.ParseInt: parsing "nope": invalid syntax, field=param`)
+ assert.EqualError(t, errs[1], `code=400, message=failed to bind field value to uint, err=strconv.ParseUint: parsing "nope": invalid syntax, field=param`)
+}
+
+func TestValueBinder_Uint64_uintValue(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue uint64
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 1,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: 99,
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 99,
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 99,
+ expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 1,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: 99,
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 99,
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 99,
+ expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := uint64(99)
+ var err error
+ if tc.whenMust {
+ err = b.MustUint64("param", &dest).BindError()
+ } else {
+ err = b.Uint64("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Int_Types(t *testing.T) {
+ type target struct {
+ int64 int64
+ mustInt64 int64
+ uint64 uint64
+ mustUint64 uint64
+
+ int32 int32
+ mustInt32 int32
+ uint32 uint32
+ mustUint32 uint32
+
+ int16 int16
+ mustInt16 int16
+ uint16 uint16
+ mustUint16 uint16
+
+ int8 int8
+ mustInt8 int8
+ uint8 uint8
+ mustUint8 uint8
+
+ byte byte
+ mustByte byte
+
+ int int
+ mustInt int
+ uint uint
+ mustUint uint
+ }
+ types := []string{
+ "int64=1",
+ "mustInt64=2",
+ "uint64=3",
+ "mustUint64=4",
+
+ "int32=5",
+ "mustInt32=6",
+ "uint32=7",
+ "mustUint32=8",
+
+ "int16=9",
+ "mustInt16=10",
+ "uint16=11",
+ "mustUint16=12",
+
+ "int8=13",
+ "mustInt8=14",
+ "uint8=15",
+ "mustUint8=16",
+
+ "byte=17",
+ "mustByte=18",
+
+ "int=19",
+ "mustInt=20",
+ "uint=21",
+ "mustUint=22",
+ }
+ c := createTestContext("/search?"+strings.Join(types, "&"), nil, nil)
+ b := QueryParamsBinder(c)
+
+ dest := target{}
+ err := b.
+ Int64("int64", &dest.int64).
+ MustInt64("mustInt64", &dest.mustInt64).
+ Uint64("uint64", &dest.uint64).
+ MustUint64("mustUint64", &dest.mustUint64).
+ Int32("int32", &dest.int32).
+ MustInt32("mustInt32", &dest.mustInt32).
+ Uint32("uint32", &dest.uint32).
+ MustUint32("mustUint32", &dest.mustUint32).
+ Int16("int16", &dest.int16).
+ MustInt16("mustInt16", &dest.mustInt16).
+ Uint16("uint16", &dest.uint16).
+ MustUint16("mustUint16", &dest.mustUint16).
+ Int8("int8", &dest.int8).
+ MustInt8("mustInt8", &dest.mustInt8).
+ Uint8("uint8", &dest.uint8).
+ MustUint8("mustUint8", &dest.mustUint8).
+ Byte("byte", &dest.byte).
+ MustByte("mustByte", &dest.mustByte).
+ Int("int", &dest.int).
+ MustInt("mustInt", &dest.mustInt).
+ Uint("uint", &dest.uint).
+ MustUint("mustUint", &dest.mustUint).
+ BindError()
+
+ assert.NoError(t, err)
+ assert.Equal(t, int64(1), dest.int64)
+ assert.Equal(t, int64(2), dest.mustInt64)
+ assert.Equal(t, uint64(3), dest.uint64)
+ assert.Equal(t, uint64(4), dest.mustUint64)
+
+ assert.Equal(t, int32(5), dest.int32)
+ assert.Equal(t, int32(6), dest.mustInt32)
+ assert.Equal(t, uint32(7), dest.uint32)
+ assert.Equal(t, uint32(8), dest.mustUint32)
+
+ assert.Equal(t, int16(9), dest.int16)
+ assert.Equal(t, int16(10), dest.mustInt16)
+ assert.Equal(t, uint16(11), dest.uint16)
+ assert.Equal(t, uint16(12), dest.mustUint16)
+
+ assert.Equal(t, int8(13), dest.int8)
+ assert.Equal(t, int8(14), dest.mustInt8)
+ assert.Equal(t, uint8(15), dest.uint8)
+ assert.Equal(t, uint8(16), dest.mustUint8)
+
+ assert.Equal(t, uint8(17), dest.byte)
+ assert.Equal(t, uint8(18), dest.mustByte)
+
+ assert.Equal(t, 19, dest.int)
+ assert.Equal(t, 20, dest.mustInt)
+ assert.Equal(t, uint(21), dest.uint)
+ assert.Equal(t, uint(22), dest.mustUint)
+}
+
+func TestValueBinder_Int64s_intsValue(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []int64
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=1¶m=2¶m=1",
+ expectValue: []int64{1, 2, 1},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: []int64{99},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []int64{99},
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []int64{99},
+ expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1¶m=2¶m=1",
+ expectValue: []int64{1, 2, 1},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: []int64{99},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []int64{99},
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []int64{99},
+ expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := []int64{99} // when values are set with bind - contents before bind is gone
+ var err error
+ if tc.whenMust {
+ err = b.MustInt64s("param", &dest).BindError()
+ } else {
+ err = b.Int64s("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Uint64s_uintsValue(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []uint64
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=1¶m=2¶m=1",
+ expectValue: []uint64{1, 2, 1},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: []uint64{99},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []uint64{99},
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []uint64{99},
+ expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1¶m=2¶m=1",
+ expectValue: []uint64{1, 2, 1},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: []uint64{99},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []uint64{99},
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []uint64{99},
+ expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := []uint64{99} // when values are set with bind - contents before bind is gone
+ var err error
+ if tc.whenMust {
+ err = b.MustUint64s("param", &dest).BindError()
+ } else {
+ err = b.Uint64s("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Ints_Types(t *testing.T) {
+ type target struct {
+ int64 []int64
+ mustInt64 []int64
+ uint64 []uint64
+ mustUint64 []uint64
+
+ int32 []int32
+ mustInt32 []int32
+ uint32 []uint32
+ mustUint32 []uint32
+
+ int16 []int16
+ mustInt16 []int16
+ uint16 []uint16
+ mustUint16 []uint16
+
+ int8 []int8
+ mustInt8 []int8
+ uint8 []uint8
+ mustUint8 []uint8
+
+ int []int
+ mustInt []int
+ uint []uint
+ mustUint []uint
+ }
+ types := []string{
+ "int64=1",
+ "mustInt64=2",
+ "uint64=3",
+ "mustUint64=4",
+
+ "int32=5",
+ "mustInt32=6",
+ "uint32=7",
+ "mustUint32=8",
+
+ "int16=9",
+ "mustInt16=10",
+ "uint16=11",
+ "mustUint16=12",
+
+ "int8=13",
+ "mustInt8=14",
+ "uint8=15",
+ "mustUint8=16",
+
+ "int=19",
+ "mustInt=20",
+ "uint=21",
+ "mustUint=22",
+ }
+ url := "/search?"
+ for _, v := range types {
+ url = url + "&" + v + "&" + v
+ }
+ c := createTestContext(url, nil, nil)
+ b := QueryParamsBinder(c)
+
+ dest := target{}
+ err := b.
+ Int64s("int64", &dest.int64).
+ MustInt64s("mustInt64", &dest.mustInt64).
+ Uint64s("uint64", &dest.uint64).
+ MustUint64s("mustUint64", &dest.mustUint64).
+ Int32s("int32", &dest.int32).
+ MustInt32s("mustInt32", &dest.mustInt32).
+ Uint32s("uint32", &dest.uint32).
+ MustUint32s("mustUint32", &dest.mustUint32).
+ Int16s("int16", &dest.int16).
+ MustInt16s("mustInt16", &dest.mustInt16).
+ Uint16s("uint16", &dest.uint16).
+ MustUint16s("mustUint16", &dest.mustUint16).
+ Int8s("int8", &dest.int8).
+ MustInt8s("mustInt8", &dest.mustInt8).
+ Uint8s("uint8", &dest.uint8).
+ MustUint8s("mustUint8", &dest.mustUint8).
+ Ints("int", &dest.int).
+ MustInts("mustInt", &dest.mustInt).
+ Uints("uint", &dest.uint).
+ MustUints("mustUint", &dest.mustUint).
+ BindError()
+
+ assert.NoError(t, err)
+ assert.Equal(t, []int64{1, 1}, dest.int64)
+ assert.Equal(t, []int64{2, 2}, dest.mustInt64)
+ assert.Equal(t, []uint64{3, 3}, dest.uint64)
+ assert.Equal(t, []uint64{4, 4}, dest.mustUint64)
+
+ assert.Equal(t, []int32{5, 5}, dest.int32)
+ assert.Equal(t, []int32{6, 6}, dest.mustInt32)
+ assert.Equal(t, []uint32{7, 7}, dest.uint32)
+ assert.Equal(t, []uint32{8, 8}, dest.mustUint32)
+
+ assert.Equal(t, []int16{9, 9}, dest.int16)
+ assert.Equal(t, []int16{10, 10}, dest.mustInt16)
+ assert.Equal(t, []uint16{11, 11}, dest.uint16)
+ assert.Equal(t, []uint16{12, 12}, dest.mustUint16)
+
+ assert.Equal(t, []int8{13, 13}, dest.int8)
+ assert.Equal(t, []int8{14, 14}, dest.mustInt8)
+ assert.Equal(t, []uint8{15, 15}, dest.uint8)
+ assert.Equal(t, []uint8{16, 16}, dest.mustUint8)
+
+ assert.Equal(t, []int{19, 19}, dest.int)
+ assert.Equal(t, []int{20, 20}, dest.mustInt)
+ assert.Equal(t, []uint{21, 21}, dest.uint)
+ assert.Equal(t, []uint{22, 22}, dest.mustUint)
+}
+
+func TestValueBinder_Ints_Types_FailFast(t *testing.T) {
+ // FailFast() should stop parsing and return early
+ errTmpl := "code=400, message=failed to bind field value to %v, err=strconv.Parse%v: parsing \"nope\": invalid syntax, field=param"
+ c := createTestContext("/search?param=1¶m=nope¶m=2", nil, nil)
+
+ var dest64 []int64
+ err := QueryParamsBinder(c).FailFast(true).Int64s("param", &dest64).BindError()
+ assert.Equal(t, []int64(nil), dest64)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int64", "Int"))
+
+ var dest32 []int32
+ err = QueryParamsBinder(c).FailFast(true).Int32s("param", &dest32).BindError()
+ assert.Equal(t, []int32(nil), dest32)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int32", "Int"))
+
+ var dest16 []int16
+ err = QueryParamsBinder(c).FailFast(true).Int16s("param", &dest16).BindError()
+ assert.Equal(t, []int16(nil), dest16)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int16", "Int"))
+
+ var dest8 []int8
+ err = QueryParamsBinder(c).FailFast(true).Int8s("param", &dest8).BindError()
+ assert.Equal(t, []int8(nil), dest8)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int8", "Int"))
+
+ var dest []int
+ err = QueryParamsBinder(c).FailFast(true).Ints("param", &dest).BindError()
+ assert.Equal(t, []int(nil), dest)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int", "Int"))
+
+ var destu64 []uint64
+ err = QueryParamsBinder(c).FailFast(true).Uint64s("param", &destu64).BindError()
+ assert.Equal(t, []uint64(nil), destu64)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint64", "Uint"))
+
+ var destu32 []uint32
+ err = QueryParamsBinder(c).FailFast(true).Uint32s("param", &destu32).BindError()
+ assert.Equal(t, []uint32(nil), destu32)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint32", "Uint"))
+
+ var destu16 []uint16
+ err = QueryParamsBinder(c).FailFast(true).Uint16s("param", &destu16).BindError()
+ assert.Equal(t, []uint16(nil), destu16)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint16", "Uint"))
+
+ var destu8 []uint8
+ err = QueryParamsBinder(c).FailFast(true).Uint8s("param", &destu8).BindError()
+ assert.Equal(t, []uint8(nil), destu8)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint8", "Uint"))
+
+ var destu []uint
+ err = QueryParamsBinder(c).FailFast(true).Uints("param", &destu).BindError()
+ assert.Equal(t, []uint(nil), destu)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint", "Uint"))
+}
+
+func TestValueBinder_Bool(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ expectValue bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=true¶m=1",
+ expectValue: true,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: false,
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: false,
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: false,
+ expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: true,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: false,
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: false,
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: false,
+ expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := false
+ var err error
+ if tc.whenMust {
+ err = b.MustBool("param", &dest).BindError()
+ } else {
+ err = b.Bool("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Bools(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []bool
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=true¶m=false¶m=1¶m=0",
+ expectValue: []bool{true, false, true, false},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: []bool(nil),
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []bool(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=true¶m=nope¶m=100",
+ expectValue: []bool(nil),
+ expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "nok, conversion fails fast, value is not changed",
+ givenFailFast: true,
+ whenURL: "/search?param=true¶m=nope¶m=100",
+ expectValue: []bool(nil),
+ expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=true¶m=false¶m=1¶m=0",
+ expectValue: []bool{true, false, true, false},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: []bool(nil),
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []bool(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []bool(nil),
+ expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ b.errors = tc.givenBindErrors
+
+ var dest []bool
+ var err error
+ if tc.whenMust {
+ err = b.MustBools("param", &dest).BindError()
+ } else {
+ err = b.Bools("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Float64(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue float64
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=4.3¶m=1",
+ expectValue: 4.3,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: 1.123,
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 1.123,
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 1.123,
+ expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=4.3¶m=100",
+ expectValue: 4.3,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: 1.123,
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 1.123,
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 1.123,
+ expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := 1.123
+ var err error
+ if tc.whenMust {
+ err = b.MustFloat64("param", &dest).BindError()
+ } else {
+ err = b.Float64("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Float64s(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []float64
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=4.3¶m=0",
+ expectValue: []float64{4.3, 0},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: []float64(nil),
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []float64(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []float64(nil),
+ expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "nok, conversion fails fast, value is not changed",
+ givenFailFast: true,
+ whenURL: "/search?param=0¶m=nope¶m=100",
+ expectValue: []float64(nil),
+ expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=4.3¶m=0",
+ expectValue: []float64{4.3, 0},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: []float64(nil),
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []float64(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []float64(nil),
+ expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ b.errors = tc.givenBindErrors
+
+ var dest []float64
+ var err error
+ if tc.whenMust {
+ err = b.MustFloat64s("param", &dest).BindError()
+ } else {
+ err = b.Float64s("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Float32(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue float32
+ givenNoFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=4.3¶m=1",
+ expectValue: 4.3,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: 1.123,
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenNoFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 1.123,
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 1.123,
+ expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=4.3¶m=100",
+ expectValue: 4.3,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: 1.123,
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenNoFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 1.123,
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 1.123,
+ expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenNoFailFast)
+ if tc.givenNoFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := float32(1.123)
+ var err error
+ if tc.whenMust {
+ err = b.MustFloat32("param", &dest).BindError()
+ } else {
+ err = b.Float32("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Float32s(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []float32
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=4.3¶m=0",
+ expectValue: []float32{4.3, 0},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: []float32(nil),
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []float32(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []float32(nil),
+ expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "nok, conversion fails fast, value is not changed",
+ givenFailFast: true,
+ whenURL: "/search?param=0¶m=nope¶m=100",
+ expectValue: []float32(nil),
+ expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=4.3¶m=0",
+ expectValue: []float32{4.3, 0},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: []float32(nil),
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []float32(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []float32(nil),
+ expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ b.errors = tc.givenBindErrors
+
+ var dest []float32
+ var err error
+ if tc.whenMust {
+ err = b.MustFloat32s("param", &dest).BindError()
+ } else {
+ err = b.Float32s("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Time(t *testing.T) {
+ exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00")
+ var testCases = []struct {
+ expectValue time.Time
+ name string
+ whenURL string
+ whenLayout string
+ expectError string
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
+ whenLayout: time.RFC3339,
+ expectValue: exampleTime,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: time.Time{},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: time.Time{},
+ expectError: "previous error",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
+ whenLayout: time.RFC3339,
+ expectValue: exampleTime,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: time.Time{},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: time.Time{},
+ expectError: "previous error",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := time.Time{}
+ var err error
+ if tc.whenMust {
+ err = b.MustTime("param", &dest, tc.whenLayout).BindError()
+ } else {
+ err = b.Time("param", &dest, tc.whenLayout).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Times(t *testing.T) {
+ exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00")
+ exampleTime2, _ := time.Parse(time.RFC3339, "2000-01-02T09:45:31+00:00")
+ var testCases = []struct {
+ name string
+ whenURL string
+ whenLayout string
+ expectError string
+ givenBindErrors []error
+ expectValue []time.Time
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
+ whenLayout: time.RFC3339,
+ expectValue: []time.Time{exampleTime, exampleTime2},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: []time.Time(nil),
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []time.Time(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
+ whenLayout: time.RFC3339,
+ expectValue: []time.Time{exampleTime, exampleTime2},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: []time.Time(nil),
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []time.Time(nil),
+ expectError: "previous error",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ b.errors = tc.givenBindErrors
+
+ layout := time.RFC3339
+ if tc.whenLayout != "" {
+ layout = tc.whenLayout
+ }
+
+ var dest []time.Time
+ var err error
+ if tc.whenMust {
+ err = b.MustTimes("param", &dest, layout).BindError()
+ } else {
+ err = b.Times("param", &dest, layout).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Duration(t *testing.T) {
+ example := 42 * time.Second
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue time.Duration
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=42s¶m=1ms",
+ expectValue: example,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: 0,
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 0,
+ expectError: "previous error",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=42s¶m=1ms",
+ expectValue: example,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: 0,
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 0,
+ expectError: "previous error",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ var dest time.Duration
+ var err error
+ if tc.whenMust {
+ err = b.MustDuration("param", &dest).BindError()
+ } else {
+ err = b.Duration("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Durations(t *testing.T) {
+ exampleDuration := 42 * time.Second
+ exampleDuration2 := 1 * time.Millisecond
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []time.Duration
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=42s¶m=1ms",
+ expectValue: []time.Duration{exampleDuration, exampleDuration2},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: []time.Duration(nil),
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []time.Duration(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=42s¶m=1ms",
+ expectValue: []time.Duration{exampleDuration, exampleDuration2},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: []time.Duration(nil),
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []time.Duration(nil),
+ expectError: "previous error",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ b.errors = tc.givenBindErrors
+
+ var dest []time.Duration
+ var err error
+ if tc.whenMust {
+ err = b.MustDurations("param", &dest).BindError()
+ } else {
+ err = b.Durations("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_BindUnmarshaler(t *testing.T) {
+ exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00")
+
+ var testCases = []struct {
+ expectValue Timestamp
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
+ expectValue: Timestamp(exampleTime),
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: Timestamp{},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: Timestamp{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: Timestamp{},
+ expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
+ expectValue: Timestamp(exampleTime),
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: Timestamp{},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: Timestamp{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: Timestamp{},
+ expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ var dest Timestamp
+ var err error
+ if tc.whenMust {
+ err = b.MustBindUnmarshaler("param", &dest).BindError()
+ } else {
+ err = b.BindUnmarshaler("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_JSONUnmarshaler(t *testing.T) {
+ example := big.NewInt(999)
+
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ expectValue big.Int
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=999¶m=998",
+ expectValue: *example,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: big.Int{},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: big.Int{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=xxx",
+ expectValue: big.Int{},
+ expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=999¶m=998",
+ expectValue: *example,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: big.Int{},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=xxx",
+ expectValue: big.Int{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=xxx",
+ expectValue: big.Int{},
+ expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ var dest big.Int
+ var err error
+ if tc.whenMust {
+ err = b.MustJSONUnmarshaler("param", &dest).BindError()
+ } else {
+ err = b.JSONUnmarshaler("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_TextUnmarshaler(t *testing.T) {
+ example := big.NewInt(999)
+
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ expectValue big.Int
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=999¶m=998",
+ expectValue: *example,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: big.Int{},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: big.Int{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=xxx",
+ expectValue: big.Int{},
+ expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=999¶m=998",
+ expectValue: *example,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: big.Int{},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=xxx",
+ expectValue: big.Int{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=xxx",
+ expectValue: big.Int{},
+ expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ var dest big.Int
+ var err error
+ if tc.whenMust {
+ err = b.MustTextUnmarshaler("param", &dest).BindError()
+ } else {
+ err = b.TextUnmarshaler("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_BindWithDelimiter_types(t *testing.T) {
+ var testCases = []struct {
+ expect any
+ name string
+ whenURL string
+ }{
+ {
+ name: "ok, strings",
+ expect: []string{"1", "2", "1"},
+ },
+ {
+ name: "ok, int64",
+ expect: []int64{1, 2, 1},
+ },
+ {
+ name: "ok, int32",
+ expect: []int32{1, 2, 1},
+ },
+ {
+ name: "ok, int16",
+ expect: []int16{1, 2, 1},
+ },
+ {
+ name: "ok, int8",
+ expect: []int8{1, 2, 1},
+ },
+ {
+ name: "ok, int",
+ expect: []int{1, 2, 1},
+ },
+ {
+ name: "ok, uint64",
+ expect: []uint64{1, 2, 1},
+ },
+ {
+ name: "ok, uint32",
+ expect: []uint32{1, 2, 1},
+ },
+ {
+ name: "ok, uint16",
+ expect: []uint16{1, 2, 1},
+ },
+ {
+ name: "ok, uint8",
+ expect: []uint8{1, 2, 1},
+ },
+ {
+ name: "ok, uint",
+ expect: []uint{1, 2, 1},
+ },
+ {
+ name: "ok, float64",
+ expect: []float64{1, 2, 1},
+ },
+ {
+ name: "ok, float32",
+ expect: []float32{1, 2, 1},
+ },
+ {
+ name: "ok, bool",
+ whenURL: "/search?param=1,false¶m=true",
+ expect: []bool{true, false, true},
+ },
+ {
+ name: "ok, Duration",
+ whenURL: "/search?param=1s,42s¶m=1ms",
+ expect: []time.Duration{1 * time.Second, 42 * time.Second, 1 * time.Millisecond},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ URL := "/search?param=1,2¶m=1"
+ if tc.whenURL != "" {
+ URL = tc.whenURL
+ }
+ c := createTestContext(URL, nil, nil)
+ b := QueryParamsBinder(c)
+
+ switch tc.expect.(type) {
+ case []string:
+ var dest []string
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []int64:
+ var dest []int64
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []int32:
+ var dest []int32
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []int16:
+ var dest []int16
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []int8:
+ var dest []int8
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []int:
+ var dest []int
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []uint64:
+ var dest []uint64
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []uint32:
+ var dest []uint32
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []uint16:
+ var dest []uint16
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []uint8:
+ var dest []uint8
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []uint:
+ var dest []uint
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []float64:
+ var dest []float64
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []float32:
+ var dest []float32
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []bool:
+ var dest []bool
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []time.Duration:
+ var dest []time.Duration
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ default:
+ assert.Fail(t, "invalid type")
+ }
+ })
+ }
+}
+
+func TestValueBinder_BindWithDelimiter(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []int64
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=1,2¶m=1",
+ expectValue: []int64{1, 2, 1},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: []int64(nil),
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []int64(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []int64(nil),
+ expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1,2¶m=1",
+ expectValue: []int64{1, 2, 1},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: []int64(nil),
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []int64(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []int64(nil),
+ expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ var dest []int64
+ var err error
+ if tc.whenMust {
+ err = b.MustBindWithDelimiter("param", &dest, ",").BindError()
+ } else {
+ err = b.BindWithDelimiter("param", &dest, ",").BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestBindWithDelimiter_invalidType(t *testing.T) {
+ c := createTestContext("/search?param=1¶m=100", nil, nil)
+ b := QueryParamsBinder(c)
+
+ var dest []BindUnmarshaler
+ err := b.BindWithDelimiter("param", &dest, ",").BindError()
+ assert.Equal(t, []BindUnmarshaler(nil), dest)
+ assert.EqualError(t, err, "code=400, message=unsupported bind type, field=param")
+}
+
+func TestValueBinder_UnixTime(t *testing.T) {
+ exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43+00:00") // => 1609180603
+ var testCases = []struct {
+ expectValue time.Time
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value, unix time in seconds",
+ whenURL: "/search?param=1609180603¶m=1609180604",
+ expectValue: exampleTime,
+ },
+ {
+ name: "ok, binds value, unix time over int32 value",
+ whenURL: "/search?param=2147483648¶m=1609180604",
+ expectValue: time.Unix(2147483648, 0),
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: time.Time{},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: time.Time{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: time.Time{},
+ expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1609180603¶m=1609180604",
+ expectValue: exampleTime,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: time.Time{},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: time.Time{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: time.Time{},
+ expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := time.Time{}
+ var err error
+ if tc.whenMust {
+ err = b.MustUnixTime("param", &dest).BindError()
+ } else {
+ err = b.UnixTime("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano())
+ assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC))
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_UnixTimeMilli(t *testing.T) {
+ exampleTime, _ := time.Parse(time.RFC3339Nano, "2022-03-13T15:13:30.140000000+00:00") // => 1647184410140
+ var testCases = []struct {
+ expectValue time.Time
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value, unix time in milliseconds",
+ whenURL: "/search?param=1647184410140¶m=1647184410199",
+ expectValue: exampleTime,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: time.Time{},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: time.Time{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: time.Time{},
+ expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1647184410140¶m=1647184410199",
+ expectValue: exampleTime,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: time.Time{},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: time.Time{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: time.Time{},
+ expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := time.Time{}
+ var err error
+ if tc.whenMust {
+ err = b.MustUnixTimeMilli("param", &dest).BindError()
+ } else {
+ err = b.UnixTimeMilli("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano())
+ assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC))
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_UnixTimeNano(t *testing.T) {
+ exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43.000000000+00:00") // => 1609180603
+ exampleTimeNano, _ := time.Parse(time.RFC3339Nano, "2020-12-28T18:36:43.123456789+00:00") // => 1609180603123456789
+ exampleTimeNanoBelowSec, _ := time.Parse(time.RFC3339Nano, "1970-01-01T00:00:00.999999999+00:00")
+ var testCases = []struct {
+ expectValue time.Time
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value, unix time in nano seconds (sec precision)",
+ whenURL: "/search?param=1609180603000000000¶m=1609180604",
+ expectValue: exampleTime,
+ },
+ {
+ name: "ok, binds value, unix time in nano seconds",
+ whenURL: "/search?param=1609180603123456789¶m=1609180604",
+ expectValue: exampleTimeNano,
+ },
+ {
+ name: "ok, binds value, unix time in nano seconds (below 1 sec)",
+ whenURL: "/search?param=999999999¶m=1609180604",
+ expectValue: exampleTimeNanoBelowSec,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: time.Time{},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: time.Time{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: time.Time{},
+ expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1609180603000000000¶m=1609180604",
+ expectValue: exampleTime,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: time.Time{},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: time.Time{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: time.Time{},
+ expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := time.Time{}
+ var err error
+ if tc.whenMust {
+ err = b.MustUnixTimeNano("param", &dest).BindError()
+ } else {
+ err = b.UnixTimeNano("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano())
+ assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC))
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) {
+ type Opts struct {
+ Param int64 `query:"param"`
+ }
+ c := createTestContext("/search?param=1¶m=100", nil, nil)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ binder := new(DefaultBinder)
+ for i := 0; i < b.N; i++ {
+ var dest Opts
+ _ = binder.Bind(c, &dest)
+ }
+}
+
+func BenchmarkValueBinder_BindInt64_single(b *testing.B) {
+ c := createTestContext("/search?param=1¶m=100", nil, nil)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ type Opts struct {
+ Param int64
+ }
+ binder := QueryParamsBinder(c)
+ for i := 0; i < b.N; i++ {
+ var dest Opts
+ _ = binder.Int64("param", &dest.Param).BindError()
+ }
+}
+
+func BenchmarkRawFunc_Int64_single(b *testing.B) {
+ c := createTestContext("/search?param=1¶m=100", nil, nil)
+
+ rawFunc := func(input string, defaultValue int64) (int64, bool) {
+ if input == "" {
+ return defaultValue, true
+ }
+ n, err := strconv.Atoi(input)
+ if err != nil {
+ return 0, false
+ }
+ return int64(n), true
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ type Opts struct {
+ Param int64
+ }
+ for i := 0; i < b.N; i++ {
+ var dest Opts
+ if n, ok := rawFunc(c.QueryParam("param"), 1); ok {
+ dest.Param = n
+ }
+ }
+}
+
+func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) {
+ type Opts struct {
+ String string `query:"string"`
+ Strings []string `query:"strings"`
+ Int64 int64 `query:"int64"`
+ Uint64 uint64 `query:"uint64"`
+ Int32 int32 `query:"int32"`
+ Uint32 uint32 `query:"uint32"`
+ Int16 int16 `query:"int16"`
+ Uint16 uint16 `query:"uint16"`
+ Int8 int8 `query:"int8"`
+ Uint8 uint8 `query:"uint8"`
+ }
+ c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ binder := new(DefaultBinder)
+ for i := 0; i < b.N; i++ {
+ var dest Opts
+ _ = binder.Bind(c, &dest)
+ if dest.Int64 != 1 {
+ b.Fatalf("int64!=1")
+ }
+ }
+}
+
+func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) {
+ type Opts struct {
+ String string `query:"string"`
+ Strings []string `query:"strings"`
+ Int64 int64 `query:"int64"`
+ Uint64 uint64 `query:"uint64"`
+ Int32 int32 `query:"int32"`
+ Uint32 uint32 `query:"uint32"`
+ Int16 int16 `query:"int16"`
+ Uint16 uint16 `query:"uint16"`
+ Int8 int8 `query:"int8"`
+ Uint8 uint8 `query:"uint8"`
+ }
+ c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ binder := QueryParamsBinder(c)
+ for i := 0; i < b.N; i++ {
+ var dest Opts
+ _ = binder.
+ Int64("int64", &dest.Int64).
+ Int32("int32", &dest.Int32).
+ Int16("int16", &dest.Int16).
+ Int8("int8", &dest.Int8).
+ String("string", &dest.String).
+ Uint64("int64", &dest.Uint64).
+ Uint32("int32", &dest.Uint32).
+ Uint16("int16", &dest.Uint16).
+ Uint8("int8", &dest.Uint8).
+ Strings("strings", &dest.Strings).
+ BindError()
+ if dest.Int64 != 1 {
+ b.Fatalf("int64!=1")
+ }
+ }
+}
+
+func TestValueBinder_TimeError(t *testing.T) {
+ var testCases = []struct {
+ expectValue time.Time
+ name string
+ whenURL string
+ whenLayout string
+ expectError string
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: time.Time{},
+ expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\": extra text: \"nope\", field=param",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: time.Time{},
+ expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\": extra text: \"nope\", field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := time.Time{}
+ var err error
+ if tc.whenMust {
+ err = b.MustTime("param", &dest, tc.whenLayout).BindError()
+ } else {
+ err = b.Time("param", &dest, tc.whenLayout).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_TimesError(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ whenLayout string
+ expectError string
+ givenBindErrors []error
+ expectValue []time.Time
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "nok, fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []time.Time(nil),
+ expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []time.Time(nil),
+ expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []time.Time(nil),
+ expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ b.errors = tc.givenBindErrors
+
+ layout := time.RFC3339
+ if tc.whenLayout != "" {
+ layout = tc.whenLayout
+ }
+
+ var dest []time.Time
+ var err error
+ if tc.whenMust {
+ err = b.MustTimes("param", &dest, layout).BindError()
+ } else {
+ err = b.Times("param", &dest, layout).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_DurationError(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue time.Duration
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 0,
+ expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 0,
+ expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ var dest time.Duration
+ var err error
+ if tc.whenMust {
+ err = b.MustDuration("param", &dest).BindError()
+ } else {
+ err = b.Duration("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_DurationsError(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []time.Duration
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "nok, fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []time.Duration(nil),
+ expectError: "code=400, message=failed to bind field value to Duration, err=time: missing unit in duration \"1\", field=param",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []time.Duration(nil),
+ expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []time.Duration(nil),
+ expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ b.errors = tc.givenBindErrors
+
+ var dest []time.Duration
+ var err error
+ if tc.whenMust {
+ err = b.MustDurations("param", &dest).BindError()
+ } else {
+ err = b.Durations("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
diff --git a/codecov.yml b/codecov.yml
new file mode 100644
index 000000000..0fa3a3f18
--- /dev/null
+++ b/codecov.yml
@@ -0,0 +1,11 @@
+coverage:
+ status:
+ project:
+ default:
+ threshold: 1%
+ patch:
+ default:
+ threshold: 1%
+
+comment:
+ require_changes: true
\ No newline at end of file
diff --git a/context.go b/context.go
index 27da5ffe3..f91ea7a60 100644
--- a/context.go
+++ b/context.go
@@ -1,254 +1,165 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package echo
import (
"bytes"
- "encoding/json"
"encoding/xml"
+ "errors"
"fmt"
"io"
+ "io/fs"
+ "log/slog"
"mime/multipart"
"net"
"net/http"
"net/url"
- "os"
+ "path"
"path/filepath"
"strings"
"sync"
)
-type (
- // Context represents the context of the current HTTP request. It holds request and
- // response objects, path, path parameters, data and registered handler.
- Context interface {
- // Request returns `*http.Request`.
- Request() *http.Request
-
- // SetRequest sets `*http.Request`.
- SetRequest(r *http.Request)
-
- // SetResponse sets `*Response`.
- SetResponse(r *Response)
-
- // Response returns `*Response`.
- Response() *Response
-
- // IsTLS returns true if HTTP connection is TLS otherwise false.
- IsTLS() bool
-
- // IsWebSocket returns true if HTTP connection is WebSocket otherwise false.
- IsWebSocket() bool
-
- // Scheme returns the HTTP protocol scheme, `http` or `https`.
- Scheme() string
-
- // RealIP returns the client's network address based on `X-Forwarded-For`
- // or `X-Real-IP` request header.
- RealIP() string
-
- // Path returns the registered path for the handler.
- Path() string
-
- // SetPath sets the registered path for the handler.
- SetPath(p string)
-
- // Param returns path parameter by name.
- Param(name string) string
-
- // ParamNames returns path parameter names.
- ParamNames() []string
-
- // SetParamNames sets path parameter names.
- SetParamNames(names ...string)
-
- // ParamValues returns path parameter values.
- ParamValues() []string
-
- // SetParamValues sets path parameter values.
- SetParamValues(values ...string)
-
- // QueryParam returns the query param for the provided name.
- QueryParam(name string) string
-
- // QueryParams returns the query parameters as `url.Values`.
- QueryParams() url.Values
-
- // QueryString returns the URL query string.
- QueryString() string
-
- // FormValue returns the form field value for the provided name.
- FormValue(name string) string
-
- // FormParams returns the form parameters as `url.Values`.
- FormParams() (url.Values, error)
-
- // FormFile returns the multipart form file for the provided name.
- FormFile(name string) (*multipart.FileHeader, error)
-
- // MultipartForm returns the multipart form.
- MultipartForm() (*multipart.Form, error)
-
- // Cookie returns the named cookie provided in the request.
- Cookie(name string) (*http.Cookie, error)
-
- // SetCookie adds a `Set-Cookie` header in HTTP response.
- SetCookie(cookie *http.Cookie)
-
- // Cookies returns the HTTP cookies sent with the request.
- Cookies() []*http.Cookie
-
- // Get retrieves data from the context.
- Get(key string) interface{}
-
- // Set saves data in the context.
- Set(key string, val interface{})
-
- // Bind binds the request body into provided type `i`. The default binder
- // does it based on Content-Type header.
- Bind(i interface{}) error
-
- // Validate validates provided `i`. It is usually called after `Context#Bind()`.
- // Validator must be registered using `Echo#Validator`.
- Validate(i interface{}) error
-
- // Render renders a template with data and sends a text/html response with status
- // code. Renderer must be registered using `Echo.Renderer`.
- Render(code int, name string, data interface{}) error
-
- // HTML sends an HTTP response with status code.
- HTML(code int, html string) error
-
- // HTMLBlob sends an HTTP blob response with status code.
- HTMLBlob(code int, b []byte) error
-
- // String sends a string response with status code.
- String(code int, s string) error
-
- // JSON sends a JSON response with status code.
- JSON(code int, i interface{}) error
-
- // JSONPretty sends a pretty-print JSON with status code.
- JSONPretty(code int, i interface{}, indent string) error
-
- // JSONBlob sends a JSON blob response with status code.
- JSONBlob(code int, b []byte) error
-
- // JSONP sends a JSONP response with status code. It uses `callback` to construct
- // the JSONP payload.
- JSONP(code int, callback string, i interface{}) error
-
- // JSONPBlob sends a JSONP blob response with status code. It uses `callback`
- // to construct the JSONP payload.
- JSONPBlob(code int, callback string, b []byte) error
-
- // XML sends an XML response with status code.
- XML(code int, i interface{}) error
-
- // XMLPretty sends a pretty-print XML with status code.
- XMLPretty(code int, i interface{}, indent string) error
-
- // XMLBlob sends an XML blob response with status code.
- XMLBlob(code int, b []byte) error
-
- // Blob sends a blob response with status code and content type.
- Blob(code int, contentType string, b []byte) error
-
- // Stream sends a streaming response with status code and content type.
- Stream(code int, contentType string, r io.Reader) error
-
- // File sends a response with the content of the file.
- File(file string) error
-
- // Attachment sends a response as attachment, prompting client to save the
- // file.
- Attachment(file string, name string) error
-
- // Inline sends a response as inline, opening the file in the browser.
- Inline(file string, name string) error
-
- // NoContent sends a response with no body and a status code.
- NoContent(code int) error
-
- // Redirect redirects the request to a provided URL with status code.
- Redirect(code int, url string) error
-
- // Error invokes the registered HTTP error handler. Generally used by middleware.
- Error(err error)
-
- // Handler returns the matched handler by router.
- Handler() HandlerFunc
-
- // SetHandler sets the matched handler by router.
- SetHandler(h HandlerFunc)
-
- // Logger returns the `Logger` instance.
- Logger() Logger
-
- // Set the logger
- SetLogger(l Logger)
+const (
+ // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain.
+ // Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests.
+ // It is added to context only when Router does not find matching method handler for request.
+ ContextKeyHeaderAllow = "echo_header_allow"
+)
- // Echo returns the `Echo` instance.
- Echo() *Echo
+const (
+ // defaultMemory is default value for memory limit that is used when
+ // parsing multipart forms (See (*http.Request).ParseMultipartForm)
+ defaultMemory int64 = 32 << 20 // 32 MB
+ indexPage = "index.html"
+)
- // Reset resets the context after request completes. It must be called along
- // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
- // See `Echo#ServeHTTP()`
- Reset(r *http.Request, w http.ResponseWriter)
+// Context represents the context of the current HTTP request. It holds request and
+// response objects, path, path parameters, data and registered handler.
+type Context struct {
+ request *http.Request
+ orgResponse *Response
+ response http.ResponseWriter
+ query url.Values
+
+ // formParseMaxMemory is used for http.Request.ParseMultipartForm
+ formParseMaxMemory int64
+
+ route *RouteInfo
+ pathValues *PathValues
+
+ store map[string]any
+ echo *Echo
+ logger *slog.Logger
+
+ path string
+ lock sync.RWMutex
+}
+
+// NewContext returns a new Context instance.
+//
+// Note: request,response and e can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway
+// these arguments are useful when creating context for tests and cases like that.
+func NewContext(r *http.Request, w http.ResponseWriter, opts ...any) *Context {
+ var e *Echo
+ for _, opt := range opts {
+ switch v := opt.(type) {
+ case *Echo:
+ e = v
+ }
}
+ return newContext(r, w, e)
+}
- context struct {
- request *http.Request
- response *Response
- path string
- pnames []string
- pvalues []string
- query url.Values
- handler HandlerFunc
- store Map
- echo *Echo
- logger Logger
- lock sync.RWMutex
+func newContext(r *http.Request, w http.ResponseWriter, e *Echo) *Context {
+ c := &Context{
+ pathValues: nil,
+ store: make(map[string]any),
+ echo: e,
+ logger: nil,
}
-)
+ var logger *slog.Logger
+ paramLen := int32(0)
+ formParseMaxMemory := defaultMemory
+ if e != nil {
+ paramLen = e.contextPathParamAllocSize.Load()
+ logger = e.Logger
+ formParseMaxMemory = e.formParseMaxMemory
+ }
+ if logger == nil {
+ logger = slog.Default()
+ }
+ c.logger = logger
+ p := make(PathValues, 0, paramLen)
+ c.pathValues = &p
-const (
- defaultMemory = 32 << 20 // 32 MB
- indexPage = "index.html"
- defaultIndent = " "
-)
+ c.SetRequest(r)
+ c.orgResponse = NewResponse(w, logger)
+ c.response = c.orgResponse
+ c.formParseMaxMemory = formParseMaxMemory
+ return c
+}
+
+// Reset resets the context after request completes. It must be called along
+// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
+// See `Echo#ServeHTTP()`
+func (c *Context) Reset(r *http.Request, w http.ResponseWriter) {
+ c.request = r
+ c.orgResponse.reset(w)
+ c.response = c.orgResponse
+ c.query = nil
+ c.store = nil
+ c.logger = c.echo.Logger
+
+ c.route = nil
+ c.path = ""
+ // NOTE: empty by setting length to 0. PathValues has to have capacity of c.echo.contextPathParamAllocSize at all times
+ *c.pathValues = (*c.pathValues)[:0]
+}
-func (c *context) writeContentType(value string) {
- header := c.Response().Header()
+func (c *Context) writeContentType(value string) {
+ header := c.response.Header()
if header.Get(HeaderContentType) == "" {
header.Set(HeaderContentType, value)
}
}
-func (c *context) Request() *http.Request {
+// Request returns `*http.Request`.
+func (c *Context) Request() *http.Request {
return c.request
}
-func (c *context) SetRequest(r *http.Request) {
+// SetRequest sets `*http.Request`.
+func (c *Context) SetRequest(r *http.Request) {
c.request = r
}
-func (c *context) Response() *Response {
+// Response returns `*Response`.
+func (c *Context) Response() http.ResponseWriter {
return c.response
}
-func (c *context) SetResponse(r *Response) {
+// SetResponse sets `*http.ResponseWriter`. Some context methods and/or middleware require that given ResponseWriter implements following
+// method `Unwrap() http.ResponseWriter` which eventually should return *echo.Response instance.
+func (c *Context) SetResponse(r http.ResponseWriter) {
c.response = r
}
-func (c *context) IsTLS() bool {
+// IsTLS returns true if HTTP connection is TLS otherwise false.
+func (c *Context) IsTLS() bool {
return c.request.TLS != nil
}
-func (c *context) IsWebSocket() bool {
+// IsWebSocket returns true if HTTP connection is WebSocket otherwise false.
+func (c *Context) IsWebSocket() bool {
upgrade := c.request.Header.Get(HeaderUpgrade)
- return strings.ToLower(upgrade) == "websocket"
+ connection := c.request.Header.Get(HeaderConnection)
+ return strings.EqualFold(upgrade, "websocket") && strings.Contains(strings.ToLower(connection), "upgrade")
}
-func (c *context) Scheme() string {
+// Scheme returns the HTTP protocol scheme, `http` or `https`.
+func (c *Context) Scheme() string {
// Can't use `r.Request.URL.Scheme`
// See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0
if c.IsTLS() {
@@ -269,77 +180,161 @@ func (c *context) Scheme() string {
return "http"
}
-func (c *context) RealIP() string {
+// RealIP returns the client's network address based on `X-Forwarded-For`
+// or `X-Real-IP` request header.
+// The behavior can be configured using `Echo#IPExtractor`.
+func (c *Context) RealIP() string {
+ if c.echo != nil && c.echo.IPExtractor != nil {
+ return c.echo.IPExtractor(c.request)
+ }
+ // Fall back to legacy behavior
if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" {
- return strings.Split(ip, ", ")[0]
+ i := strings.IndexAny(ip, ",")
+ if i > 0 {
+ xffip := strings.TrimSpace(ip[:i])
+ xffip = strings.TrimPrefix(xffip, "[")
+ xffip = strings.TrimSuffix(xffip, "]")
+ return xffip
+ }
+ return ip
}
if ip := c.request.Header.Get(HeaderXRealIP); ip != "" {
+ ip = strings.TrimPrefix(ip, "[")
+ ip = strings.TrimSuffix(ip, "]")
return ip
}
ra, _, _ := net.SplitHostPort(c.request.RemoteAddr)
return ra
}
-func (c *context) Path() string {
+// Path returns the registered path for the handler.
+func (c *Context) Path() string {
return c.path
}
-func (c *context) SetPath(p string) {
+// SetPath sets the registered path for the handler.
+func (c *Context) SetPath(p string) {
c.path = p
}
-func (c *context) Param(name string) string {
- for i, n := range c.pnames {
- if i < len(c.pvalues) {
- if n == name {
- return c.pvalues[i]
- }
- }
+// RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route.
+//
+// RouteInfo returns generic "empty" struct for these cases:
+// * Context is accessed before Routing is done. For example inside Pre middlewares (`e.Pre()`)
+// * Router did not find matching route - 404 (route not found)
+// * Router did not find matching route with same method - 405 (method not allowed)
+func (c *Context) RouteInfo() RouteInfo {
+ if c.route != nil {
+ return c.route.Clone()
}
- return ""
+ return RouteInfo{}
+}
+
+// Param returns path parameter by name.
+func (c *Context) Param(name string) string {
+ return c.pathValues.GetOr(name, "")
}
-func (c *context) ParamNames() []string {
- return c.pnames
+// ParamOr returns the path parameter or default value for the provided name.
+//
+// Notes for DefaultRouter implementation:
+// Path parameter could be empty for cases like that:
+// * route `/release-:version/bin` and request URL is `/release-/bin`
+// * route `/api/:version/image.jpg` and request URL is `/api//image.jpg`
+// but not when path parameter is last part of route path
+// * route `/download/file.:ext` will not match request `/download/file.`
+func (c *Context) ParamOr(name, defaultValue string) string {
+ return c.pathValues.GetOr(name, defaultValue)
}
-func (c *context) SetParamNames(names ...string) {
- c.pnames = names
+// PathValues returns path parameter values.
+func (c *Context) PathValues() PathValues {
+ return *c.pathValues
}
-func (c *context) ParamValues() []string {
- return c.pvalues[:len(c.pnames)]
+// SetPathValues sets path parameters for current request.
+func (c *Context) SetPathValues(pathValues PathValues) {
+ if pathValues == nil {
+ panic("context SetPathValues called with nil PathValues")
+ }
+ c.setPathValues(&pathValues)
}
-func (c *context) SetParamValues(values ...string) {
- c.pvalues = values
+// InitializeRoute sets the route related variables of this request to the context.
+func (c *Context) InitializeRoute(ri *RouteInfo, pathValues *PathValues) {
+ c.route = ri
+ c.path = ri.Path
+ c.setPathValues(pathValues)
}
-func (c *context) QueryParam(name string) string {
+func (c *Context) setPathValues(pv *PathValues) {
+ // Router accesses c.pathValues by index and may resize it to full capacity during routing
+ // for that to work without going out-of-bounds we must make sure that c.pathValues slice is not replaced with smaller
+ // slice than Router can set when routing Route with maximum amount of parameters.
+ pathValues := c.pathValues
+ if cap(*c.pathValues) < len(*pv) {
+ // normally we should not end up here. pathValues is normally sized to Echo.contextPathParamAllocSize which should not
+ // be smaller than anything router knows as maximum path parameter count to be.
+ tmp := make(PathValues, len(*pv))
+ c.pathValues = &tmp
+ pathValues = c.pathValues
+ } else if len(*c.pathValues) != len(*pv) {
+ *pathValues = (*pathValues)[0:len(*pv)] // resize slice to given params length for copy to work
+ }
+ copy(*pathValues, *pv)
+}
+
+// QueryParam returns the query param for the provided name.
+func (c *Context) QueryParam(name string) string {
if c.query == nil {
c.query = c.request.URL.Query()
}
return c.query.Get(name)
}
-func (c *context) QueryParams() url.Values {
+// QueryParamOr returns the query param or default value for the provided name.
+// Note: QueryParamOr does not distinguish if query had no value by that name or value was empty string
+// This means URLs `/test?search=` and `/test` would both return `1` for `c.QueryParamOr("search", "1")`
+func (c *Context) QueryParamOr(name, defaultValue string) string {
+ value := c.QueryParam(name)
+ if value == "" {
+ value = defaultValue
+ }
+ return value
+}
+
+// QueryParams returns the query parameters as `url.Values`.
+func (c *Context) QueryParams() url.Values {
if c.query == nil {
c.query = c.request.URL.Query()
}
return c.query
}
-func (c *context) QueryString() string {
+// QueryString returns the URL query string.
+func (c *Context) QueryString() string {
return c.request.URL.RawQuery
}
-func (c *context) FormValue(name string) string {
+// FormValue returns the form field value for the provided name.
+func (c *Context) FormValue(name string) string {
return c.request.FormValue(name)
}
-func (c *context) FormParams() (url.Values, error) {
+// FormValueOr returns the form field value or default value for the provided name.
+// Note: FormValueOr does not distinguish if form had no value by that name or value was empty string
+func (c *Context) FormValueOr(name, defaultValue string) string {
+ value := c.FormValue(name)
+ if value == "" {
+ value = defaultValue
+ }
+ return value
+}
+
+// FormValues returns the form field values as `url.Values`.
+func (c *Context) FormValues() (url.Values, error) {
if strings.HasPrefix(c.request.Header.Get(HeaderContentType), MIMEMultipartForm) {
- if err := c.request.ParseMultipartForm(defaultMemory); err != nil {
+ if err := c.request.ParseMultipartForm(c.formParseMaxMemory); err != nil {
return nil, err
}
} else {
@@ -350,91 +345,115 @@ func (c *context) FormParams() (url.Values, error) {
return c.request.Form, nil
}
-func (c *context) FormFile(name string) (*multipart.FileHeader, error) {
+// FormFile returns the multipart form file for the provided name.
+func (c *Context) FormFile(name string) (*multipart.FileHeader, error) {
f, fh, err := c.request.FormFile(name)
- defer f.Close()
- return fh, err
+ if err != nil {
+ return nil, err
+ }
+ _ = f.Close()
+ return fh, nil
}
-func (c *context) MultipartForm() (*multipart.Form, error) {
- err := c.request.ParseMultipartForm(defaultMemory)
+// MultipartForm returns the multipart form.
+func (c *Context) MultipartForm() (*multipart.Form, error) {
+ err := c.request.ParseMultipartForm(c.formParseMaxMemory)
return c.request.MultipartForm, err
}
-func (c *context) Cookie(name string) (*http.Cookie, error) {
+// Cookie returns the named cookie provided in the request.
+func (c *Context) Cookie(name string) (*http.Cookie, error) {
return c.request.Cookie(name)
}
-func (c *context) SetCookie(cookie *http.Cookie) {
+// SetCookie adds a `Set-Cookie` header in HTTP response.
+func (c *Context) SetCookie(cookie *http.Cookie) {
http.SetCookie(c.Response(), cookie)
}
-func (c *context) Cookies() []*http.Cookie {
+// Cookies returns the HTTP cookies sent with the request.
+func (c *Context) Cookies() []*http.Cookie {
return c.request.Cookies()
}
-func (c *context) Get(key string) interface{} {
+// Get retrieves data from the context.
+// Method returns any(nil) when key does not exist which is different from typed nil (eg. []byte(nil)).
+func (c *Context) Get(key string) any {
c.lock.RLock()
defer c.lock.RUnlock()
return c.store[key]
}
-func (c *context) Set(key string, val interface{}) {
+// Set saves data in the context.
+func (c *Context) Set(key string, val any) {
c.lock.Lock()
defer c.lock.Unlock()
if c.store == nil {
- c.store = make(Map)
+ c.store = make(map[string]any)
}
c.store[key] = val
}
-func (c *context) Bind(i interface{}) error {
- return c.echo.Binder.Bind(i, c)
+// Bind binds path params, query params and the request body into provided type `i`. The default binder
+// binds body based on Content-Type header.
+func (c *Context) Bind(i any) error {
+ return c.echo.Binder.Bind(c, i)
}
-func (c *context) Validate(i interface{}) error {
+// Validate validates provided `i`. It is usually called after `Context#Bind()`.
+// Validator must be registered using `Echo#Validator`.
+func (c *Context) Validate(i any) error {
if c.echo.Validator == nil {
return ErrValidatorNotRegistered
}
return c.echo.Validator.Validate(i)
}
-func (c *context) Render(code int, name string, data interface{}) (err error) {
+// Render renders a template with data and sends a text/html response with status
+// code. Renderer must be registered using `Echo.Renderer`.
+func (c *Context) Render(code int, name string, data any) (err error) {
if c.echo.Renderer == nil {
return ErrRendererNotRegistered
}
+ // as Renderer.Render can fail, and in that case we need to delay sending status code to the client until
+ // (global) error handler decides the correct status code for the error to be sent to the client, so we need to write
+ // the rendered template to the buffer first.
+ //
+ // html.Template.ExecuteTemplate() documentations writes:
+ // > If an error occurs executing the template or writing its output,
+ // > execution stops, but partial results may already have been written to
+ // > the output writer.
+
buf := new(bytes.Buffer)
- if err = c.echo.Renderer.Render(buf, name, data, c); err != nil {
+ if err = c.echo.Renderer.Render(c, buf, name, data); err != nil {
return
}
return c.HTMLBlob(code, buf.Bytes())
}
-func (c *context) HTML(code int, html string) (err error) {
+// HTML sends an HTTP response with status code.
+func (c *Context) HTML(code int, html string) (err error) {
return c.HTMLBlob(code, []byte(html))
}
-func (c *context) HTMLBlob(code int, b []byte) (err error) {
+// HTMLBlob sends an HTTP blob response with status code.
+func (c *Context) HTMLBlob(code int, b []byte) (err error) {
return c.Blob(code, MIMETextHTMLCharsetUTF8, b)
}
-func (c *context) String(code int, s string) (err error) {
+// String sends a string response with status code.
+func (c *Context) String(code int, s string) (err error) {
return c.Blob(code, MIMETextPlainCharsetUTF8, []byte(s))
}
-func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error) {
- enc := json.NewEncoder(c.response)
- _, pretty := c.QueryParams()["pretty"]
- if c.echo.Debug || pretty {
- enc.SetIndent("", " ")
- }
+func (c *Context) jsonPBlob(code int, callback string, i any) (err error) {
c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8)
c.response.WriteHeader(code)
if _, err = c.response.Write([]byte(callback + "(")); err != nil {
return
}
- if err = enc.Encode(i); err != nil {
+ if err = c.echo.JSONSerializer.Serialize(c, i, ""); err != nil {
return
}
if _, err = c.response.Write([]byte(");")); err != nil {
@@ -443,37 +462,47 @@ func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error
return
}
-func (c *context) json(code int, i interface{}, indent string) error {
- enc := json.NewEncoder(c.response)
- if indent != "" {
- enc.SetIndent("", indent)
+func (c *Context) json(code int, i any, indent string) error {
+ c.writeContentType(MIMEApplicationJSON)
+
+ // as JSONSerializer.Serialize can fail, and in that case we need to delay sending status code to the client until
+ // (global) error handler decides correct status code for the error to be sent to the client.
+ // For that we need to use writer that can store the proposed status code until the first Write is called.
+ if r, err := UnwrapResponse(c.response); err == nil {
+ r.Status = code
+ } else {
+ resp := c.Response()
+ c.SetResponse(&delayedStatusWriter{ResponseWriter: resp, status: code})
+ defer c.SetResponse(resp)
}
- c.writeContentType(MIMEApplicationJSONCharsetUTF8)
- c.response.Status = code
- return enc.Encode(i)
+
+ return c.echo.JSONSerializer.Serialize(c, i, indent)
}
-func (c *context) JSON(code int, i interface{}) (err error) {
- indent := ""
- if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty {
- indent = defaultIndent
- }
- return c.json(code, i, indent)
+// JSON sends a JSON response with status code.
+func (c *Context) JSON(code int, i any) (err error) {
+ return c.json(code, i, "")
}
-func (c *context) JSONPretty(code int, i interface{}, indent string) (err error) {
+// JSONPretty sends a pretty-print JSON with status code.
+func (c *Context) JSONPretty(code int, i any, indent string) (err error) {
return c.json(code, i, indent)
}
-func (c *context) JSONBlob(code int, b []byte) (err error) {
- return c.Blob(code, MIMEApplicationJSONCharsetUTF8, b)
+// JSONBlob sends a JSON blob response with status code.
+func (c *Context) JSONBlob(code int, b []byte) (err error) {
+ return c.Blob(code, MIMEApplicationJSON, b)
}
-func (c *context) JSONP(code int, callback string, i interface{}) (err error) {
+// JSONP sends a JSONP response with status code. It uses `callback` to construct
+// the JSONP payload.
+func (c *Context) JSONP(code int, callback string, i any) (err error) {
return c.jsonPBlob(code, callback, i)
}
-func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) {
+// JSONPBlob sends a JSONP blob response with status code. It uses `callback`
+// to construct the JSONP payload.
+func (c *Context) JSONPBlob(code int, callback string, b []byte) (err error) {
c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8)
c.response.WriteHeader(code)
if _, err = c.response.Write([]byte(callback + "(")); err != nil {
@@ -486,7 +515,7 @@ func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) {
return
}
-func (c *context) xml(code int, i interface{}, indent string) (err error) {
+func (c *Context) xml(code int, i any, indent string) (err error) {
c.writeContentType(MIMEApplicationXMLCharsetUTF8)
c.response.WriteHeader(code)
enc := xml.NewEncoder(c.response)
@@ -499,19 +528,18 @@ func (c *context) xml(code int, i interface{}, indent string) (err error) {
return enc.Encode(i)
}
-func (c *context) XML(code int, i interface{}) (err error) {
- indent := ""
- if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty {
- indent = defaultIndent
- }
- return c.xml(code, i, indent)
+// XML sends an XML response with status code.
+func (c *Context) XML(code int, i any) (err error) {
+ return c.xml(code, i, "")
}
-func (c *context) XMLPretty(code int, i interface{}, indent string) (err error) {
+// XMLPretty sends a pretty-print XML with status code.
+func (c *Context) XMLPretty(code int, i any, indent string) (err error) {
return c.xml(code, i, indent)
}
-func (c *context) XMLBlob(code int, b []byte) (err error) {
+// XMLBlob sends an XML blob response with status code.
+func (c *Context) XMLBlob(code int, b []byte) (err error) {
c.writeContentType(MIMEApplicationXMLCharsetUTF8)
c.response.WriteHeader(code)
if _, err = c.response.Write([]byte(xml.Header)); err != nil {
@@ -521,62 +549,89 @@ func (c *context) XMLBlob(code int, b []byte) (err error) {
return
}
-func (c *context) Blob(code int, contentType string, b []byte) (err error) {
+// Blob sends a blob response with status code and content type.
+func (c *Context) Blob(code int, contentType string, b []byte) (err error) {
c.writeContentType(contentType)
c.response.WriteHeader(code)
_, err = c.response.Write(b)
return
}
-func (c *context) Stream(code int, contentType string, r io.Reader) (err error) {
+// Stream sends a streaming response with status code and content type.
+func (c *Context) Stream(code int, contentType string, r io.Reader) (err error) {
c.writeContentType(contentType)
c.response.WriteHeader(code)
_, err = io.Copy(c.response, r)
return
}
-func (c *context) File(file string) (err error) {
- f, err := os.Open(file)
+// File sends a response with the content of the file.
+func (c *Context) File(file string) error {
+ return fsFile(c, file, c.echo.Filesystem)
+}
+
+// FileFS serves file from given file system.
+//
+// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
+// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
+// including `assets/images` as their prefix.
+func (c *Context) FileFS(file string, filesystem fs.FS) error {
+ return fsFile(c, file, filesystem)
+}
+
+func fsFile(c *Context, file string, filesystem fs.FS) error {
+ file = path.Clean(file) // `os.Open` and `os.DirFs.Open()` behave differently, later does not like ``, `.`, `..` at all, but we allowed those now need to clean
+ f, err := filesystem.Open(file)
if err != nil {
- return NotFoundHandler(c)
+ return ErrNotFound
}
defer f.Close()
fi, _ := f.Stat()
if fi.IsDir() {
- file = filepath.Join(file, indexPage)
- f, err = os.Open(file)
+ file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect.
+ f, err = filesystem.Open(file)
if err != nil {
- return NotFoundHandler(c)
+ return ErrNotFound
}
defer f.Close()
if fi, err = f.Stat(); err != nil {
- return
+ return err
}
}
- http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f)
- return
+ ff, ok := f.(io.ReadSeeker)
+ if !ok {
+ return errors.New("file does not implement io.ReadSeeker")
+ }
+ http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff)
+ return nil
}
-func (c *context) Attachment(file, name string) error {
+// Attachment sends a response as attachment, prompting client to save the file.
+func (c *Context) Attachment(file, name string) error {
return c.contentDisposition(file, name, "attachment")
}
-func (c *context) Inline(file, name string) error {
+// Inline sends a response as inline, opening the file in the browser.
+func (c *Context) Inline(file, name string) error {
return c.contentDisposition(file, name, "inline")
}
-func (c *context) contentDisposition(file, name, dispositionType string) error {
- c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf("%s; filename=%q", dispositionType, name))
+var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
+
+func (c *Context) contentDisposition(file, name, dispositionType string) error {
+ c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`, dispositionType, quoteEscaper.Replace(name)))
return c.File(file)
}
-func (c *context) NoContent(code int) error {
+// NoContent sends a response with no body and a status code.
+func (c *Context) NoContent(code int) error {
c.response.WriteHeader(code)
return nil
}
-func (c *context) Redirect(code int, url string) error {
+// Redirect redirects the request to a provided URL with status code.
+func (c *Context) Redirect(code int, url string) error {
if code < 300 || code > 308 {
return ErrInvalidRedirectCode
}
@@ -585,45 +640,20 @@ func (c *context) Redirect(code int, url string) error {
return nil
}
-func (c *context) Error(err error) {
- c.echo.HTTPErrorHandler(err, c)
-}
-
-func (c *context) Echo() *Echo {
- return c.echo
-}
-
-func (c *context) Handler() HandlerFunc {
- return c.handler
-}
-
-func (c *context) SetHandler(h HandlerFunc) {
- c.handler = h
-}
-
-func (c *context) Logger() Logger {
- res := c.logger
- if res != nil {
- return res
+// Logger returns logger in Context
+func (c *Context) Logger() *slog.Logger {
+ if c.logger != nil {
+ return c.logger
}
return c.echo.Logger
}
-func (c *context) SetLogger(l Logger) {
- c.logger = l
+// SetLogger sets logger in Context
+func (c *Context) SetLogger(logger *slog.Logger) {
+ c.logger = logger
}
-func (c *context) Reset(r *http.Request, w http.ResponseWriter) {
- c.request = r
- c.response.reset(w)
- c.query = nil
- c.handler = NotFoundHandler
- c.store = nil
- c.path = ""
- c.pnames = nil
- c.logger = nil
- // NOTE: Don't reset because it has to have length c.echo.maxParam at all times
- for i := 0; i < *c.echo.maxParam; i++ {
- c.pvalues[i] = ""
- }
+// Echo returns the `Echo` instance.
+func (c *Context) Echo() *Echo {
+ return c.echo
}
diff --git a/context_generic.go b/context_generic.go
new file mode 100644
index 000000000..7cf8b296c
--- /dev/null
+++ b/context_generic.go
@@ -0,0 +1,43 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import "errors"
+
+// ErrNonExistentKey is error that is returned when key does not exist
+var ErrNonExistentKey = errors.New("non existent key")
+
+// ErrInvalidKeyType is error that is returned when the value is not castable to expected type.
+var ErrInvalidKeyType = errors.New("invalid key type")
+
+// ContextGet retrieves a value from the context store or ErrNonExistentKey error the key is missing.
+// Returns ErrInvalidKeyType error if the value is not castable to type T.
+func ContextGet[T any](c *Context, key string) (T, error) {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+
+ val, ok := c.store[key]
+ if !ok {
+ var zero T
+ return zero, ErrNonExistentKey
+ }
+
+ typed, ok := val.(T)
+ if !ok {
+ var zero T
+ return zero, ErrInvalidKeyType
+ }
+
+ return typed, nil
+}
+
+// ContextGetOr retrieves a value from the context store or returns a default value when the key
+// is missing. Returns ErrInvalidKeyType error if the value is not castable to type T.
+func ContextGetOr[T any](c *Context, key string, defaultValue T) (T, error) {
+ typed, err := ContextGet[T](c, key)
+ if err == ErrNonExistentKey {
+ return defaultValue, nil
+ }
+ return typed, err
+}
diff --git a/context_generic_test.go b/context_generic_test.go
new file mode 100644
index 000000000..ce468ac3e
--- /dev/null
+++ b/context_generic_test.go
@@ -0,0 +1,70 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestContextGetOK(t *testing.T) {
+ c := NewContext(nil, nil)
+
+ c.Set("key", int64(123))
+
+ v, err := ContextGet[int64](c, "key")
+ assert.NoError(t, err)
+ assert.Equal(t, int64(123), v)
+}
+
+func TestContextGetNonExistentKey(t *testing.T) {
+ c := NewContext(nil, nil)
+
+ c.Set("key", int64(123))
+
+ v, err := ContextGet[int64](c, "nope")
+ assert.ErrorIs(t, err, ErrNonExistentKey)
+ assert.Equal(t, int64(0), v)
+}
+
+func TestContextGetInvalidCast(t *testing.T) {
+ c := NewContext(nil, nil)
+
+ c.Set("key", int64(123))
+
+ v, err := ContextGet[bool](c, "key")
+ assert.ErrorIs(t, err, ErrInvalidKeyType)
+ assert.Equal(t, false, v)
+}
+
+func TestContextGetOrOK(t *testing.T) {
+ c := NewContext(nil, nil)
+
+ c.Set("key", int64(123))
+
+ v, err := ContextGetOr[int64](c, "key", 999)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(123), v)
+}
+
+func TestContextGetOrNonExistentKey(t *testing.T) {
+ c := NewContext(nil, nil)
+
+ c.Set("key", int64(123))
+
+ v, err := ContextGetOr[int64](c, "nope", 999)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(999), v)
+}
+
+func TestContextGetOrInvalidCast(t *testing.T) {
+ c := NewContext(nil, nil)
+
+ c.Set("key", int64(123))
+
+ v, err := ContextGetOr[float32](c, "key", float32(999))
+ assert.ErrorIs(t, err, ErrInvalidKeyType)
+ assert.Equal(t, float32(0), v)
+}
diff --git a/context_test.go b/context_test.go
index 47be19cce..5945c9ecc 100644
--- a/context_test.go
+++ b/context_test.go
@@ -1,3 +1,6 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package echo
import (
@@ -5,36 +8,36 @@ import (
"crypto/tls"
"encoding/json"
"encoding/xml"
- "errors"
"fmt"
- "github.com/labstack/gommon/log"
"io"
+ "io/fs"
+ "log/slog"
"math"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/url"
+ "os"
"strings"
"testing"
"text/template"
"time"
- testify "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/assert"
)
-type (
- Template struct {
- templates *template.Template
- }
-)
+type Template struct {
+ templates *template.Template
+}
-var testUser = user{1, "Jon Snow"}
+var testUser = user{ID: 1, Name: "Jon Snow"}
func BenchmarkAllocJSONP(b *testing.B) {
e := New()
- req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON))
+ e.Logger = slog.New(slog.DiscardHandler)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec).(*context)
+ c := e.NewContext(req, rec)
b.ResetTimer()
b.ReportAllocs()
@@ -46,9 +49,10 @@ func BenchmarkAllocJSONP(b *testing.B) {
func BenchmarkAllocJSON(b *testing.B) {
e := New()
- req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON))
+ e.Logger = slog.New(slog.DiscardHandler)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec).(*context)
+ c := e.NewContext(req, rec)
b.ResetTimer()
b.ReportAllocs()
@@ -60,9 +64,10 @@ func BenchmarkAllocJSON(b *testing.B) {
func BenchmarkAllocXML(b *testing.B) {
e := New()
- req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON))
+ e.Logger = slog.New(slog.DiscardHandler)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec).(*context)
+ c := e.NewContext(req, rec)
b.ResetTimer()
b.ReportAllocs()
@@ -72,338 +77,430 @@ func BenchmarkAllocXML(b *testing.B) {
}
}
-func (t *Template) Render(w io.Writer, name string, data interface{}, c Context) error {
- return t.templates.ExecuteTemplate(w, name, data)
+func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) {
+ c := Context{request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}},
+ }}
+ for i := 0; i < b.N; i++ {
+ c.RealIP()
+ }
}
-type responseWriterErr struct {
+func (t *Template) Render(c *Context, w io.Writer, name string, data any) error {
+ return t.templates.ExecuteTemplate(w, name, data)
}
-func (responseWriterErr) Header() http.Header {
- return http.Header{}
-}
+func TestContextEcho(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec)
-func (responseWriterErr) Write([]byte) (int, error) {
- return 0, errors.New("err")
+ assert.Equal(t, e, c.Echo())
}
-func (responseWriterErr) WriteHeader(statusCode int) {
+func TestContextRequest(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ assert.NotNil(t, c.Request())
+ assert.Equal(t, req, c.Request())
}
-func TestContext(t *testing.T) {
+func TestContextResponse(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec).(*context)
- assert := testify.New(t)
-
- // Echo
- assert.Equal(e, c.Echo())
+ c := e.NewContext(req, rec)
- // Request
- assert.NotNil(c.Request())
+ assert.NotNil(t, c.Response())
+}
- // Response
- assert.NotNil(c.Response())
+func TestContextRenderTemplate(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
- //--------
- // Render
- //--------
+ c := e.NewContext(req, rec)
tmpl := &Template{
templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")),
}
- c.echo.Renderer = tmpl
+ c.Echo().Renderer = tmpl
err := c.Render(http.StatusOK, "hello", "Jon Snow")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal("Hello, Jon Snow!", rec.Body.String())
- }
-
- c.echo.Renderer = nil
- err = c.Render(http.StatusOK, "hello", "Jon Snow")
- assert.Error(err)
-
- // JSON
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.JSON(http.StatusOK, user{1, "Jon Snow"})
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(userJSON+"\n", rec.Body.String())
- }
-
- // JSON with "?pretty"
- req = httptest.NewRequest(http.MethodGet, "/?pretty", nil)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.JSON(http.StatusOK, user{1, "Jon Snow"})
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(userJSONPretty+"\n", rec.Body.String())
- }
- req = httptest.NewRequest(http.MethodGet, "/", nil) // reset
-
- // JSONPretty
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(userJSONPretty+"\n", rec.Body.String())
- }
-
- // JSON (error)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.JSON(http.StatusOK, make(chan bool))
- assert.Error(err)
-
- // JSONP
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "Hello, Jon Snow!", rec.Body.String())
+ }
+}
+
+func TestContextRenderTemplateError(t *testing.T) {
+ // we test that when template rendering fails, no response is sent to the client yet, so the global error handler can decide what to do
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ tmpl := &Template{
+ templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")),
+ }
+ c.Echo().Renderer = tmpl
+ err := c.Render(http.StatusOK, "not_existing", "Jon Snow")
+
+ assert.EqualError(t, err, `template: no template "not_existing" associated with template "hello"`)
+ assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
+ assert.Empty(t, rec.Body.String()) // body must not be sent to the client
+}
+
+func TestContextRenderErrorsOnNoRenderer(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec)
+
+ c.Echo().Renderer = nil
+ assert.Error(t, c.Render(http.StatusOK, "hello", "Jon Snow"))
+}
+
+func TestContextStream(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
+ r, w := io.Pipe()
+ go func() {
+ defer w.Close()
+ for i := 0; i < 3; i++ {
+ fmt.Fprintf(w, "data: index %v\n\n", i)
+ time.Sleep(5 * time.Millisecond)
+ }
+ }()
+
+ err := c.Stream(http.StatusOK, "text/event-stream", r)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "text/event-stream", rec.Header().Get(HeaderContentType))
+ assert.Equal(t, "data: index 0\n\ndata: index 1\n\ndata: index 2\n\n", rec.Body.String())
+ }
+}
+
+func TestContextHTML(t *testing.T) {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := NewContext(req, rec)
+
+ err := c.HTML(http.StatusOK, "Hi, Jon Snow")
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, "Hi, Jon Snow", rec.Body.String())
+ }
+}
+
+func TestContextHTMLBlob(t *testing.T) {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := NewContext(req, rec)
+
+ err := c.HTMLBlob(http.StatusOK, []byte("Hi, Jon Snow"))
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, "Hi, Jon Snow", rec.Body.String())
+ }
+}
+
+func TestContextJSON(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ c := e.NewContext(req, rec)
+
+ err := c.JSON(http.StatusOK, user{ID: 1, Name: "Jon Snow"})
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSON+"\n", rec.Body.String())
+ }
+}
+
+func TestContextJSONErrorsOut(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ c := e.NewContext(req, rec)
+
+ err := c.JSON(http.StatusOK, make(chan bool))
+ assert.EqualError(t, err, "json: unsupported type: chan bool")
+
+ assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
+ assert.Empty(t, rec.Body.String()) // body must not be sent to the client
+}
+
+func TestContextJSONWithNotEchoResponse(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ c := e.NewContext(req, rec)
+
+ c.SetResponse(rec)
+
+ err := c.JSON(http.StatusCreated, map[string]float64{"foo": math.NaN()})
+ assert.EqualError(t, err, "json: unsupported value: NaN")
+
+ assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
+ assert.Empty(t, rec.Body.String()) // body must not be sent to the client
+}
+
+func TestContextJSONPretty(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
+ err := c.JSONPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, " ")
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
+ }
+}
+
+func TestContextJSONWithEmptyIntent(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
+ u := user{ID: 1, Name: "Jon Snow"}
+ emptyIndent := ""
+ buf := new(bytes.Buffer)
+
+ enc := json.NewEncoder(buf)
+ enc.SetIndent(emptyIndent, emptyIndent)
+ _ = enc.Encode(u)
+ err := c.JSONPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, emptyIndent)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, buf.String(), rec.Body.String())
+ }
+}
+
+func TestContextJSONP(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
callback := "callback"
- err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"})
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(callback+"("+userJSON+"\n);", rec.Body.String())
- }
-
- // XML
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.XML(http.StatusOK, user{1, "Jon Snow"})
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(xml.Header+userXML, rec.Body.String())
- }
-
- // XML with "?pretty"
- req = httptest.NewRequest(http.MethodGet, "/?pretty", nil)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.XML(http.StatusOK, user{1, "Jon Snow"})
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(xml.Header+userXMLPretty, rec.Body.String())
- }
- req = httptest.NewRequest(http.MethodGet, "/", nil)
-
- // XML (error)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.XML(http.StatusOK, make(chan bool))
- assert.Error(err)
-
- // XML response write error
- c = e.NewContext(req, rec).(*context)
- c.response.Writer = responseWriterErr{}
- err = c.XML(0, 0)
- testify.Error(t, err)
-
- // XMLPretty
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(xml.Header+userXMLPretty, rec.Body.String())
- }
-
- t.Run("empty indent", func(t *testing.T) {
- var (
- u = user{1, "Jon Snow"}
- buf = new(bytes.Buffer)
- emptyIndent = ""
- )
-
- t.Run("json", func(t *testing.T) {
- buf.Reset()
- assert := testify.New(t)
-
- // New JSONBlob with empty indent
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- enc := json.NewEncoder(buf)
- enc.SetIndent(emptyIndent, emptyIndent)
- err = enc.Encode(u)
- err = c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(buf.String(), rec.Body.String())
- }
- })
+ err := c.JSONP(http.StatusOK, callback, user{ID: 1, Name: "Jon Snow"})
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String())
+ }
+}
- t.Run("xml", func(t *testing.T) {
- buf.Reset()
- assert := testify.New(t)
-
- // New XMLBlob with empty indent
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- enc := xml.NewEncoder(buf)
- enc.Indent(emptyIndent, emptyIndent)
- err = enc.Encode(u)
- err = c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(xml.Header+buf.String(), rec.Body.String())
- }
- })
- })
+func TestContextJSONBlob(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
- // Legacy JSONBlob
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- data, err := json.Marshal(user{1, "Jon Snow"})
- assert.NoError(err)
+ data, err := json.Marshal(user{ID: 1, Name: "Jon Snow"})
+ assert.NoError(t, err)
err = c.JSONBlob(http.StatusOK, data)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(userJSON, rec.Body.String())
- }
-
- // Legacy JSONPBlob
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- callback = "callback"
- data, err = json.Marshal(user{1, "Jon Snow"})
- assert.NoError(err)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSON, rec.Body.String())
+ }
+}
+
+func TestContextJSONPBlob(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
+ callback := "callback"
+ data, err := json.Marshal(user{ID: 1, Name: "Jon Snow"})
+ assert.NoError(t, err)
err = c.JSONPBlob(http.StatusOK, callback, data)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(callback+"("+userJSON+");", rec.Body.String())
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, callback+"("+userJSON+");", rec.Body.String())
}
+}
- // Legacy XMLBlob
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- data, err = xml.Marshal(user{1, "Jon Snow"})
- assert.NoError(err)
- err = c.XMLBlob(http.StatusOK, data)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(xml.Header+userXML, rec.Body.String())
- }
-
- // String
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.String(http.StatusOK, "Hello, World!")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal("Hello, World!", rec.Body.String())
- }
-
- // HTML
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.HTML(http.StatusOK, "Hello, World!")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal("Hello, World!", rec.Body.String())
- }
-
- // Stream
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- r := strings.NewReader("response from a stream")
- err = c.Stream(http.StatusOK, "application/octet-stream", r)
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal("application/octet-stream", rec.Header().Get(HeaderContentType))
- assert.Equal("response from a stream", rec.Body.String())
- }
-
- // Attachment
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.Attachment("_fixture/images/walle.png", "walle.png")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal("attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
- assert.Equal(219885, rec.Body.Len())
- }
-
- // Inline
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- err = c.Inline("_fixture/images/walle.png", "walle.png")
- if assert.NoError(err) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal("inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition))
- assert.Equal(219885, rec.Body.Len())
- }
-
- // NoContent
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- c.NoContent(http.StatusOK)
- assert.Equal(http.StatusOK, rec.Code)
-
- // Error
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec).(*context)
- c.Error(errors.New("error"))
- assert.Equal(http.StatusInternalServerError, rec.Code)
-
- // Reset
- c.SetParamNames("foo")
- c.SetParamValues("bar")
- c.Set("foe", "ban")
- c.query = url.Values(map[string][]string{"fon": {"baz"}})
- c.Reset(req, httptest.NewRecorder())
- assert.Equal(0, len(c.ParamValues()))
- assert.Equal(0, len(c.ParamNames()))
- assert.Equal(0, len(c.store))
- assert.Equal("", c.Path())
- assert.Equal(0, len(c.QueryParams()))
+func TestContextXML(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
+ err := c.XML(http.StatusOK, user{ID: 1, Name: "Jon Snow"})
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+userXML, rec.Body.String())
+ }
}
-func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
+func TestContextXMLPretty(t *testing.T) {
e := New()
+ rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
+ err := c.XMLPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, " ")
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String())
+ }
+}
+
+func TestContextXMLBlob(t *testing.T) {
+ e := New()
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec).(*context)
- err := c.JSON(http.StatusCreated, user{1, "Jon Snow"})
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
+ data, err := xml.Marshal(user{ID: 1, Name: "Jon Snow"})
+ assert.NoError(t, err)
+ err = c.XMLBlob(http.StatusOK, data)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+userXML, rec.Body.String())
+ }
+}
+
+func TestContextXMLWithEmptyIntent(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
- assert := testify.New(t)
- if assert.NoError(err) {
- assert.Equal(http.StatusCreated, rec.Code)
- assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
- assert.Equal(userJSON+"\n", rec.Body.String())
+ u := user{ID: 1, Name: "Jon Snow"}
+ emptyIndent := ""
+ buf := new(bytes.Buffer)
+
+ enc := xml.NewEncoder(buf)
+ enc.Indent(emptyIndent, emptyIndent)
+ _ = enc.Encode(u)
+ err := c.XMLPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, emptyIndent)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+buf.String(), rec.Body.String())
}
}
-func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) {
+func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec).(*context)
- err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()})
+ c := e.NewContext(req, rec)
+ err := c.JSON(http.StatusCreated, user{ID: 1, Name: "Jon Snow"})
- assert := testify.New(t)
- if assert.Error(err) {
- assert.False(c.response.Committed)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusCreated, rec.Code)
+ assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSON+"\n", rec.Body.String())
}
}
+func TestContextAttachment(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenName string
+ expectHeader string
+ }{
+ {
+ name: "ok",
+ whenName: "walle.png",
+ expectHeader: `attachment; filename="walle.png"`,
+ },
+ {
+ name: "ok, escape quotes in malicious filename",
+ whenName: `malicious.sh"; \"; dummy=.txt`,
+ expectHeader: `attachment; filename="malicious.sh\"; \\\"; dummy=.txt"`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
+ err := c.Attachment("_fixture/images/walle.png", tc.whenName)
+ if assert.NoError(t, err) {
+ assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition))
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, 219885, rec.Body.Len())
+ }
+ })
+ }
+}
+
+func TestContextInline(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenName string
+ expectHeader string
+ }{
+ {
+ name: "ok",
+ whenName: "walle.png",
+ expectHeader: `inline; filename="walle.png"`,
+ },
+ {
+ name: "ok, escape quotes in malicious filename",
+ whenName: `malicious.sh"; \"; dummy=.txt`,
+ expectHeader: `inline; filename="malicious.sh\"; \\\"; dummy=.txt"`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
+ err := c.Inline("_fixture/images/walle.png", tc.whenName)
+ if assert.NoError(t, err) {
+ assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition))
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, 219885, rec.Body.Len())
+ }
+ })
+ }
+}
+
+func TestContextNoContent(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec)
+
+ c.NoContent(http.StatusOK)
+ assert.Equal(t, http.StatusOK, rec.Code)
+}
+
func TestContextCookie(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
@@ -412,24 +509,22 @@ func TestContextCookie(t *testing.T) {
req.Header.Add(HeaderCookie, theme)
req.Header.Add(HeaderCookie, user)
rec := httptest.NewRecorder()
- c := e.NewContext(req, rec).(*context)
-
- assert := testify.New(t)
+ c := e.NewContext(req, rec)
// Read single
cookie, err := c.Cookie("theme")
- if assert.NoError(err) {
- assert.Equal("theme", cookie.Name)
- assert.Equal("light", cookie.Value)
+ if assert.NoError(t, err) {
+ assert.Equal(t, "theme", cookie.Name)
+ assert.Equal(t, "light", cookie.Value)
}
// Read multiple
for _, cookie := range c.Cookies() {
switch cookie.Name {
case "theme":
- assert.Equal("light", cookie.Value)
+ assert.Equal(t, "light", cookie.Value)
case "user":
- assert.Equal("Jon Snow", cookie.Value)
+ assert.Equal(t, "Jon Snow", cookie.Value)
}
}
@@ -444,47 +539,244 @@ func TestContextCookie(t *testing.T) {
HttpOnly: true,
}
c.SetCookie(cookie)
- assert.Contains(rec.Header().Get(HeaderSetCookie), "SSID")
- assert.Contains(rec.Header().Get(HeaderSetCookie), "Ap4PGTEq")
- assert.Contains(rec.Header().Get(HeaderSetCookie), "labstack.com")
- assert.Contains(rec.Header().Get(HeaderSetCookie), "Secure")
- assert.Contains(rec.Header().Get(HeaderSetCookie), "HttpOnly")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly")
}
-func TestContextPath(t *testing.T) {
- e := New()
- r := e.Router()
+func TestContext_PathValues(t *testing.T) {
+ var testCases = []struct {
+ name string
+ given PathValues
+ expect PathValues
+ }{
+ {
+ name: "param exists",
+ given: PathValues{
+ {Name: "uid", Value: "101"},
+ {Name: "fid", Value: "501"},
+ },
+ expect: PathValues{
+ {Name: "uid", Value: "101"},
+ {Name: "fid", Value: "501"},
+ },
+ },
+ {
+ name: "params is empty",
+ given: PathValues{},
+ expect: PathValues{},
+ },
+ }
- r.Add(http.MethodGet, "/users/:id", nil)
- c := e.NewContext(nil, nil)
- r.Find(http.MethodGet, "/users/1", c)
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, nil)
+
+ c.SetPathValues(tc.given)
+
+ assert.EqualValues(t, tc.expect, c.PathValues())
+ })
+ }
+}
+
+func TestContext_PathParam(t *testing.T) {
+ var testCases = []struct {
+ name string
+ given PathValues
+ whenParamName string
+ expect string
+ }{
+ {
+ name: "param exists",
+ given: PathValues{
+ {Name: "uid", Value: "101"},
+ {Name: "fid", Value: "501"},
+ },
+ whenParamName: "uid",
+ expect: "101",
+ },
+ {
+ name: "multiple same param values exists - return first",
+ given: PathValues{
+ {Name: "uid", Value: "101"},
+ {Name: "uid", Value: "202"},
+ {Name: "fid", Value: "501"},
+ },
+ whenParamName: "uid",
+ expect: "101",
+ },
+ {
+ name: "param does not exists",
+ given: PathValues{
+ {Name: "uid", Value: "101"},
+ },
+ whenParamName: "nope",
+ expect: "",
+ },
+ }
- assert := testify.New(t)
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, nil)
- assert.Equal("/users/:id", c.Path())
+ c.SetPathValues(tc.given)
- r.Add(http.MethodGet, "/users/:uid/files/:fid", nil)
- c = e.NewContext(nil, nil)
- r.Find(http.MethodGet, "/users/1/files/1", c)
- assert.Equal("/users/:uid/files/:fid", c.Path())
+ assert.EqualValues(t, tc.expect, c.Param(tc.whenParamName))
+ })
+ }
}
-func TestContextPathParam(t *testing.T) {
- e := New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- c := e.NewContext(req, nil)
+func TestContext_PathParamDefault(t *testing.T) {
+ var testCases = []struct {
+ name string
+ given PathValues
+ whenParamName string
+ whenDefaultValue string
+ expect string
+ }{
+ {
+ name: "param exists",
+ given: PathValues{
+ {Name: "uid", Value: "101"},
+ {Name: "fid", Value: "501"},
+ },
+ whenParamName: "uid",
+ whenDefaultValue: "999",
+ expect: "101",
+ },
+ {
+ name: "param exists and is empty",
+ given: PathValues{
+ {Name: "uid", Value: ""},
+ {Name: "fid", Value: "501"},
+ },
+ whenParamName: "uid",
+ whenDefaultValue: "999",
+ expect: "", // <-- this is different from QueryParamOr behaviour
+ },
+ {
+ name: "param does not exists",
+ given: PathValues{
+ {Name: "uid", Value: "101"},
+ },
+ whenParamName: "nope",
+ whenDefaultValue: "999",
+ expect: "999",
+ },
+ }
- // ParamNames
- c.SetParamNames("uid", "fid")
- testify.EqualValues(t, []string{"uid", "fid"}, c.ParamNames())
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, nil)
- // ParamValues
- c.SetParamValues("101", "501")
- testify.EqualValues(t, []string{"101", "501"}, c.ParamValues())
+ c.SetPathValues(tc.given)
- // Param
- testify.Equal(t, "501", c.Param("fid"))
- testify.Equal(t, "", c.Param("undefined"))
+ assert.EqualValues(t, tc.expect, c.ParamOr(tc.whenParamName, tc.whenDefaultValue))
+ })
+ }
+}
+
+func TestContextGetAndSetPathValuesMutability(t *testing.T) {
+ t.Run("c.PathValues() does not return copy and modifying raw slice mutates value in context", func(t *testing.T) {
+ e := New()
+ e.contextPathParamAllocSize.Store(1)
+
+ req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
+ c := e.NewContext(req, nil)
+
+ params := PathValues{{Name: "foo", Value: "101"}}
+ c.SetPathValues(params)
+
+ // round-trip param values with modification
+ paramVals := c.PathValues()
+ assert.Equal(t, params, c.PathValues())
+
+ // PathValues() does not return copy and modifying raw slice mutates value in context
+ paramVals[0] = PathValue{Name: "xxx", Value: "yyy"}
+ assert.Equal(t, PathValues{PathValue{Name: "xxx", Value: "yyy"}}, c.PathValues())
+ })
+
+ t.Run("calling SetPathValues with bigger size changes capacity in context", func(t *testing.T) {
+ e := New()
+ e.contextPathParamAllocSize.Store(1)
+
+ req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
+ c := e.NewContext(req, nil)
+ // increase path param capacity in context
+ pathValues := PathValues{
+ {Name: "aaa", Value: "bbb"},
+ {Name: "ccc", Value: "ddd"},
+ }
+ c.SetPathValues(pathValues)
+ assert.Equal(t, pathValues, c.PathValues())
+
+ // shouldn't explode during Reset() afterwards!
+ assert.NotPanics(t, func() {
+ c.Reset(nil, nil)
+ })
+ assert.Equal(t, PathValues{}, c.PathValues())
+ assert.Len(t, *c.pathValues, 0)
+ assert.Equal(t, 2, cap(*c.pathValues))
+ })
+
+ t.Run("calling SetPathValues with smaller size slice does not change capacity in context", func(t *testing.T) {
+ e := New()
+
+ req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
+ c := e.NewContext(req, nil)
+ c.pathValues = &PathValues{
+ {Name: "aaa", Value: "bbb"},
+ {Name: "ccc", Value: "ddd"},
+ }
+
+ pathValues := PathValues{
+ {Name: "aaa", Value: "bbb"},
+ }
+ // given pathValues slice is smaller. this should not decrease c.pathValues capacity
+ c.SetPathValues(pathValues)
+ assert.Equal(t, pathValues, c.PathValues())
+
+ // shouldn't explode during Reset() afterwards!
+ assert.NotPanics(t, func() {
+ c.Reset(nil, nil)
+ })
+ assert.Equal(t, PathValues{}, c.PathValues())
+ assert.Len(t, *c.pathValues, 0)
+ assert.Equal(t, 2, cap(*c.pathValues))
+ })
+
+}
+
+// Issue #1655
+func TestContext_SetParamNamesShouldNotModifyPathValuesCapacity(t *testing.T) {
+ e := New()
+ c := e.NewContext(nil, nil)
+
+ assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load())
+ expectedTwoParams := PathValues{
+ {Name: "1", Value: "one"},
+ {Name: "2", Value: "two"},
+ }
+ c.SetPathValues(expectedTwoParams)
+ assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load())
+ assert.Equal(t, expectedTwoParams, c.PathValues())
+
+ expectedThreeParams := PathValues{
+ {Name: "1", Value: "one"},
+ {Name: "2", Value: "two"},
+ {Name: "3", Value: "three"},
+ }
+ c.SetPathValues(expectedThreeParams)
+ assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load())
+ assert.Equal(t, expectedThreeParams, c.PathValues())
}
func TestContextFormValue(t *testing.T) {
@@ -498,44 +790,154 @@ func TestContextFormValue(t *testing.T) {
c := e.NewContext(req, nil)
// FormValue
- testify.Equal(t, "Jon Snow", c.FormValue("name"))
- testify.Equal(t, "jon@labstack.com", c.FormValue("email"))
+ assert.Equal(t, "Jon Snow", c.FormValue("name"))
+ assert.Equal(t, "jon@labstack.com", c.FormValue("email"))
+
+ // FormValueOr
+ assert.Equal(t, "Jon Snow", c.FormValueOr("name", "nope"))
+ assert.Equal(t, "default", c.FormValueOr("missing", "default"))
- // FormParams
- params, err := c.FormParams()
- if testify.NoError(t, err) {
- testify.Equal(t, url.Values{
+ // FormValues
+ values, err := c.FormValues()
+ if assert.NoError(t, err) {
+ assert.Equal(t, url.Values{
"name": []string{"Jon Snow"},
"email": []string{"jon@labstack.com"},
- }, params)
+ }, values)
}
// Multipart FormParams error
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(HeaderContentType, MIMEMultipartForm)
c = e.NewContext(req, nil)
- params, err = c.FormParams()
- testify.Nil(t, params)
- testify.Error(t, err)
+ values, err = c.FormValues()
+ assert.Nil(t, values)
+ assert.Error(t, err)
}
-func TestContextQueryParam(t *testing.T) {
- q := make(url.Values)
- q.Set("name", "Jon Snow")
- q.Set("email", "jon@labstack.com")
- req := httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil)
- e := New()
- c := e.NewContext(req, nil)
+func TestContext_QueryParams(t *testing.T) {
+ var testCases = []struct {
+ expect url.Values
+ name string
+ givenURL string
+ }{
+ {
+ name: "multiple values in url",
+ givenURL: "/?test=1&test=2&email=jon%40labstack.com",
+ expect: url.Values{
+ "test": []string{"1", "2"},
+ "email": []string{"jon@labstack.com"},
+ },
+ },
+ {
+ name: "single value in url",
+ givenURL: "/?nope=1",
+ expect: url.Values{
+ "nope": []string{"1"},
+ },
+ },
+ {
+ name: "no query params in url",
+ givenURL: "/?",
+ expect: url.Values{},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
+ e := New()
+ c := e.NewContext(req, nil)
+
+ assert.Equal(t, tc.expect, c.QueryParams())
+ })
+ }
+}
+
+func TestContext_QueryParam(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ whenParamName string
+ expect string
+ }{
+ {
+ name: "value exists in url",
+ givenURL: "/?test=1",
+ whenParamName: "test",
+ expect: "1",
+ },
+ {
+ name: "multiple values exists in url",
+ givenURL: "/?test=9&test=8",
+ whenParamName: "test",
+ expect: "9", // <-- first value in returned
+ },
+ {
+ name: "value does not exists in url",
+ givenURL: "/?nope=1",
+ whenParamName: "test",
+ expect: "",
+ },
+ {
+ name: "value is empty in url",
+ givenURL: "/?test=",
+ whenParamName: "test",
+ expect: "",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
+ e := New()
+ c := e.NewContext(req, nil)
+
+ assert.Equal(t, tc.expect, c.QueryParam(tc.whenParamName))
+ })
+ }
+}
+
+func TestContext_QueryParamDefault(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ whenParamName string
+ whenDefaultValue string
+ expect string
+ }{
+ {
+ name: "value exists in url",
+ givenURL: "/?test=1",
+ whenParamName: "test",
+ whenDefaultValue: "999",
+ expect: "1",
+ },
+ {
+ name: "value does not exists in url",
+ givenURL: "/?nope=1",
+ whenParamName: "test",
+ whenDefaultValue: "999",
+ expect: "999",
+ },
+ {
+ name: "value is empty in url",
+ givenURL: "/?test=",
+ whenParamName: "test",
+ whenDefaultValue: "999",
+ expect: "999",
+ },
+ }
- // QueryParam
- testify.Equal(t, "Jon Snow", c.QueryParam("name"))
- testify.Equal(t, "jon@labstack.com", c.QueryParam("email"))
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
+ e := New()
+ c := e.NewContext(req, nil)
- // QueryParams
- testify.Equal(t, url.Values{
- "name": []string{"Jon Snow"},
- "email": []string{"jon@labstack.com"},
- }, c.QueryParams())
+ assert.Equal(t, tc.expect, c.QueryParamOr(tc.whenParamName, tc.whenDefaultValue))
+ })
+ }
}
func TestContextFormFile(t *testing.T) {
@@ -543,7 +945,7 @@ func TestContextFormFile(t *testing.T) {
buf := new(bytes.Buffer)
mr := multipart.NewWriter(buf)
w, err := mr.CreateFormFile("file", "test")
- if testify.NoError(t, err) {
+ if assert.NoError(t, err) {
w.Write([]byte("test"))
}
mr.Close()
@@ -552,8 +954,8 @@ func TestContextFormFile(t *testing.T) {
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
f, err := c.FormFile("file")
- if testify.NoError(t, err) {
- testify.Equal(t, "test", f.Filename)
+ if assert.NoError(t, err) {
+ assert.Equal(t, "test", f.Filename)
}
}
@@ -562,14 +964,26 @@ func TestContextMultipartForm(t *testing.T) {
buf := new(bytes.Buffer)
mw := multipart.NewWriter(buf)
mw.WriteField("name", "Jon Snow")
+ fileContent := "This is a test file"
+ w, err := mw.CreateFormFile("file", "test.txt")
+ if assert.NoError(t, err) {
+ w.Write([]byte(fileContent))
+ }
mw.Close()
req := httptest.NewRequest(http.MethodPost, "/", buf)
req.Header.Set(HeaderContentType, mw.FormDataContentType())
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
f, err := c.MultipartForm()
- if testify.NoError(t, err) {
- testify.NotNil(t, f)
+ if assert.NoError(t, err) {
+ assert.NotNil(t, f)
+
+ files := f.File["file"]
+ if assert.Len(t, files, 1) {
+ file := files[0]
+ assert.Equal(t, "test.txt", file.Filename)
+ assert.Equal(t, int64(len(fileContent)), file.Size)
+ }
}
}
@@ -578,23 +992,53 @@ func TestContextRedirect(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- testify.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
- testify.Equal(t, http.StatusMovedPermanently, rec.Code)
- testify.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
- testify.Error(t, c.Redirect(310, "http://labstack.github.io/echo"))
+ assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
+ assert.Equal(t, http.StatusMovedPermanently, rec.Code)
+ assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
+ assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo"))
}
-func TestContextStore(t *testing.T) {
- var c Context
- c = new(context)
- c.Set("name", "Jon Snow")
- testify.Equal(t, "Jon Snow", c.Get("name"))
+func TestContextGet(t *testing.T) {
+ var testCases = []struct {
+ name string
+ given any
+ whenKey string
+ expect any
+ }{
+ {
+ name: "ok, value exist",
+ given: "Jon Snow",
+ whenKey: "key",
+ expect: "Jon Snow",
+ },
+ {
+ name: "ok, value does not exist",
+ given: "Jon Snow",
+ whenKey: "nope",
+ expect: nil,
+ },
+ {
+ name: "ok, value is nil value",
+ given: []byte(nil),
+ whenKey: "key",
+ expect: []byte(nil),
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ var c = new(Context)
+ c.Set("key", tc.given)
+
+ v := c.Get(tc.whenKey)
+ assert.Equal(t, tc.expect, v)
+ })
+ }
}
func BenchmarkContext_Store(b *testing.B) {
e := &Echo{}
- c := &context{
+ c := &Context{
echo: e,
}
@@ -606,47 +1050,9 @@ func BenchmarkContext_Store(b *testing.B) {
}
}
-func TestContextHandler(t *testing.T) {
- e := New()
- r := e.Router()
- b := new(bytes.Buffer)
-
- r.Add(http.MethodGet, "/handler", func(Context) error {
- _, err := b.Write([]byte("handler"))
- return err
- })
- c := e.NewContext(nil, nil)
- r.Find(http.MethodGet, "/handler", c)
- err := c.Handler()(c)
- testify.Equal(t, "handler", b.String())
- testify.NoError(t, err)
-}
-
-func TestContext_SetHandler(t *testing.T) {
- var c Context
- c = new(context)
-
- testify.Nil(t, c.Handler())
-
- c.SetHandler(func(c Context) error {
- return nil
- })
- testify.NotNil(t, c.Handler())
-}
-
-func TestContext_Path(t *testing.T) {
- path := "/pa/th"
-
- var c Context
- c = new(context)
-
- c.SetPath(path)
- testify.Equal(t, path, c.Path())
-}
-
type validator struct{}
-func (*validator) Validate(i interface{}) error {
+func (*validator) Validate(i any) error {
return nil
}
@@ -654,10 +1060,10 @@ func TestContext_Validate(t *testing.T) {
e := New()
c := e.NewContext(nil, nil)
- testify.Error(t, c.Validate(struct{}{}))
+ assert.Error(t, c.Validate(struct{}{}))
e.Validator = &validator{}
- testify.NoError(t, c.Validate(struct{}{}))
+ assert.NoError(t, c.Validate(struct{}{}))
}
func TestContext_QueryString(t *testing.T) {
@@ -665,31 +1071,30 @@ func TestContext_QueryString(t *testing.T) {
queryString := "query=string&var=val"
- req := httptest.NewRequest(GET, "/?"+queryString, nil)
+ req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil)
c := e.NewContext(req, nil)
- testify.Equal(t, queryString, c.QueryString())
+ assert.Equal(t, queryString, c.QueryString())
}
func TestContext_Request(t *testing.T) {
- var c Context
- c = new(context)
+ var c = new(Context)
- testify.Nil(t, c.Request())
+ assert.Nil(t, c.Request())
- req := httptest.NewRequest(GET, "/path", nil)
+ req := httptest.NewRequest(http.MethodGet, "/path", nil)
c.SetRequest(req)
- testify.Equal(t, req, c.Request())
+ assert.Equal(t, req, c.Request())
}
func TestContext_Scheme(t *testing.T) {
tests := []struct {
- c Context
+ c *Context
s string
}{
{
- &context{
+ &Context{
request: &http.Request{
TLS: &tls.ConnectionState{},
},
@@ -697,7 +1102,7 @@ func TestContext_Scheme(t *testing.T) {
"https",
},
{
- &context{
+ &Context{
request: &http.Request{
Header: http.Header{HeaderXForwardedProto: []string{"https"}},
},
@@ -705,7 +1110,7 @@ func TestContext_Scheme(t *testing.T) {
"https",
},
{
- &context{
+ &Context{
request: &http.Request{
Header: http.Header{HeaderXForwardedProtocol: []string{"http"}},
},
@@ -713,7 +1118,7 @@ func TestContext_Scheme(t *testing.T) {
"http",
},
{
- &context{
+ &Context{
request: &http.Request{
Header: http.Header{HeaderXForwardedSsl: []string{"on"}},
},
@@ -721,7 +1126,7 @@ func TestContext_Scheme(t *testing.T) {
"https",
},
{
- &context{
+ &Context{
request: &http.Request{
Header: http.Header{HeaderXUrlScheme: []string{"https"}},
},
@@ -729,7 +1134,7 @@ func TestContext_Scheme(t *testing.T) {
"https",
},
{
- &context{
+ &Context{
request: &http.Request{},
},
"http",
@@ -737,44 +1142,61 @@ func TestContext_Scheme(t *testing.T) {
}
for _, tt := range tests {
- testify.Equal(t, tt.s, tt.c.Scheme())
+ assert.Equal(t, tt.s, tt.c.Scheme())
}
}
func TestContext_IsWebSocket(t *testing.T) {
tests := []struct {
- c Context
- ws testify.BoolAssertionFunc
+ c *Context
+ ws assert.BoolAssertionFunc
}{
{
- &context{
+ &Context{
request: &http.Request{
- Header: http.Header{HeaderUpgrade: []string{"websocket"}},
+ Header: http.Header{
+ HeaderUpgrade: []string{"websocket"},
+ HeaderConnection: []string{"upgrade"},
+ },
},
},
- testify.True,
+ assert.True,
},
{
- &context{
+ &Context{
request: &http.Request{
- Header: http.Header{HeaderUpgrade: []string{"Websocket"}},
+ Header: http.Header{
+ HeaderUpgrade: []string{"Websocket"},
+ HeaderConnection: []string{"Upgrade"},
+ },
},
},
- testify.True,
+ assert.True,
},
{
- &context{
+ &Context{
request: &http.Request{},
},
- testify.False,
+ assert.False,
},
{
- &context{
+ &Context{
request: &http.Request{
Header: http.Header{HeaderUpgrade: []string{"other"}},
},
},
- testify.False,
+ assert.False,
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{
+ HeaderUpgrade: []string{"websocket"},
+ HeaderConnection: []string{"close"},
+ },
+ },
+ },
+ assert.False,
},
}
@@ -787,39 +1209,23 @@ func TestContext_IsWebSocket(t *testing.T) {
func TestContext_Bind(t *testing.T) {
e := New()
- req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON))
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
c := e.NewContext(req, nil)
u := new(user)
req.Header.Add(HeaderContentType, MIMEApplicationJSON)
err := c.Bind(u)
- testify.NoError(t, err)
- testify.Equal(t, &user{1, "Jon Snow"}, u)
-}
-
-func TestContext_Logger(t *testing.T) {
- e := New()
- c := e.NewContext(nil, nil)
-
- log1 := c.Logger()
- testify.NotNil(t, log1)
-
- log2 := log.New("echo2")
- c.SetLogger(log2)
- testify.Equal(t, log2, c.Logger())
-
- // Resetting the context returns the initial logger
- c.Reset(nil, nil)
- testify.Equal(t, log1, c.Logger())
+ assert.NoError(t, err)
+ assert.Equal(t, &user{ID: 1, Name: "Jon Snow"}, u)
}
func TestContext_RealIP(t *testing.T) {
tests := []struct {
- c Context
+ c *Context
s string
}{
{
- &context{
+ &Context{
request: &http.Request{
Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}},
},
@@ -827,7 +1233,47 @@ func TestContext_RealIP(t *testing.T) {
"127.0.0.1",
},
{
- &context{
+ &Context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1,127.0.1.1"}},
+ },
+ },
+ "127.0.0.1",
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1"}},
+ },
+ },
+ "127.0.0.1",
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348], 2001:db8::1, "}},
+ },
+ },
+ "2001:db8:85a3:8d3:1319:8a2e:370:7348",
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348],[2001:db8::1]"}},
+ },
+ },
+ "2001:db8:85a3:8d3:1319:8a2e:370:7348",
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}},
+ },
+ },
+ "2001:db8:85a3:8d3:1319:8a2e:370:7348",
+ },
+ {
+ &Context{
request: &http.Request{
Header: http.Header{
"X-Real-Ip": []string{"192.168.0.1"},
@@ -837,7 +1283,18 @@ func TestContext_RealIP(t *testing.T) {
"192.168.0.1",
},
{
- &context{
+ &Context{
+ request: &http.Request{
+ Header: http.Header{
+ "X-Real-Ip": []string{"[2001:db8::1]"},
+ },
+ },
+ },
+ "2001:db8::1",
+ },
+
+ {
+ &Context{
request: &http.Request{
RemoteAddr: "89.89.89.89:1654",
},
@@ -847,6 +1304,173 @@ func TestContext_RealIP(t *testing.T) {
}
for _, tt := range tests {
- testify.Equal(t, tt.s, tt.c.RealIP())
+ assert.Equal(t, tt.s, tt.c.RealIP())
+ }
+}
+
+func TestContext_File(t *testing.T) {
+ var testCases = []struct {
+ whenFS fs.FS
+ name string
+ whenFile string
+ expectError string
+ expectStartsWith []byte
+ expectStatus int
+ }{
+ {
+ name: "ok, from default file system",
+ whenFile: "_fixture/images/walle.png",
+ whenFS: nil,
+ expectStatus: http.StatusOK,
+ expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ },
+ {
+ name: "ok, from custom file system",
+ whenFile: "walle.png",
+ whenFS: os.DirFS("_fixture/images"),
+ expectStatus: http.StatusOK,
+ expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ },
+ {
+ name: "nok, not existent file",
+ whenFile: "not.png",
+ whenFS: os.DirFS("_fixture/images"),
+ expectStatus: http.StatusOK,
+ expectStartsWith: nil,
+ expectError: "Not Found",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ if tc.whenFS != nil {
+ e.Filesystem = tc.whenFS
+ }
+
+ handler := func(ec *Context) error {
+ return ec.File(tc.whenFile)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := handler(c)
+
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ body := rec.Body.Bytes()
+ if len(body) > len(tc.expectStartsWith) {
+ body = body[:len(tc.expectStartsWith)]
+ }
+ assert.Equal(t, tc.expectStartsWith, body)
+ })
+ }
+}
+
+func TestContext_FileFS(t *testing.T) {
+ var testCases = []struct {
+ whenFS fs.FS
+ name string
+ whenFile string
+ expectError string
+ expectStartsWith []byte
+ expectStatus int
+ }{
+ {
+ name: "ok",
+ whenFile: "walle.png",
+ whenFS: os.DirFS("_fixture/images"),
+ expectStatus: http.StatusOK,
+ expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ },
+ {
+ name: "nok, not existent file",
+ whenFile: "not.png",
+ whenFS: os.DirFS("_fixture/images"),
+ expectStatus: http.StatusOK,
+ expectStartsWith: nil,
+ expectError: "Not Found",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+
+ handler := func(ec *Context) error {
+ return ec.FileFS(tc.whenFile, tc.whenFS)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := handler(c)
+
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ body := rec.Body.Bytes()
+ if len(body) > len(tc.expectStartsWith) {
+ body = body[:len(tc.expectStartsWith)]
+ }
+ assert.Equal(t, tc.expectStartsWith, body)
+ })
+ }
+}
+
+func TestLogger(t *testing.T) {
+ e := New()
+ c := e.NewContext(nil, nil)
+
+ log1 := c.Logger()
+ assert.NotNil(t, log1)
+ assert.Equal(t, e.Logger, log1)
+
+ customLogger := slog.New(slog.NewTextHandler(os.Stdout, nil))
+ c.SetLogger(customLogger)
+ assert.Equal(t, customLogger, c.Logger())
+
+ // Resetting the context returns the initial Echo logger
+ c.Reset(nil, nil)
+ assert.Equal(t, e.Logger, c.Logger())
+}
+
+func TestRouteInfo(t *testing.T) {
+ e := New()
+ c := e.NewContext(nil, nil)
+
+ orgRI := RouteInfo{
+ Name: "root",
+ Method: http.MethodGet,
+ Path: "/*",
+ Parameters: []string{"*"},
}
+ c.route = &orgRI
+ ri := c.RouteInfo()
+ assert.Equal(t, orgRI, ri)
+
+ // Test mutability when middlewares start to change things
+
+ // RouteInfo inside context will not be affected when returned instance is changed
+ expect := orgRI.Clone()
+ ri.Path = "changed"
+ ri.Parameters[0] = "changed"
+ assert.Equal(t, expect, c.RouteInfo())
+
+ // RouteInfo inside context will not be affected when returned instance is changed
+ expect = c.RouteInfo()
+ orgRI.Name = "changed"
+ assert.NotEqual(t, expect, c.RouteInfo())
}
diff --git a/echo.go b/echo.go
index a6ac0fa80..4855e8429 100644
--- a/echo.go
+++ b/echo.go
@@ -1,153 +1,132 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
/*
Package echo implements high performance, minimalist Go web framework.
Example:
- package main
+ package main
- import (
- "net/http"
+ import (
+ "log/slog"
+ "net/http"
- "github.com/labstack/echo/v4"
- "github.com/labstack/echo/v4/middleware"
- )
+ "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v5/middleware"
+ )
- // Handler
- func hello(c echo.Context) error {
- return c.String(http.StatusOK, "Hello, World!")
- }
+ // Handler
+ func hello(c *echo.Context) error {
+ return c.String(http.StatusOK, "Hello, World!")
+ }
- func main() {
- // Echo instance
- e := echo.New()
+ func main() {
+ // Echo instance
+ e := echo.New()
- // Middleware
- e.Use(middleware.Logger())
- e.Use(middleware.Recover())
+ // Middleware
+ e.Use(middleware.RequestLogger())
+ e.Use(middleware.Recover())
- // Routes
- e.GET("/", hello)
+ // Routes
+ e.GET("/", hello)
- // Start server
- e.Logger.Fatal(e.Start(":1323"))
- }
+ // Start server
+ if err := e.Start(":8080"); err != nil {
+ slog.Error("failed to start server", "error", err)
+ }
+ }
Learn more at https://echo.labstack.com
*/
package echo
import (
- "bytes"
stdContext "context"
- "crypto/tls"
+ "encoding/json"
"errors"
"fmt"
- "io"
- "io/ioutil"
- stdLog "log"
- "net"
+ "io/fs"
+ "log/slog"
"net/http"
"net/url"
- "path"
+ "os"
+ "os/signal"
"path/filepath"
- "reflect"
- "runtime"
+ "strings"
"sync"
- "time"
-
- "github.com/labstack/gommon/color"
- "github.com/labstack/gommon/log"
- "golang.org/x/crypto/acme"
- "golang.org/x/crypto/acme/autocert"
+ "sync/atomic"
+ "syscall"
)
-type (
- // Echo is the top-level framework instance.
- Echo struct {
- common
- StdLogger *stdLog.Logger
- colorer *color.Color
- premiddleware []MiddlewareFunc
- middleware []MiddlewareFunc
- maxParam *int
- router *Router
- routers map[string]*Router
- notFoundHandler HandlerFunc
- pool sync.Pool
- Server *http.Server
- TLSServer *http.Server
- Listener net.Listener
- TLSListener net.Listener
- AutoTLSManager autocert.Manager
- DisableHTTP2 bool
- Debug bool
- HideBanner bool
- HidePort bool
- HTTPErrorHandler HTTPErrorHandler
- Binder Binder
- Validator Validator
- Renderer Renderer
- Logger Logger
- }
-
- // Route contains a handler and information for matching against requests.
- Route struct {
- Method string `json:"method"`
- Path string `json:"path"`
- Name string `json:"name"`
- }
-
- // HTTPError represents an error that occurred while handling a request.
- HTTPError struct {
- Code int `json:"-"`
- Message interface{} `json:"message"`
- Internal error `json:"-"` // Stores the error returned by an external dependency
- }
-
- // MiddlewareFunc defines a function to process middleware.
- MiddlewareFunc func(HandlerFunc) HandlerFunc
-
- // HandlerFunc defines a function to serve HTTP requests.
- HandlerFunc func(Context) error
-
- // HTTPErrorHandler is a centralized HTTP error handler.
- HTTPErrorHandler func(error, Context)
-
- // Validator is the interface that wraps the Validate function.
- Validator interface {
- Validate(i interface{}) error
- }
-
- // Renderer is the interface that wraps the Render function.
- Renderer interface {
- Render(io.Writer, string, interface{}, Context) error
- }
-
- // Map defines a generic map of type `map[string]interface{}`.
- Map map[string]interface{}
-
- // Common struct for Echo & Group.
- common struct{}
-)
+// Echo is the top-level framework instance.
+//
+// Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these
+// fields from handlers/middlewares and changing field values at the same time leads to data-races.
+// Same rule applies to adding new routes after server has been started - Adding a route is not Goroutine safe action.
+type Echo struct {
+ serveHTTPFunc func(http.ResponseWriter, *http.Request)
-// HTTP methods
-// NOTE: Deprecated, please use the stdlib constants directly instead.
-const (
- CONNECT = http.MethodConnect
- DELETE = http.MethodDelete
- GET = http.MethodGet
- HEAD = http.MethodHead
- OPTIONS = http.MethodOptions
- PATCH = http.MethodPatch
- POST = http.MethodPost
- // PROPFIND = "PROPFIND"
- PUT = http.MethodPut
- TRACE = http.MethodTrace
-)
+ Binder Binder
+ Filesystem fs.FS
+ Renderer Renderer
+ Validator Validator
+ JSONSerializer JSONSerializer
+ IPExtractor IPExtractor
+ OnAddRoute func(route Route) error
+ HTTPErrorHandler HTTPErrorHandler
+ Logger *slog.Logger
+
+ contextPool sync.Pool
+
+ router Router
+
+ // premiddleware are middlewares that are called before routing is done
+ premiddleware []MiddlewareFunc
+
+ // middleware are middlewares that are called after routing is done and before handler is called
+ middleware []MiddlewareFunc
+
+ contextPathParamAllocSize atomic.Int32
+
+ // formParseMaxMemory is passed to Context for multipart form parsing (See http.Request.ParseMultipartForm)
+ formParseMaxMemory int64
+}
+
+// JSONSerializer is the interface that encodes and decodes JSON to and from interfaces.
+type JSONSerializer interface {
+ Serialize(c *Context, target any, indent string) error
+ Deserialize(c *Context, target any) error
+}
+
+// HTTPErrorHandler is a centralized HTTP error handler.
+type HTTPErrorHandler func(c *Context, err error)
+
+// HandlerFunc defines a function to serve HTTP requests.
+type HandlerFunc func(c *Context) error
+
+// MiddlewareFunc defines a function to process middleware.
+type MiddlewareFunc func(next HandlerFunc) HandlerFunc
+
+// MiddlewareConfigurator defines interface for creating middleware handlers with possibility to return configuration errors instead of panicking.
+type MiddlewareConfigurator interface {
+ ToMiddleware() (MiddlewareFunc, error)
+}
+
+// Validator is the interface that wraps the Validate function.
+type Validator interface {
+ Validate(i any) error
+}
// MIME types
const (
- MIMEApplicationJSON = "application/json"
+ // MIMEApplicationJSON JavaScript Object Notation (JSON) https://www.rfc-editor.org/rfc/rfc8259
+ MIMEApplicationJSON = "application/json"
+ // Deprecated: Please use MIMEApplicationJSON instead. JSON should be encoded using UTF-8 by default.
+ // No "charset" parameter is defined for this registration.
+ // Adding one really has no effect on compliant recipients.
+ // See RFC 8259, section 8.1. https://datatracker.ietf.org/doc/html/rfc8259#section-8.1n"
MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8
MIMEApplicationJavaScript = "application/javascript"
MIMEApplicationJavaScriptCharsetUTF8 = MIMEApplicationJavaScript + "; " + charsetUTF8
@@ -172,12 +151,21 @@ const (
PROPFIND = "PROPFIND"
// REPORT Method can be used to get information about a resource, see rfc 3253
REPORT = "REPORT"
+ // RouteNotFound is special method type for routes handling "route not found" (404) cases
+ RouteNotFound = "echo_route_not_found"
+ // RouteAny is special method type that matches any HTTP method in request. Any has lower
+ // priority that other methods that have been registered with Router to that path.
+ RouteAny = "echo_route_any"
)
// Headers
const (
- HeaderAccept = "Accept"
- HeaderAcceptEncoding = "Accept-Encoding"
+ HeaderAccept = "Accept"
+ HeaderAcceptEncoding = "Accept-Encoding"
+ // HeaderAllow is the name of the "Allow" header field used to list the set of methods
+ // advertised as supported by the target resource. Returning an Allow header is mandatory
+ // for status 405 (method not found) and useful for the OPTIONS method in responses.
+ // See RFC 7231: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1
HeaderAllow = "Allow"
HeaderAuthorization = "Authorization"
HeaderContentDisposition = "Content-Disposition"
@@ -189,6 +177,7 @@ const (
HeaderIfModifiedSince = "If-Modified-Since"
HeaderLastModified = "Last-Modified"
HeaderLocation = "Location"
+ HeaderRetryAfter = "Retry-After"
HeaderUpgrade = "Upgrade"
HeaderVary = "Vary"
HeaderWWWAuthenticate = "WWW-Authenticate"
@@ -198,11 +187,17 @@ const (
HeaderXForwardedSsl = "X-Forwarded-Ssl"
HeaderXUrlScheme = "X-Url-Scheme"
HeaderXHTTPMethodOverride = "X-HTTP-Method-Override"
- HeaderXRealIP = "X-Real-IP"
- HeaderXRequestID = "X-Request-ID"
+ HeaderXRealIP = "X-Real-Ip"
+ HeaderXRequestID = "X-Request-Id"
+ HeaderXCorrelationID = "X-Correlation-Id"
HeaderXRequestedWith = "X-Requested-With"
HeaderServer = "Server"
- HeaderOrigin = "Origin"
+
+ // HeaderOrigin request header indicates the origin (scheme, hostname, and port) that caused the request.
+ // See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin
+ HeaderOrigin = "Origin"
+ HeaderCacheControl = "Cache-Control"
+ HeaderConnection = "Connection"
// Access control
HeaderAccessControlRequestMethod = "Access-Control-Request-Method"
@@ -221,314 +216,425 @@ const (
HeaderXFrameOptions = "X-Frame-Options"
HeaderContentSecurityPolicy = "Content-Security-Policy"
HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only"
- HeaderXCSRFToken = "X-CSRF-Token"
+ HeaderXCSRFToken = "X-CSRF-Token" // #nosec G101
HeaderReferrerPolicy = "Referrer-Policy"
-)
-const (
- // Version of Echo
- Version = "4.1.13"
- website = "https://echo.labstack.com"
- // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo
- banner = `
- ____ __
- / __/___/ / ___
- / _// __/ _ \/ _ \
-/___/\__/_//_/\___/ %s
-High performance, minimalist Go web framework
-%s
-____________________________________O/_______
- O\
-`
+ // HeaderSecFetchSite fetch metadata request header indicates the relationship between a request initiator's
+ // origin and the origin of the requested resource.
+ // See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Sec-Fetch-Site
+ HeaderSecFetchSite = "Sec-Fetch-Site"
)
-var (
- methods = [...]string{
- http.MethodConnect,
- http.MethodDelete,
- http.MethodGet,
- http.MethodHead,
- http.MethodOptions,
- http.MethodPatch,
- http.MethodPost,
- PROPFIND,
- http.MethodPut,
- http.MethodTrace,
- REPORT,
- }
-)
+// Config is configuration for NewWithConfig function
+type Config struct {
+ // Logger is the slog logger instance used for application-wide structured logging.
+ // If not set, a default TextHandler writing to stdout is created.
+ Logger *slog.Logger
-// Errors
-var (
- ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType)
- ErrNotFound = NewHTTPError(http.StatusNotFound)
- ErrUnauthorized = NewHTTPError(http.StatusUnauthorized)
- ErrForbidden = NewHTTPError(http.StatusForbidden)
- ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
- ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge)
- ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests)
- ErrBadRequest = NewHTTPError(http.StatusBadRequest)
- ErrBadGateway = NewHTTPError(http.StatusBadGateway)
- ErrInternalServerError = NewHTTPError(http.StatusInternalServerError)
- ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout)
- ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable)
- ErrValidatorNotRegistered = errors.New("validator not registered")
- ErrRendererNotRegistered = errors.New("renderer not registered")
- ErrInvalidRedirectCode = errors.New("invalid redirect status code")
- ErrCookieNotFound = errors.New("cookie not found")
- ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
-)
+ // HTTPErrorHandler is the centralized error handler that processes errors returned
+ // by handlers and middleware, converting them to appropriate HTTP responses.
+ // If not set, DefaultHTTPErrorHandler(false) is used.
+ HTTPErrorHandler HTTPErrorHandler
-// Error handlers
-var (
- NotFoundHandler = func(c Context) error {
- return ErrNotFound
- }
+ // Router is the HTTP request router responsible for matching URLs to handlers
+ // using a radix tree-based algorithm.
+ // If not set, NewRouter(RouterConfig{}) is used.
+ Router Router
+
+ // OnAddRoute is an optional callback hook executed when routes are registered.
+ // Useful for route validation, logging, or custom route processing.
+ // If not set, no callback is executed.
+ OnAddRoute func(route Route) error
+
+ // Filesystem is the fs.FS implementation used for serving static files.
+ // Supports os.DirFS, embed.FS, and custom implementations.
+ // If not set, defaults to current working directory.
+ Filesystem fs.FS
+
+ // Binder handles automatic data binding from HTTP requests to Go structs.
+ // Supports JSON, XML, form data, query parameters, and path parameters.
+ // If not set, DefaultBinder is used.
+ Binder Binder
+
+ // Validator provides optional struct validation after data binding.
+ // Commonly used with third-party validation libraries.
+ // If not set, Context.Validate() returns ErrValidatorNotRegistered.
+ Validator Validator
- MethodNotAllowedHandler = func(c Context) error {
- return ErrMethodNotAllowed
+ // Renderer provides template rendering for generating HTML responses.
+ // Requires integration with a template engine like html/template.
+ // If not set, Context.Render() returns ErrRendererNotRegistered.
+ Renderer Renderer
+
+ // JSONSerializer handles JSON encoding and decoding for HTTP requests/responses.
+ // Can be replaced with faster alternatives like jsoniter or sonic.
+ // If not set, DefaultJSONSerializer using encoding/json is used.
+ JSONSerializer JSONSerializer
+
+ // IPExtractor defines the strategy for extracting the real client IP address
+ // from requests, particularly important when behind proxies or load balancers.
+ // Used for rate limiting, access control, and logging.
+ // If not set, falls back to checking X-Forwarded-For and X-Real-IP headers.
+ IPExtractor IPExtractor
+
+ // FormParseMaxMemory is default value for memory limit that is used
+ // when parsing multipart forms (See (*http.Request).ParseMultipartForm)
+ FormParseMaxMemory int64
+}
+
+// NewWithConfig creates an instance of Echo with given configuration.
+func NewWithConfig(config Config) *Echo {
+ e := New()
+ if config.Logger != nil {
+ e.Logger = config.Logger
}
-)
+ if config.HTTPErrorHandler != nil {
+ e.HTTPErrorHandler = config.HTTPErrorHandler
+ }
+ if config.Router != nil {
+ e.router = config.Router
+ }
+ if config.OnAddRoute != nil {
+ e.OnAddRoute = config.OnAddRoute
+ }
+ if config.Filesystem != nil {
+ e.Filesystem = config.Filesystem
+ }
+ if config.Binder != nil {
+ e.Binder = config.Binder
+ }
+ if config.Validator != nil {
+ e.Validator = config.Validator
+ }
+ if config.Renderer != nil {
+ e.Renderer = config.Renderer
+ }
+ if config.JSONSerializer != nil {
+ e.JSONSerializer = config.JSONSerializer
+ }
+ if config.IPExtractor != nil {
+ e.IPExtractor = config.IPExtractor
+ }
+ if config.FormParseMaxMemory > 0 {
+ e.formParseMaxMemory = config.FormParseMaxMemory
+ }
+ return e
+}
// New creates an instance of Echo.
-func New() (e *Echo) {
- e = &Echo{
- Server: new(http.Server),
- TLSServer: new(http.Server),
- AutoTLSManager: autocert.Manager{
- Prompt: autocert.AcceptTOS,
- },
- Logger: log.New("echo"),
- colorer: color.New(),
- maxParam: new(int),
- }
- e.Server.Handler = e
- e.TLSServer.Handler = e
- e.HTTPErrorHandler = e.DefaultHTTPErrorHandler
- e.Binder = &DefaultBinder{}
- e.Logger.SetLevel(log.ERROR)
- e.StdLogger = stdLog.New(e.Logger.Output(), e.Logger.Prefix()+": ", 0)
- e.pool.New = func() interface{} {
- return e.NewContext(nil, nil)
- }
- e.router = NewRouter(e)
- e.routers = map[string]*Router{}
- return
-}
+func New() *Echo {
+ logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
+ e := &Echo{
+ Logger: logger,
+ Filesystem: newDefaultFS(),
+ Binder: &DefaultBinder{},
+ JSONSerializer: &DefaultJSONSerializer{},
+ formParseMaxMemory: defaultMemory,
+ }
-// NewContext returns a Context instance.
-func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context {
- return &context{
- request: r,
- response: NewResponse(w, e),
- store: make(Map),
- echo: e,
- pvalues: make([]string, *e.maxParam),
- handler: NotFoundHandler,
+ e.serveHTTPFunc = e.serveHTTP
+ e.router = NewRouter(RouterConfig{})
+ e.HTTPErrorHandler = DefaultHTTPErrorHandler(false)
+ e.contextPool.New = func() any {
+ return newContext(nil, nil, e)
}
+ return e
+}
+
+// NewContext returns a new Context instance.
+//
+// Note: both request and response can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway
+// these arguments are useful when creating context for tests and cases like that.
+func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) *Context {
+ return newContext(r, w, e)
}
// Router returns the default router.
-func (e *Echo) Router() *Router {
+func (e *Echo) Router() Router {
return e.router
}
-// Routers returns the map of host => router.
-func (e *Echo) Routers() map[string]*Router {
- return e.routers
-}
+// DefaultHTTPErrorHandler creates new default HTTP error handler implementation. It sends a JSON response
+// with status code. `exposeError` parameter decides if returned message will contain also error message or not
+//
+// Note: DefaultHTTPErrorHandler does not log errors. Use middleware for it if errors need to be logged (separately)
+// Note: In case errors happens in middleware call-chain that is returning from handler (which did not return an error).
+// When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from
+// handler. Then the error that global error handler received will be ignored because we have already "committed" the
+// response and status code header has been sent to the client.
+func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler {
+ return func(c *Context, err error) {
+ if r, _ := UnwrapResponse(c.response); r != nil && r.Committed {
+ return
+ }
-// DefaultHTTPErrorHandler is the default HTTP error handler. It sends a JSON response
-// with status code.
-func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) {
- he, ok := err.(*HTTPError)
- if ok {
- if he.Internal != nil {
- if herr, ok := he.Internal.(*HTTPError); ok {
- he = herr
+ code := http.StatusInternalServerError
+ var sc HTTPStatusCoder
+ if errors.As(err, &sc) {
+ if tmp := sc.StatusCode(); tmp != 0 {
+ code = tmp
}
}
- } else {
- he = &HTTPError{
- Code: http.StatusInternalServerError,
- Message: http.StatusText(http.StatusInternalServerError),
- }
- }
- // Issue #1426
- code := he.Code
- message := he.Message
- if e.Debug {
- message = err.Error()
- } else if m, ok := message.(string); ok {
- message = Map{"message": m}
- }
+ var result any
+ switch m := sc.(type) {
+ case json.Marshaler: // this type knows how to format itself to JSON
+ result = m
+ case *HTTPError:
+ sText := m.Message
+ if sText == "" {
+ sText = http.StatusText(code)
+ }
+ msg := map[string]any{"message": sText}
+ if exposeError {
+ if wrappedErr := m.Unwrap(); wrappedErr != nil {
+ msg["error"] = wrappedErr.Error()
+ }
+ }
+ result = msg
+ default:
+ msg := map[string]any{"message": http.StatusText(code)}
+ if exposeError {
+ msg["error"] = err.Error()
+ }
+ result = msg
+ }
- // Send response
- if !c.Response().Committed {
+ var cErr error
if c.Request().Method == http.MethodHead { // Issue #608
- err = c.NoContent(he.Code)
+ cErr = c.NoContent(code)
} else {
- err = c.JSON(code, message)
+ cErr = c.JSON(code, result)
}
- if err != nil {
- e.Logger.Error(err)
+ if cErr != nil {
+ c.Logger().Error("echo default error handler failed to send error to client", "error", cErr) // truly rare case. ala client already disconnected
}
}
}
-// Pre adds middleware to the chain which is run before router.
+// Pre adds middleware to the chain which is run before router tries to find matching route.
+// Meaning middleware is executed even for 404 (not found) cases.
func (e *Echo) Pre(middleware ...MiddlewareFunc) {
e.premiddleware = append(e.premiddleware, middleware...)
}
-// Use adds middleware to the chain which is run after router.
+// Use adds middleware to the chain which is run after router has found matching route and before route/request handler method is executed.
func (e *Echo) Use(middleware ...MiddlewareFunc) {
e.middleware = append(e.middleware, middleware...)
}
// CONNECT registers a new CONNECT route for a path with matching handler in the
-// router with optional route-level middleware.
-func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodConnect, path, h, m...)
}
// DELETE registers a new DELETE route for a path with matching handler in the router
-// with optional route-level middleware.
-func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// with optional route-level middleware. Panics on error.
+func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodDelete, path, h, m...)
}
// GET registers a new GET route for a path with matching handler in the router
-// with optional route-level middleware.
-func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// with optional route-level middleware. Panics on error.
+func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodGet, path, h, m...)
}
// HEAD registers a new HEAD route for a path with matching handler in the
-// router with optional route-level middleware.
-func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodHead, path, h, m...)
}
// OPTIONS registers a new OPTIONS route for a path with matching handler in the
-// router with optional route-level middleware.
-func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodOptions, path, h, m...)
}
// PATCH registers a new PATCH route for a path with matching handler in the
-// router with optional route-level middleware.
-func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodPatch, path, h, m...)
}
// POST registers a new POST route for a path with matching handler in the
-// router with optional route-level middleware.
-func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodPost, path, h, m...)
}
// PUT registers a new PUT route for a path with matching handler in the
-// router with optional route-level middleware.
-func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodPut, path, h, m...)
}
// TRACE registers a new TRACE route for a path with matching handler in the
-// router with optional route-level middleware.
-func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return e.Add(http.MethodTrace, path, h, m...)
}
-// Any registers a new route for all HTTP methods and path with matching handler
+// RouteNotFound registers a special-case route which is executed when no other route is found (i.e. HTTP 404 cases)
+// for current request URL.
+// Path supports static and named/any parameters just like other http method is defined. Generally path is ended with
+// wildcard/match-any character (`/*`, `/download/*` etc).
+//
+// Example: `e.RouteNotFound("/*", func(c *echo.Context) error { return c.NoContent(http.StatusNotFound) })`
+func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return e.Add(RouteNotFound, path, h, m...)
+}
+
+// Any registers a new route for all HTTP methods (supported by Echo) and path with matching handler
// in the router with optional route-level middleware.
-func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
- routes := make([]*Route, len(methods))
- for i, m := range methods {
- routes[i] = e.Add(m, path, handler, middleware...)
- }
- return routes
+//
+// Note: this method only adds specific set of supported HTTP methods as handler and is not true
+// "catch-any-arbitrary-method" way of matching requests.
+func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
+ return e.Add(RouteAny, path, handler, middleware...)
}
// Match registers a new route for multiple HTTP methods and path with matching
-// handler in the router with optional route-level middleware.
-func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
- routes := make([]*Route, len(methods))
- for i, m := range methods {
- routes[i] = e.Add(m, path, handler, middleware...)
+// handler in the router with optional route-level middleware. Panics on error.
+func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
+ errs := make([]error, 0)
+ ris := make(Routes, 0)
+ for _, m := range methods {
+ ri, err := e.AddRoute(Route{
+ Method: m,
+ Path: path,
+ Handler: handler,
+ Middlewares: middleware,
+ })
+ if err != nil {
+ errs = append(errs, err)
+ continue
+ }
+ ris = append(ris, ri)
}
- return routes
-}
-
-// Static registers a new route with path prefix to serve static files from the
-// provided root directory.
-func (e *Echo) Static(prefix, root string) *Route {
- if root == "" {
- root = "." // For security we want to restrict to CWD.
+ if len(errs) > 0 {
+ panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
- return e.static(prefix, root, e.GET)
+ return ris
}
-func (common) static(prefix, root string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route) *Route {
- h := func(c Context) error {
- p, err := url.PathUnescape(c.Param("*"))
+// Static registers a new route with path prefix to serve static files from the provided root directory.
+func (e *Echo) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo {
+ subFs := MustSubFS(e.Filesystem, fsRoot)
+ return e.Add(
+ http.MethodGet,
+ pathPrefix+"*",
+ StaticDirectoryHandler(subFs, false),
+ middleware...,
+ )
+}
+
+// StaticFS registers a new route with path prefix to serve static files from the provided file system.
+//
+// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
+// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
+// including `assets/images` as their prefix.
+func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo {
+ return e.Add(
+ http.MethodGet,
+ pathPrefix+"*",
+ StaticDirectoryHandler(filesystem, false),
+ middleware...,
+ )
+}
+
+// StaticDirectoryHandler creates handler function to serve files from provided file system
+// When disablePathUnescaping is set then file name from path is not unescaped and is served as is.
+func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc {
+ return func(c *Context) error {
+ p := c.Param("*")
+ if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice
+ tmpPath, err := url.PathUnescape(p)
+ if err != nil {
+ return fmt.Errorf("failed to unescape path variable: %w", err)
+ }
+ p = tmpPath
+ }
+
+ // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid
+ name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/")))
+ fi, err := fs.Stat(fileSystem, name)
if err != nil {
- return err
+ return ErrNotFound
+ }
+
+ // If the request is for a directory and does not end with "/"
+ p = c.Request().URL.Path // path must not be empty.
+ if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' {
+ // Redirect to ends with "/"
+ return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/"))
}
- name := filepath.Join(root, path.Clean("/"+p)) // "/"+ for security
- return c.File(name)
+ return fsFile(c, name, fileSystem)
}
- if prefix == "/" {
- return get(prefix+"*", h)
+}
+
+// FileFS registers a new route with path to serve file from the provided file system.
+func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo {
+ return e.GET(path, StaticFileHandler(file, filesystem), m...)
+}
+
+// StaticFileHandler creates handler function to serve file from provided file system
+func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc {
+ return func(c *Context) error {
+ return fsFile(c, file, filesystem)
}
- return get(prefix+"/*", h)
}
-func (common) file(path, file string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route,
- m ...MiddlewareFunc) *Route {
- return get(path, func(c Context) error {
+// File registers a new route with path to serve a static file with optional route-level middleware. Panics on error.
+func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo {
+ handler := func(c *Context) error {
return c.File(file)
- }, m...)
+ }
+ return e.Add(http.MethodGet, path, handler, middleware...)
}
-// File registers a new route with path to serve a static file with optional route-level middleware.
-func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route {
- return e.file(path, file, e.GET, m...)
+// AddRoute registers a new Route with default host Router
+func (e *Echo) AddRoute(route Route) (RouteInfo, error) {
+ return e.add(route)
}
-func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
- name := handlerName(handler)
- router := e.findRouter(host)
- router.Add(method, path, func(c Context) error {
- h := handler
- // Chain middleware
- for i := len(middleware) - 1; i >= 0; i-- {
- h = middleware[i](h)
+func (e *Echo) add(route Route) (RouteInfo, error) {
+ if e.OnAddRoute != nil {
+ if err := e.OnAddRoute(route); err != nil {
+ return RouteInfo{}, err
}
- return h(c)
- })
- r := &Route{
- Method: method,
- Path: path,
- Name: name,
}
- e.router.routes[method+path] = r
- return r
+
+ ri, err := e.router.Add(route)
+ if err != nil {
+ return RouteInfo{}, err
+ }
+
+ paramsCount := int32(len(ri.Parameters)) // #nosec G115
+ if paramsCount > e.contextPathParamAllocSize.Load() {
+ e.contextPathParamAllocSize.Store(paramsCount)
+ }
+ return ri, nil
}
// Add registers a new route for an HTTP method and path with matching handler
// in the router with optional route-level middleware.
-func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
- return e.add("", method, path, handler, middleware...)
-}
-
-// Host creates a new router group for the provided host and optional host-level middleware.
-func (e *Echo) Host(name string, m ...MiddlewareFunc) (g *Group) {
- e.routers[name] = NewRouter(e)
- g = &Group{host: name, echo: e}
- g.Use(m...)
- return
+func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
+ ri, err := e.add(
+ Route{
+ Method: method,
+ Path: path,
+ Handler: handler,
+ Middlewares: middleware,
+ Name: "",
+ },
+ )
+ if err != nil {
+ panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
+ }
+ return ri
}
// Group creates a new router group with prefix and optional group-level middleware.
@@ -538,233 +644,102 @@ func (e *Echo) Group(prefix string, m ...MiddlewareFunc) (g *Group) {
return
}
-// URI generates a URI from handler.
-func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string {
- name := handlerName(handler)
- return e.Reverse(name, params...)
-}
-
-// URL is an alias for `URI` function.
-func (e *Echo) URL(h HandlerFunc, params ...interface{}) string {
- return e.URI(h, params...)
+// PreMiddlewares returns registered pre middlewares. These are middleware to the chain
+// which are run before router tries to find matching route.
+// Use this method to build your own ServeHTTP method.
+//
+// NOTE: returned slice is not a copy. Do not mutate.
+func (e *Echo) PreMiddlewares() []MiddlewareFunc {
+ return e.premiddleware
}
-// Reverse generates an URL from route name and provided parameters.
-func (e *Echo) Reverse(name string, params ...interface{}) string {
- uri := new(bytes.Buffer)
- ln := len(params)
- n := 0
- for _, r := range e.router.routes {
- if r.Name == name {
- for i, l := 0, len(r.Path); i < l; i++ {
- if r.Path[i] == ':' && n < ln {
- for ; i < l && r.Path[i] != '/'; i++ {
- }
- uri.WriteString(fmt.Sprintf("%v", params[n]))
- n++
- }
- if i < l {
- uri.WriteByte(r.Path[i])
- }
- }
- break
- }
- }
- return uri.String()
-}
-
-// Routes returns the registered routes.
-func (e *Echo) Routes() []*Route {
- routes := make([]*Route, 0, len(e.router.routes))
- for _, v := range e.router.routes {
- routes = append(routes, v)
- }
- return routes
+// Middlewares returns registered route level middlewares. Does not contain any group level
+// middlewares. Use this method to build your own ServeHTTP method.
+//
+// NOTE: returned slice is not a copy. Do not mutate.
+func (e *Echo) Middlewares() []MiddlewareFunc {
+ return e.middleware
}
// AcquireContext returns an empty `Context` instance from the pool.
// You must return the context by calling `ReleaseContext()`.
-func (e *Echo) AcquireContext() Context {
- return e.pool.Get().(Context)
+func (e *Echo) AcquireContext() *Context {
+ return e.contextPool.Get().(*Context)
}
// ReleaseContext returns the `Context` instance back to the pool.
// You must call it after `AcquireContext()`.
-func (e *Echo) ReleaseContext(c Context) {
- e.pool.Put(c)
+func (e *Echo) ReleaseContext(c *Context) {
+ e.contextPool.Put(c)
}
// ServeHTTP implements `http.Handler` interface, which serves HTTP requests.
func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- // Acquire context
- c := e.pool.Get().(*context)
- c.Reset(r, w)
+ e.serveHTTPFunc(w, r)
+}
+
+// serveHTTP implements `http.Handler` interface, which serves HTTP requests.
+func (e *Echo) serveHTTP(w http.ResponseWriter, r *http.Request) {
+ c := e.contextPool.Get().(*Context)
+ defer e.contextPool.Put(c)
- h := NotFoundHandler
+ c.Reset(r, w)
+ var h HandlerFunc
if e.premiddleware == nil {
- e.findRouter(r.Host).Find(r.Method, getPath(r), c)
- h = c.Handler()
- h = applyMiddleware(h, e.middleware...)
+ h = applyMiddleware(e.router.Route(c), e.middleware...)
} else {
- h = func(c Context) error {
- e.findRouter(r.Host).Find(r.Method, getPath(r), c)
- h := c.Handler()
- h = applyMiddleware(h, e.middleware...)
- return h(c)
+ h = func(cc *Context) error {
+ h1 := applyMiddleware(e.router.Route(cc), e.middleware...)
+ return h1(cc)
}
h = applyMiddleware(h, e.premiddleware...)
}
// Execute chain
if err := h(c); err != nil {
- e.HTTPErrorHandler(err, c)
- }
-
- // Release context
- e.pool.Put(c)
-}
-
-// Start starts an HTTP server.
+ e.HTTPErrorHandler(c, err)
+ }
+}
+
+// Start stars HTTP server on given address with Echo as a handler serving requests. The server can be shutdown by
+// sending os.Interrupt signal with `ctrl+c`. Method returns only errors that are not http.ErrServerClosed.
+//
+// Note: this method is created for use in examples/demos and is deliberately simple without providing configuration
+// options.
+//
+// In need of customization use:
+//
+// ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
+// defer cancel()
+// sc := echo.StartConfig{Address: ":8080"}
+// if err := sc.Start(ctx, e); err != nil && !errors.Is(err, http.ErrServerClosed) {
+// slog.Error(err.Error())
+// }
+//
+// // or standard library `http.Server`
+//
+// s := http.Server{Addr: ":8080", Handler: e}
+// if err := s.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
+// slog.Error(err.Error())
+// }
func (e *Echo) Start(address string) error {
- e.Server.Addr = address
- return e.StartServer(e.Server)
-}
-
-// StartTLS starts an HTTPS server.
-// If `certFile` or `keyFile` is `string` the values are treated as file paths.
-// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is.
-func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err error) {
- var cert []byte
- if cert, err = filepathOrContent(certFile); err != nil {
- return
- }
-
- var key []byte
- if key, err = filepathOrContent(keyFile); err != nil {
- return
- }
-
- s := e.TLSServer
- s.TLSConfig = new(tls.Config)
- s.TLSConfig.Certificates = make([]tls.Certificate, 1)
- if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil {
- return
- }
-
- return e.startTLS(address)
-}
-
-func filepathOrContent(fileOrContent interface{}) (content []byte, err error) {
- switch v := fileOrContent.(type) {
- case string:
- return ioutil.ReadFile(v)
- case []byte:
- return v, nil
- default:
- return nil, ErrInvalidCertOrKeyType
- }
-}
-
-// StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org.
-func (e *Echo) StartAutoTLS(address string) error {
- s := e.TLSServer
- s.TLSConfig = new(tls.Config)
- s.TLSConfig.GetCertificate = e.AutoTLSManager.GetCertificate
- s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, acme.ALPNProto)
- return e.startTLS(address)
-}
-
-func (e *Echo) startTLS(address string) error {
- s := e.TLSServer
- s.Addr = address
- if !e.DisableHTTP2 {
- s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2")
- }
- return e.StartServer(e.TLSServer)
-}
-
-// StartServer starts a custom http server.
-func (e *Echo) StartServer(s *http.Server) (err error) {
- // Setup
- e.colorer.SetOutput(e.Logger.Output())
- s.ErrorLog = e.StdLogger
- s.Handler = e
- if e.Debug {
- e.Logger.SetLevel(log.DEBUG)
- }
-
- if !e.HideBanner {
- e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website))
- }
-
- if s.TLSConfig == nil {
- if e.Listener == nil {
- e.Listener, err = newListener(s.Addr)
- if err != nil {
- return err
- }
- }
- if !e.HidePort {
- e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr()))
- }
- return s.Serve(e.Listener)
- }
- if e.TLSListener == nil {
- l, err := newListener(s.Addr)
- if err != nil {
- return err
- }
- e.TLSListener = tls.NewListener(l, s.TLSConfig)
- }
- if !e.HidePort {
- e.colorer.Printf("⇨ https server started on %s\n", e.colorer.Green(e.TLSListener.Addr()))
- }
- return s.Serve(e.TLSListener)
-}
-
-// Close immediately stops the server.
-// It internally calls `http.Server#Close()`.
-func (e *Echo) Close() error {
- if err := e.TLSServer.Close(); err != nil {
- return err
- }
- return e.Server.Close()
-}
-
-// Shutdown stops the server gracefully.
-// It internally calls `http.Server#Shutdown()`.
-func (e *Echo) Shutdown(ctx stdContext.Context) error {
- if err := e.TLSServer.Shutdown(ctx); err != nil {
- return err
- }
- return e.Server.Shutdown(ctx)
-}
-
-// NewHTTPError creates a new HTTPError instance.
-func NewHTTPError(code int, message ...interface{}) *HTTPError {
- he := &HTTPError{Code: code, Message: http.StatusText(code)}
- if len(message) > 0 {
- he.Message = message[0]
- }
- return he
-}
-
-// Error makes it compatible with `error` interface.
-func (he *HTTPError) Error() string {
- return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal)
-}
-
-// SetInternal sets error to HTTPError.Internal
-func (he *HTTPError) SetInternal(err error) *HTTPError {
- he.Internal = err
- return he
+ sc := StartConfig{Address: address}
+ ctx, cancel := signal.NotifyContext(stdContext.Background(), os.Interrupt, syscall.SIGTERM) // start shutdown process on ctrl+c
+ defer cancel()
+ return sc.Start(ctx, e)
}
// WrapHandler wraps `http.Handler` into `echo.HandlerFunc`.
func WrapHandler(h http.Handler) HandlerFunc {
- return func(c Context) error {
- h.ServeHTTP(c.Response(), c.Request())
+ return func(c *Context) error {
+ req := c.Request()
+ req.Pattern = c.Path()
+ for _, p := range c.PathValues() {
+ req.SetPathValue(p.Name, p.Value)
+ }
+
+ h.ServeHTTP(c.Response(), req)
return nil
}
}
@@ -772,77 +747,88 @@ func WrapHandler(h http.Handler) HandlerFunc {
// WrapMiddleware wraps `func(http.Handler) http.Handler` into `echo.MiddlewareFunc`
func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc {
return func(next HandlerFunc) HandlerFunc {
- return func(c Context) (err error) {
+ return func(c *Context) (err error) {
+ req := c.Request()
+ req.Pattern = c.Path()
+ for _, p := range c.PathValues() {
+ req.SetPathValue(p.Name, p.Value)
+ }
+
m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c.SetRequest(r)
- c.SetResponse(NewResponse(w, c.Echo()))
+ c.SetResponse(NewResponse(w, c.echo.Logger))
err = next(c)
- })).ServeHTTP(c.Response(), c.Request())
+ })).ServeHTTP(c.Response(), req)
return
}
}
}
-func getPath(r *http.Request) string {
- path := r.URL.RawPath
- if path == "" {
- path = r.URL.Path
+func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc {
+ for i := len(middleware) - 1; i >= 0; i-- {
+ h = middleware[i](h)
}
- return path
+ return h
}
-func (e *Echo) findRouter(host string) *Router {
- if len(e.routers) > 0 {
- if r, ok := e.routers[host]; ok {
- return r
- }
- }
- return e.router
+// defaultFS emulates os.Open behaviour with filesystem opened by `os.DirFs`. Difference between `os.Open` and `fs.Open`
+// is that FS does not allow to open path that start with `..` or `/` etc. For example previously you could have `../images`
+// in your application but `fs := os.DirFS("./")` would not allow you to use `fs.Open("../images")` and this would break
+// all old applications that rely on being able to traverse up from current executable run path.
+// NB: private because you really should use fs.FS implementation instances
+type defaultFS struct {
+ fs fs.FS
+ prefix string
}
-func handlerName(h HandlerFunc) string {
- t := reflect.ValueOf(h).Type()
- if t.Kind() == reflect.Func {
- return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
+func newDefaultFS() *defaultFS {
+ dir, _ := os.Getwd()
+ return &defaultFS{
+ prefix: dir,
+ fs: os.DirFS(dir),
}
- return t.String()
-}
-
-// // PathUnescape is wraps `url.PathUnescape`
-// func PathUnescape(s string) (string, error) {
-// return url.PathUnescape(s)
-// }
-
-// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
-// connections. It's used by ListenAndServe and ListenAndServeTLS so
-// dead TCP connections (e.g. closing laptop mid-download) eventually
-// go away.
-type tcpKeepAliveListener struct {
- *net.TCPListener
}
-func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
- if c, err = ln.AcceptTCP(); err != nil {
- return
- } else if err = c.(*net.TCPConn).SetKeepAlive(true); err != nil {
- return
- } else if err = c.(*net.TCPConn).SetKeepAlivePeriod(3 * time.Minute); err != nil {
- return
- }
- return
+func (fs defaultFS) Open(name string) (fs.File, error) {
+ return fs.fs.Open(name)
}
-func newListener(address string) (*tcpKeepAliveListener, error) {
- l, err := net.Listen("tcp", address)
+func subFS(currentFs fs.FS, root string) (fs.FS, error) {
+ root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows
+ if dFS, ok := currentFs.(*defaultFS); ok {
+ // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS.
+ // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we
+ // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs
+ if !filepath.IsAbs(root) {
+ root = filepath.Join(dFS.prefix, root)
+ }
+ return &defaultFS{
+ prefix: root,
+ fs: os.DirFS(root),
+ }, nil
+ }
+ return fs.Sub(currentFs, root)
+}
+
+// MustSubFS creates sub FS from current filesystem or panic on failure.
+// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules.
+//
+// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with
+// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to
+// create sub fs which uses necessary prefix for directory path.
+func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS {
+ subFs, err := subFS(currentFs, fsRoot)
if err != nil {
- return nil, err
+ panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err))
}
- return &tcpKeepAliveListener{l.(*net.TCPListener)}, nil
+ return subFs
}
-func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc {
- for i := len(middleware) - 1; i >= 0; i-- {
- h = middleware[i](h)
+func sanitizeURI(uri string) string {
+ // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri
+ // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash
+ if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') {
+ uri = "/" + strings.TrimLeft(uri, `/\`)
}
- return h
+ return uri
}
diff --git a/echo_test.go b/echo_test.go
index 3f2e48e51..f26eed8e2 100644
--- a/echo_test.go
+++ b/echo_test.go
@@ -1,30 +1,36 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package echo
import (
"bytes"
stdContext "context"
"errors"
- "io/ioutil"
+ "fmt"
+ "io/fs"
+ "log/slog"
+ "net"
"net/http"
"net/http/httptest"
- "reflect"
+ "net/url"
+ "os"
+ "runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
)
-type (
- user struct {
- ID int `json:"id" xml:"id" form:"id" query:"id" param:"id"`
- Name string `json:"name" xml:"name" form:"name" query:"name" param:"name"`
- }
-)
+type user struct {
+ ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"`
+ Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"`
+}
const (
userJSON = `{"id":1,"name":"Jon Snow"}`
+ usersJSON = `[{"id":1,"name":"Jon Snow"}]`
userXML = `1Jon Snow`
userForm = `id=1&name=Jon Snow`
invalidContent = "invalid content"
@@ -43,6 +49,8 @@ const userXMLPretty = `
Jon Snow
`
+var dummyQuery = url.Values{"dummy": []string{"useless"}}
+
func TestEcho(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
@@ -52,50 +60,354 @@ func TestEcho(t *testing.T) {
// Router
assert.NotNil(t, e.Router())
- // DefaultHTTPErrorHandler
- e.DefaultHTTPErrorHandler(errors.New("error"), c)
+ e.HTTPErrorHandler(c, errors.New("error"))
+
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
-func TestEchoStatic(t *testing.T) {
- e := New()
+func TestNewWithConfig(t *testing.T) {
+ e := NewWithConfig(Config{})
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e.GET("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "Hello, World!")
+ })
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusTeapot, rec.Code)
+ assert.Equal(t, `Hello, World!`, rec.Body.String())
+}
+
+func TestEcho_StaticFS(t *testing.T) {
+ var testCases = []struct {
+ givenFs fs.FS
+ name string
+ givenPrefix string
+ givenFsRoot string
+ whenURL string
+ expectHeaderLocation string
+ expectBodyStartsWith string
+ expectStatus int
+ }{
+ {
+ name: "ok",
+ givenPrefix: "/images",
+ givenFs: os.DirFS("./_fixture/images"),
+ whenURL: "/images/walle.png",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
+ },
+ {
+ name: "ok, from sub fs",
+ givenPrefix: "/images",
+ givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"),
+ whenURL: "/images/walle.png",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
+ },
+ {
+ name: "No file",
+ givenPrefix: "/images",
+ givenFs: os.DirFS("_fixture/scripts"),
+ whenURL: "/images/bolt.png",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Directory",
+ givenPrefix: "/images",
+ givenFs: os.DirFS("_fixture/images"),
+ whenURL: "/images/",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Directory Redirect",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture/"),
+ whenURL: "/folder",
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/folder/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Directory Redirect with non-root path",
+ givenPrefix: "/static",
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/static",
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/static/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory 404 (request URL without slash)",
+ givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/folder", // no trailing slash
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Prefixed directory redirect (without slash redirect to slash)",
+ givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/folder", // no trailing slash
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/folder/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Directory with index.html",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory with index.html (prefix ending with slash)",
+ givenPrefix: "/assets/",
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/assets/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory with index.html (prefix ending without slash)",
+ givenPrefix: "/assets",
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/assets/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Sub-directory with index.html",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/folder/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "do not allow directory traversal (backslash - windows separator)",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture/"),
+ whenURL: `/..\\middleware/basic_auth.go`,
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "do not allow directory traversal (slash - unix separator)",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture/"),
+ whenURL: `/../middleware/basic_auth.go`,
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "open redirect vulnerability",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture/"),
+ whenURL: "/open.redirect.hackercom%2f..",
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/open.redirect.hackercom/../", // location starting with `//open` would be very bad
+ expectBodyStartsWith: "",
+ },
+ }
- assert := assert.New(t)
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
- // OK
- e.Static("/images", "_fixture/images")
- c, b := request(http.MethodGet, "/images/walle.png", e)
- assert.Equal(http.StatusOK, c)
- assert.NotEmpty(b)
+ tmpFs := tc.givenFs
+ if tc.givenFsRoot != "" {
+ tmpFs = MustSubFS(tmpFs, tc.givenFsRoot)
+ }
+ e.StaticFS(tc.givenPrefix, tmpFs)
- // No file
- e.Static("/images", "_fixture/scripts")
- c, _ = request(http.MethodGet, "/images/bolt.png", e)
- assert.Equal(http.StatusNotFound, c)
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
- // Directory
- e.Static("/images", "_fixture/images")
- c, _ = request(http.MethodGet, "/images", e)
- assert.Equal(http.StatusNotFound, c)
+ e.ServeHTTP(rec, req)
- // Directory with index.html
- e.Static("/", "_fixture")
- c, r := request(http.MethodGet, "/", e)
- assert.Equal(http.StatusOK, c)
- assert.Equal(true, strings.HasPrefix(r, ""))
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ body := rec.Body.String()
+ if tc.expectBodyStartsWith != "" {
+ assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
+ } else {
+ assert.Equal(t, "", body)
+ }
- // Sub-directory with index.html
- c, r = request(http.MethodGet, "/folder", e)
- assert.Equal(http.StatusOK, c)
- assert.Equal(true, strings.HasPrefix(r, ""))
+ if tc.expectHeaderLocation != "" {
+ assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
+ } else {
+ _, ok := rec.Result().Header["Location"]
+ assert.False(t, ok)
+ }
+ })
+ }
}
-func TestEchoFile(t *testing.T) {
+func TestEcho_FileFS(t *testing.T) {
+ var testCases = []struct {
+ whenFS fs.FS
+ name string
+ whenPath string
+ whenFile string
+ givenURL string
+ expectStartsWith []byte
+ expectCode int
+ }{
+ {
+ name: "ok",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "walle.png",
+ givenURL: "/walle",
+ expectCode: http.StatusOK,
+ expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ },
+ {
+ name: "nok, requesting invalid path",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "walle.png",
+ givenURL: "/walle.png",
+ expectCode: http.StatusNotFound,
+ expectStartsWith: []byte(`{"message":"Not Found"}`),
+ },
+ {
+ name: "nok, serving not existent file from filesystem",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "not-existent.png",
+ givenURL: "/walle",
+ expectCode: http.StatusNotFound,
+ expectStartsWith: []byte(`{"message":"Not Found"}`),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS)
+
+ req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectCode, rec.Code)
+
+ body := rec.Body.Bytes()
+ if len(body) > len(tc.expectStartsWith) {
+ body = body[:len(tc.expectStartsWith)]
+ }
+ assert.Equal(t, tc.expectStartsWith, body)
+ })
+ }
+}
+
+func TestEcho_StaticPanic(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenRoot string
+ }{
+ {
+ name: "panics for ../",
+ givenRoot: "../assets",
+ },
+ {
+ name: "panics for /",
+ givenRoot: "/assets",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ e.Filesystem = os.DirFS("./")
+
+ assert.Panics(t, func() {
+ e.Static("../assets", tc.givenRoot)
+ })
+ })
+ }
+}
+
+func TestEchoStaticRedirectIndex(t *testing.T) {
e := New()
- e.File("/walle", "_fixture/images/walle.png")
- c, b := request(http.MethodGet, "/walle", e)
- assert.Equal(t, http.StatusOK, c)
- assert.NotEmpty(t, b)
+
+ // HandlerFunc
+ ri := e.Static("/static", "_fixture")
+ assert.Equal(t, http.MethodGet, ri.Method)
+ assert.Equal(t, "/static*", ri.Path)
+ assert.Equal(t, "GET:/static*", ri.Name)
+ assert.Equal(t, []string{"*"}, ri.Parameters)
+
+ ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
+ defer cancel()
+ addr, err := startOnRandomPort(ctx, e)
+ if err != nil {
+ assert.Fail(t, err.Error())
+ }
+
+ code, body, err := doGet(fmt.Sprintf("http://%v/static", addr))
+ assert.NoError(t, err)
+ assert.True(t, strings.HasPrefix(body, ""))
+ assert.Equal(t, http.StatusOK, code)
+}
+
+func TestEchoFile(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenPath string
+ givenFile string
+ whenPath string
+ expectStartsWith string
+ expectCode int
+ }{
+ {
+ name: "ok",
+ givenPath: "/walle",
+ givenFile: "_fixture/images/walle.png",
+ whenPath: "/walle",
+ expectCode: http.StatusOK,
+ expectStartsWith: string([]byte{0x89, 0x50, 0x4e}),
+ },
+ {
+ name: "ok with relative path",
+ givenPath: "/",
+ givenFile: "./go.mod",
+ whenPath: "/",
+ expectCode: http.StatusOK,
+ expectStartsWith: "module github.com/labstack/echo/v",
+ },
+ {
+ name: "nok file does not exist",
+ givenPath: "/",
+ givenFile: "./this-file-does-not-exist",
+ whenPath: "/",
+ expectCode: http.StatusNotFound,
+ expectStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New() // we are using echo.defaultFS instance
+ e.File(tc.givenPath, tc.givenFile)
+
+ c, b := request(http.MethodGet, tc.whenPath, e)
+ assert.Equal(t, tc.expectCode, c)
+
+ if len(b) > len(tc.expectStartsWith) {
+ b = b[:len(tc.expectStartsWith)]
+ }
+ assert.Equal(t, tc.expectStartsWith, b)
+ })
+ }
}
func TestEchoMiddleware(t *testing.T) {
@@ -103,36 +415,37 @@ func TestEchoMiddleware(t *testing.T) {
buf := new(bytes.Buffer)
e.Pre(func(next HandlerFunc) HandlerFunc {
- return func(c Context) error {
- assert.Empty(t, c.Path())
+ return func(c *Context) error {
+ // before route match is found RouteInfo does not exist
+ assert.Equal(t, RouteInfo{}, c.RouteInfo())
buf.WriteString("-1")
return next(c)
}
})
e.Use(func(next HandlerFunc) HandlerFunc {
- return func(c Context) error {
+ return func(c *Context) error {
buf.WriteString("1")
return next(c)
}
})
e.Use(func(next HandlerFunc) HandlerFunc {
- return func(c Context) error {
+ return func(c *Context) error {
buf.WriteString("2")
return next(c)
}
})
e.Use(func(next HandlerFunc) HandlerFunc {
- return func(c Context) error {
+ return func(c *Context) error {
buf.WriteString("3")
return next(c)
}
})
// Route
- e.GET("/", func(c Context) error {
+ e.GET("/", func(c *Context) error {
return c.String(http.StatusOK, "OK")
})
@@ -145,11 +458,11 @@ func TestEchoMiddleware(t *testing.T) {
func TestEchoMiddlewareError(t *testing.T) {
e := New()
e.Use(func(next HandlerFunc) HandlerFunc {
- return func(c Context) error {
+ return func(c *Context) error {
return errors.New("error")
}
})
- e.GET("/", NotFoundHandler)
+ e.GET("/", notFoundHandler)
c, _ := request(http.MethodGet, "/", e)
assert.Equal(t, http.StatusInternalServerError, c)
}
@@ -158,7 +471,7 @@ func TestEchoHandler(t *testing.T) {
e := New()
// HandlerFunc
- e.GET("/ok", func(c Context) error {
+ e.GET("/ok", func(c *Context) error {
return c.String(http.StatusOK, "OK")
})
@@ -169,171 +482,302 @@ func TestEchoHandler(t *testing.T) {
func TestEchoWrapHandler(t *testing.T) {
e := New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
- h := WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+
+ var actualID string
+ var actualPattern string
+ e.GET("/:id", WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("test"))
- }))
- if assert.NoError(t, h(c)) {
- assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, "test", rec.Body.String())
- }
+ actualID = r.PathValue("id")
+ actualPattern = r.Pattern
+ })))
+
+ req := httptest.NewRequest(http.MethodGet, "/123", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "test", rec.Body.String())
+ assert.Equal(t, "123", actualID)
+ assert.Equal(t, "/:id", actualPattern)
}
func TestEchoWrapMiddleware(t *testing.T) {
e := New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
- buf := new(bytes.Buffer)
- mw := WrapMiddleware(func(h http.Handler) http.Handler {
+
+ var actualID string
+ var actualPattern string
+ e.Use(WrapMiddleware(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- buf.Write([]byte("mw"))
+ actualID = r.PathValue("id")
+ actualPattern = r.Pattern
h.ServeHTTP(w, r)
})
+ }))
+
+ e.GET("/:id", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
})
- h := mw(func(c Context) error {
- return c.String(http.StatusOK, "OK")
- })
- if assert.NoError(t, h(c)) {
- assert.Equal(t, "mw", buf.String())
- assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, "OK", rec.Body.String())
- }
+
+ req := httptest.NewRequest(http.MethodGet, "/123", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusTeapot, rec.Code)
+ assert.Equal(t, "OK", rec.Body.String())
+ assert.Equal(t, "123", actualID)
+ assert.Equal(t, "/:id", actualPattern)
}
func TestEchoConnect(t *testing.T) {
e := New()
- testMethod(t, http.MethodConnect, "/", e)
+
+ ri := e.CONNECT("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodConnect, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodConnect+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodConnect, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoDelete(t *testing.T) {
e := New()
- testMethod(t, http.MethodDelete, "/", e)
+
+ ri := e.DELETE("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodDelete, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodDelete+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodDelete, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoGet(t *testing.T) {
e := New()
- testMethod(t, http.MethodGet, "/", e)
+
+ ri := e.GET("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodGet, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodGet+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodGet, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoHead(t *testing.T) {
e := New()
- testMethod(t, http.MethodHead, "/", e)
+
+ ri := e.HEAD("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodHead, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodHead+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodHead, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoOptions(t *testing.T) {
e := New()
- testMethod(t, http.MethodOptions, "/", e)
+
+ ri := e.OPTIONS("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodOptions, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodOptions+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodOptions, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoPatch(t *testing.T) {
e := New()
- testMethod(t, http.MethodPatch, "/", e)
+
+ ri := e.PATCH("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPatch, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodPatch+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodPatch, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoPost(t *testing.T) {
e := New()
- testMethod(t, http.MethodPost, "/", e)
+
+ ri := e.POST("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPost, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodPost+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodPost, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoPut(t *testing.T) {
e := New()
- testMethod(t, http.MethodPut, "/", e)
+
+ ri := e.PUT("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPut, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodPut+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodPut, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoTrace(t *testing.T) {
e := New()
- testMethod(t, http.MethodTrace, "/", e)
-}
-func TestEchoAny(t *testing.T) { // JFC
- e := New()
- e.Any("/", func(c Context) error {
- return c.String(http.StatusOK, "Any")
+ ri := e.TRACE("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
})
+
+ assert.Equal(t, http.MethodTrace, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodTrace+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodTrace, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
-func TestEchoMatch(t *testing.T) { // JFC
+func TestEcho_Any(t *testing.T) {
e := New()
- e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error {
- return c.String(http.StatusOK, "Match")
+
+ ri := e.Any("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK from ANY")
})
+
+ assert.Equal(t, RouteAny, ri.Method)
+ assert.Equal(t, "/activate", ri.Path)
+ assert.Equal(t, RouteAny+":/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodTrace, "/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK from ANY`, body)
}
-func TestEchoURL(t *testing.T) {
+func TestEcho_Any_hasLowerPriority(t *testing.T) {
e := New()
- static := func(Context) error { return nil }
- getUser := func(Context) error { return nil }
- getFile := func(Context) error { return nil }
- e.GET("/static/file", static)
- e.GET("/users/:id", getUser)
- g := e.Group("/group")
- g.GET("/users/:uid/files/:fid", getFile)
+ e.Any("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "ANY")
+ })
+ e.GET("/activate", func(c *Context) error {
+ return c.String(http.StatusLocked, "GET")
+ })
- assert := assert.New(t)
+ status, body := request(http.MethodTrace, "/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `ANY`, body)
- assert.Equal("/static/file", e.URL(static))
- assert.Equal("/users/:id", e.URL(getUser))
- assert.Equal("/users/1", e.URL(getUser, "1"))
- assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1"))
- assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1"))
+ status, body = request(http.MethodGet, "/activate", e)
+ assert.Equal(t, http.StatusLocked, status)
+ assert.Equal(t, `GET`, body)
}
-func TestEchoRoutes(t *testing.T) {
+func TestEchoMatch(t *testing.T) { // JFC
e := New()
- routes := []*Route{
- {http.MethodGet, "/users/:user/events", ""},
- {http.MethodGet, "/users/:user/events/public", ""},
- {http.MethodPost, "/repos/:owner/:repo/git/refs", ""},
- {http.MethodPost, "/repos/:owner/:repo/git/tags", ""},
- }
- for _, r := range routes {
- e.Add(r.Method, r.Path, func(c Context) error {
- return c.String(http.StatusOK, "OK")
- })
- }
-
- if assert.Equal(t, len(routes), len(e.Routes())) {
- for _, r := range e.Routes() {
- found := false
- for _, rr := range routes {
- if r.Method == rr.Method && r.Path == rr.Path {
- found = true
- break
- }
- }
- if !found {
- t.Errorf("Route %s %s not found", r.Method, r.Path)
- }
- }
- }
+ ris := e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c *Context) error {
+ return c.String(http.StatusOK, "Match")
+ })
+ assert.Len(t, ris, 2)
}
-func TestEchoEncodedPath(t *testing.T) {
+func TestEchoServeHTTPPathEncoding(t *testing.T) {
e := New()
- e.GET("/:id", func(c Context) error {
- return c.NoContent(http.StatusOK)
+ e.GET("/with/slash", func(c *Context) error {
+ return c.String(http.StatusOK, "/with/slash")
})
- req := httptest.NewRequest(http.MethodGet, "/with%2Fslash", nil)
- rec := httptest.NewRecorder()
- e.ServeHTTP(rec, req)
- assert.Equal(t, http.StatusOK, rec.Code)
+ e.GET("/:id", func(c *Context) error {
+ return c.String(http.StatusOK, c.Param("id"))
+ })
+
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectURL string
+ expectStatus int
+ }{
+ {
+ name: "url with encoding is not decoded for routing",
+ whenURL: "/with%2Fslash",
+ expectURL: "with%2Fslash", // `%2F` is not decoded to `/` for routing
+ expectStatus: http.StatusOK,
+ },
+ {
+ name: "url without encoding is used as is",
+ whenURL: "/with/slash",
+ expectURL: "/with/slash",
+ expectStatus: http.StatusOK,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ assert.Equal(t, tc.expectURL, rec.Body.String())
+ })
+ }
}
func TestEchoGroup(t *testing.T) {
e := New()
buf := new(bytes.Buffer)
e.Use(MiddlewareFunc(func(next HandlerFunc) HandlerFunc {
- return func(c Context) error {
+ return func(c *Context) error {
buf.WriteString("0")
return next(c)
}
}))
- h := func(c Context) error {
+ h := func(c *Context) error {
return c.NoContent(http.StatusOK)
}
@@ -346,7 +790,7 @@ func TestEchoGroup(t *testing.T) {
// Group
g1 := e.Group("/group1")
g1.Use(func(next HandlerFunc) HandlerFunc {
- return func(c Context) error {
+ return func(c *Context) error {
buf.WriteString("1")
return next(c)
}
@@ -356,14 +800,14 @@ func TestEchoGroup(t *testing.T) {
// Nested groups with middleware
g2 := e.Group("/group2")
g2.Use(func(next HandlerFunc) HandlerFunc {
- return func(c Context) error {
+ return func(c *Context) error {
buf.WriteString("2")
return next(c)
}
})
g3 := g2.Group("/group3")
g3.Use(func(next HandlerFunc) HandlerFunc {
- return func(c Context) error {
+ return func(c *Context) error {
buf.WriteString("3")
return next(c)
}
@@ -382,6 +826,70 @@ func TestEchoGroup(t *testing.T) {
assert.Equal(t, "023", buf.String())
}
+func TestEcho_RouteNotFound(t *testing.T) {
+ var testCases = []struct {
+ expectRoute any
+ name string
+ whenURL string
+ expectCode int
+ }{
+ {
+ name: "404, route to static not found handler /a/c/xx",
+ whenURL: "/a/c/xx",
+ expectRoute: "GET /a/c/xx",
+ expectCode: http.StatusNotFound,
+ },
+ {
+ name: "404, route to path param not found handler /a/:file",
+ whenURL: "/a/echo.exe",
+ expectRoute: "GET /a/:file",
+ expectCode: http.StatusNotFound,
+ },
+ {
+ name: "404, route to any not found handler /*",
+ whenURL: "/b/echo.exe",
+ expectRoute: "GET /*",
+ expectCode: http.StatusNotFound,
+ },
+ {
+ name: "200, route /a/c/df to /a/c/df",
+ whenURL: "/a/c/df",
+ expectRoute: "GET /a/c/df",
+ expectCode: http.StatusOK,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+
+ okHandler := func(c *Context) error {
+ return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
+ }
+ notFoundHandler := func(c *Context) error {
+ return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
+ }
+
+ e.GET("/", okHandler)
+ e.GET("/a/c/df", okHandler)
+ e.GET("/a/b*", okHandler)
+ e.PUT("/*", okHandler)
+
+ e.RouteNotFound("/a/c/xx", notFoundHandler) // static
+ e.RouteNotFound("/a/:file", notFoundHandler) // param
+ e.RouteNotFound("/*", notFoundHandler) // any
+
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectCode, rec.Code)
+ assert.Equal(t, tc.expectRoute, rec.Body.String())
+ })
+ }
+}
+
func TestEchoNotFound(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/files", nil)
@@ -392,189 +900,328 @@ func TestEchoNotFound(t *testing.T) {
func TestEchoMethodNotAllowed(t *testing.T) {
e := New()
- e.GET("/", func(c Context) error {
+
+ e.GET("/", func(c *Context) error {
return c.String(http.StatusOK, "Echo!")
})
req := httptest.NewRequest(http.MethodPost, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
+
assert.Equal(t, http.StatusMethodNotAllowed, rec.Code)
+ assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow))
+}
+
+func TestEcho_OnAddRoute(t *testing.T) {
+ exampleRoute := Route{
+ Method: http.MethodGet,
+ Path: "/api/files/:id",
+ Handler: notFoundHandler,
+ Middlewares: nil,
+ Name: "x",
+ }
+
+ var testCases = []struct {
+ whenRoute Route
+ whenError error
+ name string
+ expectError string
+ expectAdded []string
+ expectLen int
+ }{
+ {
+ name: "ok",
+ whenRoute: exampleRoute,
+ whenError: nil,
+ expectAdded: []string{"/static", "/api/files/:id"},
+ expectError: "",
+ expectLen: 2,
+ },
+ {
+ name: "nok, error is returned",
+ whenRoute: exampleRoute,
+ whenError: errors.New("nope"),
+ expectAdded: []string{"/static"},
+ expectError: "nope",
+ expectLen: 1,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+
+ e := New()
+
+ added := make([]string, 0)
+ cnt := 0
+ e.OnAddRoute = func(route Route) error {
+ if cnt > 0 && tc.whenError != nil { // we want to GET /static to succeed for nok tests
+ return tc.whenError
+ }
+ cnt++
+ added = append(added, route.Path)
+ return nil
+ }
+
+ e.GET("/static", notFoundHandler)
+
+ var err error
+ _, err = e.AddRoute(tc.whenRoute)
+
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ assert.Len(t, e.Router().Routes(), tc.expectLen)
+ assert.Equal(t, tc.expectAdded, added)
+ })
+ }
}
func TestEchoContext(t *testing.T) {
e := New()
c := e.AcquireContext()
- assert.IsType(t, new(context), c)
+ assert.IsType(t, new(Context), c)
e.ReleaseContext(c)
}
-func TestEchoStart(t *testing.T) {
+func TestPreMiddlewares(t *testing.T) {
e := New()
- go func() {
- assert.NoError(t, e.Start(":0"))
- }()
- time.Sleep(200 * time.Millisecond)
+ assert.Equal(t, 0, len(e.PreMiddlewares()))
+
+ e.Pre(func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ return next(c)
+ }
+ })
+
+ assert.Equal(t, 1, len(e.PreMiddlewares()))
}
-func TestEchoStartTLS(t *testing.T) {
+func TestMiddlewares(t *testing.T) {
e := New()
- go func() {
- err := e.StartTLS(":0", "_fixture/certs/cert.pem", "_fixture/certs/key.pem")
- // Prevent the test to fail after closing the servers
- if err != http.ErrServerClosed {
- assert.NoError(t, err)
+ assert.Equal(t, 0, len(e.Middlewares()))
+
+ e.Use(func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ return next(c)
}
+ })
+
+ assert.Equal(t, 1, len(e.Middlewares()))
+}
+
+func TestEcho_Start(t *testing.T) {
+ e := New()
+ e.GET("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+ rndPort, err := net.Listen("tcp", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer rndPort.Close()
+ errChan := make(chan error, 1)
+ go func() {
+ errChan <- e.Start(rndPort.Addr().String())
}()
- time.Sleep(200 * time.Millisecond)
- e.Close()
+ select {
+ case <-time.After(250 * time.Millisecond):
+ t.Fatal("start did not error out")
+ case err := <-errChan:
+ expectContains := "bind: address already in use"
+ if runtime.GOOS == "windows" {
+ expectContains = "bind: Only one usage of each socket address"
+ }
+ assert.Contains(t, err.Error(), expectContains)
+ }
+}
+
+func request(method, path string, e *Echo) (int, string) {
+ req := httptest.NewRequest(method, path, nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ return rec.Code, rec.Body.String()
+}
+
+type customError struct {
+ Code int
+ Message string
}
-func TestEchoStartTLSByteString(t *testing.T) {
- cert, err := ioutil.ReadFile("_fixture/certs/cert.pem")
- require.NoError(t, err)
- key, err := ioutil.ReadFile("_fixture/certs/key.pem")
- require.NoError(t, err)
+func (ce *customError) StatusCode() int {
+ return ce.Code
+}
- testCases := []struct {
- cert interface{}
- key interface{}
- expectedErr error
- name string
+func (ce *customError) MarshalJSON() ([]byte, error) {
+ return []byte(fmt.Sprintf(`{"x":"%v"}`, ce.Message)), nil
+}
+
+func (ce *customError) Error() string {
+ return ce.Message
+}
+
+func TestDefaultHTTPErrorHandler(t *testing.T) {
+ var testCases = []struct {
+ whenError error
+ name string
+ whenMethod string
+ expectBody string
+ expectLogged string
+ expectStatus int
+ givenExposeError bool
+ givenLoggerFunc bool
}{
{
- cert: "_fixture/certs/cert.pem",
- key: "_fixture/certs/key.pem",
- expectedErr: nil,
- name: `ValidCertAndKeyFilePath`,
+ name: "ok, expose error = true, HTTPError, no wrapped err",
+ givenExposeError: true,
+ whenError: &HTTPError{Code: http.StatusTeapot, Message: "my_error"},
+ expectStatus: http.StatusTeapot,
+ expectBody: `{"message":"my_error"}` + "\n",
+ },
+ {
+ name: "ok, expose error = true, HTTPError + wrapped error",
+ givenExposeError: true,
+ whenError: HTTPError{Code: http.StatusTeapot, Message: "my_error"}.Wrap(errors.New("internal_error")),
+ expectStatus: http.StatusTeapot,
+ expectBody: `{"error":"internal_error","message":"my_error"}` + "\n",
+ },
+ {
+ name: "ok, expose error = true, HTTPError + wrapped HTTPError",
+ givenExposeError: true,
+ whenError: HTTPError{Code: http.StatusTeapot, Message: "my_error"}.Wrap(&HTTPError{Code: http.StatusTeapot, Message: "early_error"}),
+ expectStatus: http.StatusTeapot,
+ expectBody: `{"error":"code=418, message=early_error","message":"my_error"}` + "\n",
+ },
+ {
+ name: "ok, expose error = false, HTTPError",
+ whenError: &HTTPError{Code: http.StatusTeapot, Message: "my_error"},
+ expectStatus: http.StatusTeapot,
+ expectBody: `{"message":"my_error"}` + "\n",
+ },
+ {
+ name: "ok, expose error = false, HTTPError, no message",
+ whenError: &HTTPError{Code: http.StatusTeapot, Message: ""},
+ expectStatus: http.StatusTeapot,
+ expectBody: `{"message":"I'm a teapot"}` + "\n",
+ },
+ {
+ name: "ok, expose error = false, HTTPError + internal HTTPError",
+ whenError: HTTPError{Code: http.StatusTooEarly, Message: "my_error"}.Wrap(&HTTPError{Code: http.StatusTeapot, Message: "early_error"}),
+ expectStatus: http.StatusTooEarly,
+ expectBody: `{"message":"my_error"}` + "\n",
},
{
- cert: cert,
- key: key,
- expectedErr: nil,
- name: `ValidCertAndKeyByteString`,
+ name: "ok, expose error = true, Error",
+ givenExposeError: true,
+ whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")),
+ expectStatus: http.StatusInternalServerError,
+ expectBody: `{"error":"my errors wraps: internal_error","message":"Internal Server Error"}` + "\n",
},
{
- cert: cert,
- key: 1,
- expectedErr: ErrInvalidCertOrKeyType,
- name: `InvalidKeyType`,
+ name: "ok, expose error = false, Error",
+ whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")),
+ expectStatus: http.StatusInternalServerError,
+ expectBody: `{"message":"Internal Server Error"}` + "\n",
},
{
- cert: 0,
- key: key,
- expectedErr: ErrInvalidCertOrKeyType,
- name: `InvalidCertType`,
+ name: "ok, http.HEAD, expose error = true, Error",
+ givenExposeError: true,
+ whenMethod: http.MethodHead,
+ whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")),
+ expectStatus: http.StatusInternalServerError,
+ expectBody: ``,
},
{
- cert: 0,
- key: 1,
- expectedErr: ErrInvalidCertOrKeyType,
- name: `InvalidCertAndKeyTypes`,
+ name: "ok, custom error implement MarshalJSON + HTTPStatusCoder",
+ whenMethod: http.MethodGet,
+ whenError: &customError{Code: http.StatusTeapot, Message: "custom error msg"},
+ expectStatus: http.StatusTeapot,
+ expectBody: `{"x":"custom error msg"}` + "\n",
},
}
- for _, test := range testCases {
- test := test
- t.Run(test.name, func(t *testing.T) {
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ buf := new(bytes.Buffer)
e := New()
- e.HideBanner = true
-
- go func() {
- err := e.StartTLS(":0", test.cert, test.key)
- if test.expectedErr != nil {
- require.EqualError(t, err, test.expectedErr.Error())
- } else if err != http.ErrServerClosed { // Prevent the test to fail after closing the servers
- require.NoError(t, err)
- }
- }()
- time.Sleep(200 * time.Millisecond)
+ e.Logger = slog.New(slog.DiscardHandler)
+ e.Any("/path", func(c *Context) error {
+ return tc.whenError
+ })
+
+ e.HTTPErrorHandler = DefaultHTTPErrorHandler(tc.givenExposeError)
+
+ method := http.MethodGet
+ if tc.whenMethod != "" {
+ method = tc.whenMethod
+ }
+ c, b := request(method, "/path", e)
- require.NoError(t, e.Close())
+ assert.Equal(t, tc.expectStatus, c)
+ assert.Equal(t, tc.expectBody, b)
+ assert.Equal(t, tc.expectLogged, buf.String())
})
}
}
-func TestEchoStartAutoTLS(t *testing.T) {
+func TestDefaultHTTPErrorHandler_CommitedResponse(t *testing.T) {
e := New()
- errChan := make(chan error, 0)
-
- go func() {
- errChan <- e.StartAutoTLS(":0")
- }()
- time.Sleep(200 * time.Millisecond)
-
- select {
- case err := <-errChan:
- assert.NoError(t, err)
- default:
- assert.NoError(t, e.Close())
- }
-}
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ resp := httptest.NewRecorder()
+ c := e.NewContext(req, resp)
-func testMethod(t *testing.T, method, path string, e *Echo) {
- p := reflect.ValueOf(path)
- h := reflect.ValueOf(func(c Context) error {
- return c.String(http.StatusOK, method)
- })
- i := interface{}(e)
- reflect.ValueOf(i).MethodByName(method).Call([]reflect.Value{p, h})
- _, body := request(method, path, e)
- assert.Equal(t, method, body)
-}
+ c.orgResponse.Committed = true
+ errHandler := DefaultHTTPErrorHandler(false)
-func request(method, path string, e *Echo) (int, string) {
- req := httptest.NewRequest(method, path, nil)
- rec := httptest.NewRecorder()
- e.ServeHTTP(rec, req)
- return rec.Code, rec.Body.String()
+ errHandler(c, errors.New("my_error"))
+ assert.Equal(t, http.StatusOK, resp.Code)
}
-func TestHTTPError(t *testing.T) {
- err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{
- "code": 12,
- })
- assert.Equal(t, "code=400, message=map[code:12], internal=", err.Error())
-}
-
-func TestEchoClose(t *testing.T) {
+func benchmarkEchoRoutes(b *testing.B, routes []testRoute) {
e := New()
- errCh := make(chan error)
-
- go func() {
- errCh <- e.Start(":0")
- }()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ u := req.URL
+ w := httptest.NewRecorder()
- time.Sleep(200 * time.Millisecond)
+ b.ReportAllocs()
- if err := e.Close(); err != nil {
- t.Fatal(err)
+ // Add routes
+ for _, route := range routes {
+ e.Add(route.Method, route.Path, func(c *Context) error {
+ return nil
+ })
}
- assert.NoError(t, e.Close())
-
- err := <-errCh
- assert.Equal(t, err.Error(), "http: Server closed")
+ // Find routes
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ for _, route := range routes {
+ req.Method = route.Method
+ u.Path = route.Path
+ e.ServeHTTP(w, req)
+ }
+ }
}
-func TestEchoShutdown(t *testing.T) {
- e := New()
- errCh := make(chan error)
-
- go func() {
- errCh <- e.Start(":0")
- }()
+func BenchmarkEchoStaticRoutes(b *testing.B) {
+ benchmarkEchoRoutes(b, staticRoutes)
+}
- time.Sleep(200 * time.Millisecond)
+func BenchmarkEchoStaticRoutesMisses(b *testing.B) {
+ benchmarkEchoRoutes(b, staticRoutes)
+}
- if err := e.Close(); err != nil {
- t.Fatal(err)
- }
+func BenchmarkEchoGitHubAPI(b *testing.B) {
+ benchmarkEchoRoutes(b, gitHubAPI)
+}
- ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 10*time.Second)
- defer cancel()
- assert.NoError(t, e.Shutdown(ctx))
+func BenchmarkEchoGitHubAPIMisses(b *testing.B) {
+ benchmarkEchoRoutes(b, gitHubAPI)
+}
- err := <-errCh
- assert.Equal(t, err.Error(), "http: Server closed")
+func BenchmarkEchoParseAPI(b *testing.B) {
+ benchmarkEchoRoutes(b, parseAPI)
}
diff --git a/echotest/context.go b/echotest/context.go
new file mode 100644
index 000000000..2f665705d
--- /dev/null
+++ b/echotest/context.go
@@ -0,0 +1,183 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echotest
+
+import (
+ "bytes"
+ "io"
+ "mime/multipart"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+)
+
+// ContextConfig is configuration for creating echo.Context for testing purposes.
+type ContextConfig struct {
+ // Request will be used instead of default `httptest.NewRequest(http.MethodGet, "/", nil)`
+ Request *http.Request
+
+ // Response will be used instead of default `httptest.NewRecorder()`
+ Response *httptest.ResponseRecorder
+
+ // QueryValues wil be set as Request.URL.RawQuery value
+ QueryValues url.Values
+
+ // Headers wil be set as Request.Header value
+ Headers http.Header
+
+ // PathValues initializes context.PathValues with given value.
+ PathValues echo.PathValues
+
+ // RouteInfo initializes context.RouteInfo() with given value
+ RouteInfo *echo.RouteInfo
+
+ // FormValues creates form-urlencoded form out of given values. If there is no
+ // `content-type` header it will be set to `application/x-www-form-urlencoded`
+ // In case Request was not set the Request.Method is set to `POST`
+ //
+ // FormValues, MultipartForm and JSONBody are mutually exclusive.
+ FormValues url.Values
+
+ // MultipartForm creates multipart form out of given value. If there is no
+ // `content-type` header it will be set to `multipart/form-data`
+ // In case Request was not set the Request.Method is set to `POST`
+ //
+ // FormValues, MultipartForm and JSONBody are mutually exclusive.
+ MultipartForm *MultipartForm
+
+ // JSONBody creates JSON body out of given bytes. If there is no
+ // `content-type` header it will be set to `application/json`
+ // In case Request was not set the Request.Method is set to `POST`
+ //
+ // FormValues, MultipartForm and JSONBody are mutually exclusive.
+ JSONBody []byte
+}
+
+// MultipartForm is used to create multipart form out of given value
+type MultipartForm struct {
+ Fields map[string]string
+ Files []MultipartFormFile
+}
+
+// MultipartFormFile is used to create file in multipart form out of given value
+type MultipartFormFile struct {
+ Fieldname string
+ Filename string
+ Content []byte
+}
+
+// ToContext converts ContextConfig to echo.Context
+func (conf ContextConfig) ToContext(t *testing.T) *echo.Context {
+ c, _ := conf.ToContextRecorder(t)
+ return c
+}
+
+// ToContextRecorder converts ContextConfig to echo.Context and httptest.ResponseRecorder
+func (conf ContextConfig) ToContextRecorder(t *testing.T) (*echo.Context, *httptest.ResponseRecorder) {
+ if conf.Response == nil {
+ conf.Response = httptest.NewRecorder()
+ }
+ isDefaultRequest := false
+ if conf.Request == nil {
+ isDefaultRequest = true
+ conf.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+ }
+
+ if len(conf.QueryValues) > 0 {
+ conf.Request.URL.RawQuery = conf.QueryValues.Encode()
+ }
+ if len(conf.Headers) > 0 {
+ conf.Request.Header = conf.Headers
+ }
+ if len(conf.FormValues) > 0 {
+ body := strings.NewReader(url.Values(conf.FormValues).Encode())
+ conf.Request.Body = io.NopCloser(body)
+ conf.Request.ContentLength = int64(body.Len())
+
+ if conf.Request.Header.Get(echo.HeaderContentType) == "" {
+ conf.Request.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
+ }
+ if isDefaultRequest {
+ conf.Request.Method = http.MethodPost
+ }
+ } else if conf.MultipartForm != nil {
+ var body bytes.Buffer
+ mw := multipart.NewWriter(&body)
+ for field, value := range conf.MultipartForm.Fields {
+ if err := mw.WriteField(field, value); err != nil {
+ t.Fatal(err)
+ }
+ }
+ for _, file := range conf.MultipartForm.Files {
+ fw, err := mw.CreateFormFile(file.Fieldname, file.Filename)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err = fw.Write(file.Content); err != nil {
+ t.Fatal(err)
+ }
+ }
+ if err := mw.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ conf.Request.Body = io.NopCloser(&body)
+ conf.Request.ContentLength = int64(body.Len())
+ if conf.Request.Header.Get(echo.HeaderContentType) == "" {
+ conf.Request.Header.Set(echo.HeaderContentType, mw.FormDataContentType())
+ }
+ if isDefaultRequest {
+ conf.Request.Method = http.MethodPost
+ }
+ } else if conf.JSONBody != nil {
+ body := bytes.NewReader(conf.JSONBody)
+ conf.Request.Body = io.NopCloser(body)
+ conf.Request.ContentLength = int64(body.Len())
+
+ if conf.Request.Header.Get(echo.HeaderContentType) == "" {
+ conf.Request.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
+ }
+ if isDefaultRequest {
+ conf.Request.Method = http.MethodPost
+ }
+ }
+
+ ec := echo.NewContext(conf.Request, conf.Response, echo.New())
+ if conf.RouteInfo == nil {
+ conf.RouteInfo = &echo.RouteInfo{
+ Name: "",
+ Method: conf.Request.Method,
+ Path: "/test",
+ Parameters: []string{},
+ }
+ for _, p := range conf.PathValues {
+ conf.RouteInfo.Parameters = append(conf.RouteInfo.Parameters, p.Name)
+ }
+ }
+ ec.InitializeRoute(conf.RouteInfo, &conf.PathValues)
+ return ec, conf.Response
+}
+
+// ServeWithHandler serves ContextConfig with given handler and returns httptest.ResponseRecorder for response checking
+func (conf ContextConfig) ServeWithHandler(t *testing.T, handler echo.HandlerFunc, opts ...any) *httptest.ResponseRecorder {
+ c, rec := conf.ToContextRecorder(t)
+
+ errHandler := echo.DefaultHTTPErrorHandler(false)
+ for _, opt := range opts {
+ switch o := opt.(type) {
+ case echo.HTTPErrorHandler:
+ errHandler = o
+ }
+ }
+
+ err := handler(c)
+ if err != nil {
+ errHandler(c, err)
+ }
+ return rec
+}
diff --git a/echotest/context_external_test.go b/echotest/context_external_test.go
new file mode 100644
index 000000000..d98257148
--- /dev/null
+++ b/echotest/context_external_test.go
@@ -0,0 +1,27 @@
+package echotest_test
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v5/echotest"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestToContext_JSONBody(t *testing.T) {
+ c := echotest.ContextConfig{
+ JSONBody: echotest.LoadBytes(t, "testdata/test.json"),
+ }.ToContext(t)
+
+ payload := struct {
+ Field string `json:"field"`
+ }{}
+ if err := c.Bind(&payload); err != nil {
+ t.Fatal(err)
+ }
+
+ assert.Equal(t, "value", payload.Field)
+ assert.Equal(t, http.MethodPost, c.Request().Method)
+ assert.Equal(t, echo.MIMEApplicationJSON, c.Request().Header.Get(echo.HeaderContentType))
+}
diff --git a/echotest/context_test.go b/echotest/context_test.go
new file mode 100644
index 000000000..66815e4b0
--- /dev/null
+++ b/echotest/context_test.go
@@ -0,0 +1,157 @@
+package echotest
+
+import (
+ "net/http"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestServeWithHandler(t *testing.T) {
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, c.QueryParam("key"))
+ }
+ testConf := ContextConfig{
+ QueryValues: url.Values{"key": []string{"value"}},
+ }
+
+ resp := testConf.ServeWithHandler(t, handler)
+
+ assert.Equal(t, http.StatusOK, resp.Code)
+ assert.Equal(t, "value", resp.Body.String())
+}
+
+func TestServeWithHandler_error(t *testing.T) {
+ handler := func(c *echo.Context) error {
+ return echo.NewHTTPError(http.StatusBadRequest, "something went wrong")
+ }
+ testConf := ContextConfig{
+ QueryValues: url.Values{"key": []string{"value"}},
+ }
+
+ customErrHandler := echo.DefaultHTTPErrorHandler(true)
+
+ resp := testConf.ServeWithHandler(t, handler, customErrHandler)
+
+ assert.Equal(t, http.StatusBadRequest, resp.Code)
+ assert.Equal(t, `{"message":"something went wrong"}`+"\n", resp.Body.String())
+}
+
+func TestToContext_QueryValues(t *testing.T) {
+ testConf := ContextConfig{
+ QueryValues: url.Values{"t": []string{"2006-01-02"}},
+ }
+ c := testConf.ToContext(t)
+
+ v, err := echo.QueryParam[string](c, "t")
+
+ assert.NoError(t, err)
+ assert.Equal(t, "2006-01-02", v)
+}
+
+func TestToContext_Headers(t *testing.T) {
+ testConf := ContextConfig{
+ Headers: http.Header{echo.HeaderXRequestID: []string{"ABC"}},
+ }
+ c := testConf.ToContext(t)
+
+ id := c.Request().Header.Get(echo.HeaderXRequestID)
+
+ assert.Equal(t, "ABC", id)
+}
+
+func TestToContext_PathValues(t *testing.T) {
+ testConf := ContextConfig{
+ PathValues: echo.PathValues{{
+ Name: "key",
+ Value: "value",
+ }},
+ }
+ c := testConf.ToContext(t)
+
+ key := c.Param("key")
+
+ assert.Equal(t, "value", key)
+}
+
+func TestToContext_RouteInfo(t *testing.T) {
+ testConf := ContextConfig{
+ RouteInfo: &echo.RouteInfo{
+ Name: "my_route",
+ Method: http.MethodGet,
+ Path: "/:id",
+ Parameters: []string{"id"},
+ },
+ }
+ c := testConf.ToContext(t)
+
+ ri := c.RouteInfo()
+
+ assert.Equal(t, echo.RouteInfo{
+ Name: "my_route",
+ Method: http.MethodGet,
+ Path: "/:id",
+ Parameters: []string{"id"},
+ }, ri)
+}
+
+func TestToContext_FormValues(t *testing.T) {
+ testConf := ContextConfig{
+ FormValues: url.Values{"key": []string{"value"}},
+ }
+ c := testConf.ToContext(t)
+
+ assert.Equal(t, "value", c.FormValue("key"))
+ assert.Equal(t, http.MethodPost, c.Request().Method)
+ assert.Equal(t, echo.MIMEApplicationForm, c.Request().Header.Get(echo.HeaderContentType))
+}
+
+func TestToContext_MultipartForm(t *testing.T) {
+ testConf := ContextConfig{
+ MultipartForm: &MultipartForm{
+ Fields: map[string]string{
+ "key": "value",
+ },
+ Files: []MultipartFormFile{
+ {
+ Fieldname: "file",
+ Filename: "test.json",
+ Content: LoadBytes(t, "testdata/test.json"),
+ },
+ },
+ },
+ }
+ c := testConf.ToContext(t)
+
+ assert.Equal(t, "value", c.FormValue("key"))
+ assert.Equal(t, http.MethodPost, c.Request().Method)
+ assert.Equal(t, true, strings.HasPrefix(c.Request().Header.Get(echo.HeaderContentType), "multipart/form-data; boundary="))
+
+ fv, err := c.FormFile("file")
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, "test.json", fv.Filename)
+ assert.Equal(t, int64(23), fv.Size)
+}
+
+func TestToContext_JSONBody(t *testing.T) {
+ testConf := ContextConfig{
+ JSONBody: LoadBytes(t, "testdata/test.json"),
+ }
+ c := testConf.ToContext(t)
+
+ payload := struct {
+ Field string `json:"field"`
+ }{}
+ if err := c.Bind(&payload); err != nil {
+ t.Fatal(err)
+ }
+
+ assert.Equal(t, "value", payload.Field)
+ assert.Equal(t, http.MethodPost, c.Request().Method)
+ assert.Equal(t, echo.MIMEApplicationJSON, c.Request().Header.Get(echo.HeaderContentType))
+}
diff --git a/echotest/reader.go b/echotest/reader.go
new file mode 100644
index 000000000..0caceca02
--- /dev/null
+++ b/echotest/reader.go
@@ -0,0 +1,46 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echotest
+
+import (
+ "os"
+ "path/filepath"
+ "runtime"
+ "testing"
+)
+
+type loadBytesOpts func([]byte) []byte
+
+// TrimNewlineEnd instructs LoadBytes to remove `\n` from the end of loaded file.
+func TrimNewlineEnd(bytes []byte) []byte {
+ bLen := len(bytes)
+ if bLen > 1 && bytes[bLen-1] == '\n' {
+ bytes = bytes[:bLen-1]
+ }
+ return bytes
+}
+
+// LoadBytes is helper to load file contents relative to current (where test file is) package
+// directory.
+func LoadBytes(t *testing.T, name string, opts ...loadBytesOpts) []byte {
+ bytes := loadBytes(t, name, 2)
+
+ for _, f := range opts {
+ bytes = f(bytes)
+ }
+
+ return bytes
+}
+
+func loadBytes(t *testing.T, name string, callDepth int) []byte {
+ _, b, _, _ := runtime.Caller(callDepth)
+ basepath := filepath.Dir(b)
+
+ path := filepath.Join(basepath, name) // relative path
+ bytes, err := os.ReadFile(path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return bytes[:]
+}
diff --git a/echotest/reader_external_test.go b/echotest/reader_external_test.go
new file mode 100644
index 000000000..43fd57416
--- /dev/null
+++ b/echotest/reader_external_test.go
@@ -0,0 +1,25 @@
+package echotest_test
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/labstack/echo/v5/echotest"
+ "github.com/stretchr/testify/assert"
+)
+
+const testJSONContent = `{
+ "field": "value"
+}`
+
+func TestLoadBytesOK(t *testing.T) {
+ data := echotest.LoadBytes(t, "testdata/test.json")
+ assert.Equal(t, []byte(testJSONContent+"\n"), data)
+}
+
+func TestLoadBytes_custom(t *testing.T) {
+ data := echotest.LoadBytes(t, "testdata/test.json", func(bytes []byte) []byte {
+ return []byte(strings.ToUpper(string(bytes)))
+ })
+ assert.Equal(t, []byte(strings.ToUpper(testJSONContent)+"\n"), data)
+}
diff --git a/echotest/reader_test.go b/echotest/reader_test.go
new file mode 100644
index 000000000..23b3c2dd2
--- /dev/null
+++ b/echotest/reader_test.go
@@ -0,0 +1,21 @@
+package echotest
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+const testJSONContent = `{
+ "field": "value"
+}`
+
+func TestLoadBytesOK(t *testing.T) {
+ data := LoadBytes(t, "testdata/test.json")
+ assert.Equal(t, []byte(testJSONContent+"\n"), data)
+}
+
+func TestLoadBytesOK_TrimNewlineEnd(t *testing.T) {
+ data := LoadBytes(t, "testdata/test.json", TrimNewlineEnd)
+ assert.Equal(t, []byte(testJSONContent), data)
+}
diff --git a/echotest/testdata/test.json b/echotest/testdata/test.json
new file mode 100644
index 000000000..94ae65f17
--- /dev/null
+++ b/echotest/testdata/test.json
@@ -0,0 +1,3 @@
+{
+ "field": "value"
+}
diff --git a/go.mod b/go.mod
index c5db2ae1a..a2480a285 100644
--- a/go.mod
+++ b/go.mod
@@ -1,17 +1,16 @@
-module github.com/labstack/echo/v4
+module github.com/labstack/echo/v5
-go 1.12
+go 1.25.0
require (
- github.com/dgrijalva/jwt-go v3.2.0+incompatible
- github.com/labstack/echo v3.3.10+incompatible // indirect
- github.com/labstack/gommon v0.3.0
- github.com/mattn/go-colorable v0.1.4 // indirect
- github.com/mattn/go-isatty v0.0.11 // indirect
- github.com/stretchr/testify v1.4.0
- github.com/valyala/fasttemplate v1.1.0
- golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876
- golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 // indirect
- golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8 // indirect
- golang.org/x/text v0.3.2 // indirect
+ github.com/stretchr/testify v1.11.1
+ golang.org/x/net v0.49.0
+ golang.org/x/time v0.14.0
+)
+
+require (
+ github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/pmezard/go-difflib v1.0.0 // indirect
+ golang.org/x/text v0.33.0 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index 57c79877e..f1e80fc13 100644
--- a/go.sum
+++ b/go.sum
@@ -1,62 +1,16 @@
-github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
-github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=
-github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
-github.com/labstack/echo v3.3.10+incompatible h1:pGRcYk231ExFAyoAjAfD85kQzRJCRI8bbnE7CX5OEgg=
-github.com/labstack/echo v3.3.10+incompatible/go.mod h1:0INS7j/VjnFxD4E2wkz67b8cVwCLbBmJyDaka6Cmk1s=
-github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0=
-github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k=
-github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU=
-github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
-github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
-github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
-github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
-github.com/mattn/go-isatty v0.0.9 h1:d5US/mDsogSGW37IV293h//ZFaeajb69h+EHFsv2xGg=
-github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
-github.com/mattn/go-isatty v0.0.10 h1:qxFzApOv4WsAL965uUPIsXzAKCZxN2p9UqdhFS4ZW10=
-github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84=
-github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM=
-github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
-github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
-github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
-github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
-github.com/valyala/fasttemplate v1.0.1 h1:tY9CJiPnMXf1ERmG2EyK7gNUd+c6RKGD0IfU8WdUSz8=
-github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
-github.com/valyala/fasttemplate v1.1.0 h1:RZqt0yGBsps8NGvLSGW804QQqCUYYLsaOjTVHy1Ocw4=
-github.com/valyala/fasttemplate v1.1.0/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8=
-golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 h1:HuIa8hRrWRSrqYzx1qI49NNxhdi2PrY7gxVSq1JjLDc=
-golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
-golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8=
-golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
-golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 h1:sKJQZMuxjOAR/Uo2LBfU90onWEf1dF4C+0hPJCc9Mpc=
-golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
-golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ=
-golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
-golang.org/x/net v0.0.0-20191021144547-ec77196f6094 h1:5O4U9trLjNpuhpynaDsqwCk+Tw6seqJz1EbqbnzHrc8=
-golang.org/x/net v0.0.0-20191021144547-ec77196f6094/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 h1:efeOvDhwQ29Dj3SdAV/MJf8oukgn+8D8WgaCaRMchF8=
-golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a h1:aYOabOQFp6Vj6W1F80affTUvO9UxmJRx8K0gsfABByQ=
-golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191024172528-b4ff53e7a1cb h1:ZxSglHghKPYD8WDeRUzRJrUJtDF0PxsTUSxyqr9/5BI=
-golang.org/x/sys v0.0.0-20191024172528-b4ff53e7a1cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8 h1:JA8d3MPx/IToSyXZG/RhwYEtfrKO1Fxrqe8KrkiLXKM=
-golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
-golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
-golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
-golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
+github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
+golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
+golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
+golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
+golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
+golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
+golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
-gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/group.go b/group.go
index 5d9582535..d81cd9163 100644
--- a/group.go
+++ b/group.go
@@ -1,124 +1,172 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package echo
import (
+ "io/fs"
"net/http"
)
-type (
- // Group is a set of sub-routes for a specified route. It can be used for inner
- // routes that share a common middleware or functionality that should be separate
- // from the parent echo instance while still inheriting from it.
- Group struct {
- common
- host string
- prefix string
- middleware []MiddlewareFunc
- echo *Echo
- }
-)
+// Group is a set of sub-routes for a specified route. It can be used for inner
+// routes that share a common middleware or functionality that should be separate
+// from the parent echo instance while still inheriting from it.
+type Group struct {
+ echo *Echo
+ prefix string
+ middleware []MiddlewareFunc
+}
// Use implements `Echo#Use()` for sub-routes within the Group.
+// Group middlewares are not executed on request when there is no matching route found.
func (g *Group) Use(middleware ...MiddlewareFunc) {
g.middleware = append(g.middleware, middleware...)
- if len(g.middleware) == 0 {
- return
- }
- // Allow all requests to reach the group as they might get dropped if router
- // doesn't find a match, making none of the group middleware process.
- g.Any("", NotFoundHandler)
- g.Any("/*", NotFoundHandler)
}
-// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group.
-func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error.
+func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodConnect, path, h, m...)
}
-// DELETE implements `Echo#DELETE()` for sub-routes within the Group.
-func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error.
+func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodDelete, path, h, m...)
}
-// GET implements `Echo#GET()` for sub-routes within the Group.
-func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// GET implements `Echo#GET()` for sub-routes within the Group. Panics on error.
+func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodGet, path, h, m...)
}
-// HEAD implements `Echo#HEAD()` for sub-routes within the Group.
-func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error.
+func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodHead, path, h, m...)
}
-// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group.
-func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error.
+func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodOptions, path, h, m...)
}
-// PATCH implements `Echo#PATCH()` for sub-routes within the Group.
-func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error.
+func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPatch, path, h, m...)
}
-// POST implements `Echo#POST()` for sub-routes within the Group.
-func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// POST implements `Echo#POST()` for sub-routes within the Group. Panics on error.
+func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPost, path, h, m...)
}
-// PUT implements `Echo#PUT()` for sub-routes within the Group.
-func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error.
+func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodPut, path, h, m...)
}
-// TRACE implements `Echo#TRACE()` for sub-routes within the Group.
-func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route {
+// TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error.
+func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
return g.Add(http.MethodTrace, path, h, m...)
}
-// Any implements `Echo#Any()` for sub-routes within the Group.
-func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
- routes := make([]*Route, len(methods))
- for i, m := range methods {
- routes[i] = g.Add(m, path, handler, middleware...)
+// Any implements `Echo#Any()` for sub-routes within the Group. Panics on error.
+func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
+ return g.Add(RouteAny, path, handler, middleware...)
+}
+
+// Match implements `Echo#Match()` for sub-routes within the Group. Panics on error.
+func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
+ errs := make([]error, 0)
+ ris := make(Routes, 0)
+ for _, m := range methods {
+ ri, err := g.AddRoute(Route{
+ Method: m,
+ Path: path,
+ Handler: handler,
+ Middlewares: middleware,
+ })
+ if err != nil {
+ errs = append(errs, err)
+ continue
+ }
+ ris = append(ris, ri)
}
- return routes
-}
-
-// Match implements `Echo#Match()` for sub-routes within the Group.
-func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route {
- routes := make([]*Route, len(methods))
- for i, m := range methods {
- routes[i] = g.Add(m, path, handler, middleware...)
+ if len(errs) > 0 {
+ panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
}
- return routes
+ return ris
}
// Group creates a new sub-group with prefix and optional sub-group-level middleware.
+// Important! Group middlewares are only executed in case there was exact route match and not
+// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add
+// a catch-all route `/*` for the group which handler returns always 404
func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
m = append(m, g.middleware...)
m = append(m, middleware...)
sg = g.echo.Group(g.prefix+prefix, m...)
- sg.host = g.host
return
}
// Static implements `Echo#Static()` for sub-routes within the Group.
-func (g *Group) Static(prefix, root string) {
- g.static(prefix, root, g.GET)
-}
-
-// File implements `Echo#File()` for sub-routes within the Group.
-func (g *Group) File(path, file string) {
- g.file(g.prefix+path, file, g.GET)
+func (g *Group) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo {
+ subFs := MustSubFS(g.echo.Filesystem, fsRoot)
+ return g.StaticFS(pathPrefix, subFs, middleware...)
+}
+
+// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group.
+//
+// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
+// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
+// including `assets/images` as their prefix.
+func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo {
+ return g.Add(
+ http.MethodGet,
+ pathPrefix+"*",
+ StaticDirectoryHandler(filesystem, false),
+ middleware...,
+ )
+}
+
+// FileFS implements `Echo#FileFS()` for sub-routes within the Group.
+func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo {
+ return g.GET(path, StaticFileHandler(file, filesystem), m...)
+}
+
+// File implements `Echo#File()` for sub-routes within the Group. Panics on error.
+func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo {
+ handler := func(c *Context) error {
+ return c.File(file)
+ }
+ return g.Add(http.MethodGet, path, handler, middleware...)
+}
+
+// RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group.
+//
+// Example: `g.RouteNotFound("/*", func(c *echo.Context) error { return c.NoContent(http.StatusNotFound) })`
+func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return g.Add(RouteNotFound, path, h, m...)
+}
+
+// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error.
+func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
+ ri, err := g.AddRoute(Route{
+ Method: method,
+ Path: path,
+ Handler: handler,
+ Middlewares: middleware,
+ })
+ if err != nil {
+ panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
+ }
+ return ri
}
-// Add implements `Echo#Add()` for sub-routes within the Group.
-func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route {
- // Combine into a new slice to avoid accidentally passing the same slice for
+// AddRoute registers a new Routable with Router
+func (g *Group) AddRoute(route Route) (RouteInfo, error) {
+ // Combine middleware into a new slice to avoid accidentally passing the same slice for
// multiple routes, which would lead to later add() calls overwriting the
// middleware from earlier calls.
- m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
- m = append(m, g.middleware...)
- m = append(m, middleware...)
- return g.echo.add(g.host, method, g.prefix+path, handler, m...)
+ groupRoute := route.WithPrefix(g.prefix, append([]MiddlewareFunc{}, g.middleware...))
+ return g.echo.add(groupRoute)
}
diff --git a/group_test.go b/group_test.go
index 342cd29e2..7078b6497 100644
--- a/group_test.go
+++ b/group_test.go
@@ -1,58 +1,115 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package echo
import (
+ "io/fs"
"net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
"testing"
"github.com/stretchr/testify/assert"
)
-// TODO: Fix me
-func TestGroup(t *testing.T) {
- g := New().Group("/group")
- h := func(Context) error { return nil }
- g.CONNECT("/", h)
- g.DELETE("/", h)
- g.GET("/", h)
- g.HEAD("/", h)
- g.OPTIONS("/", h)
- g.PATCH("/", h)
- g.POST("/", h)
- g.PUT("/", h)
- g.TRACE("/", h)
- g.Any("/", h)
- g.Match([]string{http.MethodGet, http.MethodPost}, "/", h)
- g.Static("/static", "/tmp")
- g.File("/walle", "_fixture/images//walle.png")
+func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) {
+ e := New()
+
+ called := false
+ mw := func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ called = true
+ return c.NoContent(http.StatusTeapot)
+ }
+ }
+ // even though group has middleware it will not be executed when there are no routes under that group
+ _ = e.Group("/group", mw)
+
+ status, body := request(http.MethodGet, "/group/nope", e)
+ assert.Equal(t, http.StatusNotFound, status)
+ assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
+
+ assert.False(t, called)
+}
+
+func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) {
+ e := New()
+
+ called := false
+ mw := func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ called = true
+ return c.NoContent(http.StatusTeapot)
+ }
+ }
+ // even though group has middleware and routes when we have no match on some route the middlewares for that
+ // group will not be executed
+ g := e.Group("/group", mw)
+ g.GET("/yes", handlerFunc)
+
+ status, body := request(http.MethodGet, "/group/nope", e)
+ assert.Equal(t, http.StatusNotFound, status)
+ assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
+
+ assert.False(t, called)
+}
+
+func TestGroup_multiLevelGroup(t *testing.T) {
+ e := New()
+
+ api := e.Group("/api")
+ users := api.Group("/users")
+ users.GET("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ status, body := request(http.MethodGet, "/api/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroupFile(t *testing.T) {
+ e := New()
+ g := e.Group("/group")
+ g.File("/walle", "_fixture/images/walle.png")
+ expectedData, err := os.ReadFile("_fixture/images/walle.png")
+ assert.Nil(t, err)
+ req := httptest.NewRequest(http.MethodGet, "/group/walle", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, expectedData, rec.Body.Bytes())
}
func TestGroupRouteMiddleware(t *testing.T) {
// Ensure middleware slices are not re-used
e := New()
g := e.Group("/group")
- h := func(Context) error { return nil }
+ h := func(*Context) error { return nil }
m1 := func(next HandlerFunc) HandlerFunc {
- return func(c Context) error {
+ return func(c *Context) error {
return next(c)
}
}
m2 := func(next HandlerFunc) HandlerFunc {
- return func(c Context) error {
+ return func(c *Context) error {
return next(c)
}
}
m3 := func(next HandlerFunc) HandlerFunc {
- return func(c Context) error {
+ return func(c *Context) error {
return next(c)
}
}
m4 := func(next HandlerFunc) HandlerFunc {
- return func(c Context) error {
+ return func(c *Context) error {
return c.NoContent(404)
}
}
m5 := func(next HandlerFunc) HandlerFunc {
- return func(c Context) error {
+ return func(c *Context) error {
return c.NoContent(405)
}
}
@@ -71,17 +128,17 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
e := New()
g := e.Group("/group")
m1 := func(next HandlerFunc) HandlerFunc {
- return func(c Context) error {
+ return func(c *Context) error {
return next(c)
}
}
m2 := func(next HandlerFunc) HandlerFunc {
- return func(c Context) error {
- return c.String(http.StatusOK, c.Path())
+ return func(c *Context) error {
+ return c.String(http.StatusOK, c.RouteInfo().Path)
}
}
- h := func(c Context) error {
- return c.String(http.StatusOK, c.Path())
+ h := func(c *Context) error {
+ return c.String(http.StatusOK, c.RouteInfo().Path)
}
g.Use(m1)
g.GET("/help", h, m2)
@@ -104,3 +161,654 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
assert.Equal(t, "/*", m)
}
+
+func TestGroup_CONNECT(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.CONNECT("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodConnect, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodConnect, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_DELETE(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.DELETE("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodDelete, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodDelete, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_HEAD(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.HEAD("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodHead, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, http.MethodHead+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodHead, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_OPTIONS(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.OPTIONS("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodOptions, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodOptions, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_PATCH(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.PATCH("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPatch, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodPatch, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_POST(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.POST("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPost, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, http.MethodPost+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodPost, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_PUT(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.PUT("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPut, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, http.MethodPut+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodPut, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_TRACE(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.TRACE("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodTrace, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodTrace, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_RouteNotFound(t *testing.T) {
+ var testCases = []struct {
+ expectRoute any
+ name string
+ whenURL string
+ expectCode int
+ }{
+ {
+ name: "404, route to static not found handler /group/a/c/xx",
+ whenURL: "/group/a/c/xx",
+ expectRoute: "GET /group/a/c/xx",
+ expectCode: http.StatusNotFound,
+ },
+ {
+ name: "404, route to path param not found handler /group/a/:file",
+ whenURL: "/group/a/echo.exe",
+ expectRoute: "GET /group/a/:file",
+ expectCode: http.StatusNotFound,
+ },
+ {
+ name: "404, route to any not found handler /group/*",
+ whenURL: "/group/b/echo.exe",
+ expectRoute: "GET /group/*",
+ expectCode: http.StatusNotFound,
+ },
+ {
+ name: "200, route /group/a/c/df to /group/a/c/df",
+ whenURL: "/group/a/c/df",
+ expectRoute: "GET /group/a/c/df",
+ expectCode: http.StatusOK,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ g := e.Group("/group")
+
+ okHandler := func(c *Context) error {
+ return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
+ }
+ notFoundHandler := func(c *Context) error {
+ return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
+ }
+
+ g.GET("/", okHandler)
+ g.GET("/a/c/df", okHandler)
+ g.GET("/a/b*", okHandler)
+ g.PUT("/*", okHandler)
+
+ g.RouteNotFound("/a/c/xx", notFoundHandler) // static
+ g.RouteNotFound("/a/:file", notFoundHandler) // param
+ g.RouteNotFound("/*", notFoundHandler) // any
+
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectCode, rec.Code)
+ assert.Equal(t, tc.expectRoute, rec.Body.String())
+ })
+ }
+}
+
+func TestGroup_Any(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.Any("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK from ANY")
+ })
+
+ assert.Equal(t, RouteAny, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, RouteAny+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodTrace, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK from ANY`, body)
+}
+
+func TestGroup_Match(t *testing.T) {
+ e := New()
+
+ myMethods := []string{http.MethodGet, http.MethodPost}
+ users := e.Group("/users")
+ ris := users.Match(myMethods, "/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+ assert.Len(t, ris, 2)
+
+ for _, m := range myMethods {
+ status, body := request(m, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+ }
+}
+
+func TestGroup_MatchWithErrors(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ users.GET("/activate", func(c *Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+ myMethods := []string{http.MethodGet, http.MethodPost}
+
+ errs := func() (errs []error) {
+ defer func() {
+ if r := recover(); r != nil {
+ if tmpErr, ok := r.([]error); ok {
+ errs = tmpErr
+ return
+ }
+ panic(r)
+ }
+ }()
+
+ users.Match(myMethods, "/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+ return nil
+ }()
+ assert.Len(t, errs, 1)
+ assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed")
+
+ for _, m := range myMethods {
+ status, body := request(m, "/users/activate", e)
+
+ expect := http.StatusTeapot
+ if m == http.MethodGet {
+ expect = http.StatusOK
+ }
+ assert.Equal(t, expect, status)
+ assert.Equal(t, `OK`, body)
+ }
+}
+
+func TestGroup_Static(t *testing.T) {
+ e := New()
+
+ g := e.Group("/books")
+ ri := g.Static("/download", "_fixture")
+ assert.Equal(t, http.MethodGet, ri.Method)
+ assert.Equal(t, "/books/download*", ri.Path)
+ assert.Equal(t, "GET:/books/download*", ri.Name)
+ assert.Equal(t, []string{"*"}, ri.Parameters)
+
+ req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+ body := rec.Body.String()
+ assert.True(t, strings.HasPrefix(body, ""))
+}
+
+func TestGroup_StaticMultiTest(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenPrefix string
+ givenRoot string
+ whenURL string
+ expectHeaderLocation string
+ expectBodyStartsWith string
+ expectBodyNotContains string
+ expectStatus int
+ }{
+ {
+ name: "ok",
+ givenPrefix: "/images",
+ givenRoot: "_fixture/images",
+ whenURL: "/test/images/walle.png",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
+ },
+ {
+ name: "ok, without prefix",
+ givenPrefix: "",
+ givenRoot: "_fixture/images",
+ whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png`
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
+ },
+ {
+ name: "nok, without prefix does not serve dir index",
+ givenPrefix: "",
+ givenRoot: "_fixture/images",
+ whenURL: "/test/", // `/test` + `*` creates route `/test*`
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "No file",
+ givenPrefix: "/images",
+ givenRoot: "_fixture/scripts",
+ whenURL: "/test/images/bolt.png",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Directory",
+ givenPrefix: "/images",
+ givenRoot: "_fixture/images",
+ whenURL: "/test/images/",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Directory Redirect",
+ givenPrefix: "/",
+ givenRoot: "_fixture",
+ whenURL: "/test/folder",
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/test/folder/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Directory Redirect with non-root path",
+ givenPrefix: "/static",
+ givenRoot: "_fixture",
+ whenURL: "/test/static",
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/test/static/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory 404 (request URL without slash)",
+ givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
+ givenRoot: "_fixture",
+ whenURL: "/test/folder", // no trailing slash
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Prefixed directory redirect (without slash redirect to slash)",
+ givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
+ givenRoot: "_fixture",
+ whenURL: "/test/folder", // no trailing slash
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/test/folder/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Directory with index.html",
+ givenPrefix: "/",
+ givenRoot: "_fixture",
+ whenURL: "/test/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory with index.html (prefix ending with slash)",
+ givenPrefix: "/assets/",
+ givenRoot: "_fixture",
+ whenURL: "/test/assets/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory with index.html (prefix ending without slash)",
+ givenPrefix: "/assets",
+ givenRoot: "_fixture",
+ whenURL: "/test/assets/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Sub-directory with index.html",
+ givenPrefix: "/",
+ givenRoot: "_fixture",
+ whenURL: "/test/folder/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "nok, URL encoded path traversal (single encoding, slash - unix separator)",
+ givenRoot: "_fixture/dist/public",
+ whenURL: "/%2e%2e%2fprivate.txt",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ expectBodyNotContains: `private file`,
+ },
+ {
+ name: "nok, URL encoded path traversal (single encoding, backslash - windows separator)",
+ givenRoot: "_fixture/dist/public",
+ whenURL: "/%2e%2e%5cprivate.txt",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ expectBodyNotContains: `private file`,
+ },
+ {
+ name: "do not allow directory traversal (backslash - windows separator)",
+ givenPrefix: "/",
+ givenRoot: "_fixture/",
+ whenURL: `/test/..\\middleware/basic_auth.go`,
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "do not allow directory traversal (slash - unix separator)",
+ givenPrefix: "/",
+ givenRoot: "_fixture/",
+ whenURL: `/test/../middleware/basic_auth.go`,
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+
+ g := e.Group("/test")
+ g.Static(tc.givenPrefix, tc.givenRoot)
+
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ body := rec.Body.String()
+ if tc.expectBodyStartsWith != "" {
+ assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
+ } else {
+ assert.Equal(t, "", body)
+ }
+ if tc.expectBodyNotContains != "" {
+ assert.NotContains(t, body, tc.expectBodyNotContains)
+ }
+
+ if tc.expectHeaderLocation != "" {
+ assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
+ } else {
+ _, ok := rec.Result().Header["Location"]
+ assert.False(t, ok)
+ }
+ })
+ }
+}
+
+func TestGroup_FileFS(t *testing.T) {
+ var testCases = []struct {
+ whenFS fs.FS
+ name string
+ whenPath string
+ whenFile string
+ givenURL string
+ expectStartsWith []byte
+ expectCode int
+ }{
+ {
+ name: "ok",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "walle.png",
+ givenURL: "/assets/walle",
+ expectCode: http.StatusOK,
+ expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ },
+ {
+ name: "nok, requesting invalid path",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "walle.png",
+ givenURL: "/assets/walle.png",
+ expectCode: http.StatusNotFound,
+ expectStartsWith: []byte(`{"message":"Not Found"}`),
+ },
+ {
+ name: "nok, serving not existent file from filesystem",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "not-existent.png",
+ givenURL: "/assets/walle",
+ expectCode: http.StatusNotFound,
+ expectStartsWith: []byte(`{"message":"Not Found"}`),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ g := e.Group("/assets")
+ g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS)
+
+ req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectCode, rec.Code)
+
+ body := rec.Body.Bytes()
+ if len(body) > len(tc.expectStartsWith) {
+ body = body[:len(tc.expectStartsWith)]
+ }
+ assert.Equal(t, tc.expectStartsWith, body)
+ })
+ }
+}
+
+func TestGroup_StaticPanic(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenRoot string
+ }{
+ {
+ name: "panics for ../",
+ givenRoot: "../images",
+ },
+ {
+ name: "panics for /",
+ givenRoot: "/images",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ e.Filesystem = os.DirFS("./")
+
+ g := e.Group("/assets")
+
+ assert.Panics(t, func() {
+ g.Static("/images", tc.givenRoot)
+ })
+ })
+ }
+}
+
+func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) {
+ var testCases = []struct {
+ expectBody any
+ name string
+ whenURL string
+ expectCode int
+ givenCustom404 bool
+ expectMiddlewareCalled bool
+ }{
+ {
+ name: "ok, custom 404 handler is called with middleware",
+ givenCustom404: true,
+ whenURL: "/group/test3",
+ expectBody: "404 GET /group/*",
+ expectCode: http.StatusNotFound,
+ expectMiddlewareCalled: true, // because RouteNotFound is added after middleware is added
+ },
+ {
+ name: "ok, default group 404 handler is not called with middleware",
+ givenCustom404: false,
+ whenURL: "/group/test3",
+ expectBody: "404 GET /*",
+ expectCode: http.StatusNotFound,
+ expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added
+ },
+ {
+ name: "ok, (no slash) default group 404 handler is called with middleware",
+ givenCustom404: false,
+ whenURL: "/group",
+ expectBody: "404 GET /*",
+ expectCode: http.StatusNotFound,
+ expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+
+ okHandler := func(c *Context) error {
+ return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
+ }
+ notFoundHandler := func(c *Context) error {
+ return c.String(http.StatusNotFound, "404 "+c.Request().Method+" "+c.Path())
+ }
+
+ e := New()
+ e.GET("/test1", okHandler)
+ e.RouteNotFound("/*", notFoundHandler)
+
+ g := e.Group("/group")
+ g.GET("/test1", okHandler)
+
+ middlewareCalled := false
+ g.Use(func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ middlewareCalled = true
+ return next(c)
+ }
+ })
+ if tc.givenCustom404 {
+ g.RouteNotFound("/*", notFoundHandler)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectMiddlewareCalled, middlewareCalled)
+ assert.Equal(t, tc.expectCode, rec.Code)
+ assert.Equal(t, tc.expectBody, rec.Body.String())
+ })
+ }
+}
diff --git a/httperror.go b/httperror.go
new file mode 100644
index 000000000..6e14da3d9
--- /dev/null
+++ b/httperror.go
@@ -0,0 +1,117 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+)
+
+// HTTPStatusCoder is interface that errors can implement to produce status code for HTTP response
+type HTTPStatusCoder interface {
+ StatusCode() int
+}
+
+// StatusCode returns status code from error if it implements HTTPStatusCoder interface.
+// If error does not implement the interface it returns 0.
+func StatusCode(err error) int {
+ var sc HTTPStatusCoder
+ if errors.As(err, &sc) {
+ return sc.StatusCode()
+ }
+ return 0
+}
+
+// Following errors can produce HTTP status code by implementing HTTPStatusCoder interface
+var (
+ ErrBadRequest = &httpError{http.StatusBadRequest} // 400
+ ErrUnauthorized = &httpError{http.StatusUnauthorized} // 401
+ ErrForbidden = &httpError{http.StatusForbidden} // 403
+ ErrNotFound = &httpError{http.StatusNotFound} // 404
+ ErrMethodNotAllowed = &httpError{http.StatusMethodNotAllowed} // 405
+ ErrRequestTimeout = &httpError{http.StatusRequestTimeout} // 408
+ ErrStatusRequestEntityTooLarge = &httpError{http.StatusRequestEntityTooLarge} // 413
+ ErrUnsupportedMediaType = &httpError{http.StatusUnsupportedMediaType} // 415
+ ErrTooManyRequests = &httpError{http.StatusTooManyRequests} // 429
+ ErrInternalServerError = &httpError{http.StatusInternalServerError} // 500
+ ErrBadGateway = &httpError{http.StatusBadGateway} // 502
+ ErrServiceUnavailable = &httpError{http.StatusServiceUnavailable} // 503
+)
+
+// Following errors fall into 500 (InternalServerError) category
+var (
+ ErrValidatorNotRegistered = errors.New("validator not registered")
+ ErrRendererNotRegistered = errors.New("renderer not registered")
+ ErrInvalidRedirectCode = errors.New("invalid redirect status code")
+ ErrCookieNotFound = errors.New("cookie not found")
+ ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
+ ErrInvalidListenerNetwork = errors.New("invalid listener network")
+)
+
+// NewHTTPError creates new instance of HTTPError
+func NewHTTPError(code int, message string) *HTTPError {
+ return &HTTPError{
+ Code: code,
+ Message: message,
+ }
+}
+
+// HTTPError represents an error that occurred while handling a request.
+type HTTPError struct {
+ // Code is status code for HTTP response
+ Code int `json:"-"`
+ Message string `json:"message"`
+ err error
+}
+
+// StatusCode returns status code for HTTP response
+func (he *HTTPError) StatusCode() int {
+ return he.Code
+}
+
+// Error makes it compatible with `error` interface.
+func (he *HTTPError) Error() string {
+ msg := he.Message
+ if msg == "" {
+ msg = http.StatusText(he.Code)
+ }
+ if he.err == nil {
+ return fmt.Sprintf("code=%d, message=%v", he.Code, msg)
+ }
+ return fmt.Sprintf("code=%d, message=%v, err=%v", he.Code, msg, he.err.Error())
+}
+
+// Wrap eturns new HTTPError with given errors wrapped inside
+func (he HTTPError) Wrap(err error) error {
+ return &HTTPError{
+ Code: he.Code,
+ Message: he.Message,
+ err: err,
+ }
+}
+
+func (he *HTTPError) Unwrap() error {
+ return he.err
+}
+
+type httpError struct {
+ code int
+}
+
+func (he httpError) StatusCode() int {
+ return he.code
+}
+
+func (he httpError) Error() string {
+ return http.StatusText(he.code) // does not include status code
+}
+
+func (he httpError) Wrap(err error) error {
+ return &HTTPError{
+ Code: he.code,
+ Message: http.StatusText(he.code),
+ err: err,
+ }
+}
diff --git a/httperror_external_test.go b/httperror_external_test.go
new file mode 100644
index 000000000..91acdca25
--- /dev/null
+++ b/httperror_external_test.go
@@ -0,0 +1,52 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+// run tests as external package to get real feel for API
+package echo_test
+
+import (
+ "encoding/json"
+ "fmt"
+ "github.com/labstack/echo/v5"
+ "net/http"
+ "net/http/httptest"
+)
+
+func ExampleDefaultHTTPErrorHandler() {
+ e := echo.New()
+ e.GET("/api/endpoint", func(c *echo.Context) error {
+ return &apiError{
+ Code: http.StatusBadRequest,
+ Body: map[string]any{"message": "custom error"},
+ }
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/api/endpoint?err=1", nil)
+ resp := httptest.NewRecorder()
+
+ e.ServeHTTP(resp, req)
+
+ fmt.Printf("%d %s", resp.Code, resp.Body.String())
+
+ // Output: 400 {"error":{"message":"custom error"}}
+}
+
+type apiError struct {
+ Code int
+ Body any
+}
+
+func (e *apiError) StatusCode() int {
+ return e.Code
+}
+
+func (e *apiError) MarshalJSON() ([]byte, error) {
+ type body struct {
+ Error any `json:"error"`
+ }
+ return json.Marshal(body{Error: e.Body})
+}
+
+func (e *apiError) Error() string {
+ return http.StatusText(e.Code)
+}
diff --git a/httperror_test.go b/httperror_test.go
new file mode 100644
index 000000000..0a91bbc9c
--- /dev/null
+++ b/httperror_test.go
@@ -0,0 +1,109 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestHTTPError_StatusCode(t *testing.T) {
+ var err error = &HTTPError{Code: http.StatusBadRequest, Message: "my error message"}
+
+ code := 0
+ var sc HTTPStatusCoder
+ if errors.As(err, &sc) {
+ code = sc.StatusCode()
+ }
+ assert.Equal(t, http.StatusBadRequest, code)
+}
+
+func TestHTTPError_Error(t *testing.T) {
+ var testCases = []struct {
+ name string
+ error error
+ expect string
+ }{
+ {
+ name: "ok, without message",
+ error: &HTTPError{Code: http.StatusBadRequest},
+ expect: "code=400, message=Bad Request",
+ },
+ {
+ name: "ok, with message",
+ error: &HTTPError{Code: http.StatusBadRequest, Message: "my error message"},
+ expect: "code=400, message=my error message",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ assert.Equal(t, tc.expect, tc.error.Error())
+ })
+ }
+}
+
+func TestHTTPError_WrapUnwrap(t *testing.T) {
+ err := &HTTPError{Code: http.StatusBadRequest, Message: "bad"}
+ wrapped := err.Wrap(errors.New("my_error")).(*HTTPError)
+
+ err.Code = http.StatusOK
+ err.Message = "changed"
+
+ assert.Equal(t, http.StatusBadRequest, wrapped.Code)
+ assert.Equal(t, "bad", wrapped.Message)
+
+ assert.Equal(t, errors.New("my_error"), wrapped.Unwrap())
+ assert.Equal(t, "code=400, message=bad, err=my_error", wrapped.Error())
+}
+
+func TestNewHTTPError(t *testing.T) {
+ err := NewHTTPError(http.StatusBadRequest, "bad")
+ err2 := &HTTPError{Code: http.StatusBadRequest, Message: "bad"}
+
+ assert.Equal(t, err2, err)
+}
+
+func TestStatusCode(t *testing.T) {
+ var testCases = []struct {
+ name string
+ err error
+ expect int
+ }{
+ {
+ name: "ok, HTTPError",
+ err: &HTTPError{Code: http.StatusNotFound},
+ expect: http.StatusNotFound,
+ },
+ {
+ name: "ok, sentinel error",
+ err: ErrNotFound,
+ expect: http.StatusNotFound,
+ },
+ {
+ name: "ok, wrapped HTTPError",
+ err: fmt.Errorf("wrapped: %w", &HTTPError{Code: http.StatusTeapot}),
+ expect: http.StatusTeapot,
+ },
+ {
+ name: "nok, normal error",
+ err: errors.New("error"),
+ expect: 0,
+ },
+ {
+ name: "nok, nil",
+ err: nil,
+ expect: 0,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ assert.Equal(t, tc.expect, StatusCode(tc.err))
+ })
+ }
+}
diff --git a/ip.go b/ip.go
new file mode 100644
index 000000000..e2b287bfd
--- /dev/null
+++ b/ip.go
@@ -0,0 +1,267 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "net"
+ "net/http"
+ "strings"
+)
+
+/**
+By: https://github.com/tmshn (See: https://github.com/labstack/echo/pull/1478 , https://github.com/labstack/echox/pull/134 )
+Source: https://echo.labstack.com/guide/ip-address/
+
+IP address plays fundamental role in HTTP; it's used for access control, auditing, geo-based access analysis and more.
+Echo provides handy method [`Context#RealIP()`](https://godoc.org/github.com/labstack/echo#Context) for that.
+
+However, it is not trivial to retrieve the _real_ IP address from requests especially when you put L7 proxies before the application.
+In such situation, _real_ IP needs to be relayed on HTTP layer from proxies to your app, but you must not trust HTTP headers unconditionally.
+Otherwise, you might give someone a chance of deceiving you. **A security risk!**
+
+To retrieve IP address reliably/securely, you must let your application be aware of the entire architecture of your infrastructure.
+In Echo, this can be done by configuring `Echo#IPExtractor` appropriately.
+This guides show you why and how.
+
+> Note: if you don't set `Echo#IPExtractor` explicitly, Echo fallback to legacy behavior, which is not a good choice.
+
+Let's start from two questions to know the right direction:
+
+1. Do you put any HTTP (L7) proxy in front of the application?
+ - It includes both cloud solutions (such as AWS ALB or GCP HTTP LB) and OSS ones (such as Nginx, Envoy or Istio ingress gateway).
+2. If yes, what HTTP header do your proxies use to pass client IP to the application?
+
+## Case 1. With no proxy
+
+If you put no proxy (e.g.: directory facing to the internet), all you need to (and have to) see is IP address from network layer.
+Any HTTP header is untrustable because the clients have full control what headers to be set.
+
+In this case, use `echo.ExtractIPDirect()`.
+
+```go
+e.IPExtractor = echo.ExtractIPDirect()
+```
+
+## Case 2. With proxies using `X-Forwarded-For` header
+
+[`X-Forwared-For` (XFF)](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For) is the popular header
+to relay clients' IP addresses.
+At each hop on the proxies, they append the request IP address at the end of the header.
+
+Following example diagram illustrates this behavior.
+
+```text
+┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐
+│ "Origin" │───────────>│ Proxy 1 │───────────>│ Proxy 2 │───────────>│ Your app │
+│ (IP: a) │ │ (IP: b) │ │ (IP: c) │ │ │
+└──────────┘ └──────────┘ └──────────┘ └──────────┘
+
+Case 1.
+XFF: "" "a" "a, b"
+ ~~~~~~
+Case 2.
+XFF: "x" "x, a" "x, a, b"
+ ~~~~~~~~~
+ ↑ What your app will see
+```
+
+In this case, use **first _untrustable_ IP reading from right**. Never use first one reading from left, as it is
+configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructure".
+In above example, if `b` and `c` are trustable, the IP address of the client is `a` for both cases, never be `x`.
+
+In Echo, use `ExtractIPFromXFFHeader(...TrustOption)`.
+
+```go
+e.IPExtractor = echo.ExtractIPFromXFFHeader()
+```
+
+By default, it trusts internal IP addresses (loopback, link-local unicast, private-use and unique local address
+from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and
+[RFC4193](https://tools.ietf.org/html/rfc4193)).
+To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s.
+
+E.g.:
+
+```go
+e.IPExtractor = echo.ExtractIPFromXFFHeader(
+ TrustLinkLocal(false),
+ TrustIPRanges(lbIPRange),
+)
+```
+
+- Ref: https://godoc.org/github.com/labstack/echo#TrustOption
+
+## Case 3. With proxies using `X-Real-IP` header
+
+`X-Real-IP` is another HTTP header to relay clients' IP addresses, but it carries only one address unlike XFF.
+
+If your proxies set this header, use `ExtractIPFromRealIPHeader(...TrustOption)`.
+
+```go
+e.IPExtractor = echo.ExtractIPFromRealIPHeader()
+```
+
+Again, it trusts internal IP addresses by default (loopback, link-local unicast, private-use and unique local address
+from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and
+[RFC4193](https://tools.ietf.org/html/rfc4193)).
+To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s.
+
+- Ref: https://godoc.org/github.com/labstack/echo#TrustOption
+
+> **Never forget** to configure the outermost proxy (i.e.; at the edge of your infrastructure) **not to pass through incoming headers**.
+> Otherwise there is a chance of fraud, as it is what clients can control.
+
+## About default behavior
+
+In default behavior, Echo sees all of first XFF header, X-Real-IP header and IP from network layer.
+
+As you might already notice, after reading this article, this is not good.
+Sole reason this is default is just backward compatibility.
+
+## Private IP ranges
+
+See: https://en.wikipedia.org/wiki/Private_network
+
+Private IPv4 address ranges (RFC 1918):
+* 10.0.0.0 – 10.255.255.255 (24-bit block)
+* 172.16.0.0 – 172.31.255.255 (20-bit block)
+* 192.168.0.0 – 192.168.255.255 (16-bit block)
+
+Private IPv6 address ranges:
+* fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA)
+
+*/
+
+type ipChecker struct {
+ trustExtraRanges []*net.IPNet
+ trustLoopback bool
+ trustLinkLocal bool
+ trustPrivateNet bool
+}
+
+// TrustOption is config for which IP address to trust
+type TrustOption func(*ipChecker)
+
+// TrustLoopback configures if you trust loopback address (default: true).
+func TrustLoopback(v bool) TrustOption {
+ return func(c *ipChecker) {
+ c.trustLoopback = v
+ }
+}
+
+// TrustLinkLocal configures if you trust link-local address (default: true).
+func TrustLinkLocal(v bool) TrustOption {
+ return func(c *ipChecker) {
+ c.trustLinkLocal = v
+ }
+}
+
+// TrustPrivateNet configures if you trust private network address (default: true).
+func TrustPrivateNet(v bool) TrustOption {
+ return func(c *ipChecker) {
+ c.trustPrivateNet = v
+ }
+}
+
+// TrustIPRange add trustable IP ranges using CIDR notation.
+func TrustIPRange(ipRange *net.IPNet) TrustOption {
+ return func(c *ipChecker) {
+ c.trustExtraRanges = append(c.trustExtraRanges, ipRange)
+ }
+}
+
+func newIPChecker(configs []TrustOption) *ipChecker {
+ checker := &ipChecker{trustLoopback: true, trustLinkLocal: true, trustPrivateNet: true}
+ for _, configure := range configs {
+ configure(checker)
+ }
+ return checker
+}
+
+func (c *ipChecker) trust(ip net.IP) bool {
+ if c.trustLoopback && ip.IsLoopback() {
+ return true
+ }
+ if c.trustLinkLocal && ip.IsLinkLocalUnicast() {
+ return true
+ }
+ if c.trustPrivateNet && ip.IsPrivate() {
+ return true
+ }
+ for _, trustedRange := range c.trustExtraRanges {
+ if trustedRange.Contains(ip) {
+ return true
+ }
+ }
+ return false
+}
+
+// IPExtractor is a function to extract IP addr from http.Request.
+// Set appropriate one to Echo#IPExtractor.
+// See https://echo.labstack.com/guide/ip-address for more details.
+type IPExtractor func(*http.Request) string
+
+// ExtractIPDirect extracts IP address using actual IP address.
+// Use this if your server faces to internet directory (i.e.: uses no proxy).
+func ExtractIPDirect() IPExtractor {
+ return extractIP
+}
+
+func extractIP(req *http.Request) string {
+ host, _, err := net.SplitHostPort(req.RemoteAddr)
+ if err != nil {
+ if net.ParseIP(req.RemoteAddr) != nil {
+ return req.RemoteAddr
+ }
+ return ""
+ }
+ return host
+}
+
+// ExtractIPFromRealIPHeader extracts IP address using x-real-ip header.
+// Use this if you put proxy which uses this header.
+func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor {
+ checker := newIPChecker(options)
+ return func(req *http.Request) string {
+ realIP := req.Header.Get(HeaderXRealIP)
+ if realIP != "" {
+ realIP = strings.TrimPrefix(realIP, "[")
+ realIP = strings.TrimSuffix(realIP, "]")
+ if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) {
+ return realIP
+ }
+ }
+ return extractIP(req)
+ }
+}
+
+// ExtractIPFromXFFHeader extracts IP address using x-forwarded-for header.
+// Use this if you put proxy which uses this header.
+// This returns nearest untrustable IP. If all IPs are trustable, returns furthest one (i.e.: XFF[0]).
+func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor {
+ checker := newIPChecker(options)
+ return func(req *http.Request) string {
+ directIP := extractIP(req)
+ xffs := req.Header[HeaderXForwardedFor]
+ if len(xffs) == 0 {
+ return directIP
+ }
+ ips := append(strings.Split(strings.Join(xffs, ","), ","), directIP)
+ for i := len(ips) - 1; i >= 0; i-- {
+ ips[i] = strings.TrimSpace(ips[i])
+ ips[i] = strings.TrimPrefix(ips[i], "[")
+ ips[i] = strings.TrimSuffix(ips[i], "]")
+ ip := net.ParseIP(ips[i])
+ if ip == nil {
+ // Unable to parse IP; cannot trust entire records
+ return directIP
+ }
+ if !checker.trust(ip) {
+ return ip.String()
+ }
+ }
+ // All of the IPs are trusted; return first element because it is furthest from server (best effort strategy).
+ return strings.TrimSpace(ips[0])
+ }
+}
diff --git a/ip_test.go b/ip_test.go
new file mode 100644
index 000000000..29bf6afde
--- /dev/null
+++ b/ip_test.go
@@ -0,0 +1,716 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "net"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func mustParseCIDR(s string) *net.IPNet {
+ _, IPNet, err := net.ParseCIDR(s)
+ if err != nil {
+ panic(err)
+ }
+ return IPNet
+}
+
+func TestIPChecker_TrustOption(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenIP string
+ givenOptions []TrustOption
+ expect bool
+ }{
+ {
+ name: "ip is within trust range, trusts additional private IPV6 network",
+ givenOptions: []TrustOption{
+ TrustLoopback(false),
+ TrustLinkLocal(false),
+ TrustPrivateNet(false),
+ // this is private IPv6 ip
+ // CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48
+ // Address: 2001:0db8:0000:0000:0000:0000:0000:0103
+ // Range start: 2001:0db8:0000:0000:0000:0000:0000:0000
+ // Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff
+ TrustIPRange(mustParseCIDR("2001:db8::103/48")),
+ },
+ whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103",
+ expect: true,
+ },
+ {
+ name: "ip is within trust range, trusts additional private IPV6 network",
+ givenOptions: []TrustOption{
+ TrustIPRange(mustParseCIDR("2001:db8::103/48")),
+ },
+ whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103",
+ expect: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ checker := newIPChecker(tc.givenOptions)
+
+ result := checker.trust(net.ParseIP(tc.whenIP))
+ assert.Equal(t, tc.expect, result)
+ })
+ }
+}
+
+func TestTrustIPRange(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenRange string
+ whenIP string
+ expect bool
+ }{
+ {
+ name: "ip is within trust range, IPV6 network range",
+ // CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48
+ // Address: 2001:0db8:0000:0000:0000:0000:0000:0103
+ // Range start: 2001:0db8:0000:0000:0000:0000:0000:0000
+ // Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff
+ givenRange: "2001:db8::103/48",
+ whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103",
+ expect: true,
+ },
+ {
+ name: "ip is outside (upper bounds) of trust range, IPV6 network range",
+ givenRange: "2001:db8::103/48",
+ whenIP: "2001:0db8:0001:0000:0000:0000:0000:0000",
+ expect: false,
+ },
+ {
+ name: "ip is outside (lower bounds) of trust range, IPV6 network range",
+ givenRange: "2001:db8::103/48",
+ whenIP: "2001:0db7:ffff:ffff:ffff:ffff:ffff:ffff",
+ expect: false,
+ },
+ {
+ name: "ip is within trust range, IPV4 network range",
+ // CIDR Notation: 8.8.8.8/24
+ // Address: 8.8.8.8
+ // Range start: 8.8.8.0
+ // Range end: 8.8.8.255
+ givenRange: "8.8.8.0/24",
+ whenIP: "8.8.8.8",
+ expect: true,
+ },
+ {
+ name: "ip is within trust range, IPV4 network range",
+ // CIDR Notation: 8.8.8.8/24
+ // Address: 8.8.8.8
+ // Range start: 8.8.8.0
+ // Range end: 8.8.8.255
+ givenRange: "8.8.8.0/24",
+ whenIP: "8.8.8.8",
+ expect: true,
+ },
+ {
+ name: "ip is outside (upper bounds) of trust range, IPV4 network range",
+ givenRange: "8.8.8.0/24",
+ whenIP: "8.8.9.0",
+ expect: false,
+ },
+ {
+ name: "ip is outside (lower bounds) of trust range, IPV4 network range",
+ givenRange: "8.8.8.0/24",
+ whenIP: "8.8.7.255",
+ expect: false,
+ },
+ {
+ name: "public ip, trust everything in IPV4 network range",
+ givenRange: "0.0.0.0/0",
+ whenIP: "8.8.8.8",
+ expect: true,
+ },
+ {
+ name: "internal ip, trust everything in IPV4 network range",
+ givenRange: "0.0.0.0/0",
+ whenIP: "127.0.10.1",
+ expect: true,
+ },
+ {
+ name: "public ip, trust everything in IPV6 network range",
+ givenRange: "::/0",
+ whenIP: "2a00:1450:4026:805::200e",
+ expect: true,
+ },
+ {
+ name: "internal ip, trust everything in IPV6 network range",
+ givenRange: "::/0",
+ whenIP: "0:0:0:0:0:0:0:1",
+ expect: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ cidr := mustParseCIDR(tc.givenRange)
+ checker := newIPChecker([]TrustOption{
+ TrustLoopback(false), // disable to avoid interference
+ TrustLinkLocal(false), // disable to avoid interference
+ TrustPrivateNet(false), // disable to avoid interference
+
+ TrustIPRange(cidr),
+ })
+
+ result := checker.trust(net.ParseIP(tc.whenIP))
+ assert.Equal(t, tc.expect, result)
+ })
+ }
+}
+
+func TestTrustPrivateNet(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenIP string
+ expect bool
+ }{
+ {
+ name: "do not trust public IPv4 address",
+ whenIP: "8.8.8.8",
+ expect: false,
+ },
+ {
+ name: "do not trust public IPv6 address",
+ whenIP: "2a00:1450:4026:805::200e",
+ expect: false,
+ },
+
+ { // Class A: 10.0.0.0 — 10.255.255.255
+ name: "do not trust IPv4 just outside of class A (lower bounds)",
+ whenIP: "9.255.255.255",
+ expect: false,
+ },
+ {
+ name: "do not trust IPv4 just outside of class A (upper bounds)",
+ whenIP: "11.0.0.0",
+ expect: false,
+ },
+ {
+ name: "trust IPv4 of class A (lower bounds)",
+ whenIP: "10.0.0.0",
+ expect: true,
+ },
+ {
+ name: "trust IPv4 of class A (upper bounds)",
+ whenIP: "10.255.255.255",
+ expect: true,
+ },
+
+ { // Class B: 172.16.0.0 — 172.31.255.255
+ name: "do not trust IPv4 just outside of class B (lower bounds)",
+ whenIP: "172.15.255.255",
+ expect: false,
+ },
+ {
+ name: "do not trust IPv4 just outside of class B (upper bounds)",
+ whenIP: "172.32.0.0",
+ expect: false,
+ },
+ {
+ name: "trust IPv4 of class B (lower bounds)",
+ whenIP: "172.16.0.0",
+ expect: true,
+ },
+ {
+ name: "trust IPv4 of class B (upper bounds)",
+ whenIP: "172.31.255.255",
+ expect: true,
+ },
+
+ { // Class C: 192.168.0.0 — 192.168.255.255
+ name: "do not trust IPv4 just outside of class C (lower bounds)",
+ whenIP: "192.167.255.255",
+ expect: false,
+ },
+ {
+ name: "do not trust IPv4 just outside of class C (upper bounds)",
+ whenIP: "192.169.0.0",
+ expect: false,
+ },
+ {
+ name: "trust IPv4 of class C (lower bounds)",
+ whenIP: "192.168.0.0",
+ expect: true,
+ },
+ {
+ name: "trust IPv4 of class C (upper bounds)",
+ whenIP: "192.168.255.255",
+ expect: true,
+ },
+
+ { // fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA)
+ // splits the address block in two equally sized halves, fc00::/8 and fd00::/8.
+ // https://en.wikipedia.org/wiki/Unique_local_address
+ name: "trust IPv6 private address",
+ whenIP: "fdfc:3514:2cb3:4bd5::",
+ expect: true,
+ },
+ {
+ name: "do not trust IPv6 just out of /fd (upper bounds)",
+ whenIP: "/fe00:0000:0000:0000:0000",
+ expect: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ checker := newIPChecker([]TrustOption{
+ TrustLoopback(false), // disable to avoid interference
+ TrustLinkLocal(false), // disable to avoid interference
+
+ TrustPrivateNet(true),
+ })
+
+ result := checker.trust(net.ParseIP(tc.whenIP))
+ assert.Equal(t, tc.expect, result)
+ })
+ }
+}
+
+func TestTrustLinkLocal(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenIP string
+ expect bool
+ }{
+ {
+ name: "trust link local IPv4 address (lower bounds)",
+ whenIP: "169.254.0.0",
+ expect: true,
+ },
+ {
+ name: "trust link local IPv4 address (upper bounds)",
+ whenIP: "169.254.255.255",
+ expect: true,
+ },
+ {
+ name: "do not trust link local IPv4 address (outside of lower bounds)",
+ whenIP: "169.253.255.255",
+ expect: false,
+ },
+ {
+ name: "do not trust link local IPv4 address (outside of upper bounds)",
+ whenIP: "169.255.0.0",
+ expect: false,
+ },
+ {
+ name: "trust link local IPv6 address ",
+ whenIP: "fe80::1",
+ expect: true,
+ },
+ {
+ name: "do not trust link local IPv6 address ",
+ whenIP: "fec0::1",
+ expect: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ checker := newIPChecker([]TrustOption{
+ TrustLoopback(false), // disable to avoid interference
+ TrustPrivateNet(false), // disable to avoid interference
+
+ TrustLinkLocal(true),
+ })
+
+ result := checker.trust(net.ParseIP(tc.whenIP))
+ assert.Equal(t, tc.expect, result)
+ })
+ }
+}
+
+func TestTrustLoopback(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenIP string
+ expect bool
+ }{
+ {
+ name: "trust IPv4 as localhost",
+ whenIP: "127.0.0.1",
+ expect: true,
+ },
+ {
+ name: "trust IPv6 as localhost",
+ whenIP: "::1",
+ expect: true,
+ },
+ {
+ name: "do not trust public ip as localhost",
+ whenIP: "8.8.8.8",
+ expect: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ checker := newIPChecker([]TrustOption{
+ TrustLinkLocal(false), // disable to avoid interference
+ TrustPrivateNet(false), // disable to avoid interference
+
+ TrustLoopback(true),
+ })
+
+ result := checker.trust(net.ParseIP(tc.whenIP))
+ assert.Equal(t, tc.expect, result)
+ })
+ }
+}
+
+func TestExtractIPDirect(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenRequest http.Request
+ expectIP string
+ }{
+ {
+ name: "request has no headers, extracts IP from request remote addr",
+ whenRequest: http.Request{
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "remote addr is IP without port, extracts IP directly",
+ whenRequest: http.Request{
+ RemoteAddr: "203.0.113.1",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "remote addr is IPv6 without port, extracts IP directly",
+ whenRequest: http.Request{
+ RemoteAddr: "2001:db8::1",
+ },
+ expectIP: "2001:db8::1",
+ },
+ {
+ name: "remote addr is IPv6 with port",
+ whenRequest: http.Request{
+ RemoteAddr: "[2001:db8::1]:8080",
+ },
+ expectIP: "2001:db8::1",
+ },
+ {
+ name: "remote addr is invalid, returns empty string",
+ whenRequest: http.Request{
+ RemoteAddr: "invalid-ip-format",
+ },
+ expectIP: "",
+ },
+ {
+ name: "request is from external IP has X-Real-Ip header, extractor still extracts IP from request remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"203.0.113.10"},
+ },
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "request is from internal IP and has Real-IP header, extractor still extracts internal IP from request remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"203.0.113.10"},
+ },
+ RemoteAddr: "127.0.0.1:8080",
+ },
+ expectIP: "127.0.0.1",
+ },
+ {
+ name: "request is from external IP and has XFF + Real-IP header, extractor still extracts external IP from request remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"203.0.113.10"},
+ HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
+ },
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "request is from internal IP and has XFF + Real-IP header, extractor still extracts internal IP from request remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"127.0.0.1"},
+ HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
+ },
+ RemoteAddr: "127.0.0.1:8080",
+ },
+ expectIP: "127.0.0.1",
+ },
+ {
+ name: "request is from external IP and has XFF header, extractor still extracts external IP from request remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
+ },
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "request is from internal IP and has XFF header, extractor still extracts internal IP from request remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
+ },
+ RemoteAddr: "127.0.0.1:8080",
+ },
+ expectIP: "127.0.0.1",
+ },
+ {
+ name: "request is from internal IP and has INVALID XFF header, extractor still extracts internal IP from request remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"this.is.broken.lol, 169.254.0.101"},
+ },
+ RemoteAddr: "127.0.0.1:8080",
+ },
+ expectIP: "127.0.0.1",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ extractedIP := ExtractIPDirect()(&tc.whenRequest)
+ assert.Equal(t, tc.expectIP, extractedIP)
+ })
+ }
+}
+
+func TestExtractIPFromRealIPHeader(t *testing.T) {
+ _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24")
+ _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64")
+
+ var testCases = []struct {
+ whenRequest http.Request
+ name string
+ expectIP string
+ givenTrustOptions []TrustOption
+ }{
+ {
+ name: "request has no headers, extracts IP from request remote addr",
+ whenRequest: http.Request{
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "request is from external IP has INVALID external X-Real-Ip header, extract IP from remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"xxx.yyy.zzz.ccc"}, // <-- this is invalid
+ },
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"203.0.113.199"}, // <-- this is untrusted
+ },
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"[2001:db8::113:199]"}, // <-- this is untrusted
+ },
+ RemoteAddr: "[2001:db8::113:1]:8080",
+ },
+ expectIP: "2001:db8::113:1",
+ },
+ {
+ name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
+ givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
+ TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
+ },
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"203.0.113.199"},
+ },
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.199",
+ },
+ {
+ name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
+ givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
+ TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64"
+ },
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"[2001:db8::113:199]"},
+ },
+ RemoteAddr: "[2001:db8::113:1]:8080",
+ },
+ expectIP: "2001:db8::113:199",
+ },
+ {
+ name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
+ givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
+ TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
+ },
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"203.0.113.199"},
+ HeaderXForwardedFor: []string{"203.0.113.198, 203.0.113.197"}, // <-- should not affect anything
+ },
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.199",
+ },
+ {
+ name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
+ givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
+ TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64"
+ },
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"[2001:db8::113:199]"},
+ HeaderXForwardedFor: []string{"[2001:db8::113:198], [2001:db8::113:197]"}, // <-- should not affect anything
+ },
+ RemoteAddr: "[2001:db8::113:1]:8080",
+ },
+ expectIP: "2001:db8::113:199",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ extractedIP := ExtractIPFromRealIPHeader(tc.givenTrustOptions...)(&tc.whenRequest)
+ assert.Equal(t, tc.expectIP, extractedIP)
+ })
+ }
+}
+
+func TestExtractIPFromXFFHeader(t *testing.T) {
+ _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24")
+ _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64")
+
+ var testCases = []struct {
+ whenRequest http.Request
+ name string
+ expectIP string
+ givenTrustOptions []TrustOption
+ }{
+ {
+ name: "request has no headers, extracts IP from request remote addr",
+ whenRequest: http.Request{
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "request has INVALID external XFF header, extract IP from remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"xxx.yyy.zzz.ccc, 127.0.0.2"}, // <-- this is invalid
+ },
+ RemoteAddr: "127.0.0.1:8080",
+ },
+ expectIP: "127.0.0.1",
+ },
+ {
+ name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"127.0.0.3, 127.0.0.2, 127.0.0.1"},
+ },
+ RemoteAddr: "127.0.0.1:8080",
+ },
+ expectIP: "127.0.0.3",
+ },
+ {
+ name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"[fe80::3], [fe80::2], [fe80::1]"},
+ },
+ RemoteAddr: "[fe80::1]:8080",
+ },
+ expectIP: "fe80::3",
+ },
+ {
+ name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"203.0.113.199"}, // <-- this is untrusted
+ },
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"[2001:db8::1]"}, // <-- this is untrusted
+ },
+ RemoteAddr: "[2001:db8::2]:8080",
+ },
+ expectIP: "2001:db8::2",
+ },
+ {
+ name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header",
+ givenTrustOptions: []TrustOption{
+ TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
+ },
+ // from request its seems that request has been proxied through 6 servers.
+ // 1) 203.0.1.100 (this is external IP set by 203.0.100.100 which we do not trust - could be spoofed)
+ // 2) 203.0.100.100 (this is outside of our network but set by 203.0.113.199 which we trust to set correct IPs)
+ // 3) 203.0.113.199 (we trust, for example maybe our proxy from some other office)
+ // 4) 192.168.1.100 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products)
+ // 5) 127.0.0.1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing)
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"203.0.1.100, 203.0.100.100, 203.0.113.199, 192.168.1.100"},
+ },
+ RemoteAddr: "127.0.0.1:8080", // IP of proxy upstream of our APP
+ },
+ expectIP: "203.0.100.100", // this is first trusted IP in XFF chain
+ },
+ {
+ name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header",
+ givenTrustOptions: []TrustOption{
+ TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64"
+ },
+ // from request its seems that request has been proxied through 6 servers.
+ // 1) 2001:db8:1::1:100 (this is external IP set by 2001:db8:2::100:100 which we do not trust - could be spoofed)
+ // 2) 2001:db8:2::100:100 (this is outside of our network but set by 2001:db8::113:199 which we trust to set correct IPs)
+ // 3) 2001:db8::113:199 (we trust, for example maybe our proxy from some other office)
+ // 4) fd12:3456:789a:1::1 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products)
+ // 5) fe80::1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing)
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"[2001:db8:1::1:100], [2001:db8:2::100:100], [2001:db8::113:199], [fd12:3456:789a:1::1]"},
+ },
+ RemoteAddr: "[fe80::1]:8080", // IP of proxy upstream of our APP
+ },
+ expectIP: "2001:db8:2::100:100", // this is first trusted IP in XFF chain
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ extractedIP := ExtractIPFromXFFHeader(tc.givenTrustOptions...)(&tc.whenRequest)
+ assert.Equal(t, tc.expectIP, extractedIP)
+ })
+ }
+}
diff --git a/json.go b/json.go
new file mode 100644
index 000000000..a969ccb8c
--- /dev/null
+++ b/json.go
@@ -0,0 +1,29 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "encoding/json"
+)
+
+// DefaultJSONSerializer implements JSON encoding using encoding/json.
+type DefaultJSONSerializer struct{}
+
+// Serialize converts an interface into a json and writes it to the response.
+// You can optionally use the indent parameter to produce pretty JSONs.
+func (d DefaultJSONSerializer) Serialize(c *Context, target any, indent string) error {
+ enc := json.NewEncoder(c.Response())
+ if indent != "" {
+ enc.SetIndent("", indent)
+ }
+ return enc.Encode(target)
+}
+
+// Deserialize reads a JSON from a request body and converts it into an interface.
+func (d DefaultJSONSerializer) Deserialize(c *Context, target any) error {
+ if err := json.NewDecoder(c.Request().Body).Decode(target); err != nil {
+ return ErrBadRequest.Wrap(err)
+ }
+ return nil
+}
diff --git a/json_test.go b/json_test.go
new file mode 100644
index 000000000..1804b3e82
--- /dev/null
+++ b/json_test.go
@@ -0,0 +1,100 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "github.com/stretchr/testify/assert"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+// Note this test is deliberately simple as there's not a lot to test.
+// Just need to ensure it writes JSONs. The heavy work is done by the context methods.
+func TestDefaultJSONCodec_Encode(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ // Echo
+ assert.Equal(t, e, c.Echo())
+
+ // Request
+ assert.NotNil(t, c.Request())
+
+ // Response
+ assert.NotNil(t, c.Response())
+
+ //--------
+ // Default JSON encoder
+ //--------
+
+ enc := new(DefaultJSONSerializer)
+
+ err := enc.Serialize(c, user{ID: 1, Name: "Jon Snow"}, "")
+ if assert.NoError(t, err) {
+ assert.Equal(t, userJSON+"\n", rec.Body.String())
+ }
+
+ req = httptest.NewRequest(http.MethodPost, "/", nil)
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+ err = enc.Serialize(c, user{ID: 1, Name: "Jon Snow"}, " ")
+ if assert.NoError(t, err) {
+ assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
+ }
+}
+
+// Note this test is deliberately simple as there's not a lot to test.
+// Just need to ensure it writes JSONs. The heavy work is done by the context methods.
+func TestDefaultJSONCodec_Decode(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ // Echo
+ assert.Equal(t, e, c.Echo())
+
+ // Request
+ assert.NotNil(t, c.Request())
+
+ // Response
+ assert.NotNil(t, c.Response())
+
+ //--------
+ // Default JSON encoder
+ //--------
+
+ enc := new(DefaultJSONSerializer)
+
+ var u = user{}
+ err := enc.Deserialize(c, &u)
+ if assert.NoError(t, err) {
+ assert.Equal(t, u, user{ID: 1, Name: "Jon Snow"})
+ }
+
+ var userUnmarshalSyntaxError = user{}
+ req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(invalidContent))
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+ err = enc.Deserialize(c, &userUnmarshalSyntaxError)
+ assert.IsType(t, &HTTPError{}, err)
+ assert.EqualError(t, err, "code=400, message=Bad Request, err=invalid character 'i' looking for beginning of value")
+
+ var userUnmarshalTypeError = struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ }{}
+
+ req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+ err = enc.Deserialize(c, &userUnmarshalTypeError)
+ assert.IsType(t, &HTTPError{}, err)
+ assert.EqualError(t, err, "code=400, message=Bad Request, err=json: cannot unmarshal number into Go struct field .id of type string")
+
+}
diff --git a/log.go b/log.go
deleted file mode 100644
index 3f8de5904..000000000
--- a/log.go
+++ /dev/null
@@ -1,41 +0,0 @@
-package echo
-
-import (
- "io"
-
- "github.com/labstack/gommon/log"
-)
-
-type (
- // Logger defines the logging interface.
- Logger interface {
- Output() io.Writer
- SetOutput(w io.Writer)
- Prefix() string
- SetPrefix(p string)
- Level() log.Lvl
- SetLevel(v log.Lvl)
- SetHeader(h string)
- Print(i ...interface{})
- Printf(format string, args ...interface{})
- Printj(j log.JSON)
- Debug(i ...interface{})
- Debugf(format string, args ...interface{})
- Debugj(j log.JSON)
- Info(i ...interface{})
- Infof(format string, args ...interface{})
- Infoj(j log.JSON)
- Warn(i ...interface{})
- Warnf(format string, args ...interface{})
- Warnj(j log.JSON)
- Error(i ...interface{})
- Errorf(format string, args ...interface{})
- Errorj(j log.JSON)
- Fatal(i ...interface{})
- Fatalj(j log.JSON)
- Fatalf(format string, args ...interface{})
- Panic(i ...interface{})
- Panicj(j log.JSON)
- Panicf(format string, args ...interface{})
- }
-)
diff --git a/middleware/DEVELOPMENT.md b/middleware/DEVELOPMENT.md
new file mode 100644
index 000000000..77cb226dd
--- /dev/null
+++ b/middleware/DEVELOPMENT.md
@@ -0,0 +1,11 @@
+# Development Guidelines for middlewares
+
+## Best practices:
+
+* Do not use `panic` in middleware creator functions in case of invalid configuration.
+* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead
+ because previous middlewares up in call chain could have logic for dealing with returned errors.
+* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they
+ want to create middleware with panics or with returning errors on configuration errors.
+* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same.
+
diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go
index 76ba24206..e0a284c67 100644
--- a/middleware/basic_auth.go
+++ b/middleware/basic_auth.go
@@ -1,106 +1,156 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
+ "bytes"
+ "cmp"
"encoding/base64"
+ "errors"
"strconv"
"strings"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
)
-type (
- // BasicAuthConfig defines the config for BasicAuth middleware.
- BasicAuthConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
+// BasicAuthConfig defines the config for BasicAuthWithConfig middleware.
+//
+// SECURITY: The Validator function is responsible for securely comparing credentials.
+// See BasicAuthValidator documentation for guidance on preventing timing attacks.
+type BasicAuthConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
- // Validator is a function to validate BasicAuth credentials.
- // Required.
- Validator BasicAuthValidator
+ // Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers
+ // this function would be called once for each header until first valid result is returned
+ // Required.
+ Validator BasicAuthValidator
- // Realm is a string to define realm attribute of BasicAuth.
- // Default value "Restricted".
- Realm string
- }
+ // Realm is a string to define realm attribute of BasicAuthWithConfig.
+ // Default value "Restricted".
+ Realm string
- // BasicAuthValidator defines a function to validate BasicAuth credentials.
- BasicAuthValidator func(string, string, echo.Context) (bool, error)
-)
+ // AllowedCheckLimit set how many headers are allowed to be checked. This is useful
+ // environments like corporate test environments with application proxies restricting
+ // access to environment with their own auth scheme.
+ // Defaults to 1.
+ AllowedCheckLimit uint
+}
+
+// BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials.
+//
+// SECURITY WARNING: To prevent timing attacks that could allow attackers to enumerate
+// valid usernames or passwords, validator implementations MUST use constant-time
+// comparison for credential checking. Use crypto/subtle.ConstantTimeCompare instead
+// of standard string equality (==) or switch statements.
+//
+// Example of SECURE implementation:
+//
+// import "crypto/subtle"
+//
+// validator := func(c *echo.Context, username, password string) (bool, error) {
+// // Fetch expected credentials from database/config
+// expectedUser := "admin"
+// expectedPass := "secretpassword"
+//
+// // Use constant-time comparison to prevent timing attacks
+// userMatch := subtle.ConstantTimeCompare([]byte(username), []byte(expectedUser)) == 1
+// passMatch := subtle.ConstantTimeCompare([]byte(password), []byte(expectedPass)) == 1
+//
+// if userMatch && passMatch {
+// return true, nil
+// }
+// return false, nil
+// }
+//
+// Example of INSECURE implementation (DO NOT USE):
+//
+// // VULNERABLE TO TIMING ATTACKS - DO NOT USE
+// validator := func(c *echo.Context, username, password string) (bool, error) {
+// if username == "admin" && password == "secret" { // Timing leak!
+// return true, nil
+// }
+// return false, nil
+// }
+type BasicAuthValidator func(c *echo.Context, user string, password string) (bool, error)
const (
basic = "basic"
defaultRealm = "Restricted"
)
-var (
- // DefaultBasicAuthConfig is the default BasicAuth middleware config.
- DefaultBasicAuthConfig = BasicAuthConfig{
- Skipper: DefaultSkipper,
- Realm: defaultRealm,
- }
-)
-
// BasicAuth returns an BasicAuth middleware.
//
// For valid credentials it calls the next handler.
// For missing or invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
- c := DefaultBasicAuthConfig
- c.Validator = fn
- return BasicAuthWithConfig(c)
+ return BasicAuthWithConfig(BasicAuthConfig{Validator: fn})
}
-// BasicAuthWithConfig returns an BasicAuth middleware with config.
-// See `BasicAuth()`.
+// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config.
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
- // Defaults
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration
+func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Validator == nil {
- panic("echo: basic-auth middleware requires a validator function")
+ return nil, errors.New("echo basic-auth middleware requires a validator function")
}
if config.Skipper == nil {
- config.Skipper = DefaultBasicAuthConfig.Skipper
+ config.Skipper = DefaultSkipper
}
- if config.Realm == "" {
- config.Realm = defaultRealm
+ realm := defaultRealm
+ if config.Realm != "" {
+ realm = config.Realm
}
+ realm = strconv.Quote(realm)
+ limit := cmp.Or(config.AllowedCheckLimit, 1)
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) error {
+ return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
- auth := c.Request().Header.Get(echo.HeaderAuthorization)
+ var lastError error
l := len(basic)
+ i := uint(0)
+ for _, auth := range c.Request().Header[echo.HeaderAuthorization] {
+ if i >= limit {
+ break
+ }
+ if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) {
+ continue
+ }
+ i++
- if len(auth) > l+1 && strings.ToLower(auth[:l]) == basic {
- b, err := base64.StdEncoding.DecodeString(auth[l+1:])
- if err != nil {
- return err
+ // Invalid base64 shouldn't be treated as error
+ // instead should be treated as invalid client input
+ b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
+ if errDecode != nil {
+ lastError = echo.ErrBadRequest.Wrap(errDecode)
+ continue
}
- cred := string(b)
- for i := 0; i < len(cred); i++ {
- if cred[i] == ':' {
- // Verify credentials
- valid, err := config.Validator(cred[:i], cred[i+1:], c)
- if err != nil {
- return err
- } else if valid {
- return next(c)
- }
- break
+ idx := bytes.IndexByte(b, ':')
+ if idx >= 0 {
+ valid, errValidate := config.Validator(c, string(b[:idx]), string(b[idx+1:]))
+ if errValidate != nil {
+ lastError = errValidate
+ } else if valid {
+ return next(c)
}
}
}
- realm := defaultRealm
- if config.Realm != defaultRealm {
- realm = strconv.Quote(config.Realm)
+ if lastError != nil {
+ return lastError
}
// Need to return `401` for browsers to pop-up login box.
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm)
return echo.ErrUnauthorized
}
- }
+ }, nil
}
diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go
index 76039db0a..42386354f 100644
--- a/middleware/basic_auth_test.go
+++ b/middleware/basic_auth_test.go
@@ -1,71 +1,240 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
+ "crypto/subtle"
"encoding/base64"
+ "errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestBasicAuth(t *testing.T) {
- e := echo.New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- res := httptest.NewRecorder()
- c := e.NewContext(req, res)
- f := func(u, p string, c echo.Context) (bool, error) {
- if u == "joe" && p == "secret" {
+ validatorFunc := func(c *echo.Context, u, p string) (bool, error) {
+ // Use constant-time comparison to prevent timing attacks
+ userMatch := subtle.ConstantTimeCompare([]byte(u), []byte("joe")) == 1
+ passMatch := subtle.ConstantTimeCompare([]byte(p), []byte("secret")) == 1
+
+ if userMatch && passMatch {
return true, nil
}
+
+ // Special case for testing error handling
+ if u == "error" {
+ return false, errors.New(p)
+ }
+
return false, nil
}
- h := BasicAuth(f)(func(c echo.Context) error {
- return c.String(http.StatusOK, "test")
- })
+ defaultConfig := BasicAuthConfig{Validator: validatorFunc}
+
+ var testCases = []struct {
+ name string
+ givenConfig BasicAuthConfig
+ whenAuth []string
+ expectHeader string
+ expectErr string
+ }{
+ {
+ name: "ok",
+ givenConfig: defaultConfig,
+ whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
+ },
+ {
+ name: "ok, multiple",
+ givenConfig: BasicAuthConfig{Validator: validatorFunc, AllowedCheckLimit: 2},
+ whenAuth: []string{
+ "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")),
+ basic + " NOT_BASE64",
+ basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
+ },
+ },
+ {
+ name: "nok, multiple, valid out of limit",
+ givenConfig: BasicAuthConfig{Validator: validatorFunc, AllowedCheckLimit: 1},
+ whenAuth: []string{
+ "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")),
+ basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid_password")),
+ // limit only check first and should not check auth below
+ basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
+ },
+ expectHeader: basic + ` realm="Restricted"`,
+ expectErr: "Unauthorized",
+ },
+ {
+ name: "nok, invalid Authorization header",
+ givenConfig: defaultConfig,
+ whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
+ expectHeader: basic + ` realm="Restricted"`,
+ expectErr: "Unauthorized",
+ },
+ {
+ name: "nok, not base64 Authorization header",
+ givenConfig: defaultConfig,
+ whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
+ expectErr: "code=400, message=Bad Request, err=illegal base64 data at input byte 3",
+ },
+ {
+ name: "nok, missing Authorization header",
+ givenConfig: defaultConfig,
+ expectHeader: basic + ` realm="Restricted"`,
+ expectErr: "Unauthorized",
+ },
+ {
+ name: "ok, realm",
+ givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
+ whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
+ },
+ {
+ name: "ok, realm, case-insensitive header scheme",
+ givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
+ whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
+ },
+ {
+ name: "nok, realm, invalid Authorization header",
+ givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
+ whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
+ expectHeader: basic + ` realm="someRealm"`,
+ expectErr: "Unauthorized",
+ },
+ {
+ name: "nok, validator func returns an error",
+ givenConfig: defaultConfig,
+ whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))},
+ expectErr: "my_error",
+ },
+ {
+ name: "ok, skipped",
+ givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c *echo.Context) bool {
+ return true
+ }},
+ whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ res := httptest.NewRecorder()
+ c := e.NewContext(req, res)
+
+ config := tc.givenConfig
+
+ mw, err := config.ToMiddleware()
+ assert.NoError(t, err)
- assert := assert.New(t)
+ h := mw(func(c *echo.Context) error {
+ return c.String(http.StatusTeapot, "test")
+ })
- // Valid credentials
- auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
- req.Header.Set(echo.HeaderAuthorization, auth)
- assert.NoError(h(c))
+ if len(tc.whenAuth) != 0 {
+ for _, a := range tc.whenAuth {
+ req.Header.Add(echo.HeaderAuthorization, a)
+ }
+ }
+ err = h(c)
- h = BasicAuthWithConfig(BasicAuthConfig{
- Skipper: nil,
- Validator: f,
- Realm: "someRealm",
- })(func(c echo.Context) error {
- return c.String(http.StatusOK, "test")
+ if tc.expectErr != "" {
+ assert.Equal(t, http.StatusOK, res.Code)
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.Equal(t, http.StatusTeapot, res.Code)
+ assert.NoError(t, err)
+ }
+ if tc.expectHeader != "" {
+ assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate))
+ }
+ })
+ }
+}
+
+func TestBasicAuth_panic(t *testing.T) {
+ assert.Panics(t, func() {
+ mw := BasicAuth(nil)
+ assert.NotNil(t, mw)
})
- // Valid credentials
- auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
- req.Header.Set(echo.HeaderAuthorization, auth)
- assert.NoError(h(c))
-
- // Case-insensitive header scheme
- auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
- req.Header.Set(echo.HeaderAuthorization, auth)
- assert.NoError(h(c))
-
- // Invalid credentials
- auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
- req.Header.Set(echo.HeaderAuthorization, auth)
- he := h(c).(*echo.HTTPError)
- assert.Equal(http.StatusUnauthorized, he.Code)
- assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
-
- // Missing Authorization header
- req.Header.Del(echo.HeaderAuthorization)
- he = h(c).(*echo.HTTPError)
- assert.Equal(http.StatusUnauthorized, he.Code)
-
- // Invalid Authorization header
- auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
- req.Header.Set(echo.HeaderAuthorization, auth)
- he = h(c).(*echo.HTTPError)
- assert.Equal(http.StatusUnauthorized, he.Code)
+ mw := BasicAuth(func(c *echo.Context, user string, password string) (bool, error) {
+ return true, nil
+ })
+ assert.NotNil(t, mw)
+}
+
+func TestBasicAuthWithConfig_panic(t *testing.T) {
+ assert.Panics(t, func() {
+ mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil})
+ assert.NotNil(t, mw)
+ })
+
+ mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c *echo.Context, user string, password string) (bool, error) {
+ return true, nil
+ }})
+ assert.NotNil(t, mw)
+}
+
+func TestBasicAuthRealm(t *testing.T) {
+ e := echo.New()
+ mockValidator := func(c *echo.Context, u, p string) (bool, error) {
+ return false, nil // Always fail to trigger WWW-Authenticate header
+ }
+
+ tests := []struct {
+ name string
+ realm string
+ expectedAuth string
+ }{
+ {
+ name: "Default realm",
+ realm: "Restricted",
+ expectedAuth: `basic realm="Restricted"`,
+ },
+ {
+ name: "Custom realm",
+ realm: "My API",
+ expectedAuth: `basic realm="My API"`,
+ },
+ {
+ name: "Realm with special characters",
+ realm: `Realm with "quotes" and \backslashes`,
+ expectedAuth: `basic realm="Realm with \"quotes\" and \\backslashes"`,
+ },
+ {
+ name: "Empty realm (falls back to default)",
+ realm: "",
+ expectedAuth: `basic realm="Restricted"`,
+ },
+ {
+ name: "Realm with unicode",
+ realm: "测试领域",
+ expectedAuth: `basic realm="测试领域"`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ res := httptest.NewRecorder()
+ c := e.NewContext(req, res)
+
+ h := BasicAuthWithConfig(BasicAuthConfig{
+ Validator: mockValidator,
+ Realm: tt.realm,
+ })(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ err := h(c)
+
+ assert.Equal(t, echo.ErrUnauthorized, err)
+ assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate))
+ })
+ }
}
diff --git a/middleware/body_dump.go b/middleware/body_dump.go
index ebd0d0ab2..d5c823c9b 100644
--- a/middleware/body_dump.go
+++ b/middleware/body_dump.go
@@ -1,93 +1,146 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
"bufio"
"bytes"
+ "errors"
"io"
- "io/ioutil"
"net"
"net/http"
+ "sync"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
)
-type (
- // BodyDumpConfig defines the config for BodyDump middleware.
- BodyDumpConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
-
- // Handler receives request and response payload.
- // Required.
- Handler BodyDumpHandler
- }
-
- // BodyDumpHandler receives the request and response payload.
- BodyDumpHandler func(echo.Context, []byte, []byte)
+// BodyDumpConfig defines the config for BodyDump middleware.
+type BodyDumpConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // Handler receives request, response payloads and handler error if there are any.
+ // Required.
+ Handler BodyDumpHandler
+
+ // MaxRequestBytes limits how much of the request body to dump.
+ // If the request body exceeds this limit, only the first MaxRequestBytes
+ // are dumped. The handler callback receives truncated data.
+ // Default: 5 * MB (5,242,880 bytes)
+ // Set to -1 to disable limits (not recommended in production).
+ MaxRequestBytes int64
+
+ // MaxResponseBytes limits how much of the response body to dump.
+ // If the response body exceeds this limit, only the first MaxResponseBytes
+ // are dumped. The handler callback receives truncated data.
+ // Default: 5 * MB (5,242,880 bytes)
+ // Set to -1 to disable limits (not recommended in production).
+ MaxResponseBytes int64
+}
- bodyDumpResponseWriter struct {
- io.Writer
- http.ResponseWriter
- }
-)
+// BodyDumpHandler receives the request and response payload.
+type BodyDumpHandler func(c *echo.Context, reqBody []byte, resBody []byte, err error)
-var (
- // DefaultBodyDumpConfig is the default BodyDump middleware config.
- DefaultBodyDumpConfig = BodyDumpConfig{
- Skipper: DefaultSkipper,
- }
-)
+type bodyDumpResponseWriter struct {
+ io.Writer
+ http.ResponseWriter
+}
// BodyDump returns a BodyDump middleware.
//
// BodyDump middleware captures the request and response payload and calls the
// registered handler.
+//
+// SECURITY: By default, this limits dumped bodies to 5MB to prevent memory exhaustion
+// attacks. To customize limits, use BodyDumpWithConfig. To disable limits (not recommended
+// in production), explicitly set MaxRequestBytes and MaxResponseBytes to -1.
func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc {
- c := DefaultBodyDumpConfig
- c.Handler = handler
- return BodyDumpWithConfig(c)
+ return BodyDumpWithConfig(BodyDumpConfig{Handler: handler})
}
// BodyDumpWithConfig returns a BodyDump middleware with config.
// See: `BodyDump()`.
+//
+// SECURITY: If MaxRequestBytes and MaxResponseBytes are not set (zero values), they default
+// to 5MB each to prevent DoS attacks via large payloads. Set them explicitly to -1 to disable
+// limits if needed for your use case.
func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
- // Defaults
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration
+func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Handler == nil {
- panic("echo: body-dump middleware requires a handler function")
+ return nil, errors.New("echo body-dump middleware requires a handler function")
}
if config.Skipper == nil {
- config.Skipper = DefaultBodyDumpConfig.Skipper
+ config.Skipper = DefaultSkipper
+ }
+ if config.MaxRequestBytes == 0 {
+ config.MaxRequestBytes = 5 * MB
+ }
+ if config.MaxResponseBytes == 0 {
+ config.MaxResponseBytes = 5 * MB
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) (err error) {
+ return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
- // Request
- reqBody := []byte{}
- if c.Request().Body != nil { // Read
- reqBody, _ = ioutil.ReadAll(c.Request().Body)
- }
- c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset
+ reqBuf := bodyDumpBufferPool.Get().(*bytes.Buffer)
+ reqBuf.Reset()
+ defer bodyDumpBufferPool.Put(reqBuf)
- // Response
- resBody := new(bytes.Buffer)
- mw := io.MultiWriter(c.Response().Writer, resBody)
- writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer}
- c.Response().Writer = writer
+ var bodyReader io.Reader = c.Request().Body
+ if config.MaxRequestBytes > 0 {
+ bodyReader = io.LimitReader(c.Request().Body, config.MaxRequestBytes)
+ }
+ _, readErr := io.Copy(reqBuf, bodyReader)
+ if readErr != nil && readErr != io.EOF {
+ return readErr
+ }
+ if config.MaxRequestBytes > 0 {
+ // Drain any remaining body data to prevent connection issues
+ _, _ = io.Copy(io.Discard, c.Request().Body)
+ _ = c.Request().Body.Close()
+ }
- if err = next(c); err != nil {
- c.Error(err)
+ reqBody := make([]byte, reqBuf.Len())
+ copy(reqBody, reqBuf.Bytes())
+ c.Request().Body = io.NopCloser(bytes.NewReader(reqBody))
+
+ // response part
+ resBuf := bodyDumpBufferPool.Get().(*bytes.Buffer)
+ resBuf.Reset()
+ defer bodyDumpBufferPool.Put(resBuf)
+
+ var respWriter io.Writer
+ if config.MaxResponseBytes > 0 {
+ respWriter = &limitedWriter{
+ response: c.Response(),
+ dumpBuf: resBuf,
+ limit: config.MaxResponseBytes,
+ }
+ } else {
+ respWriter = io.MultiWriter(c.Response(), resBuf)
+ }
+ writer := &bodyDumpResponseWriter{
+ Writer: respWriter,
+ ResponseWriter: c.Response(),
}
+ c.SetResponse(writer)
+
+ err := next(c)
// Callback
- config.Handler(c, reqBody, resBody.Bytes())
+ config.Handler(c, reqBody, resBuf.Bytes(), err)
- return
+ return err
}
- }
+ }, nil
}
func (w *bodyDumpResponseWriter) WriteHeader(code int) {
@@ -99,9 +152,50 @@ func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) {
}
func (w *bodyDumpResponseWriter) Flush() {
- w.ResponseWriter.(http.Flusher).Flush()
+ err := http.NewResponseController(w.ResponseWriter).Flush()
+ if err != nil && errors.Is(err, http.ErrNotSupported) {
+ panic(errors.New("response writer flushing is not supported"))
+ }
}
func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
- return w.ResponseWriter.(http.Hijacker).Hijack()
+ return http.NewResponseController(w.ResponseWriter).Hijack()
+}
+
+func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter {
+ return w.ResponseWriter
+}
+
+var bodyDumpBufferPool = sync.Pool{
+ New: func() any {
+ return new(bytes.Buffer)
+ },
+}
+
+type limitedWriter struct {
+ response http.ResponseWriter
+ dumpBuf *bytes.Buffer
+ dumped int64
+ limit int64
+}
+
+func (w *limitedWriter) Write(b []byte) (n int, err error) {
+ // Always write full data to actual response (don't truncate client response)
+ n, err = w.response.Write(b)
+ if err != nil {
+ return n, err
+ }
+
+ // Write to dump buffer only up to limit
+ if w.dumped < w.limit {
+ remaining := w.limit - w.dumped
+ toDump := int64(n)
+ if toDump > remaining {
+ toDump = remaining
+ }
+ w.dumpBuf.Write(b[:toDump])
+ w.dumped += toDump
+ }
+
+ return n, nil
}
diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go
index e6e00f726..f493e75c8 100644
--- a/middleware/body_dump_test.go
+++ b/middleware/body_dump_test.go
@@ -1,14 +1,17 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
"errors"
- "io/ioutil"
+ "io"
"net/http"
"net/http/httptest"
"strings"
"testing"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
@@ -18,8 +21,8 @@ func TestBodyDump(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- h := func(c echo.Context) error {
- body, err := ioutil.ReadAll(c.Request().Body)
+ h := func(c *echo.Context) error {
+ body, err := io.ReadAll(c.Request().Body)
if err != nil {
return err
}
@@ -28,64 +31,551 @@ func TestBodyDump(t *testing.T) {
requestBody := ""
responseBody := ""
- mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {
+ mw, err := BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
requestBody = string(reqBody)
responseBody = string(resBody)
- })
-
- assert := assert.New(t)
+ }}.ToMiddleware()
+ assert.NoError(t, err)
- if assert.NoError(mw(h)(c)) {
- assert.Equal(requestBody, hw)
- assert.Equal(responseBody, hw)
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(hw, rec.Body.String())
+ if assert.NoError(t, mw(h)(c)) {
+ assert.Equal(t, requestBody, hw)
+ assert.Equal(t, responseBody, hw)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, hw, rec.Body.String())
}
- // Must set default skipper
- BodyDumpWithConfig(BodyDumpConfig{
- Skipper: nil,
- Handler: func(c echo.Context, reqBody, resBody []byte) {
- requestBody = string(reqBody)
- responseBody = string(resBody)
+}
+
+func TestBodyDump_skipper(t *testing.T) {
+ e := echo.New()
+
+ isCalled := false
+ mw, err := BodyDumpConfig{
+ Skipper: func(c *echo.Context) bool {
+ return true
},
- })
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ isCalled = true
+ },
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}"))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := func(c *echo.Context) error {
+ return errors.New("some error")
+ }
+
+ err = mw(h)(c)
+ assert.EqualError(t, err, "some error")
+ assert.False(t, isCalled)
}
-func TestBodyDumpFails(t *testing.T) {
+func TestBodyDump_fails(t *testing.T) {
e := echo.New()
hw := "Hello, World!"
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- h := func(c echo.Context) error {
+ h := func(c *echo.Context) error {
return errors.New("some error")
}
- mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {})
+ mw, err := BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}}.ToMiddleware()
+ assert.NoError(t, err)
- if !assert.Error(t, mw(h)(c)) {
- t.FailNow()
- }
+ err = mw(h)(c)
+ assert.EqualError(t, err, "some error")
+ assert.Equal(t, http.StatusOK, rec.Code)
+
+}
+func TestBodyDumpWithConfig_panic(t *testing.T) {
assert.Panics(t, func() {
- mw = BodyDumpWithConfig(BodyDumpConfig{
+ mw := BodyDumpWithConfig(BodyDumpConfig{
Skipper: nil,
Handler: nil,
})
+ assert.NotNil(t, mw)
})
assert.NotPanics(t, func() {
- mw = BodyDumpWithConfig(BodyDumpConfig{
- Skipper: func(c echo.Context) bool {
- return true
- },
- Handler: func(c echo.Context, reqBody, resBody []byte) {
- },
- })
+ mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}})
+ assert.NotNil(t, mw)
+ })
+}
- if !assert.Error(t, mw(h)(c)) {
- t.FailNow()
- }
+func TestBodyDump_panic(t *testing.T) {
+ assert.Panics(t, func() {
+ mw := BodyDump(nil)
+ assert.NotNil(t, mw)
+ })
+
+ assert.NotPanics(t, func() {
+ BodyDump(func(c *echo.Context, reqBody, resBody []byte, err error) {})
+ })
+}
+
+func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) {
+ bdrw := bodyDumpResponseWriter{
+ ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush
+ }
+ assert.PanicsWithError(t, "response writer flushing is not supported", func() {
+ bdrw.Flush()
+ })
+}
+
+func TestBodyDumpResponseWriter_CanFlush(t *testing.T) {
+ trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
+ bdrw := bodyDumpResponseWriter{
+ ResponseWriter: &trwu,
+ }
+ bdrw.Flush()
+ assert.Equal(t, 1, trwu.unwrapCalled)
+}
+
+func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) {
+ trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
+ bdrw := bodyDumpResponseWriter{
+ ResponseWriter: trwu,
+ }
+ result := bdrw.Unwrap()
+ assert.Equal(t, trwu, result)
+}
+
+func TestBodyDumpResponseWriter_CanHijack(t *testing.T) {
+ trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
+ bdrw := bodyDumpResponseWriter{
+ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
+ }
+ _, _, err := bdrw.Hijack()
+ assert.EqualError(t, err, "can hijack")
+}
+
+func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) {
+ trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
+ bdrw := bodyDumpResponseWriter{
+ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
+ }
+ _, _, err := bdrw.Hijack()
+ assert.EqualError(t, err, "feature not supported")
+}
+
+func TestBodyDump_ReadError(t *testing.T) {
+ e := echo.New()
+
+ // Create a reader that fails during read
+ failingReader := &failingReadCloser{
+ data: []byte("partial data"),
+ failAt: 7, // Fail after 7 bytes
+ failWith: errors.New("connection reset"),
+ }
+
+ req := httptest.NewRequest(http.MethodPost, "/", failingReader)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ // This handler should not be reached if body read fails
+ body, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(body))
+ }
+
+ requestBodyReceived := ""
+ mw := BodyDump(func(c *echo.Context, reqBody, resBody []byte, err error) {
+ requestBodyReceived = string(reqBody)
})
+
+ err := mw(h)(c)
+
+ // Verify error is propagated
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "connection reset")
+
+ // Verify handler was not executed (callback wouldn't have received data)
+ assert.Empty(t, requestBodyReceived)
+}
+
+// failingReadCloser is a helper type for testing read errors
+type failingReadCloser struct {
+ data []byte
+ pos int
+ failAt int
+ failWith error
+}
+
+func (f *failingReadCloser) Read(p []byte) (n int, err error) {
+ if f.pos >= f.failAt {
+ return 0, f.failWith
+ }
+
+ n = copy(p, f.data[f.pos:])
+ f.pos += n
+
+ if f.pos >= f.failAt {
+ return n, f.failWith
+ }
+
+ return n, nil
+}
+
+func (f *failingReadCloser) Close() error {
+ return nil
+}
+
+func TestBodyDump_RequestWithinLimit(t *testing.T) {
+ e := echo.New()
+ requestData := "Hello, World!"
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ body, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(body))
+ }
+
+ requestBodyDumped := ""
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ requestBodyDumped = string(reqBody)
+ },
+ MaxRequestBytes: 1 * MB, // 1MB limit
+ MaxResponseBytes: 1 * MB,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, requestData, requestBodyDumped, "Small request should be fully dumped")
+ assert.Equal(t, requestData, rec.Body.String(), "Handler should receive full request")
+}
+
+func TestBodyDump_RequestExceedsLimit(t *testing.T) {
+ e := echo.New()
+ // Create 2KB of data but limit to 1KB
+ largeData := strings.Repeat("A", 2*1024)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeData))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ body, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(body))
+ }
+
+ requestBodyDumped := ""
+ limit := int64(1024) // 1KB limit
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ requestBodyDumped = string(reqBody)
+ },
+ MaxRequestBytes: limit,
+ MaxResponseBytes: 1 * MB,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, int(limit), len(requestBodyDumped), "Dumped request should be truncated to limit")
+ assert.Equal(t, strings.Repeat("A", 1024), requestBodyDumped, "Dumped data should match first N bytes")
+ // Handler should receive truncated data (what was dumped)
+ assert.Equal(t, strings.Repeat("A", 1024), rec.Body.String())
+}
+
+func TestBodyDump_RequestAtExactLimit(t *testing.T) {
+ e := echo.New()
+ exactData := strings.Repeat("B", 1024)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(exactData))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ body, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(body))
+ }
+
+ requestBodyDumped := ""
+ limit := int64(1024)
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ requestBodyDumped = string(reqBody)
+ },
+ MaxRequestBytes: limit,
+ MaxResponseBytes: 1 * MB,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, int(limit), len(requestBodyDumped), "Exact limit should dump full data")
+ assert.Equal(t, exactData, requestBodyDumped)
+}
+
+func TestBodyDump_ResponseWithinLimit(t *testing.T) {
+ e := echo.New()
+ responseData := "Response data"
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ return c.String(http.StatusOK, responseData)
+ }
+
+ responseBodyDumped := ""
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ responseBodyDumped = string(resBody)
+ },
+ MaxRequestBytes: 1 * MB,
+ MaxResponseBytes: 1 * MB,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, responseData, responseBodyDumped, "Small response should be fully dumped")
+ assert.Equal(t, responseData, rec.Body.String(), "Client should receive full response")
+}
+
+func TestBodyDump_ResponseExceedsLimit(t *testing.T) {
+ e := echo.New()
+ largeResponse := strings.Repeat("X", 2*1024) // 2KB
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ return c.String(http.StatusOK, largeResponse)
+ }
+
+ responseBodyDumped := ""
+ limit := int64(1024) // 1KB limit
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ responseBodyDumped = string(resBody)
+ },
+ MaxRequestBytes: 1 * MB,
+ MaxResponseBytes: limit,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ // Dump should be truncated
+ assert.Equal(t, int(limit), len(responseBodyDumped), "Dumped response should be truncated to limit")
+ assert.Equal(t, strings.Repeat("X", 1024), responseBodyDumped)
+ // Client should still receive full response!
+ assert.Equal(t, largeResponse, rec.Body.String(), "Client must receive full response despite dump truncation")
+}
+
+func TestBodyDump_ClientGetsFullResponse(t *testing.T) {
+ e := echo.New()
+ // This is critical - even when dump is limited, client gets everything
+ largeResponse := strings.Repeat("DATA", 500) // 2KB
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ // Write response in chunks to test incremental writes
+ for i := 0; i < 4; i++ {
+ c.Response().Write([]byte(strings.Repeat("DATA", 125)))
+ }
+ return nil
+ }
+
+ responseBodyDumped := ""
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ responseBodyDumped = string(resBody)
+ },
+ MaxRequestBytes: 1 * MB,
+ MaxResponseBytes: 512, // Very small limit
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, 512, len(responseBodyDumped), "Dump should be limited")
+ assert.Equal(t, largeResponse, rec.Body.String(), "Client must get full response")
+}
+
+func TestBodyDump_BothLimitsSimultaneous(t *testing.T) {
+ e := echo.New()
+ largeRequest := strings.Repeat("Q", 2*1024)
+ largeResponse := strings.Repeat("R", 2*1024)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeRequest))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ io.ReadAll(c.Request().Body) // Consume request
+ return c.String(http.StatusOK, largeResponse)
+ }
+
+ requestBodyDumped := ""
+ responseBodyDumped := ""
+ limit := int64(1024)
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ requestBodyDumped = string(reqBody)
+ responseBodyDumped = string(resBody)
+ },
+ MaxRequestBytes: limit,
+ MaxResponseBytes: limit,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, int(limit), len(requestBodyDumped), "Request dump should be limited")
+ assert.Equal(t, int(limit), len(responseBodyDumped), "Response dump should be limited")
+}
+
+func TestBodyDump_DefaultConfig(t *testing.T) {
+ e := echo.New()
+ smallData := "test"
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(smallData))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ body, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(body))
+ }
+
+ requestBodyDumped := ""
+ // Use default config which should have 1MB limits
+ config := BodyDumpConfig{}
+ config.Handler = func(c *echo.Context, reqBody, resBody []byte, err error) {
+ requestBodyDumped = string(reqBody)
+ }
+ mw, err := config.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, smallData, requestBodyDumped)
+}
+
+func TestBodyDump_LargeRequestDosPrevention(t *testing.T) {
+ e := echo.New()
+ // Simulate a very large request (10MB) that could cause OOM
+ largeSize := 10 * 1024 * 1024 // 10MB
+ largeData := strings.Repeat("M", largeSize)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeData))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ body, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(body))
+ }
+
+ requestBodyDumped := ""
+ limit := int64(1 * MB) // Only dump 1MB max
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ requestBodyDumped = string(reqBody)
+ },
+ MaxRequestBytes: limit,
+ MaxResponseBytes: 1 * MB,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ // Verify only 1MB was dumped, not 10MB
+ assert.Equal(t, int(limit), len(requestBodyDumped), "Should only dump up to limit, preventing DoS")
+ assert.Less(t, len(requestBodyDumped), largeSize, "Dump should be much smaller than full request")
+}
+
+func TestBodyDump_LargeResponseDosPrevention(t *testing.T) {
+ e := echo.New()
+ // Simulate a very large response (10MB)
+ largeSize := 10 * 1024 * 1024 // 10MB
+ largeResponse := strings.Repeat("R", largeSize)
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ return c.String(http.StatusOK, largeResponse)
+ }
+
+ responseBodyDumped := ""
+ limit := int64(1 * MB) // Only dump 1MB max
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ responseBodyDumped = string(resBody)
+ },
+ MaxRequestBytes: 1 * MB,
+ MaxResponseBytes: limit,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ // Verify only 1MB was dumped, not 10MB
+ assert.Equal(t, int(limit), len(responseBodyDumped), "Should only dump up to limit, preventing DoS")
+ assert.Less(t, len(responseBodyDumped), largeSize, "Dump should be much smaller than full response")
+ // Client still gets full response
+ assert.Equal(t, largeSize, rec.Body.Len(), "Client must receive full response")
+}
+
+func BenchmarkBodyDump_WithLimit(b *testing.B) {
+ e := echo.New()
+ requestData := strings.Repeat("data", 256) // 1KB
+ responseData := strings.Repeat("resp", 256) // 1KB
+
+ h := func(c *echo.Context) error {
+ io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, responseData)
+ }
+
+ mw, _ := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ // Simulate logging
+ _ = len(reqBody) + len(resBody)
+ },
+ MaxRequestBytes: 1 * MB,
+ MaxResponseBytes: 1 * MB,
+ }.ToMiddleware()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ mw(h)(c)
+ }
+}
+
+func BenchmarkBodyDump_BufferPooling(b *testing.B) {
+ e := echo.New()
+ requestData := strings.Repeat("x", 1024)
+ responseData := "response"
+
+ h := func(c *echo.Context) error {
+ io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, responseData)
+ }
+
+ mw, _ := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {},
+ MaxRequestBytes: 1 * MB,
+ MaxResponseBytes: 1 * MB,
+ }.ToMiddleware()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ mw(h)(c)
+ }
}
diff --git a/middleware/body_limit.go b/middleware/body_limit.go
index b436bd595..4f1963e18 100644
--- a/middleware/body_limit.go
+++ b/middleware/body_limit.go
@@ -1,98 +1,89 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
- "fmt"
"io"
+ "net/http"
"sync"
- "github.com/labstack/echo/v4"
- "github.com/labstack/gommon/bytes"
+ "github.com/labstack/echo/v5"
)
-type (
- // BodyLimitConfig defines the config for BodyLimit middleware.
- BodyLimitConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
-
- // Maximum allowed size for a request body, it can be specified
- // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P.
- Limit string `yaml:"limit"`
- limit int64
- }
+// BodyLimitConfig defines the config for BodyLimitWithConfig middleware.
+type BodyLimitConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
- limitedReader struct {
- BodyLimitConfig
- reader io.ReadCloser
- read int64
- context echo.Context
- }
-)
+ // LimitBytes is maximum allowed size in bytes for a request body
+ LimitBytes int64
+}
-var (
- // DefaultBodyLimitConfig is the default BodyLimit middleware config.
- DefaultBodyLimitConfig = BodyLimitConfig{
- Skipper: DefaultSkipper,
- }
-)
+type limitedReader struct {
+ BodyLimitConfig
+ reader io.ReadCloser
+ read int64
+}
// BodyLimit returns a BodyLimit middleware.
//
-// BodyLimit middleware sets the maximum allowed size for a request body, if the
-// size exceeds the configured limit, it sends "413 - Request Entity Too Large"
-// response. The BodyLimit is determined based on both `Content-Length` request
+// BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it
+// sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request
// header and actual content read, which makes it super secure.
-// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M,
-// G, T or P.
-func BodyLimit(limit string) echo.MiddlewareFunc {
- c := DefaultBodyLimitConfig
- c.Limit = limit
- return BodyLimitWithConfig(c)
+func BodyLimit(limitBytes int64) echo.MiddlewareFunc {
+ return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes})
}
-// BodyLimitWithConfig returns a BodyLimit middleware with config.
-// See: `BodyLimit()`.
+// BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for
+// a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response.
+// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which
+// makes it super secure.
func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
- // Defaults
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration
+func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
- config.Skipper = DefaultBodyLimitConfig.Skipper
+ config.Skipper = DefaultSkipper
}
-
- limit, err := bytes.Parse(config.Limit)
- if err != nil {
- panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit))
+ pool := sync.Pool{
+ New: func() any {
+ return &limitedReader{BodyLimitConfig: config}
+ },
}
- config.limit = limit
- pool := limitedReaderPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) error {
+ return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
-
req := c.Request()
// Based on content length
- if req.ContentLength > config.limit {
+ if req.ContentLength > config.LimitBytes {
return echo.ErrStatusRequestEntityTooLarge
}
// Based on content read
- r := pool.Get().(*limitedReader)
- r.Reset(req.Body, c)
+ r, ok := pool.Get().(*limitedReader)
+ if !ok {
+ return echo.NewHTTPError(http.StatusInternalServerError, "invalid pool object")
+ }
+ r.Reset(req.Body)
defer pool.Put(r)
req.Body = r
return next(c)
}
- }
+ }, nil
}
func (r *limitedReader) Read(b []byte) (n int, err error) {
n, err = r.reader.Read(b)
r.read += int64(n)
- if r.read > r.limit {
+ if r.read > r.LimitBytes {
return n, echo.ErrStatusRequestEntityTooLarge
}
return
@@ -102,16 +93,7 @@ func (r *limitedReader) Close() error {
return r.reader.Close()
}
-func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) {
+func (r *limitedReader) Reset(reader io.ReadCloser) {
r.reader = reader
- r.context = context
r.read = 0
}
-
-func limitedReaderPool(c BodyLimitConfig) sync.Pool {
- return sync.Pool{
- New: func() interface{} {
- return &limitedReader{BodyLimitConfig: c}
- },
- }
-}
diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go
index 0e8642a06..5529f5d84 100644
--- a/middleware/body_limit_test.go
+++ b/middleware/body_limit_test.go
@@ -1,85 +1,166 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
"bytes"
- "io/ioutil"
+ "io"
"net/http"
"net/http/httptest"
"testing"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
-func TestBodyLimit(t *testing.T) {
+func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
e := echo.New()
hw := []byte("Hello, World!")
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- h := func(c echo.Context) error {
- body, err := ioutil.ReadAll(c.Request().Body)
+ h := func(c *echo.Context) error {
+ body, err := io.ReadAll(c.Request().Body)
if err != nil {
return err
}
return c.String(http.StatusOK, string(body))
}
- assert := assert.New(t)
-
// Based on content length (within limit)
- if assert.NoError(BodyLimit("2M")(h)(c)) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal(hw, rec.Body.Bytes())
+ mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, hw, rec.Body.Bytes())
}
// Based on content read (overlimit)
- he := BodyLimit("2B")(h)(c).(*echo.HTTPError)
- assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
+ mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
+ assert.NoError(t, err)
+ he := mw(h)(c).(echo.HTTPStatusCoder)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
// Based on content read (within limit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ req.ContentLength = -1
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
- if assert.NoError(BodyLimit("2M")(h)(c)) {
- assert.Equal(http.StatusOK, rec.Code)
- assert.Equal("Hello, World!", rec.Body.String())
- }
+
+ mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
+ assert.NoError(t, err)
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "Hello, World!", rec.Body.String())
// Based on content read (overlimit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ req.ContentLength = -1
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
- he = BodyLimit("2B")(h)(c).(*echo.HTTPError)
- assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
+ mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
+ assert.NoError(t, err)
+ he = mw(h)(c).(echo.HTTPStatusCoder)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
}
func TestBodyLimitReader(t *testing.T) {
hw := []byte("Hello, World!")
- e := echo.New()
- req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
- rec := httptest.NewRecorder()
config := BodyLimitConfig{
- Skipper: DefaultSkipper,
- Limit: "2B",
- limit: 2,
+ Skipper: DefaultSkipper,
+ LimitBytes: 2,
}
reader := &limitedReader{
BodyLimitConfig: config,
- reader: ioutil.NopCloser(bytes.NewReader(hw)),
- context: e.NewContext(req, rec),
+ reader: io.NopCloser(bytes.NewReader(hw)),
}
// read all should return ErrStatusRequestEntityTooLarge
- _, err := ioutil.ReadAll(reader)
- he := err.(*echo.HTTPError)
- assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
+ _, err := io.ReadAll(reader)
+ he := err.(echo.HTTPStatusCoder)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
// reset reader and read two bytes must succeed
bt := make([]byte, 2)
- reader.Reset(ioutil.NopCloser(bytes.NewReader(hw)), e.NewContext(req, rec))
+ reader.Reset(io.NopCloser(bytes.NewReader(hw)))
n, err := reader.Read(bt)
assert.Equal(t, 2, n)
assert.Equal(t, nil, err)
}
+
+func TestBodyLimit_skipper(t *testing.T) {
+ e := echo.New()
+ h := func(c *echo.Context) error {
+ body, err := io.ReadAll(c.Request().Body)
+ if err != nil {
+ return err
+ }
+ return c.String(http.StatusOK, string(body))
+ }
+ mw, err := BodyLimitConfig{
+ Skipper: func(c *echo.Context) bool {
+ return true
+ },
+ LimitBytes: 2,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ hw := []byte("Hello, World!")
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, hw, rec.Body.Bytes())
+}
+
+func TestBodyLimitWithConfig(t *testing.T) {
+ e := echo.New()
+ hw := []byte("Hello, World!")
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := func(c *echo.Context) error {
+ body, err := io.ReadAll(c.Request().Body)
+ if err != nil {
+ return err
+ }
+ return c.String(http.StatusOK, string(body))
+ }
+
+ mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB})
+
+ err := mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, hw, rec.Body.Bytes())
+}
+
+func TestBodyLimit(t *testing.T) {
+ e := echo.New()
+ hw := []byte("Hello, World!")
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := func(c *echo.Context) error {
+ body, err := io.ReadAll(c.Request().Body)
+ if err != nil {
+ return err
+ }
+ return c.String(http.StatusOK, string(body))
+ }
+
+ mw := BodyLimit(2 * MB)
+
+ err := mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, hw, rec.Body.Bytes())
+}
diff --git a/middleware/compress.go b/middleware/compress.go
index 89da16efe..7754d5db8 100644
--- a/middleware/compress.go
+++ b/middleware/compress.go
@@ -1,65 +1,90 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
"bufio"
+ "bytes"
"compress/gzip"
+ "errors"
"io"
- "io/ioutil"
"net"
"net/http"
"strings"
+ "sync"
- "github.com/labstack/echo/v4"
-)
-
-type (
- // GzipConfig defines the config for Gzip middleware.
- GzipConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
-
- // Gzip compression level.
- // Optional. Default value -1.
- Level int `yaml:"level"`
- }
-
- gzipResponseWriter struct {
- io.Writer
- http.ResponseWriter
- }
+ "github.com/labstack/echo/v5"
)
const (
gzipScheme = "gzip"
)
-var (
- // DefaultGzipConfig is the default Gzip middleware config.
- DefaultGzipConfig = GzipConfig{
- Skipper: DefaultSkipper,
- Level: -1,
- }
-)
+// GzipConfig defines the config for Gzip middleware.
+type GzipConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // Gzip compression level.
+ // Optional. Default value -1.
+ Level int
+
+ // Length threshold before gzip compression is applied.
+ // Optional. Default value 0.
+ //
+ // Most of the time you will not need to change the default. Compressing
+ // a short response might increase the transmitted data because of the
+ // gzip format overhead. Compressing the response will also consume CPU
+ // and time on the server and the client (for decompressing). Depending on
+ // your use case such a threshold might be useful.
+ //
+ // See also:
+ // https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits
+ MinLength int
+}
+
+type gzipResponseWriter struct {
+ io.Writer
+ http.ResponseWriter
+ wroteHeader bool
+ wroteBody bool
+ minLength int
+ minLengthExceeded bool
+ buffer *bytes.Buffer
+ code int
+}
-// Gzip returns a middleware which compresses HTTP response using gzip compression
-// scheme.
+// Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
func Gzip() echo.MiddlewareFunc {
- return GzipWithConfig(DefaultGzipConfig)
+ return GzipWithConfig(GzipConfig{})
}
-// GzipWithConfig return Gzip middleware with config.
-// See: `Gzip()`.
+// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
- // Defaults
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration
+func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
- config.Skipper = DefaultGzipConfig.Skipper
+ config.Skipper = DefaultSkipper
+ }
+ if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
+ return nil, errors.New("invalid gzip level")
}
if config.Level == 0 {
- config.Level = DefaultGzipConfig.Level
+ config.Level = -1
}
+ if config.MinLength < 0 {
+ config.MinLength = 0
+ }
+
+ pool := gzipCompressPool(config)
+ bpool := bufferPool()
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) error {
+ return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
@@ -67,55 +92,144 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
res := c.Response()
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
- res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
- rw := res.Writer
- w, err := gzip.NewWriterLevel(rw, config.Level)
- if err != nil {
- return err
+ i := pool.Get()
+ w, ok := i.(*gzip.Writer)
+ if !ok {
+ return echo.NewHTTPError(http.StatusInternalServerError, "invalid pool object")
}
+ rw := res
+ w.Reset(rw)
+ buf := bpool.Get().(*bytes.Buffer)
+ buf.Reset()
+
+ grw := &gzipResponseWriter{
+ Writer: w,
+ ResponseWriter: rw,
+ minLength: config.MinLength,
+ buffer: buf,
+ }
+ c.SetResponse(grw)
defer func() {
- if res.Size == 0 {
+ // There are different reasons for cases when we have not yet written response to the client and now need to do so.
+ // a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now.
+ // b) body is shorter than our minimum length threshold and being buffered currently and needs to be written
+ if !grw.wroteBody {
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
res.Header().Del(echo.HeaderContentEncoding)
}
+ if grw.wroteHeader {
+ rw.WriteHeader(grw.code)
+ }
// We have to reset response to it's pristine state when
// nothing is written to body or error is returned.
// See issue #424, #407.
- res.Writer = rw
- w.Reset(ioutil.Discard)
+ c.SetResponse(rw)
+ w.Reset(io.Discard)
+ } else if !grw.minLengthExceeded {
+ // Write uncompressed response
+ c.SetResponse(rw)
+ if grw.wroteHeader {
+ grw.ResponseWriter.WriteHeader(grw.code)
+ }
+ _, _ = grw.buffer.WriteTo(rw)
+ w.Reset(io.Discard)
}
- w.Close()
+ _ = w.Close()
+ bpool.Put(buf)
+ pool.Put(w)
}()
- grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw}
- res.Writer = grw
}
return next(c)
}
- }
+ }, nil
}
func (w *gzipResponseWriter) WriteHeader(code int) {
- if code == http.StatusNoContent { // Issue #489
- w.ResponseWriter.Header().Del(echo.HeaderContentEncoding)
- }
w.Header().Del(echo.HeaderContentLength) // Issue #444
- w.ResponseWriter.WriteHeader(code)
+
+ w.wroteHeader = true
+
+ // Delay writing of the header until we know if we'll actually compress the response
+ w.code = code
}
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
if w.Header().Get(echo.HeaderContentType) == "" {
w.Header().Set(echo.HeaderContentType, http.DetectContentType(b))
}
+ w.wroteBody = true
+
+ if !w.minLengthExceeded {
+ n, err := w.buffer.Write(b)
+
+ if w.buffer.Len() >= w.minLength {
+ w.minLengthExceeded = true
+
+ // The minimum length is exceeded, add Content-Encoding header and write the header
+ w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
+ if w.wroteHeader {
+ w.ResponseWriter.WriteHeader(w.code)
+ }
+
+ return w.Writer.Write(w.buffer.Bytes())
+ }
+
+ return n, err
+ }
+
return w.Writer.Write(b)
}
func (w *gzipResponseWriter) Flush() {
- w.Writer.(*gzip.Writer).Flush()
- if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
- flusher.Flush()
+ if !w.minLengthExceeded {
+ // Enforce compression because we will not know how much more data will come
+ w.minLengthExceeded = true
+ w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
+ if w.wroteHeader {
+ w.ResponseWriter.WriteHeader(w.code)
+ }
+
+ _, _ = w.Writer.Write(w.buffer.Bytes())
+ }
+
+ if gw, ok := w.Writer.(*gzip.Writer); ok {
+ gw.Flush()
}
+ _ = http.NewResponseController(w.ResponseWriter).Flush()
}
func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
- return w.ResponseWriter.(http.Hijacker).Hijack()
+ return http.NewResponseController(w.ResponseWriter).Hijack()
+}
+
+func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
+ return w.ResponseWriter
+}
+
+func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
+ if p, ok := w.ResponseWriter.(http.Pusher); ok {
+ return p.Push(target, opts)
+ }
+ return http.ErrNotSupported
+}
+
+func gzipCompressPool(config GzipConfig) sync.Pool {
+ return sync.Pool{
+ New: func() any {
+ w, err := gzip.NewWriterLevel(io.Discard, config.Level)
+ if err != nil {
+ return err
+ }
+ return w
+ },
+ }
+}
+
+func bufferPool() sync.Pool {
+ return sync.Pool{
+ New: func() any {
+ b := &bytes.Buffer{}
+ return b
+ },
+ }
}
diff --git a/middleware/compress_test.go b/middleware/compress_test.go
index ac5b6c3bb..084ffc9c7 100644
--- a/middleware/compress_test.go
+++ b/middleware/compress_test.go
@@ -1,102 +1,140 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
"bytes"
"compress/gzip"
"io"
- "io/ioutil"
"net/http"
"net/http/httptest"
+ "os"
"testing"
+ "time"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
-func TestGzip(t *testing.T) {
+func TestGzip_NoAcceptEncodingHeader(t *testing.T) {
+ // Skip if no Accept-Encoding header
+ h := Gzip()(func(c *echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })
+
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- // Skip if no Accept-Encoding header
- h := Gzip()(func(c echo.Context) error {
+ err := h(c)
+ assert.NoError(t, err)
+
+ assert.Equal(t, "test", rec.Body.String())
+}
+
+func TestMustGzipWithConfig_panics(t *testing.T) {
+ assert.Panics(t, func() {
+ GzipWithConfig(GzipConfig{Level: 999})
+ })
+}
+
+func TestGzip_AcceptEncodingHeader(t *testing.T) {
+ h := Gzip()(func(c *echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
- h(c)
- assert := assert.New(t)
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
- assert.Equal("test", rec.Body.String())
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
- // Gzip
- req = httptest.NewRequest(http.MethodGet, "/", nil)
- req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
- h(c)
- assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
- assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
- r, err := gzip.NewReader(rec.Body)
- if assert.NoError(err) {
- buf := new(bytes.Buffer)
- defer r.Close()
- buf.ReadFrom(r)
- assert.Equal("test", buf.String())
- }
+ err := h(c)
+ assert.NoError(t, err)
- chunkBuf := make([]byte, 5)
+ assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
+ assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
- // Gzip chunked
- req = httptest.NewRequest(http.MethodGet, "/", nil)
+ r, err := gzip.NewReader(rec.Body)
+ assert.NoError(t, err)
+ buf := new(bytes.Buffer)
+ defer r.Close()
+ buf.ReadFrom(r)
+ assert.Equal(t, "test", buf.String())
+}
+
+func TestGzip_chunked(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
- rec = httptest.NewRecorder()
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
- c = e.NewContext(req, rec)
- Gzip()(func(c echo.Context) error {
+ chunkChan := make(chan struct{})
+ waitChan := make(chan struct{})
+ h := Gzip()(func(c *echo.Context) error {
+ rc := http.NewResponseController(c.Response())
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Transfer-Encoding", "chunked")
// Write and flush the first part of the data
- c.Response().Write([]byte("test\n"))
- c.Response().Flush()
-
- // Read the first part of the data
- assert.True(rec.Flushed)
- assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
- r.Reset(rec.Body)
+ c.Response().Write([]byte("first\n"))
+ rc.Flush()
- _, err = io.ReadFull(r, chunkBuf)
- assert.NoError(err)
- assert.Equal("test\n", string(chunkBuf))
+ chunkChan <- struct{}{}
+ <-waitChan
// Write and flush the second part of the data
- c.Response().Write([]byte("test\n"))
- c.Response().Flush()
+ c.Response().Write([]byte("second\n"))
+ rc.Flush()
- _, err = io.ReadFull(r, chunkBuf)
- assert.NoError(err)
- assert.Equal("test\n", string(chunkBuf))
+ chunkChan <- struct{}{}
+ <-waitChan
// Write the final part of the data and return
- c.Response().Write([]byte("test"))
+ c.Response().Write([]byte("third"))
+
+ chunkChan <- struct{}{}
return nil
- })(c)
+ })
+
+ go func() {
+ err := h(c)
+ chunkChan <- struct{}{}
+ assert.NoError(t, err)
+ }()
+ <-chunkChan // wait for first write
+ waitChan <- struct{}{}
+
+ <-chunkChan // wait for second write
+ waitChan <- struct{}{}
+
+ <-chunkChan // wait for final write in handler
+ <-chunkChan // wait for return from handler
+ time.Sleep(5 * time.Millisecond) // to have time for flushing
+
+ assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
+
+ r, err := gzip.NewReader(rec.Body)
+ assert.NoError(t, err)
buf := new(bytes.Buffer)
- defer r.Close()
buf.ReadFrom(r)
- assert.Equal("test", buf.String())
+ assert.Equal(t, "first\nsecond\nthird", buf.String())
}
-func TestGzipNoContent(t *testing.T) {
+func TestGzip_NoContent(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- h := Gzip()(func(c echo.Context) error {
+ h := Gzip()(func(c *echo.Context) error {
return c.NoContent(http.StatusNoContent)
})
if assert.NoError(t, h(c)) {
@@ -106,10 +144,31 @@ func TestGzipNoContent(t *testing.T) {
}
}
-func TestGzipErrorReturned(t *testing.T) {
+func TestGzip_Empty(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := Gzip()(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "")
+ })
+ if assert.NoError(t, h(c)) {
+ assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
+ assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get(echo.HeaderContentType))
+ r, err := gzip.NewReader(rec.Body)
+ if assert.NoError(t, err) {
+ var buf bytes.Buffer
+ buf.ReadFrom(r)
+ assert.Equal(t, "", buf.String())
+ }
+ }
+}
+
+func TestGzip_ErrorReturned(t *testing.T) {
e := echo.New()
e.Use(Gzip())
- e.GET("/", func(c echo.Context) error {
+ e.GET("/", func(c *echo.Context) error {
return echo.ErrNotFound
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
@@ -120,15 +179,25 @@ func TestGzipErrorReturned(t *testing.T) {
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
}
+func TestGzipWithConfig_invalidLevel(t *testing.T) {
+ mw, err := GzipConfig{Level: 12}.ToMiddleware()
+ assert.EqualError(t, err, "invalid gzip level")
+ assert.Nil(t, mw)
+}
+
// Issue #806
func TestGzipWithStatic(t *testing.T) {
e := echo.New()
+ e.Filesystem = os.DirFS("../")
+
e.Use(Gzip())
- e.Static("/test", "../_fixture/images")
+ e.Static("/test", "_fixture/images")
req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
+
e.ServeHTTP(rec, req)
+
assert.Equal(t, http.StatusOK, rec.Code)
// Data is written out in chunks when Content-Length == "", so only
// validate the content length if it's not set.
@@ -138,7 +207,7 @@ func TestGzipWithStatic(t *testing.T) {
r, err := gzip.NewReader(rec.Body)
if assert.NoError(t, err) {
defer r.Close()
- want, err := ioutil.ReadFile("../_fixture/images/walle.png")
+ want, err := os.ReadFile("../_fixture/images/walle.png")
if assert.NoError(t, err) {
buf := new(bytes.Buffer)
buf.ReadFrom(r)
@@ -146,3 +215,184 @@ func TestGzipWithStatic(t *testing.T) {
}
}
}
+
+func TestGzipWithMinLength(t *testing.T) {
+ e := echo.New()
+ // Minimal response length
+ e.Use(GzipWithConfig(GzipConfig{MinLength: 10}))
+ e.GET("/", func(c *echo.Context) error {
+ c.Response().Write([]byte("foobarfoobar"))
+ return nil
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
+ r, err := gzip.NewReader(rec.Body)
+ if assert.NoError(t, err) {
+ buf := new(bytes.Buffer)
+ defer r.Close()
+ buf.ReadFrom(r)
+ assert.Equal(t, "foobarfoobar", buf.String())
+ }
+}
+
+func TestGzipWithMinLengthTooShort(t *testing.T) {
+ e := echo.New()
+ // Minimal response length
+ e.Use(GzipWithConfig(GzipConfig{MinLength: 10}))
+ e.GET("/", func(c *echo.Context) error {
+ c.Response().Write([]byte("test"))
+ return nil
+ })
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding))
+ assert.Contains(t, rec.Body.String(), "test")
+}
+
+func TestGzipWithResponseWithoutBody(t *testing.T) {
+ e := echo.New()
+
+ e.Use(Gzip())
+ e.GET("/", func(c *echo.Context) error {
+ return c.Redirect(http.StatusMovedPermanently, "http://localhost")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusMovedPermanently, rec.Code)
+ assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding))
+}
+
+func TestGzipWithMinLengthChunked(t *testing.T) {
+ e := echo.New()
+
+ // Gzip chunked
+ chunkBuf := make([]byte, 5)
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+
+ var r *gzip.Reader = nil
+
+ c := e.NewContext(req, rec)
+ next := func(c *echo.Context) error {
+ rc := http.NewResponseController(c.Response())
+ c.Response().Header().Set("Content-Type", "text/event-stream")
+ c.Response().Header().Set("Transfer-Encoding", "chunked")
+
+ // Write and flush the first part of the data
+ c.Response().Write([]byte("test\n"))
+ rc.Flush()
+
+ // Read the first part of the data
+ assert.True(t, rec.Flushed)
+ assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
+
+ var err error
+ r, err = gzip.NewReader(rec.Body)
+ assert.NoError(t, err)
+
+ _, err = io.ReadFull(r, chunkBuf)
+ assert.NoError(t, err)
+ assert.Equal(t, "test\n", string(chunkBuf))
+
+ // Write and flush the second part of the data
+ c.Response().Write([]byte("test\n"))
+ rc.Flush()
+
+ _, err = io.ReadFull(r, chunkBuf)
+ assert.NoError(t, err)
+ assert.Equal(t, "test\n", string(chunkBuf))
+
+ // Write the final part of the data and return
+ c.Response().Write([]byte("test"))
+ return nil
+ }
+ err := GzipWithConfig(GzipConfig{MinLength: 10})(next)(c)
+
+ assert.NoError(t, err)
+ assert.NotNil(t, r)
+
+ buf := new(bytes.Buffer)
+
+ buf.ReadFrom(r)
+ assert.Equal(t, "test", buf.String())
+
+ r.Close()
+}
+
+func TestGzipWithMinLengthNoContent(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c *echo.Context) error {
+ return c.NoContent(http.StatusNoContent)
+ })
+ if assert.NoError(t, h(c)) {
+ assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
+ assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
+ assert.Equal(t, 0, len(rec.Body.Bytes()))
+ }
+}
+
+func TestGzipResponseWriter_CanUnwrap(t *testing.T) {
+ trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
+ bdrw := gzipResponseWriter{
+ ResponseWriter: trwu,
+ }
+ result := bdrw.Unwrap()
+ assert.Equal(t, trwu, result)
+}
+
+func TestGzipResponseWriter_CanHijack(t *testing.T) {
+ trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
+ bdrw := gzipResponseWriter{
+ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
+ }
+ _, _, err := bdrw.Hijack()
+ assert.EqualError(t, err, "can hijack")
+}
+
+func TestGzipResponseWriter_CanNotHijack(t *testing.T) {
+ trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
+ bdrw := gzipResponseWriter{
+ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
+ }
+ _, _, err := bdrw.Hijack()
+ assert.EqualError(t, err, "feature not supported")
+}
+
+func BenchmarkGzip(b *testing.B) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+
+ h := Gzip()(func(c *echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ // Gzip
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h(c)
+ }
+}
diff --git a/middleware/context_timeout.go b/middleware/context_timeout.go
new file mode 100644
index 000000000..68465199a
--- /dev/null
+++ b/middleware/context_timeout.go
@@ -0,0 +1,71 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "context"
+ "errors"
+ "time"
+
+ "github.com/labstack/echo/v5"
+)
+
+// ContextTimeoutConfig defines the config for ContextTimeout middleware.
+type ContextTimeoutConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // ErrorHandler is a function when error arises in middeware execution.
+ ErrorHandler func(c *echo.Context, err error) error
+
+ // Timeout configures a timeout for the middleware
+ Timeout time.Duration
+}
+
+// ContextTimeout returns a middleware which returns error (503 Service Unavailable error) to client
+// when underlying method returns context.DeadlineExceeded error.
+func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc {
+ return ContextTimeoutWithConfig(ContextTimeoutConfig{Timeout: timeout})
+}
+
+// ContextTimeoutWithConfig returns a Timeout middleware with config.
+func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts Config to middleware.
+func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Timeout == 0 {
+ return nil, errors.New("timeout must be set")
+ }
+ if config.Skipper == nil {
+ config.Skipper = DefaultSkipper
+ }
+ if config.ErrorHandler == nil {
+ config.ErrorHandler = func(c *echo.Context, err error) error {
+ if err != nil && errors.Is(err, context.DeadlineExceeded) {
+ return echo.ErrServiceUnavailable.Wrap(err)
+ }
+ return err
+ }
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ timeoutContext, cancel := context.WithTimeout(c.Request().Context(), config.Timeout)
+ defer cancel()
+
+ c.SetRequest(c.Request().WithContext(timeoutContext))
+
+ if err := next(c); err != nil {
+ return config.ErrorHandler(c, err)
+ }
+ return nil
+ }
+ }, nil
+}
diff --git a/middleware/context_timeout_test.go b/middleware/context_timeout_test.go
new file mode 100644
index 000000000..c7ba76beb
--- /dev/null
+++ b/middleware/context_timeout_test.go
@@ -0,0 +1,229 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "context"
+ "errors"
+ "github.com/labstack/echo/v5"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestContextTimeoutSkipper(t *testing.T) {
+ t.Parallel()
+ m := ContextTimeoutWithConfig(ContextTimeoutConfig{
+ Skipper: func(context *echo.Context) bool {
+ return true
+ },
+ Timeout: 10 * time.Millisecond,
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e := echo.New()
+ c := e.NewContext(req, rec)
+
+ err := m(func(c *echo.Context) error {
+ if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil {
+ return err
+ }
+
+ return errors.New("response from handler")
+ })(c)
+
+ // if not skipped we would have not returned error due context timeout logic
+ assert.EqualError(t, err, "response from handler")
+}
+
+func TestContextTimeoutWithTimeout0(t *testing.T) {
+ t.Parallel()
+ assert.Panics(t, func() {
+ ContextTimeout(time.Duration(0))
+ })
+}
+
+func TestContextTimeoutErrorOutInHandler(t *testing.T) {
+ t.Parallel()
+ m := ContextTimeoutWithConfig(ContextTimeoutConfig{
+ // Timeout has to be defined or the whole flow for timeout middleware will be skipped
+ Timeout: 10 * time.Millisecond,
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e := echo.New()
+ c := e.NewContext(req, rec)
+
+ rec.Code = 1 // we want to be sure that even 200 will not be sent
+ err := m(func(c *echo.Context) error {
+ // this error must not be written to the client response. Middlewares upstream of timeout middleware must be able
+ // to handle returned error and this can be done only then handler has not yet committed (written status code)
+ // the response.
+ return echo.NewHTTPError(http.StatusTeapot, "err")
+ })(c)
+
+ assert.Error(t, err)
+ assert.EqualError(t, err, "code=418, message=err")
+ assert.Equal(t, 1, rec.Code)
+ assert.Equal(t, "", rec.Body.String())
+}
+
+func TestContextTimeoutSuccessfulRequest(t *testing.T) {
+ t.Parallel()
+ m := ContextTimeoutWithConfig(ContextTimeoutConfig{
+ // Timeout has to be defined or the whole flow for timeout middleware will be skipped
+ Timeout: 10 * time.Millisecond,
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e := echo.New()
+ c := e.NewContext(req, rec)
+
+ err := m(func(c *echo.Context) error {
+ return c.JSON(http.StatusCreated, map[string]string{"data": "ok"})
+ })(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusCreated, rec.Code)
+ assert.Equal(t, "{\"data\":\"ok\"}\n", rec.Body.String())
+}
+
+func TestContextTimeoutTestRequestClone(t *testing.T) {
+ t.Parallel()
+ req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode()))
+ req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"})
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rec := httptest.NewRecorder()
+
+ m := ContextTimeoutWithConfig(ContextTimeoutConfig{
+ // Timeout has to be defined or the whole flow for timeout middleware will be skipped
+ Timeout: 1 * time.Second,
+ })
+
+ e := echo.New()
+ c := e.NewContext(req, rec)
+
+ err := m(func(c *echo.Context) error {
+ // Cookie test
+ cookie, err := c.Request().Cookie("cookie")
+ if assert.NoError(t, err) {
+ assert.EqualValues(t, "cookie", cookie.Name)
+ assert.EqualValues(t, "value", cookie.Value)
+ }
+
+ // Form values
+ if assert.NoError(t, c.Request().ParseForm()) {
+ assert.EqualValues(t, "value", c.Request().FormValue("form"))
+ }
+
+ // Query string
+ assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0])
+ return nil
+ })(c)
+
+ assert.NoError(t, err)
+}
+
+func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) {
+ t.Parallel()
+
+ timeout := 10 * time.Millisecond
+ m := ContextTimeoutWithConfig(ContextTimeoutConfig{
+ Timeout: timeout,
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e := echo.New()
+ c := e.NewContext(req, rec)
+
+ err := m(func(c *echo.Context) error {
+ if err := sleepWithContext(c.Request().Context(), time.Duration(80*time.Millisecond)); err != nil {
+ return err
+ }
+ return c.String(http.StatusOK, "Hello, World!")
+ })(c)
+
+ assert.Error(t, err)
+ if assert.IsType(t, &echo.HTTPError{}, err) {
+ assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
+ assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message)
+ }
+}
+
+func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) {
+ t.Parallel()
+
+ timeoutErrorHandler := func(c *echo.Context, err error) error {
+ if err != nil {
+ if errors.Is(err, context.DeadlineExceeded) {
+ return &echo.HTTPError{
+ Code: http.StatusServiceUnavailable,
+ Message: "Timeout! change me",
+ }
+ }
+ return err
+ }
+ return nil
+ }
+
+ timeout := 50 * time.Millisecond
+ m := ContextTimeoutWithConfig(ContextTimeoutConfig{
+ Timeout: timeout,
+ ErrorHandler: timeoutErrorHandler,
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e := echo.New()
+ c := e.NewContext(req, rec)
+
+ err := m(func(c *echo.Context) error {
+ // NOTE: Very short periods are not reliable for tests due to Go routine scheduling and the unpredictable order
+ // for 1) request and 2) time goroutine. For most OS this works as expected, but MacOS seems most flaky.
+
+ if err := sleepWithContext(c.Request().Context(), 100*time.Millisecond); err != nil {
+ return err
+ }
+
+ // The Request Context should have a Deadline set by http.ContextTimeoutHandler
+ if _, ok := c.Request().Context().Deadline(); !ok {
+ assert.Fail(t, "No timeout set on Request Context")
+ }
+ return c.String(http.StatusOK, "Hello, World!")
+ })(c)
+
+ assert.IsType(t, &echo.HTTPError{}, err)
+ assert.Error(t, err)
+ assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
+ assert.Equal(t, "Timeout! change me", err.(*echo.HTTPError).Message)
+}
+
+func sleepWithContext(ctx context.Context, d time.Duration) error {
+ timer := time.NewTimer(d)
+
+ defer func() {
+ _ = timer.Stop()
+ }()
+
+ select {
+ case <-ctx.Done():
+ return context.DeadlineExceeded
+ case <-timer.C:
+ return nil
+ }
+}
diff --git a/middleware/cors.go b/middleware/cors.go
index 5dfe31f95..96ed16985 100644
--- a/middleware/cors.go
+++ b/middleware/cors.go
@@ -1,88 +1,190 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
+ "errors"
+ "fmt"
"net/http"
"strconv"
"strings"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
)
-type (
- // CORSConfig defines the config for CORS middleware.
- CORSConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
-
- // AllowOrigin defines a list of origins that may access the resource.
- // Optional. Default value []string{"*"}.
- AllowOrigins []string `yaml:"allow_origins"`
-
- // AllowMethods defines a list methods allowed when accessing the resource.
- // This is used in response to a preflight request.
- // Optional. Default value DefaultCORSConfig.AllowMethods.
- AllowMethods []string `yaml:"allow_methods"`
-
- // AllowHeaders defines a list of request headers that can be used when
- // making the actual request. This is in response to a preflight request.
- // Optional. Default value []string{}.
- AllowHeaders []string `yaml:"allow_headers"`
-
- // AllowCredentials indicates whether or not the response to the request
- // can be exposed when the credentials flag is true. When used as part of
- // a response to a preflight request, this indicates whether or not the
- // actual request can be made using credentials.
- // Optional. Default value false.
- AllowCredentials bool `yaml:"allow_credentials"`
-
- // ExposeHeaders defines a whitelist headers that clients are allowed to
- // access.
- // Optional. Default value []string{}.
- ExposeHeaders []string `yaml:"expose_headers"`
-
- // MaxAge indicates how long (in seconds) the results of a preflight request
- // can be cached.
- // Optional. Default value 0.
- MaxAge int `yaml:"max_age"`
- }
-)
+// CORSConfig defines the config for CORS middleware.
+type CORSConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
-var (
- // DefaultCORSConfig is the default CORS middleware config.
- DefaultCORSConfig = CORSConfig{
- Skipper: DefaultSkipper,
- AllowOrigins: []string{"*"},
- AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
- }
-)
+ // AllowOrigins determines the value of the Access-Control-Allow-Origin
+ // response header. This header defines a list of origins that may access the
+ // resource.
+ //
+ // Origin consist of following parts: `scheme + "://" + host + optional ":" + port`
+ // Wildcard can be used, but has to be set explicitly []string{"*"}
+ // Example: `https://example.com`, `http://example.com:8080`, `*`
+ //
+ // Security: use extreme caution when handling the origin, and carefully
+ // validate any logic. Remember that attackers may register hostile domain names.
+ // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
+ // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
+ //
+ // Mandatory.
+ AllowOrigins []string
+
+ // UnsafeAllowOriginFunc is an optional custom function to validate the origin. It takes the
+ // origin as an argument and returns
+ // - string, allowed origin
+ // - bool, true if allowed or false otherwise.
+ // - error, if an error is returned, it is returned immediately by the handler.
+ // If this option is set, AllowOrigins is ignored.
+ //
+ // Security: use extreme caution when handling the origin, and carefully
+ // validate any logic. Remember that attackers may register hostile (sub)domain names.
+ // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
+ //
+ // Sub-domain checks example:
+ // UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) {
+ // if strings.HasSuffix(origin, ".example.com") {
+ // return origin, true, nil
+ // }
+ // return "", false, nil
+ // },
+ //
+ // Optional.
+ UnsafeAllowOriginFunc func(c *echo.Context, origin string) (allowedOrigin string, allowed bool, err error)
+
+ // AllowMethods determines the value of the Access-Control-Allow-Methods
+ // response header. This header specified the list of methods allowed when
+ // accessing the resource. This is used in response to a preflight request.
+ //
+ // Optional. Default value DefaultCORSConfig.AllowMethods.
+ // If `allowMethods` is left empty, this middleware will fill for preflight
+ // request `Access-Control-Allow-Methods` header value
+ // from `Allow` header that echo.Router set into context.
+ //
+ // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
+ AllowMethods []string
+
+ // AllowHeaders determines the value of the Access-Control-Allow-Headers
+ // response header. This header is used in response to a preflight request to
+ // indicate which HTTP headers can be used when making the actual request.
+ //
+ // Optional. Defaults to empty list. No domains allowed for CORS.
+ //
+ // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
+ AllowHeaders []string
+
+ // AllowCredentials determines the value of the
+ // Access-Control-Allow-Credentials response header. This header indicates
+ // whether or not the response to the request can be exposed when the
+ // credentials mode (Request.credentials) is true. When used as part of a
+ // response to a preflight request, this indicates whether or not the actual
+ // request can be made using credentials. See also
+ // [MDN: Access-Control-Allow-Credentials].
+ //
+ // Optional. Default value false, in which case the header is not set.
+ //
+ // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
+ // See "Exploiting CORS misconfigurations for Bitcoins and bounties",
+ // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
+ //
+ // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
+ AllowCredentials bool
+
+ // ExposeHeaders determines the value of Access-Control-Expose-Headers, which
+ // defines a list of headers that clients are allowed to access.
+ //
+ // Optional. Default value []string{}, in which case the header is not set.
+ //
+ // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header
+ ExposeHeaders []string
+
+ // MaxAge determines the value of the Access-Control-Max-Age response header.
+ // This header indicates how long (in seconds) the results of a preflight
+ // request can be cached.
+ // The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response.
+ //
+ // Optional. Default value 0 - meaning header is not sent.
+ //
+ // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
+ MaxAge int
+}
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
-// See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
-func CORS() echo.MiddlewareFunc {
- return CORSWithConfig(DefaultCORSConfig)
+// See also [MDN: Cross-Origin Resource Sharing (CORS)].
+//
+// Origin consist of following parts: `scheme + "://" + host + optional ":" + port`
+// Wildcard `*` can be used, but has to be set explicitly.
+// Example: `https://example.com`, `http://example.com:8080`, `*`
+//
+// Security: Poorly configured CORS can compromise security because it allows
+// relaxation of the browser's Same-Origin policy. See [Exploiting CORS
+// misconfigurations for Bitcoins and bounties] and [Portswigger: Cross-origin
+// resource sharing (CORS)] for more details.
+//
+// [MDN: Cross-Origin Resource Sharing (CORS)]: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
+// [Exploiting CORS misconfigurations for Bitcoins and bounties]: https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
+// [Portswigger: Cross-origin resource sharing (CORS)]: https://portswigger.net/web-security/cors
+func CORS(allowOrigins ...string) echo.MiddlewareFunc {
+ c := CORSConfig{
+ AllowOrigins: allowOrigins,
+ }
+ return CORSWithConfig(c)
}
-// CORSWithConfig returns a CORS middleware with config.
-// See: `CORS()`.
+// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration.
+// See: [CORS].
func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration
+func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
- config.Skipper = DefaultCORSConfig.Skipper
- }
- if len(config.AllowOrigins) == 0 {
- config.AllowOrigins = DefaultCORSConfig.AllowOrigins
+ config.Skipper = DefaultSkipper
}
+ hasCustomAllowMethods := true
if len(config.AllowMethods) == 0 {
- config.AllowMethods = DefaultCORSConfig.AllowMethods
+ hasCustomAllowMethods = false
+ config.AllowMethods = []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}
}
allowMethods := strings.Join(config.AllowMethods, ",")
allowHeaders := strings.Join(config.AllowHeaders, ",")
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
- maxAge := strconv.Itoa(config.MaxAge)
+
+ maxAge := "0"
+ if config.MaxAge > 0 {
+ maxAge = strconv.Itoa(config.MaxAge)
+ }
+
+ allowOriginFunc := config.UnsafeAllowOriginFunc
+ if config.UnsafeAllowOriginFunc == nil {
+ if len(config.AllowOrigins) == 0 {
+ return nil, errors.New("at least one AllowOrigins is required or UnsafeAllowOriginFunc must be provided")
+ }
+ allowOriginFunc = config.defaultAllowOriginFunc
+ for _, origin := range config.AllowOrigins {
+ if origin == "*" {
+ if config.AllowCredentials {
+ return nil, fmt.Errorf("* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc")
+ }
+ allowOriginFunc = config.starAllowOriginFunc
+ break
+ }
+ if err := validateOrigin(origin, "allow origin"); err != nil {
+ return nil, err
+ }
+ }
+ config.AllowOrigins = append([]string(nil), config.AllowOrigins...)
+ }
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) error {
+ return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
@@ -90,46 +192,84 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
req := c.Request()
res := c.Response()
origin := req.Header.Get(echo.HeaderOrigin)
- allowOrigin := ""
- // Check allowed origins
- for _, o := range config.AllowOrigins {
- if o == "*" && config.AllowCredentials {
- allowOrigin = origin
- break
- }
- if o == "*" || o == origin {
- allowOrigin = o
- break
+ res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
+
+ // Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
+ // Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
+ // For simplicity we just consider method type and later `Origin` header.
+ preflight := req.Method == http.MethodOptions
+
+ // Although router adds special handler in case of OPTIONS method we avoid calling next for OPTIONS in this middleware
+ // as CORS requests do not have cookies / authentication headers by default, so we could get stuck in auth
+ // middlewares by calling next(c).
+ // But we still want to send `Allow` header as response in case of Non-CORS OPTIONS request as router default
+ // handler does.
+ routerAllowMethods := ""
+ if preflight {
+ tmpAllowMethods, ok := c.Get(echo.ContextKeyHeaderAllow).(string)
+ if ok && tmpAllowMethods != "" {
+ routerAllowMethods = tmpAllowMethods
+ c.Response().Header().Set(echo.HeaderAllow, routerAllowMethods)
}
- if matchSubdomain(origin, o) {
- allowOrigin = origin
- break
+ }
+
+ // No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
+ if origin == "" {
+ if preflight { // req.Method=OPTIONS
+ return c.NoContent(http.StatusNoContent)
}
+ return next(c) // let non-browser calls through
}
- // Simple request
- if req.Method != http.MethodOptions {
- res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
- res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
- if config.AllowCredentials {
- res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
+ allowedOrigin, allowed, err := allowOriginFunc(c, origin)
+ if err != nil {
+ return err
+ }
+ if !allowed {
+ // Origin existed and was NOT allowed
+ if preflight {
+ // From: https://github.com/labstack/echo/issues/2767
+ // If the request's origin isn't allowed by the CORS configuration,
+ // the middleware should simply omit the relevant CORS headers from the response
+ // and let the browser fail the CORS check (if any).
+ return c.NoContent(http.StatusNoContent)
}
+ // From: https://github.com/labstack/echo/issues/2767
+ // no CORS middleware should block non-preflight requests;
+ // such requests should be let through. One reason is that not all requests that
+ // carry an Origin header participate in the CORS protocol.
+ return next(c)
+ }
+
+ // Origin existed and was allowed
+
+ res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
+ if config.AllowCredentials {
+ res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
+ }
+
+ // Simple request will be let though
+ if !preflight {
if exposeHeaders != "" {
res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
}
return next(c)
}
-
- // Preflight request
- res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
+ // Below code is for Preflight (OPTIONS) request
+ //
+ // Preflight will end with c.NoContent(http.StatusNoContent) as we do not know if
+ // at the end of handler chain is actual OPTIONS route or 404/405 route which
+ // response code will confuse browsers
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
- res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
- res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods)
- if config.AllowCredentials {
- res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
+
+ if !hasCustomAllowMethods && routerAllowMethods != "" {
+ res.Header().Set(echo.HeaderAccessControlAllowMethods, routerAllowMethods)
+ } else {
+ res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods)
}
+
if allowHeaders != "" {
res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders)
} else {
@@ -138,10 +278,23 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
res.Header().Set(echo.HeaderAccessControlAllowHeaders, h)
}
}
- if config.MaxAge > 0 {
+ if config.MaxAge != 0 {
res.Header().Set(echo.HeaderAccessControlMaxAge, maxAge)
}
return c.NoContent(http.StatusNoContent)
}
+ }, nil
+}
+
+func (config CORSConfig) starAllowOriginFunc(c *echo.Context, origin string) (string, bool, error) {
+ return "*", true, nil
+}
+
+func (config CORSConfig) defaultAllowOriginFunc(c *echo.Context, origin string) (string, bool, error) {
+ for _, allowedOrigin := range config.AllowOrigins {
+ if strings.EqualFold(allowedOrigin, origin) {
+ return allowedOrigin, true, nil
+ }
}
+ return "", false, nil
}
diff --git a/middleware/cors_test.go b/middleware/cors_test.go
index acfdf47bc..5de4ca063 100644
--- a/middleware/cors_test.go
+++ b/middleware/cors_test.go
@@ -1,85 +1,628 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
+ "cmp"
+ "errors"
"net/http"
"net/http/httptest"
+ "strings"
"testing"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestCORS(t *testing.T) {
e := echo.New()
-
- // Wildcard origin
- req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req := httptest.NewRequest(http.MethodOptions, "/", nil) // Preflight request
+ req.Header.Set(echo.HeaderOrigin, "http://example.com")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- h := CORS()(echo.NotFoundHandler)
- h(c)
- assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
- // Allow origins
- req = httptest.NewRequest(http.MethodGet, "/", nil)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
- h = CORSWithConfig(CORSConfig{
- AllowOrigins: []string{"localhost"},
- })(echo.NotFoundHandler)
- req.Header.Set(echo.HeaderOrigin, "localhost")
- h(c)
- assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
-
- // Preflight request
- req = httptest.NewRequest(http.MethodOptions, "/", nil)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
- req.Header.Set(echo.HeaderOrigin, "localhost")
- req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
- cors := CORSWithConfig(CORSConfig{
- AllowOrigins: []string{"localhost"},
- AllowCredentials: true,
- MaxAge: 3600,
+ mw := CORS("*")
+ handler := mw(func(c *echo.Context) error {
+ return nil
})
- h = cors(echo.NotFoundHandler)
- h(c)
- assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
- assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
- assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
- assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
-
- // Preflight request with `AllowOrigins` *
- req = httptest.NewRequest(http.MethodOptions, "/", nil)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
- req.Header.Set(echo.HeaderOrigin, "localhost")
- req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
- cors = CORSWithConfig(CORSConfig{
- AllowOrigins: []string{"*"},
- AllowCredentials: true,
- MaxAge: 3600,
- })
- h = cors(echo.NotFoundHandler)
- h(c)
- assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
- assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
- assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
- assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
-
- // Preflight request with `AllowOrigins` which allow all subdomains with *
- req = httptest.NewRequest(http.MethodOptions, "/", nil)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
- req.Header.Set(echo.HeaderOrigin, "http://aaa.example.com")
- cors = CORSWithConfig(CORSConfig{
- AllowOrigins: []string{"http://*.example.com"},
- })
- h = cors(echo.NotFoundHandler)
- h(c)
- assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
- req.Header.Set(echo.HeaderOrigin, "http://bbb.example.com")
- h(c)
- assert.Equal(t, "http://bbb.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ err := handler(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusNoContent, rec.Code)
+ assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+}
+
+func TestCORSConfig(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenConfig *CORSConfig
+ whenMethod string
+ whenHeaders map[string]string
+ expectHeaders map[string]string
+ notExpectHeaders map[string]string
+ expectErr string
+ }{
+ {
+ name: "ok, wildcard origin",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"*"},
+ },
+ whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"},
+ expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "*"},
+ },
+ {
+ name: "ok, wildcard AllowedOrigin with no Origin header in request",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"*"},
+ },
+ notExpectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: ""},
+ },
+ {
+ name: "ok, specific AllowOrigins and AllowCredentials",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"http://localhost", "http://localhost:8080"},
+ AllowCredentials: true,
+ MaxAge: 3600,
+ },
+ whenHeaders: map[string]string{echo.HeaderOrigin: "http://localhost"},
+ expectHeaders: map[string]string{
+ echo.HeaderAccessControlAllowOrigin: "http://localhost",
+ echo.HeaderAccessControlAllowCredentials: "true",
+ },
+ },
+ {
+ name: "ok, preflight request with matching origin for `AllowOrigins`",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"http://localhost"},
+ AllowCredentials: true,
+ MaxAge: 3600,
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{
+ echo.HeaderOrigin: "http://localhost",
+ echo.HeaderContentType: echo.MIMEApplicationJSON,
+ },
+ expectHeaders: map[string]string{
+ echo.HeaderAccessControlAllowOrigin: "http://localhost",
+ echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
+ echo.HeaderAccessControlAllowCredentials: "true",
+ echo.HeaderAccessControlMaxAge: "3600",
+ },
+ },
+ {
+ name: "ok, preflight request when `Access-Control-Max-Age` is set",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"http://localhost"},
+ AllowCredentials: true,
+ MaxAge: 1,
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{
+ echo.HeaderOrigin: "http://localhost",
+ echo.HeaderContentType: echo.MIMEApplicationJSON,
+ },
+ expectHeaders: map[string]string{
+ echo.HeaderAccessControlMaxAge: "1",
+ },
+ },
+ {
+ name: "ok, preflight request when `Access-Control-Max-Age` is set to 0 - not to cache response",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"http://localhost"},
+ AllowCredentials: true,
+ MaxAge: -1, // forces `Access-Control-Max-Age: 0`
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{
+ echo.HeaderOrigin: "http://localhost",
+ echo.HeaderContentType: echo.MIMEApplicationJSON,
+ },
+ expectHeaders: map[string]string{
+ echo.HeaderAccessControlMaxAge: "0",
+ },
+ },
+ {
+ name: "ok, CORS check are skipped",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"http://localhost"},
+ AllowCredentials: true,
+ Skipper: func(c *echo.Context) bool {
+ return true
+ },
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{
+ echo.HeaderOrigin: "http://localhost",
+ echo.HeaderContentType: echo.MIMEApplicationJSON,
+ },
+ notExpectHeaders: map[string]string{
+ echo.HeaderAccessControlAllowOrigin: "localhost",
+ echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
+ echo.HeaderAccessControlAllowCredentials: "true",
+ echo.HeaderAccessControlMaxAge: "3600",
+ },
+ },
+ {
+ name: "nok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"*"},
+ AllowCredentials: true,
+ MaxAge: 3600,
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{
+ echo.HeaderOrigin: "localhost",
+ echo.HeaderContentType: echo.MIMEApplicationJSON,
+ },
+ expectErr: `* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc`,
+ },
+ {
+ name: "nok, preflight request with invalid `AllowOrigins` value",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"http://server", "missing-scheme"},
+ },
+ expectErr: `allow origin is missing scheme or host: missing-scheme`,
+ },
+ {
+ name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` false",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"*"},
+ AllowCredentials: false, // important for this testcase
+ MaxAge: 3600,
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{
+ echo.HeaderOrigin: "localhost",
+ echo.HeaderContentType: echo.MIMEApplicationJSON,
+ },
+ expectHeaders: map[string]string{
+ echo.HeaderAccessControlAllowOrigin: "*",
+ echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
+ echo.HeaderAccessControlMaxAge: "3600",
+ },
+ notExpectHeaders: map[string]string{
+ echo.HeaderAccessControlAllowCredentials: "",
+ },
+ },
+ {
+ name: "ok, INSECURE preflight request with wildcard `AllowOrigins` and `AllowCredentials` true",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"*"},
+ AllowCredentials: true,
+ MaxAge: 3600,
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{
+ echo.HeaderOrigin: "localhost",
+ echo.HeaderContentType: echo.MIMEApplicationJSON,
+ },
+ expectErr: `* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc`,
+ },
+ {
+ name: "ok, preflight request with Access-Control-Request-Headers",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"*"},
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{
+ echo.HeaderOrigin: "localhost",
+ echo.HeaderContentType: echo.MIMEApplicationJSON,
+ echo.HeaderAccessControlRequestHeaders: "Special-Request-Header",
+ },
+ expectHeaders: map[string]string{
+ echo.HeaderAccessControlAllowOrigin: "*",
+ echo.HeaderAccessControlAllowHeaders: "Special-Request-Header",
+ echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
+ },
+ },
+ {
+ name: "ok, preflight request with `AllowOrigins` which allow all subdomains aaa with *",
+ givenConfig: &CORSConfig{
+ UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (allowedOrigin string, allowed bool, err error) {
+ if strings.HasSuffix(origin, ".example.com") {
+ allowed = true
+ }
+ return origin, allowed, nil
+ },
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{echo.HeaderOrigin: "http://aaa.example.com"},
+ expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://aaa.example.com"},
+ },
+ {
+ name: "ok, preflight request with `AllowOrigins` which allow all subdomains bbb with *",
+ givenConfig: &CORSConfig{
+ UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) {
+ if strings.HasSuffix(origin, ".example.com") {
+ return origin, true, nil
+ }
+ return "", false, nil
+ },
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{echo.HeaderOrigin: "http://bbb.example.com"},
+ expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://bbb.example.com"},
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ var mw echo.MiddlewareFunc
+ var err error
+ if tc.givenConfig != nil {
+ mw, err = tc.givenConfig.ToMiddleware()
+ } else {
+ mw, err = CORSConfig{}.ToMiddleware()
+ }
+ if err != nil {
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ return
+ }
+ t.Fatal(err)
+ }
+
+ h := mw(func(c *echo.Context) error {
+ return nil
+ })
+
+ method := cmp.Or(tc.whenMethod, http.MethodGet)
+ req := httptest.NewRequest(method, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ for k, v := range tc.whenHeaders {
+ req.Header.Set(k, v)
+ }
+
+ err = h(c)
+
+ assert.NoError(t, err)
+ header := rec.Header()
+ for k, v := range tc.expectHeaders {
+ assert.Equal(t, v, header.Get(k), "header: `%v` should be `%v`", k, v)
+ }
+ for k, v := range tc.notExpectHeaders {
+ if v == "" {
+ assert.Len(t, header.Values(k), 0, "header: `%v` should not be set", k)
+ } else {
+ assert.NotEqual(t, v, header.Get(k), "header: `%v` should not be `%v`", k, v)
+ }
+ }
+ })
+ }
+}
+
+func Test_allowOriginScheme(t *testing.T) {
+ tests := []struct {
+ domain, pattern string
+ expected bool
+ }{
+ {
+ domain: "http://example.com",
+ pattern: "http://example.com",
+ expected: true,
+ },
+ {
+ domain: "https://example.com",
+ pattern: "https://example.com",
+ expected: true,
+ },
+ {
+ domain: "http://example.com",
+ pattern: "https://example.com",
+ expected: false,
+ },
+ {
+ domain: "https://example.com",
+ pattern: "http://example.com",
+ expected: false,
+ },
+ }
+
+ e := echo.New()
+ for _, tt := range tests {
+ req := httptest.NewRequest(http.MethodOptions, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ req.Header.Set(echo.HeaderOrigin, tt.domain)
+ cors := CORSWithConfig(CORSConfig{
+ AllowOrigins: []string{tt.pattern},
+ })
+ h := cors(func(c *echo.Context) error { return echo.ErrNotFound })
+ h(c)
+
+ if tt.expected {
+ assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ } else {
+ assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
+ }
+ }
+}
+
+func TestCORSWithConfig_AllowMethods(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenAllowOrigins []string
+ givenAllowMethods []string
+ whenAllowContextKey string
+ whenOrigin string
+ expectAllow string
+ expectAccessControlAllowMethods string
+ }{
+ {
+ name: "custom AllowMethods, preflight, no origin, sets only allow header from context key",
+ givenAllowOrigins: []string{"*"},
+ givenAllowMethods: []string{http.MethodGet, http.MethodHead},
+ whenAllowContextKey: "OPTIONS, GET",
+ whenOrigin: "",
+ expectAllow: "OPTIONS, GET",
+ },
+ {
+ name: "default AllowMethods, preflight, no origin, no allow header in context key and in response",
+ givenAllowOrigins: []string{"*"},
+ givenAllowMethods: nil,
+ whenAllowContextKey: "",
+ whenOrigin: "",
+ expectAllow: "",
+ },
+ {
+ name: "custom AllowMethods, preflight, existing origin, sets both headers different values",
+ givenAllowOrigins: []string{"*"},
+ givenAllowMethods: []string{http.MethodGet, http.MethodHead},
+ whenAllowContextKey: "OPTIONS, GET",
+ whenOrigin: "http://google.com",
+ expectAllow: "OPTIONS, GET",
+ expectAccessControlAllowMethods: "GET,HEAD",
+ },
+ {
+ name: "default AllowMethods, preflight, existing origin, sets both headers",
+ givenAllowOrigins: []string{"*"},
+ givenAllowMethods: nil,
+ whenAllowContextKey: "OPTIONS, GET",
+ whenOrigin: "http://google.com",
+ expectAllow: "OPTIONS, GET",
+ expectAccessControlAllowMethods: "OPTIONS, GET",
+ },
+ {
+ name: "default AllowMethods, preflight, existing origin, no allows, sets only CORS allow methods",
+ givenAllowOrigins: []string{"*"},
+ givenAllowMethods: nil,
+ whenAllowContextKey: "",
+ whenOrigin: "http://google.com",
+ expectAllow: "",
+ expectAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+ e.GET("/test", func(c *echo.Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+
+ cors := CORSWithConfig(CORSConfig{
+ AllowOrigins: tc.givenAllowOrigins,
+ AllowMethods: tc.givenAllowMethods,
+ })
+
+ req := httptest.NewRequest(http.MethodOptions, "/test", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ req.Header.Set(echo.HeaderOrigin, tc.whenOrigin)
+ if tc.whenAllowContextKey != "" {
+ c.Set(echo.ContextKeyHeaderAllow, tc.whenAllowContextKey)
+ }
+
+ h := cors(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+ h(c)
+
+ assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow))
+ assert.Equal(t, tc.expectAccessControlAllowMethods, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
+ })
+ }
+}
+
+func TestCorsHeaders(t *testing.T) {
+ tests := []struct {
+ name string
+ originDomain string
+ method string
+ allowedOrigin string
+ expected bool
+ expectStatus int
+ expectAllowHeader string
+ }{
+ {
+ name: "non-preflight request, allow any origin, missing origin header = no CORS logic done",
+ originDomain: "",
+ allowedOrigin: "*",
+ method: http.MethodGet,
+ expected: false,
+ expectStatus: http.StatusOK,
+ },
+ {
+ name: "non-preflight request, allow any origin, specific origin domain",
+ originDomain: "http://example.com",
+ allowedOrigin: "*",
+ method: http.MethodGet,
+ expected: true,
+ expectStatus: http.StatusOK,
+ },
+ {
+ name: "non-preflight request, allow specific origin, missing origin header = no CORS logic done",
+ originDomain: "", // Request does not have Origin header
+ allowedOrigin: "http://example.com",
+ method: http.MethodGet,
+ expected: false,
+ expectStatus: http.StatusOK,
+ },
+ {
+ name: "non-preflight request, allow specific origin, different origin header = CORS logic failure",
+ originDomain: "http://bar.com",
+ allowedOrigin: "http://example.com",
+ method: http.MethodGet,
+ expected: false,
+ expectStatus: http.StatusOK,
+ },
+ {
+ name: "non-preflight request, allow specific origin, matching origin header = CORS logic done",
+ originDomain: "http://example.com",
+ allowedOrigin: "http://example.com",
+ method: http.MethodGet,
+ expected: true,
+ expectStatus: http.StatusOK,
+ },
+ {
+ name: "preflight, allow any origin, missing origin header = no CORS logic done",
+ originDomain: "", // Request does not have Origin header
+ allowedOrigin: "*",
+ method: http.MethodOptions,
+ expected: false,
+ expectStatus: http.StatusNoContent,
+ expectAllowHeader: "OPTIONS, GET, POST",
+ },
+ {
+ name: "preflight, allow any origin, existing origin header = CORS logic done",
+ originDomain: "http://example.com",
+ allowedOrigin: "*",
+ method: http.MethodOptions,
+ expected: true,
+ expectStatus: http.StatusNoContent,
+ expectAllowHeader: "OPTIONS, GET, POST",
+ },
+ {
+ name: "preflight, allow any origin, missing origin header = no CORS logic done",
+ originDomain: "", // Request does not have Origin header
+ allowedOrigin: "http://example.com",
+ method: http.MethodOptions,
+ expected: false,
+ expectStatus: http.StatusNoContent,
+ expectAllowHeader: "OPTIONS, GET, POST",
+ },
+ {
+ name: "preflight, allow specific origin, different origin header = no CORS logic done",
+ originDomain: "http://bar.com",
+ allowedOrigin: "http://example.com",
+ method: http.MethodOptions,
+ expected: false,
+ expectStatus: http.StatusNoContent,
+ expectAllowHeader: "OPTIONS, GET, POST",
+ },
+ {
+ name: "preflight, allow specific origin, matching origin header = CORS logic done",
+ originDomain: "http://example.com",
+ allowedOrigin: "http://example.com",
+ method: http.MethodOptions,
+ expected: true,
+ expectStatus: http.StatusNoContent,
+ expectAllowHeader: "OPTIONS, GET, POST",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ e.Use(CORSWithConfig(CORSConfig{
+ AllowOrigins: []string{tc.allowedOrigin},
+ //AllowCredentials: true,
+ //MaxAge: 3600,
+ }))
+
+ e.GET("/", func(c *echo.Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+ e.POST("/", func(c *echo.Context) error {
+ return c.String(http.StatusCreated, "OK")
+ })
+
+ req := httptest.NewRequest(tc.method, "/", nil)
+ rec := httptest.NewRecorder()
+
+ if tc.originDomain != "" {
+ req.Header.Set(echo.HeaderOrigin, tc.originDomain)
+ }
+
+ // we run through whole Echo handler chain to see how CORS works with Router OPTIONS handler
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
+ assert.Equal(t, tc.expectAllowHeader, rec.Header().Get(echo.HeaderAllow))
+ assert.Equal(t, tc.expectStatus, rec.Code)
+
+ expectedAllowOrigin := ""
+ if tc.allowedOrigin == "*" {
+ expectedAllowOrigin = "*"
+ } else {
+ expectedAllowOrigin = tc.originDomain
+ }
+ switch {
+ case tc.expected && tc.method == http.MethodOptions:
+ assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
+ assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+
+ assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
+
+ case tc.expected && tc.method == http.MethodGet:
+ assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
+ default:
+ assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
+ assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
+ }
+ })
+
+ }
+}
+
+func Test_allowOriginFunc(t *testing.T) {
+ returnTrue := func(c *echo.Context, origin string) (string, bool, error) {
+ return origin, true, nil
+ }
+ returnFalse := func(c *echo.Context, origin string) (string, bool, error) {
+ return origin, false, nil
+ }
+ returnError := func(c *echo.Context, origin string) (string, bool, error) {
+ return origin, true, errors.New("this is a test error")
+ }
+
+ allowOriginFuncs := []func(c *echo.Context, origin string) (string, bool, error){
+ returnTrue,
+ returnFalse,
+ returnError,
+ }
+
+ const origin = "http://example.com"
+
+ e := echo.New()
+ for _, allowOriginFunc := range allowOriginFuncs {
+ req := httptest.NewRequest(http.MethodOptions, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ req.Header.Set(echo.HeaderOrigin, origin)
+ cors, err := CORSConfig{UnsafeAllowOriginFunc: allowOriginFunc}.ToMiddleware()
+ assert.NoError(t, err)
+
+ h := cors(func(c *echo.Context) error { return echo.ErrNotFound })
+ err = h(c)
+
+ allowedOrigin, allowed, expectedErr := allowOriginFunc(c, origin)
+ if expectedErr != nil {
+ assert.Equal(t, expectedErr, err)
+ assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ continue
+ }
+
+ if allowed {
+ assert.Equal(t, allowedOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ } else {
+ assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ }
+ }
}
diff --git a/middleware/csrf.go b/middleware/csrf.go
index 09a66bb64..33757b760 100644
--- a/middleware/csrf.go
+++ b/middleware/csrf.go
@@ -1,91 +1,120 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
"crypto/subtle"
- "errors"
"net/http"
+ "slices"
"strings"
"time"
- "github.com/labstack/echo/v4"
- "github.com/labstack/gommon/random"
+ "github.com/labstack/echo/v5"
)
-type (
- // CSRFConfig defines the config for CSRF middleware.
- CSRFConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
-
- // TokenLength is the length of the generated token.
- TokenLength uint8 `yaml:"token_length"`
- // Optional. Default value 32.
-
- // TokenLookup is a string in the form of ":" that is used
- // to extract token from the request.
- // Optional. Default value "header:X-CSRF-Token".
- // Possible values:
- // - "header:"
- // - "form:"
- // - "query:"
- TokenLookup string `yaml:"token_lookup"`
-
- // Context key to store generated CSRF token into context.
- // Optional. Default value "csrf".
- ContextKey string `yaml:"context_key"`
-
- // Name of the CSRF cookie. This cookie will store CSRF token.
- // Optional. Default value "csrf".
- CookieName string `yaml:"cookie_name"`
-
- // Domain of the CSRF cookie.
- // Optional. Default value none.
- CookieDomain string `yaml:"cookie_domain"`
-
- // Path of the CSRF cookie.
- // Optional. Default value none.
- CookiePath string `yaml:"cookie_path"`
-
- // Max age (in seconds) of the CSRF cookie.
- // Optional. Default value 86400 (24hr).
- CookieMaxAge int `yaml:"cookie_max_age"`
-
- // Indicates if CSRF cookie is secure.
- // Optional. Default value false.
- CookieSecure bool `yaml:"cookie_secure"`
-
- // Indicates if CSRF cookie is HTTP only.
- // Optional. Default value false.
- CookieHTTPOnly bool `yaml:"cookie_http_only"`
- }
-
- // csrfTokenExtractor defines a function that takes `echo.Context` and returns
- // either a token or an error.
- csrfTokenExtractor func(echo.Context) (string, error)
-)
+// CSRFConfig defines the config for CSRF middleware.
+type CSRFConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+ // TrustedOrigin permits any request with `Sec-Fetch-Site` header whose `Origin` header
+ // exactly matches the specified value.
+ // Values should be formated as Origin header "scheme://host[:port]".
+ //
+ // See [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin
+ // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers
+ TrustedOrigins []string
-var (
- // DefaultCSRFConfig is the default CSRF middleware config.
- DefaultCSRFConfig = CSRFConfig{
- Skipper: DefaultSkipper,
- TokenLength: 32,
- TokenLookup: "header:" + echo.HeaderXCSRFToken,
- ContextKey: "csrf",
- CookieName: "_csrf",
- CookieMaxAge: 86400,
- }
-)
+ // AllowSecFetchSameSite allows custom behaviour for `Sec-Fetch-Site` requests that are about to
+ // fail with CRSF error, to be allowed or replaced with custom error.
+ // This function applies to `Sec-Fetch-Site` values:
+ // - `same-site` same registrable domain (subdomain and/or different port)
+ // - `cross-site` request originates from different site
+ // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers
+ AllowSecFetchSiteFunc func(c *echo.Context) (bool, error)
+
+ // TokenLength is the length of the generated token.
+ TokenLength uint8
+ // Optional. Default value 32.
+
+ // TokenLookup is a string in the form of ":" or ":,:" that is used
+ // to extract token from the request.
+ // Optional. Default value "header:X-CSRF-Token".
+ // Possible values:
+ // - "header:" or "header::"
+ // - "query:"
+ // - "form:"
+ // Multiple sources example:
+ // - "header:X-CSRF-Token,query:csrf"
+ TokenLookup string `yaml:"token_lookup"`
+
+ // Generator defines a function to generate token.
+ // Optional. Defaults tp randomString(TokenLength).
+ Generator func() string
+
+ // Context key to store generated CSRF token into context.
+ // Optional. Default value "csrf".
+ ContextKey string
+
+ // Name of the CSRF cookie. This cookie will store CSRF token.
+ // Optional. Default value "csrf".
+ CookieName string
+
+ // Domain of the CSRF cookie.
+ // Optional. Default value none.
+ CookieDomain string
+
+ // Path of the CSRF cookie.
+ // Optional. Default value none.
+ CookiePath string
+
+ // Max age (in seconds) of the CSRF cookie.
+ // Optional. Default value 86400 (24hr).
+ CookieMaxAge int
+
+ // Indicates if CSRF cookie is secure.
+ // Optional. Default value false.
+ CookieSecure bool
+
+ // Indicates if CSRF cookie is HTTP only.
+ // Optional. Default value false.
+ CookieHTTPOnly bool
+
+ // Indicates SameSite mode of the CSRF cookie.
+ // Optional. Default value SameSiteDefaultMode.
+ CookieSameSite http.SameSite
+
+ // ErrorHandler defines a function which is executed for returning custom errors.
+ ErrorHandler func(c *echo.Context, err error) error
+}
+
+// ErrCSRFInvalid is returned when CSRF check fails
+var ErrCSRFInvalid = &echo.HTTPError{Code: http.StatusForbidden, Message: "invalid csrf token"}
+
+// DefaultCSRFConfig is the default CSRF middleware config.
+var DefaultCSRFConfig = CSRFConfig{
+ Skipper: DefaultSkipper,
+ TokenLength: 32,
+ TokenLookup: "header:" + echo.HeaderXCSRFToken,
+ ContextKey: "csrf",
+ CookieName: "_csrf",
+ CookieMaxAge: 86400,
+ CookieSameSite: http.SameSiteDefaultMode,
+}
// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
func CSRF() echo.MiddlewareFunc {
- c := DefaultCSRFConfig
- return CSRFWithConfig(c)
+ return CSRFWithConfig(DefaultCSRFConfig)
}
-// CSRFWithConfig returns a CSRF middleware with config.
-// See `CSRF()`.
+// CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration.
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration
+func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultCSRFConfig.Skipper
@@ -93,6 +122,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
if config.TokenLength == 0 {
config.TokenLength = DefaultCSRFConfig.TokenLength
}
+ if config.Generator == nil {
+ config.Generator = createRandomStringGenerator(config.TokenLength)
+ }
if config.TokenLookup == "" {
config.TokenLookup = DefaultCSRFConfig.TokenLookup
}
@@ -105,45 +137,79 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
if config.CookieMaxAge == 0 {
config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge
}
+ if config.CookieSameSite == http.SameSiteNoneMode {
+ config.CookieSecure = true
+ }
+ if len(config.TrustedOrigins) > 0 {
+ if err := validateOrigins(config.TrustedOrigins, "trusted origin"); err != nil {
+ return nil, err
+ }
+ config.TrustedOrigins = append([]string(nil), config.TrustedOrigins...)
+ }
- // Initialize
- parts := strings.Split(config.TokenLookup, ":")
- extractor := csrfTokenFromHeader(parts[1])
- switch parts[0] {
- case "form":
- extractor = csrfTokenFromForm(parts[1])
- case "query":
- extractor = csrfTokenFromQuery(parts[1])
+ extractors, cErr := createExtractors(config.TokenLookup, 1)
+ if cErr != nil {
+ return nil, cErr
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) error {
+ return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
- req := c.Request()
- k, err := c.Cookie(config.CookieName)
- token := ""
-
- // Generate token
+ // use the `Sec-Fetch-Site` header as part of a modern approach to CSRF protection
+ allow, err := config.checkSecFetchSiteRequest(c)
if err != nil {
- token = random.String(config.TokenLength)
+ return err
+ }
+ if allow {
+ return next(c)
+ }
+
+ // Fallback to legacy token based CSRF protection
+
+ token := ""
+ if k, err := c.Cookie(config.CookieName); err != nil {
+ token = config.Generator() // Generate token
} else {
- // Reuse token
- token = k.Value
+ token = k.Value // Reuse token
}
- switch req.Method {
+ switch c.Request().Method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
default:
// Validate token only for requests which are not defined as 'safe' by RFC7231
- clientToken, err := extractor(c)
- if err != nil {
- return echo.NewHTTPError(http.StatusBadRequest, err.Error())
+ var lastExtractorErr error
+ var lastTokenErr error
+ outer:
+ for _, extractor := range extractors {
+ clientTokens, _, err := extractor(c)
+ if err != nil {
+ lastExtractorErr = err
+ continue
+ }
+
+ for _, clientToken := range clientTokens {
+ if validateCSRFToken(token, clientToken) {
+ lastTokenErr = nil
+ lastExtractorErr = nil
+ break outer
+ }
+ lastTokenErr = ErrCSRFInvalid
+ }
}
- if !validateCSRFToken(token, clientToken) {
- return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
+ var finalErr error
+ if lastTokenErr != nil {
+ finalErr = lastTokenErr
+ } else if lastExtractorErr != nil {
+ finalErr = echo.ErrBadRequest.Wrap(lastExtractorErr)
+ }
+ if finalErr != nil {
+ if config.ErrorHandler != nil {
+ return config.ErrorHandler(c, finalErr)
+ }
+ return finalErr
}
}
@@ -157,6 +223,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
if config.CookieDomain != "" {
cookie.Domain = config.CookieDomain
}
+ if config.CookieSameSite != http.SameSiteDefaultMode {
+ cookie.SameSite = config.CookieSameSite
+ }
cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
cookie.Secure = config.CookieSecure
cookie.HttpOnly = config.CookieHTTPOnly
@@ -170,41 +239,55 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return next(c)
}
- }
+ }, nil
}
-// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
-// provided request header.
-func csrfTokenFromHeader(header string) csrfTokenExtractor {
- return func(c echo.Context) (string, error) {
- return c.Request().Header.Get(header), nil
- }
+func validateCSRFToken(token, clientToken string) bool {
+ return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
}
-// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
-// provided form parameter.
-func csrfTokenFromForm(param string) csrfTokenExtractor {
- return func(c echo.Context) (string, error) {
- token := c.FormValue(param)
- if token == "" {
- return "", errors.New("missing csrf token in the form parameter")
- }
- return token, nil
+var safeMethods = []string{http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace}
+
+func (config CSRFConfig) checkSecFetchSiteRequest(c *echo.Context) (bool, error) {
+ // https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers
+ // Sec-Fetch-Site values are:
+ // - `same-origin` exact origin match - allow always
+ // - `same-site` same registrable domain (subdomain and/or different port) - block, unless explicitly trusted
+ // - `cross-site` request originates from different site - block, unless explicitly trusted
+ // - `none` direct navigation (URL bar, bookmark) - allow always
+ secFetchSite := c.Request().Header.Get(echo.HeaderSecFetchSite)
+ if secFetchSite == "" {
+ return false, nil
}
-}
-// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
-// provided query parameter.
-func csrfTokenFromQuery(param string) csrfTokenExtractor {
- return func(c echo.Context) (string, error) {
- token := c.QueryParam(param)
- if token == "" {
- return "", errors.New("missing csrf token in the query string")
+ if len(config.TrustedOrigins) > 0 {
+ // trusted sites ala OAuth callbacks etc. should be let through
+ origin := c.Request().Header.Get(echo.HeaderOrigin)
+ if origin != "" {
+ for _, trustedOrigin := range config.TrustedOrigins {
+ if strings.EqualFold(origin, trustedOrigin) {
+ return true, nil
+ }
+ }
}
- return token, nil
}
-}
+ isSafe := slices.Contains(safeMethods, c.Request().Method)
+ if !isSafe { // for state-changing request check SecFetchSite value
+ isSafe = secFetchSite == "same-origin" || secFetchSite == "none"
+ }
-func validateCSRFToken(token, clientToken string) bool {
- return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
+ if isSafe {
+ return true, nil
+ }
+ // we are here when request is state-changing and `cross-site` or `same-site`
+
+ // Note: if you want to allow `same-site` use config.TrustedOrigins or `config.AllowSecFetchSiteFunc`
+ if config.AllowSecFetchSiteFunc != nil {
+ return config.AllowSecFetchSiteFunc(c)
+ }
+
+ if secFetchSite == "same-site" {
+ return false, echo.NewHTTPError(http.StatusForbidden, "same-site request blocked by CSRF")
+ }
+ return false, echo.NewHTTPError(http.StatusForbidden, "cross-site request blocked by CSRF")
}
diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go
index efb4dd1d2..ddecc10e3 100644
--- a/middleware/csrf_test.go
+++ b/middleware/csrf_test.go
@@ -1,26 +1,358 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
+ "cmp"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
- "github.com/labstack/echo/v4"
- "github.com/labstack/gommon/random"
+ "github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
+func TestCSRF_tokenExtractors(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenTokenLookup string
+ whenCookieName string
+ givenCSRFCookie string
+ givenMethod string
+ givenQueryTokens map[string][]string
+ givenFormTokens map[string][]string
+ givenHeaderTokens map[string][]string
+ expectError string
+ expectToMiddlewareError string
+ }{
+ {
+ name: "ok, multiple token lookups sources, succeeds on last one",
+ whenTokenLookup: "header:X-CSRF-Token,form:csrf",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenHeaderTokens: map[string][]string{
+ echo.HeaderXCSRFToken: {"invalid_token"},
+ },
+ givenFormTokens: map[string][]string{
+ "csrf": {"token"},
+ },
+ },
+ {
+ name: "ok, token from POST form",
+ whenTokenLookup: "form:csrf",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenFormTokens: map[string][]string{
+ "csrf": {"token"},
+ },
+ },
+ {
+ name: "ok, token from POST form, second token passes",
+ whenTokenLookup: "form:csrf",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenFormTokens: map[string][]string{
+ "csrf": {"invalid", "token"},
+ },
+ expectError: "code=403, message=invalid csrf token",
+ },
+ {
+ name: "nok, invalid token from POST form",
+ whenTokenLookup: "form:csrf",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenFormTokens: map[string][]string{
+ "csrf": {"invalid_token"},
+ },
+ expectError: "code=403, message=invalid csrf token",
+ },
+ {
+ name: "nok, missing token from POST form",
+ whenTokenLookup: "form:csrf",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenFormTokens: map[string][]string{},
+ expectError: "code=400, message=Bad Request, err=missing value in the form",
+ },
+ {
+ name: "ok, token from POST header",
+ whenTokenLookup: "", // will use defaults
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenHeaderTokens: map[string][]string{
+ echo.HeaderXCSRFToken: {"token"},
+ },
+ },
+ {
+ name: "nok, token from POST header, tokens limited to 1, second token would pass",
+ whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenHeaderTokens: map[string][]string{
+ echo.HeaderXCSRFToken: {"invalid", "token"},
+ },
+ expectError: "code=403, message=invalid csrf token",
+ },
+ {
+ name: "nok, invalid token from POST header",
+ whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenHeaderTokens: map[string][]string{
+ echo.HeaderXCSRFToken: {"invalid_token"},
+ },
+ expectError: "code=403, message=invalid csrf token",
+ },
+ {
+ name: "nok, missing token from POST header",
+ whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenHeaderTokens: map[string][]string{},
+ expectError: "code=400, message=Bad Request, err=missing value in request header",
+ },
+ {
+ name: "ok, token from PUT query param",
+ whenTokenLookup: "query:csrf-param",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPut,
+ givenQueryTokens: map[string][]string{
+ "csrf-param": {"token"},
+ },
+ },
+ {
+ name: "nok, token from PUT query form, second token would pass",
+ whenTokenLookup: "query:csrf",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPut,
+ givenQueryTokens: map[string][]string{
+ "csrf": {"invalid", "token"},
+ },
+ expectError: "code=403, message=invalid csrf token",
+ },
+ {
+ name: "nok, invalid token from PUT query form",
+ whenTokenLookup: "query:csrf",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPut,
+ givenQueryTokens: map[string][]string{
+ "csrf": {"invalid_token"},
+ },
+ expectError: "code=403, message=invalid csrf token",
+ },
+ {
+ name: "nok, missing token from PUT query form",
+ whenTokenLookup: "query:csrf",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPut,
+ givenQueryTokens: map[string][]string{},
+ expectError: "code=400, message=Bad Request, err=missing value in the query string",
+ },
+ {
+ name: "nok, invalid TokenLookup",
+ whenTokenLookup: "q",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPut,
+ givenQueryTokens: map[string][]string{},
+ expectToMiddlewareError: "extractor source for lookup could not be split into needed parts: q",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ q := make(url.Values)
+ for queryParam, values := range tc.givenQueryTokens {
+ for _, v := range values {
+ q.Add(queryParam, v)
+ }
+ }
+
+ f := make(url.Values)
+ for formKey, values := range tc.givenFormTokens {
+ for _, v := range values {
+ f.Add(formKey, v)
+ }
+ }
+
+ var req *http.Request
+ switch tc.givenMethod {
+ case http.MethodGet:
+ req = httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil)
+ case http.MethodPost, http.MethodPut:
+ req = httptest.NewRequest(http.MethodPost, "/?"+q.Encode(), strings.NewReader(f.Encode()))
+ req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
+ }
+
+ for header, values := range tc.givenHeaderTokens {
+ for _, v := range values {
+ req.Header.Add(header, v)
+ }
+ }
+
+ if tc.givenCSRFCookie != "" {
+ req.Header.Set(echo.HeaderCookie, "_csrf="+tc.givenCSRFCookie)
+ }
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ config := CSRFConfig{
+ TokenLookup: tc.whenTokenLookup,
+ CookieName: tc.whenCookieName,
+ }
+ csrf, err := config.ToMiddleware()
+ if tc.expectToMiddlewareError != "" {
+ assert.EqualError(t, err, tc.expectToMiddlewareError)
+ return
+ } else if err != nil {
+ assert.NoError(t, err)
+ }
+
+ h := csrf(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ err = h(c)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestCSRFWithConfig(t *testing.T) {
+ token := randomString(16)
+
+ var testCases = []struct {
+ name string
+ givenConfig *CSRFConfig
+ whenMethod string
+ whenHeaders map[string]string
+ expectEmptyBody bool
+ expectMWError string
+ expectCookieContains string
+ expectErr string
+ }{
+ {
+ name: "ok, GET",
+ whenMethod: http.MethodGet,
+ expectCookieContains: "_csrf",
+ },
+ {
+ name: "ok, POST valid token",
+ whenHeaders: map[string]string{
+ echo.HeaderCookie: "_csrf=" + token,
+ echo.HeaderXCSRFToken: token,
+ },
+ whenMethod: http.MethodPost,
+ expectCookieContains: "_csrf",
+ },
+ {
+ name: "nok, POST without token",
+ whenMethod: http.MethodPost,
+ expectEmptyBody: true,
+ expectErr: `code=400, message=Bad Request, err=missing value in request header`,
+ },
+ {
+ name: "nok, POST empty token",
+ whenHeaders: map[string]string{echo.HeaderXCSRFToken: ""},
+ whenMethod: http.MethodPost,
+ expectEmptyBody: true,
+ expectErr: `code=403, message=invalid csrf token`,
+ },
+ {
+ name: "nok, invalid trusted origin in Config",
+ givenConfig: &CSRFConfig{
+ TrustedOrigins: []string{"http://example.com", "invalid"},
+ },
+ expectMWError: `trusted origin is missing scheme or host: invalid`,
+ },
+ {
+ name: "ok, TokenLength",
+ givenConfig: &CSRFConfig{
+ TokenLength: 16,
+ },
+ whenMethod: http.MethodGet,
+ expectCookieContains: "_csrf",
+ },
+ {
+ name: "ok, unsafe method + SecFetchSite=same-origin passes",
+ whenHeaders: map[string]string{
+ echo.HeaderSecFetchSite: "same-origin",
+ },
+ whenMethod: http.MethodPost,
+ },
+ {
+ name: "nok, unsafe method + SecFetchSite=same-cross blocked",
+ whenHeaders: map[string]string{
+ echo.HeaderSecFetchSite: "same-cross",
+ },
+ whenMethod: http.MethodPost,
+ expectEmptyBody: true,
+ expectErr: `code=403, message=cross-site request blocked by CSRF`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(cmp.Or(tc.whenMethod, http.MethodPost), "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ for key, value := range tc.whenHeaders {
+ req.Header.Set(key, value)
+ }
+
+ config := CSRFConfig{}
+ if tc.givenConfig != nil {
+ config = *tc.givenConfig
+ }
+ mw, err := config.ToMiddleware()
+ if tc.expectMWError != "" {
+ assert.EqualError(t, err, tc.expectMWError)
+ return
+ }
+ assert.NoError(t, err)
+
+ h := mw(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ err = h(c)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ expect := "test"
+ if tc.expectEmptyBody {
+ expect = ""
+ }
+ assert.Equal(t, expect, rec.Body.String())
+
+ if tc.expectCookieContains != "" {
+ assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), tc.expectCookieContains)
+ }
+ })
+ }
+}
+
func TestCSRF(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- csrf := CSRFWithConfig(CSRFConfig{
- TokenLength: 16,
- })
- h := csrf(func(c echo.Context) error {
+ csrf := CSRF()
+ h := csrf(func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
})
@@ -28,56 +360,495 @@ func TestCSRF(t *testing.T) {
h(c)
assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf")
- // Without CSRF cookie
- req = httptest.NewRequest(http.MethodPost, "/", nil)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
- assert.Error(t, h(c))
-
- // Empty/invalid CSRF token
- req = httptest.NewRequest(http.MethodPost, "/", nil)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
- req.Header.Set(echo.HeaderXCSRFToken, "")
- assert.Error(t, h(c))
-
- // Valid CSRF token
- token := random.String(16)
- req.Header.Set(echo.HeaderCookie, "_csrf="+token)
- req.Header.Set(echo.HeaderXCSRFToken, token)
- if assert.NoError(t, h(c)) {
- assert.Equal(t, http.StatusOK, rec.Code)
- }
}
-func TestCSRFTokenFromForm(t *testing.T) {
- f := make(url.Values)
- f.Set("csrf", "token")
+func TestCSRFSetSameSiteMode(t *testing.T) {
e := echo.New()
- req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
- req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
- c := e.NewContext(req, nil)
- token, err := csrfTokenFromForm("csrf")(c)
- if assert.NoError(t, err) {
- assert.Equal(t, "token", token)
- }
- _, err = csrfTokenFromForm("invalid")(c)
- assert.Error(t, err)
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ csrf := CSRFWithConfig(CSRFConfig{
+ CookieSameSite: http.SameSiteStrictMode,
+ })
+
+ h := csrf(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ r := h(c)
+ assert.NoError(t, r)
+ assert.Regexp(t, "SameSite=Strict", rec.Header()["Set-Cookie"])
}
-func TestCSRFTokenFromQuery(t *testing.T) {
- q := make(url.Values)
- q.Set("csrf", "token")
+func TestCSRFWithoutSameSiteMode(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
- req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
- req.URL.RawQuery = q.Encode()
- c := e.NewContext(req, nil)
- token, err := csrfTokenFromQuery("csrf")(c)
- if assert.NoError(t, err) {
- assert.Equal(t, "token", token)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ csrf := CSRFWithConfig(CSRFConfig{})
+
+ h := csrf(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ r := h(c)
+ assert.NoError(t, r)
+ assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"])
+}
+
+func TestCSRFWithSameSiteDefaultMode(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ csrf := CSRFWithConfig(CSRFConfig{
+ CookieSameSite: http.SameSiteDefaultMode,
+ })
+
+ h := csrf(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ r := h(c)
+ assert.NoError(t, r)
+ assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"])
+}
+
+func TestCSRFWithSameSiteModeNone(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ csrf, err := CSRFConfig{
+ CookieSameSite: http.SameSiteNoneMode,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ h := csrf(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ r := h(c)
+ assert.NoError(t, r)
+ assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"])
+ assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"])
+}
+
+func TestCSRFConfig_skipper(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenSkip bool
+ expectCookies int
+ }{
+ {
+ name: "do skip",
+ whenSkip: true,
+ expectCookies: 0,
+ },
+ {
+ name: "do not skip",
+ whenSkip: false,
+ expectCookies: 1,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ csrf := CSRFWithConfig(CSRFConfig{
+ Skipper: func(c *echo.Context) bool {
+ return tc.whenSkip
+ },
+ })
+
+ h := csrf(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ r := h(c)
+ assert.NoError(t, r)
+ cookie := rec.Header()["Set-Cookie"]
+ assert.Len(t, cookie, tc.expectCookies)
+ })
+ }
+}
+
+func TestCSRFErrorHandling(t *testing.T) {
+ cfg := CSRFConfig{
+ ErrorHandler: func(c *echo.Context, err error) error {
+ return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed")
+ },
+ }
+
+ e := echo.New()
+ e.POST("/", func(c *echo.Context) error {
+ return c.String(http.StatusNotImplemented, "should not end up here")
+ })
+
+ e.Use(CSRFWithConfig(cfg))
+
+ req := httptest.NewRequest(http.MethodPost, "/", nil)
+ res := httptest.NewRecorder()
+ e.ServeHTTP(res, req)
+
+ assert.Equal(t, http.StatusTeapot, res.Code)
+ assert.Equal(t, "{\"message\":\"error_handler_executed\"}\n", res.Body.String())
+}
+
+func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenConfig CSRFConfig
+ whenMethod string
+ whenSecFetchSite string
+ whenOrigin string
+ expectAllow bool
+ expectErr string
+ }{
+ {
+ name: "ok, unsafe POST, no SecFetchSite is not blocked",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "",
+ expectAllow: false, // should fall back to token CSRF
+ },
+ {
+ name: "ok, safe GET + same-origin passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodGet,
+ whenSecFetchSite: "same-origin",
+ expectAllow: true,
+ },
+ {
+ name: "ok, safe GET + none passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodGet,
+ whenSecFetchSite: "none",
+ expectAllow: true,
+ },
+ {
+ name: "ok, safe GET + same-site passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodGet,
+ whenSecFetchSite: "same-site",
+ expectAllow: true,
+ },
+ {
+ name: "ok, safe GET + cross-site passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodGet,
+ whenSecFetchSite: "cross-site",
+ expectAllow: true,
+ },
+ {
+ name: "nok, unsafe POST + cross-site is blocked",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ expectAllow: false,
+ expectErr: `code=403, message=cross-site request blocked by CSRF`,
+ },
+ {
+ name: "nok, unsafe POST + same-site is blocked",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "same-site",
+ expectAllow: false,
+ expectErr: `code=403, message=same-site request blocked by CSRF`,
+ },
+ {
+ name: "ok, unsafe POST + same-origin passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "same-origin",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe POST + none passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "none",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe PUT + same-origin passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPut,
+ whenSecFetchSite: "same-origin",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe PUT + none passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPut,
+ whenSecFetchSite: "none",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe DELETE + same-origin passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodDelete,
+ whenSecFetchSite: "same-origin",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe PATCH + same-origin passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPatch,
+ whenSecFetchSite: "same-origin",
+ expectAllow: true,
+ },
+ {
+ name: "nok, unsafe PUT + cross-site is blocked",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPut,
+ whenSecFetchSite: "cross-site",
+ expectAllow: false,
+ expectErr: `code=403, message=cross-site request blocked by CSRF`,
+ },
+ {
+ name: "nok, unsafe PUT + same-site is blocked",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPut,
+ whenSecFetchSite: "same-site",
+ expectAllow: false,
+ expectErr: `code=403, message=same-site request blocked by CSRF`,
+ },
+ {
+ name: "nok, unsafe DELETE + cross-site is blocked",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodDelete,
+ whenSecFetchSite: "cross-site",
+ expectAllow: false,
+ expectErr: `code=403, message=cross-site request blocked by CSRF`,
+ },
+ {
+ name: "nok, unsafe DELETE + same-site is blocked",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodDelete,
+ whenSecFetchSite: "same-site",
+ expectAllow: false,
+ expectErr: `code=403, message=same-site request blocked by CSRF`,
+ },
+ {
+ name: "nok, unsafe PATCH + cross-site is blocked",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPatch,
+ whenSecFetchSite: "cross-site",
+ expectAllow: false,
+ expectErr: `code=403, message=cross-site request blocked by CSRF`,
+ },
+ {
+ name: "ok, safe HEAD + same-origin passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodHead,
+ whenSecFetchSite: "same-origin",
+ expectAllow: true,
+ },
+ {
+ name: "ok, safe HEAD + cross-site passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodHead,
+ whenSecFetchSite: "cross-site",
+ expectAllow: true,
+ },
+ {
+ name: "ok, safe OPTIONS + cross-site passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodOptions,
+ whenSecFetchSite: "cross-site",
+ expectAllow: true,
+ },
+ {
+ name: "ok, safe TRACE + cross-site passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodTrace,
+ whenSecFetchSite: "cross-site",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe POST + cross-site + matching trusted origin passes",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://trusted.example.com"},
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ whenOrigin: "https://trusted.example.com",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe POST + same-site + matching trusted origin passes",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://trusted.example.com"},
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "same-site",
+ whenOrigin: "https://trusted.example.com",
+ expectAllow: true,
+ },
+ {
+ name: "nok, unsafe POST + cross-site + non-matching origin is blocked",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://trusted.example.com"},
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ whenOrigin: "https://evil.example.com",
+ expectAllow: false,
+ expectErr: `code=403, message=cross-site request blocked by CSRF`,
+ },
+ {
+ name: "ok, unsafe POST + cross-site + case-insensitive trusted origin match passes",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://trusted.example.com"},
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ whenOrigin: "https://TRUSTED.example.com",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe POST + same-origin + trusted origins configured but not matched passes",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://trusted.example.com"},
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "same-origin",
+ whenOrigin: "https://different.example.com",
+ expectAllow: true,
+ },
+ {
+ name: "nok, unsafe POST + cross-site + empty origin + trusted origins configured is blocked",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://trusted.example.com"},
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ whenOrigin: "",
+ expectAllow: false,
+ expectErr: `code=403, message=cross-site request blocked by CSRF`,
+ },
+ {
+ name: "ok, unsafe POST + cross-site + multiple trusted origins, second one matches",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://first.example.com", "https://second.example.com"},
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ whenOrigin: "https://second.example.com",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe POST + same-site + custom func allows",
+ givenConfig: CSRFConfig{
+ AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ return true, nil
+ },
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "same-site",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe POST + cross-site + custom func allows",
+ givenConfig: CSRFConfig{
+ AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ return true, nil
+ },
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ expectAllow: true,
+ },
+ {
+ name: "nok, unsafe POST + same-site + custom func returns custom error",
+ givenConfig: CSRFConfig{
+ AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ return false, echo.NewHTTPError(http.StatusTeapot, "custom error from func")
+ },
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "same-site",
+ expectAllow: false,
+ expectErr: `code=418, message=custom error from func`,
+ },
+ {
+ name: "nok, unsafe POST + cross-site + custom func returns false with nil error",
+ givenConfig: CSRFConfig{
+ AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ return false, nil
+ },
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ expectAllow: false,
+ expectErr: "", // custom func returns nil error, so no error expected
+ },
+ {
+ name: "nok, unsafe POST + invalid Sec-Fetch-Site value treated as cross-site",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "invalid-value",
+ expectAllow: false,
+ expectErr: `code=403, message=cross-site request blocked by CSRF`,
+ },
+ {
+ name: "ok, unsafe POST + cross-site + trusted origin takes precedence over custom func",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://trusted.example.com"},
+ AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ return false, echo.NewHTTPError(http.StatusTeapot, "should not be called")
+ },
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ whenOrigin: "https://trusted.example.com",
+ expectAllow: true,
+ },
+ {
+ name: "nok, unsafe POST + cross-site + trusted origin not matched, custom func blocks",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://trusted.example.com"},
+ AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ return false, echo.NewHTTPError(http.StatusTeapot, "custom block")
+ },
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ whenOrigin: "https://evil.example.com",
+ expectAllow: false,
+ expectErr: `code=418, message=custom block`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(tc.whenMethod, "/", nil)
+ if tc.whenSecFetchSite != "" {
+ req.Header.Set(echo.HeaderSecFetchSite, tc.whenSecFetchSite)
+ }
+ if tc.whenOrigin != "" {
+ req.Header.Set(echo.HeaderOrigin, tc.whenOrigin)
+ }
+
+ res := httptest.NewRecorder()
+ c := echo.NewContext(req, res)
+
+ allow, err := tc.givenConfig.checkSecFetchSiteRequest(c)
+
+ assert.Equal(t, tc.expectAllow, allow)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
}
- _, err = csrfTokenFromQuery("invalid")(c)
- assert.Error(t, err)
- csrfTokenFromQuery("csrf")
}
diff --git a/middleware/decompress.go b/middleware/decompress.go
new file mode 100644
index 000000000..a384af2ea
--- /dev/null
+++ b/middleware/decompress.go
@@ -0,0 +1,155 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "compress/gzip"
+ "io"
+ "net/http"
+ "sync"
+
+ "github.com/labstack/echo/v5"
+)
+
+// DecompressConfig defines the config for Decompress middleware.
+type DecompressConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
+ GzipDecompressPool Decompressor
+
+ // MaxDecompressedSize limits the maximum size of decompressed request body in bytes.
+ // If the decompressed body exceeds this limit, the middleware returns HTTP 413 error.
+ // This prevents zip bomb attacks where small compressed payloads decompress to huge sizes.
+ // Default: 100 * MB (104,857,600 bytes)
+ // Set to -1 to disable limits (not recommended in production).
+ MaxDecompressedSize int64
+}
+
+// GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
+const GZIPEncoding string = "gzip"
+
+// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
+type Decompressor interface {
+ gzipDecompressPool() sync.Pool
+}
+
+// DefaultGzipDecompressPool is the default implementation of Decompressor interface
+type DefaultGzipDecompressPool struct {
+}
+
+func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
+ return sync.Pool{New: func() any { return new(gzip.Reader) }}
+}
+
+// Decompress decompresses request body based if content encoding type is set to "gzip" with default config
+//
+// SECURITY: By default, this limits decompressed data to 100MB to prevent zip bomb attacks.
+// To customize the limit, use DecompressWithConfig. To disable limits (not recommended in production),
+// set MaxDecompressedSize to -1.
+func Decompress() echo.MiddlewareFunc {
+ return DecompressWithConfig(DecompressConfig{})
+}
+
+// DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration.
+//
+// SECURITY: If MaxDecompressedSize is not set (zero value), it defaults to 100MB to prevent
+// DoS attacks via zip bombs. Set to -1 to explicitly disable limits if needed for your use case.
+func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration
+func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Skipper == nil {
+ config.Skipper = DefaultSkipper
+ }
+ if config.GzipDecompressPool == nil {
+ config.GzipDecompressPool = &DefaultGzipDecompressPool{}
+ }
+ // Apply secure default for decompression limit
+ if config.MaxDecompressedSize == 0 {
+ config.MaxDecompressedSize = 100 * MB
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ pool := config.GzipDecompressPool.gzipDecompressPool()
+
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ if c.Request().Header.Get(echo.HeaderContentEncoding) != GZIPEncoding {
+ return next(c)
+ }
+
+ i := pool.Get()
+ gr, ok := i.(*gzip.Reader)
+ if !ok || gr == nil {
+ if err, isErr := i.(error); isErr {
+ return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
+ }
+ return echo.NewHTTPError(http.StatusInternalServerError, "unexpected type from gzip decompression pool")
+ }
+ defer pool.Put(gr)
+
+ b := c.Request().Body
+ defer b.Close()
+
+ if err := gr.Reset(b); err != nil {
+ if err == io.EOF { //ignore if body is empty
+ return next(c)
+ }
+ return err
+ }
+
+ // only Close gzip reader if it was set to a proper gzip source otherwise it will panic on close.
+ defer gr.Close()
+
+ // Apply decompression size limit to prevent zip bombs
+ if config.MaxDecompressedSize > 0 {
+ c.Request().Body = &limitedGzipReader{
+ Reader: gr,
+ remaining: config.MaxDecompressedSize,
+ limit: config.MaxDecompressedSize,
+ }
+ } else {
+ // -1 means explicitly unlimited (not recommended)
+ c.Request().Body = gr
+ }
+
+ return next(c)
+ }
+ }, nil
+}
+
+// limitedGzipReader wraps a gzip reader with size limiting to prevent zip bombs
+type limitedGzipReader struct {
+ *gzip.Reader
+ remaining int64
+ limit int64
+}
+
+func (r *limitedGzipReader) Read(p []byte) (n int, err error) {
+ if r.remaining <= 0 {
+ // Limit exceeded - return 413 error
+ return 0, echo.ErrStatusRequestEntityTooLarge
+ }
+
+ // Limit the read to remaining bytes
+ if int64(len(p)) > r.remaining {
+ p = p[:r.remaining]
+ }
+
+ n, err = r.Reader.Read(p)
+ r.remaining -= int64(n)
+
+ return n, err
+}
+
+func (r *limitedGzipReader) Close() error {
+ return r.Reader.Close()
+}
diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go
new file mode 100644
index 000000000..1823e94bb
--- /dev/null
+++ b/middleware/decompress_test.go
@@ -0,0 +1,508 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "bytes"
+ "compress/gzip"
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestDecompress(t *testing.T) {
+ e := echo.New()
+
+ h := Decompress()(func(c *echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })
+
+ // Decompress request body
+ body := `{"name": "echo"}`
+ gz, _ := gzipString(body)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := h(c)
+ assert.NoError(t, err)
+
+ assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
+ b, err := io.ReadAll(req.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, body, string(b))
+}
+
+func TestDecompress_skippedIfNoHeader(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ // Skip if no Content-Encoding header
+ h := Decompress()(func(c *echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })
+
+ err := h(c)
+ assert.NoError(t, err)
+ assert.Equal(t, "test", rec.Body.String())
+
+}
+
+func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{}.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })(c)
+ assert.NoError(t, err)
+
+ assert.Equal(t, "test", rec.Body.String())
+
+}
+
+func TestDecompressWithConfig_DefaultConfig(t *testing.T) {
+ e := echo.New()
+
+ h := Decompress()(func(c *echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })
+
+ // Decompress
+ body := `{"name": "echo"}`
+ gz, _ := gzipString(body)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := h(c)
+ assert.NoError(t, err)
+
+ assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
+ b, err := io.ReadAll(req.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, body, string(b))
+}
+
+func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
+ e := echo.New()
+ body := `{"name":"echo"}`
+ gz, _ := gzipString(body)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ e.NewContext(req, rec)
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
+ b, err := io.ReadAll(req.Body)
+ assert.NoError(t, err)
+ assert.NotEqual(t, b, body)
+ assert.Equal(t, b, gz)
+}
+
+func TestDecompressNoContent(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := Decompress()(func(c *echo.Context) error {
+ return c.NoContent(http.StatusNoContent)
+ })
+
+ err := h(c)
+
+ if assert.NoError(t, err) {
+ assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
+ assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
+ assert.Equal(t, 0, len(rec.Body.Bytes()))
+ }
+}
+
+func TestDecompressErrorReturned(t *testing.T) {
+ e := echo.New()
+ e.Use(Decompress())
+ e.GET("/", func(c *echo.Context) error {
+ return echo.ErrNotFound
+ })
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusNotFound, rec.Code)
+ assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
+}
+
+func TestDecompressSkipper(t *testing.T) {
+ e := echo.New()
+ e.Use(DecompressWithConfig(DecompressConfig{
+ Skipper: func(c *echo.Context) bool {
+ return c.Request().URL.Path == "/skip"
+ },
+ }))
+ body := `{"name": "echo"}`
+ req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSON)
+ reqBody, err := io.ReadAll(c.Request().Body)
+ assert.NoError(t, err)
+ assert.Equal(t, body, string(reqBody))
+}
+
+type TestDecompressPoolWithError struct {
+}
+
+func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool {
+ return sync.Pool{
+ New: func() any {
+ return errors.New("pool error")
+ },
+ }
+}
+
+func TestDecompressPoolError(t *testing.T) {
+ e := echo.New()
+ e.Use(DecompressWithConfig(DecompressConfig{
+ Skipper: DefaultSkipper,
+ GzipDecompressPool: &TestDecompressPoolWithError{},
+ }))
+ body := `{"name": "echo"}`
+ req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
+ reqBody, err := io.ReadAll(c.Request().Body)
+ assert.NoError(t, err)
+ assert.Equal(t, body, string(reqBody))
+ assert.Equal(t, rec.Code, http.StatusInternalServerError)
+}
+
+func BenchmarkDecompress(b *testing.B) {
+ e := echo.New()
+ body := `{"name": "echo"}`
+ gz, _ := gzipString(body)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+
+ h := Decompress()(func(c *echo.Context) error {
+ c.Response().Write([]byte(body)) // For Content-Type sniffing
+ return nil
+ })
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ // Decompress
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h(c)
+ }
+}
+
+func gzipString(body string) ([]byte, error) {
+ var buf bytes.Buffer
+ gz := gzip.NewWriter(&buf)
+
+ _, err := gz.Write([]byte(body))
+ if err != nil {
+ return nil, err
+ }
+
+ if err := gz.Close(); err != nil {
+ return nil, err
+ }
+
+ return buf.Bytes(), nil
+}
+
+func TestDecompress_WithinLimit(t *testing.T) {
+ e := echo.New()
+ body := strings.Repeat("test data ", 100) // Small payload ~1KB
+ gz, _ := gzipString(body)
+
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{MaxDecompressedSize: 100 * MB}.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ b, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(b))
+ })(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, body, rec.Body.String())
+}
+
+func TestDecompress_ExceedsLimit(t *testing.T) {
+ e := echo.New()
+ // Create 2KB of data but limit to 1KB
+ largeBody := strings.Repeat("A", 2*1024)
+ gz, _ := gzipString(largeBody)
+
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ _, readErr := io.ReadAll(c.Request().Body)
+ return readErr
+ })(c)
+
+ // Should return 413 error
+ assert.Error(t, err)
+ he, ok := err.(echo.HTTPStatusCoder)
+ assert.True(t, ok)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
+}
+
+func TestDecompress_AtExactLimit(t *testing.T) {
+ e := echo.New()
+ exactBody := strings.Repeat("B", 1024) // Exactly 1KB
+ gz, _ := gzipString(exactBody)
+
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ b, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(b))
+ })(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, exactBody, rec.Body.String())
+}
+
+func TestDecompress_ZipBomb(t *testing.T) {
+ e := echo.New()
+ // Create highly compressed data that expands to 2MB
+ // but limit is 1MB
+ largeBody := bytes.Repeat([]byte("A"), 2*1024*1024) // 2MB
+ var buf bytes.Buffer
+ gzWriter := gzip.NewWriter(&buf)
+ gzWriter.Write(largeBody)
+ gzWriter.Close()
+
+ req := httptest.NewRequest(http.MethodPost, "/", &buf)
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{MaxDecompressedSize: 1 * MB}.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ _, readErr := io.ReadAll(c.Request().Body)
+ return readErr
+ })(c)
+
+ // Should return 413 error
+ assert.Error(t, err)
+ he, ok := err.(echo.HTTPStatusCoder)
+ assert.True(t, ok)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
+}
+
+func TestDecompress_UnlimitedExplicit(t *testing.T) {
+ e := echo.New()
+ largeBody := strings.Repeat("X", 10*1024) // 10KB
+ gz, _ := gzipString(largeBody)
+
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{MaxDecompressedSize: -1}.ToMiddleware() // Unlimited
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ b, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(b))
+ })(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, largeBody, rec.Body.String())
+}
+
+func TestDecompress_DefaultLimit(t *testing.T) {
+ e := echo.New()
+ smallBody := "test"
+ gz, _ := gzipString(smallBody)
+
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ // Use zero value which should apply 100MB default
+ h, err := DecompressConfig{}.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ b, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(b))
+ })(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, smallBody, rec.Body.String())
+}
+
+func TestDecompress_SmallCustomLimit(t *testing.T) {
+ e := echo.New()
+ body := strings.Repeat("D", 512) // 512 bytes
+ gz, _ := gzipString(body)
+
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ b, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(b))
+ })(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, body, rec.Body.String())
+}
+
+func TestDecompress_MultipleReads(t *testing.T) {
+ e := echo.New()
+ // Test that limit is enforced across multiple Read() calls
+ largeBody := strings.Repeat("M", 2*1024) // 2KB
+ gz, _ := gzipString(largeBody)
+
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ // Read in small chunks
+ buf := make([]byte, 256)
+ total := 0
+ for {
+ n, readErr := c.Request().Body.Read(buf)
+ total += n
+ if readErr != nil {
+ if readErr == io.EOF {
+ return nil
+ }
+ return readErr
+ }
+ }
+ })(c)
+
+ // Should return 413 error from cumulative reads
+ assert.Error(t, err)
+ he, ok := err.(echo.HTTPStatusCoder)
+ assert.True(t, ok)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
+}
+
+func TestDecompress_LargePayloadDosPrevention(t *testing.T) {
+ e := echo.New()
+ // Simulate a DoS attack with highly compressed large payload
+ largeSize := 10 * 1024 * 1024 // 10MB decompressed
+ largeBody := bytes.Repeat([]byte("Z"), largeSize)
+ var buf bytes.Buffer
+ gzWriter := gzip.NewWriter(&buf)
+ gzWriter.Write(largeBody)
+ gzWriter.Close()
+
+ req := httptest.NewRequest(http.MethodPost, "/", &buf)
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{MaxDecompressedSize: 1 * MB}.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ _, readErr := io.ReadAll(c.Request().Body)
+ return readErr
+ })(c)
+
+ // Should prevent DoS by returning 413
+ assert.Error(t, err)
+ he, ok := err.(echo.HTTPStatusCoder)
+ assert.True(t, ok)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
+}
+
+func BenchmarkDecompress_WithLimit(b *testing.B) {
+ e := echo.New()
+ body := strings.Repeat("benchmark data ", 1000) // ~15KB
+ gz, _ := gzipString(body)
+
+ h, _ := DecompressConfig{MaxDecompressedSize: 100 * MB}.ToMiddleware()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h(func(c *echo.Context) error {
+ io.ReadAll(c.Request().Body)
+ return nil
+ })(c)
+ }
+}
diff --git a/middleware/extractor.go b/middleware/extractor.go
new file mode 100644
index 000000000..abb603186
--- /dev/null
+++ b/middleware/extractor.go
@@ -0,0 +1,253 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "fmt"
+ "net/textproto"
+ "strings"
+
+ "github.com/labstack/echo/v5"
+)
+
+const (
+ // extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion
+ // attack vector
+ extractorLimit = 20
+)
+
+// ExtractorSource is type to indicate source for extracted value
+type ExtractorSource string
+
+const (
+ // ExtractorSourceHeader means value was extracted from request header
+ ExtractorSourceHeader ExtractorSource = "header"
+ // ExtractorSourceQuery means value was extracted from request query parameters
+ ExtractorSourceQuery ExtractorSource = "query"
+ // ExtractorSourcePathParam means value was extracted from route path parameters
+ ExtractorSourcePathParam ExtractorSource = "param"
+ // ExtractorSourceCookie means value was extracted from request cookies
+ ExtractorSourceCookie ExtractorSource = "cookie"
+ // ExtractorSourceForm means value was extracted from request form values
+ ExtractorSourceForm ExtractorSource = "form"
+)
+
+// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups
+type ValueExtractorError struct {
+ message string
+}
+
+// Error returns errors text
+func (e *ValueExtractorError) Error() string {
+ return e.message
+}
+
+var errHeaderExtractorValueMissing = &ValueExtractorError{message: "missing value in request header"}
+var errHeaderExtractorValueInvalid = &ValueExtractorError{message: "invalid value in request header"}
+var errQueryExtractorValueMissing = &ValueExtractorError{message: "missing value in the query string"}
+var errParamExtractorValueMissing = &ValueExtractorError{message: "missing value in path params"}
+var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing value in cookies"}
+var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"}
+
+// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context.
+type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error)
+
+// CreateExtractors creates ValuesExtractors from given lookups.
+// lookups is a string in the form of ":" or ":,:" that is used
+// to extract key from the request.
+// Possible values:
+// - "header:" or "header::"
+// `` is argument value to cut/trim prefix of the extracted value. This is useful if header
+// value has static prefix like `Authorization: ` where part that we
+// want to cut is ` ` note the space at the end.
+// In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `.
+// - "query:"
+// - "param:"
+// - "form:"
+// - "cookie:"
+//
+// Multiple sources example:
+// - "header:Authorization,header:X-Api-Key"
+//
+// limit sets the maximum amount how many lookups can be returned.
+func CreateExtractors(lookups string, limit uint) ([]ValuesExtractor, error) {
+ return createExtractors(lookups, limit)
+}
+
+func createExtractors(lookups string, limit uint) ([]ValuesExtractor, error) {
+ if lookups == "" {
+ return nil, nil
+ }
+ if limit == 0 {
+ limit = 1
+ } else if limit > extractorLimit {
+ limit = extractorLimit
+ }
+
+ sources := strings.Split(lookups, ",")
+ var extractors = make([]ValuesExtractor, 0)
+ for _, source := range sources {
+ parts := strings.Split(source, ":")
+ if len(parts) < 2 {
+ return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source)
+ }
+
+ switch parts[0] {
+ case "query":
+ extractors = append(extractors, valuesFromQuery(parts[1], limit))
+ case "param":
+ extractors = append(extractors, valuesFromParam(parts[1], limit))
+ case "cookie":
+ extractors = append(extractors, valuesFromCookie(parts[1], limit))
+ case "form":
+ extractors = append(extractors, valuesFromForm(parts[1], limit))
+ case "header":
+ prefix := ""
+ if len(parts) > 2 {
+ prefix = parts[2]
+ }
+ extractors = append(extractors, valuesFromHeader(parts[1], prefix, limit))
+ }
+ }
+ return extractors, nil
+}
+
+// valuesFromHeader returns a functions that extracts values from the request header.
+// valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static
+// prefix like `Authorization: ` where part that we want to remove is ` `
+// note the space at the end. In case of basic authentication `Authorization: Basic ` prefix we want to remove
+// is `Basic `. In case of JWT tokens `Authorization: Bearer ` prefix is `Bearer `.
+// If prefix is left empty the whole value is returned.
+func valuesFromHeader(header string, valuePrefix string, limit uint) ValuesExtractor {
+ prefixLen := len(valuePrefix)
+ // standard library parses http.Request header keys in canonical form but we may provide something else so fix this
+ header = textproto.CanonicalMIMEHeaderKey(header)
+ if limit == 0 {
+ limit = 1
+ }
+ return func(c *echo.Context) ([]string, ExtractorSource, error) {
+ values := c.Request().Header.Values(header)
+ if len(values) == 0 {
+ return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing
+ }
+
+ i := uint(0)
+ result := make([]string, 0)
+ for _, value := range values {
+ if prefixLen == 0 {
+ result = append(result, value)
+ i++
+ if i >= limit {
+ break
+ }
+ } else if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) {
+ result = append(result, value[prefixLen:])
+ i++
+ if i >= limit {
+ break
+ }
+ }
+ }
+
+ if len(result) == 0 {
+ if prefixLen > 0 {
+ return nil, ExtractorSourceHeader, errHeaderExtractorValueInvalid
+ }
+ return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing
+ }
+ return result, ExtractorSourceHeader, nil
+ }
+}
+
+// valuesFromQuery returns a function that extracts values from the query string.
+func valuesFromQuery(param string, limit uint) ValuesExtractor {
+ if limit == 0 {
+ limit = 1
+ }
+ return func(c *echo.Context) ([]string, ExtractorSource, error) {
+ result := c.QueryParams()[param]
+ if len(result) == 0 {
+ return nil, ExtractorSourceQuery, errQueryExtractorValueMissing
+ } else if len(result) > int(limit)-1 {
+ result = result[:limit]
+ }
+ return result, ExtractorSourceQuery, nil
+ }
+}
+
+// valuesFromParam returns a function that extracts values from the url param string.
+func valuesFromParam(param string, limit uint) ValuesExtractor {
+ if limit == 0 {
+ limit = 1
+ }
+ return func(c *echo.Context) ([]string, ExtractorSource, error) {
+ result := make([]string, 0)
+ i := uint(0)
+ for _, p := range c.PathValues() {
+ if param != p.Name {
+ continue
+ }
+ result = append(result, p.Value)
+ i++
+ if i >= limit {
+ break
+ }
+ }
+ if len(result) == 0 {
+ return nil, ExtractorSourcePathParam, errParamExtractorValueMissing
+ }
+ return result, ExtractorSourcePathParam, nil
+ }
+}
+
+// valuesFromCookie returns a function that extracts values from the named cookie.
+func valuesFromCookie(name string, limit uint) ValuesExtractor {
+ if limit == 0 {
+ limit = 1
+ }
+ return func(c *echo.Context) ([]string, ExtractorSource, error) {
+ cookies := c.Cookies()
+ if len(cookies) == 0 {
+ return nil, ExtractorSourceCookie, errCookieExtractorValueMissing
+ }
+
+ i := uint(0)
+ result := make([]string, 0)
+ for _, cookie := range cookies {
+ if name != cookie.Name {
+ continue
+ }
+ result = append(result, cookie.Value)
+ i++
+ if i >= limit {
+ break
+ }
+ }
+ if len(result) == 0 {
+ return nil, ExtractorSourceCookie, errCookieExtractorValueMissing
+ }
+ return result, ExtractorSourceCookie, nil
+ }
+}
+
+// valuesFromForm returns a function that extracts values from the form field.
+func valuesFromForm(name string, limit uint) ValuesExtractor {
+ if limit == 0 {
+ limit = 1
+ }
+ return func(c *echo.Context) ([]string, ExtractorSource, error) {
+ if c.Request().Form == nil {
+ _, _ = c.MultipartForm() // we want to trigger c.request.ParseMultipartForm(c.formParseMaxMemory)
+ }
+ values := c.Request().Form[name]
+ if len(values) == 0 {
+ return nil, ExtractorSourceForm, errFormExtractorValueMissing
+ }
+ if len(values) > int(limit)-1 {
+ values = values[:limit]
+ }
+ result := append([]string{}, values...)
+ return result, ExtractorSourceForm, nil
+ }
+}
diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go
new file mode 100644
index 000000000..04cc7b829
--- /dev/null
+++ b/middleware/extractor_test.go
@@ -0,0 +1,625 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "bytes"
+ "fmt"
+ "mime/multipart"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestCreateExtractors(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenRequest func() *http.Request
+ givenPathValues echo.PathValues
+ whenLookups string
+ whenLimit uint
+ expectValues []string
+ expectSource ExtractorSource
+ expectCreateError string
+ expectError string
+ }{
+ {
+ name: "ok, header",
+ givenRequest: func() *http.Request {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAuthorization, "Bearer token")
+ return req
+ },
+ whenLookups: "header:Authorization:Bearer ",
+ expectValues: []string{"token"},
+ expectSource: ExtractorSourceHeader,
+ },
+ {
+ name: "ok, form",
+ givenRequest: func() *http.Request {
+ f := make(url.Values)
+ f.Set("name", "Jon Snow")
+
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
+ req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
+ return req
+ },
+ whenLookups: "form:name",
+ expectValues: []string{"Jon Snow"},
+ expectSource: ExtractorSourceForm,
+ },
+ {
+ name: "ok, cookie",
+ givenRequest: func() *http.Request {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderCookie, "_csrf=token")
+ return req
+ },
+ whenLookups: "cookie:_csrf",
+ expectValues: []string{"token"},
+ expectSource: ExtractorSourceCookie,
+ },
+ {
+ name: "ok, param",
+ givenPathValues: echo.PathValues{
+ {Name: "id", Value: "123"},
+ },
+ whenLookups: "param:id",
+ expectValues: []string{"123"},
+ expectSource: ExtractorSourcePathParam,
+ },
+ {
+ name: "ok, query",
+ givenRequest: func() *http.Request {
+ req := httptest.NewRequest(http.MethodGet, "/?id=999", nil)
+ return req
+ },
+ whenLookups: "query:id",
+ expectValues: []string{"999"},
+ expectSource: ExtractorSourceQuery,
+ },
+ {
+ name: "nok, invalid lookup",
+ whenLookups: "query",
+ expectCreateError: "extractor source for lookup could not be split into needed parts: query",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ if tc.givenRequest != nil {
+ req = tc.givenRequest()
+ }
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ if tc.givenPathValues != nil {
+ c.SetPathValues(tc.givenPathValues)
+ }
+
+ extractors, err := CreateExtractors(tc.whenLookups, tc.whenLimit)
+ if tc.expectCreateError != "" {
+ assert.EqualError(t, err, tc.expectCreateError)
+ return
+ }
+ assert.NoError(t, err)
+
+ for _, e := range extractors {
+ values, source, eErr := e(c)
+ assert.Equal(t, tc.expectValues, values)
+ assert.Equal(t, tc.expectSource, source)
+ if tc.expectError != "" {
+ assert.EqualError(t, eErr, tc.expectError)
+ return
+ }
+ assert.NoError(t, eErr)
+ }
+ })
+ }
+}
+
+func TestValuesFromHeader(t *testing.T) {
+ exampleRequest := func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
+ }
+
+ var testCases = []struct {
+ name string
+ givenRequest func(req *http.Request)
+ whenName string
+ whenValuePrefix string
+ whenLimit uint
+ expectValues []string
+ expectError string
+ }{
+ {
+ name: "ok, single value",
+ givenRequest: exampleRequest,
+ whenName: echo.HeaderAuthorization,
+ whenValuePrefix: "basic ",
+ expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
+ },
+ {
+ name: "ok, single value, case insensitive",
+ givenRequest: exampleRequest,
+ whenName: echo.HeaderAuthorization,
+ whenValuePrefix: "Basic ",
+ expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
+ },
+ {
+ name: "ok, multiple value",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
+ req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
+ },
+ whenName: echo.HeaderAuthorization,
+ whenValuePrefix: "basic ",
+ whenLimit: 2,
+ expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"},
+ },
+ {
+ name: "ok, empty prefix",
+ givenRequest: exampleRequest,
+ whenName: echo.HeaderAuthorization,
+ whenValuePrefix: "",
+ expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="},
+ },
+ {
+ name: "nok, no matching due different prefix",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
+ req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
+ },
+ whenName: echo.HeaderAuthorization,
+ whenValuePrefix: "Bearer ",
+ expectError: errHeaderExtractorValueInvalid.Error(),
+ },
+ {
+ name: "nok, no matching due different prefix",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
+ req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
+ },
+ whenName: echo.HeaderWWWAuthenticate,
+ whenValuePrefix: "",
+ expectError: errHeaderExtractorValueMissing.Error(),
+ },
+ {
+ name: "nok, no headers",
+ givenRequest: nil,
+ whenName: echo.HeaderAuthorization,
+ whenValuePrefix: "basic ",
+ expectError: errHeaderExtractorValueMissing.Error(),
+ },
+ {
+ name: "ok, prefix, cut values over extractorLimit",
+ givenRequest: func(req *http.Request) {
+ for i := 1; i <= 25; i++ {
+ req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("basic %v", i))
+ }
+ },
+ whenName: echo.HeaderAuthorization,
+ whenValuePrefix: "basic ",
+ whenLimit: extractorLimit,
+ expectValues: []string{
+ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
+ "11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
+ },
+ },
+ {
+ name: "ok, cut values over extractorLimit",
+ givenRequest: func(req *http.Request) {
+ for i := 1; i <= 25; i++ {
+ req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("%v", i))
+ }
+ },
+ whenName: echo.HeaderAuthorization,
+ whenValuePrefix: "",
+ whenLimit: extractorLimit,
+ expectValues: []string{
+ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
+ "11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ if tc.givenRequest != nil {
+ tc.givenRequest(req)
+ }
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix, tc.whenLimit)
+
+ values, source, err := extractor(c)
+ assert.Equal(t, tc.expectValues, values)
+ assert.Equal(t, ExtractorSourceHeader, source)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValuesFromQuery(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenQueryPart string
+ whenName string
+ whenLimit uint
+ expectValues []string
+ expectError string
+ }{
+ {
+ name: "ok, single value",
+ givenQueryPart: "?id=123&name=test",
+ whenName: "id",
+ expectValues: []string{"123"},
+ },
+ {
+ name: "ok, multiple value",
+ givenQueryPart: "?id=123&id=456&name=test",
+ whenName: "id",
+ whenLimit: 2,
+ expectValues: []string{"123", "456"},
+ },
+ {
+ name: "nok, missing value",
+ givenQueryPart: "?id=123&name=test",
+ whenName: "nope",
+ expectError: errQueryExtractorValueMissing.Error(),
+ },
+ {
+ name: "ok, cut values over extractorLimit",
+ givenQueryPart: "?name=test" +
+ "&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" +
+ "&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" +
+ "&id=21&id=22&id=23&id=24&id=25",
+ whenName: "id",
+ whenLimit: extractorLimit,
+ expectValues: []string{
+ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
+ "11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ extractor := valuesFromQuery(tc.whenName, tc.whenLimit)
+
+ values, source, err := extractor(c)
+ assert.Equal(t, tc.expectValues, values)
+ assert.Equal(t, ExtractorSourceQuery, source)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValuesFromParam(t *testing.T) {
+ examplePathValues := echo.PathValues{
+ {Name: "id", Value: "123"},
+ {Name: "gid", Value: "456"},
+ {Name: "gid", Value: "789"},
+ }
+ examplePathValues20 := make(echo.PathValues, 0)
+ for i := 1; i < 25; i++ {
+ examplePathValues20 = append(examplePathValues20, echo.PathValue{Name: "id", Value: fmt.Sprintf("%v", i)})
+ }
+
+ var testCases = []struct {
+ name string
+ givenPathValues echo.PathValues
+ whenName string
+ whenLimit uint
+ expectValues []string
+ expectError string
+ }{
+ {
+ name: "ok, single value",
+ givenPathValues: examplePathValues,
+ whenName: "id",
+ expectValues: []string{"123"},
+ },
+ {
+ name: "ok, multiple value",
+ givenPathValues: examplePathValues,
+ whenName: "gid",
+ whenLimit: 2,
+ expectValues: []string{"456", "789"},
+ },
+ {
+ name: "nok, no values",
+ givenPathValues: nil,
+ whenName: "nope",
+ expectValues: nil,
+ expectError: errParamExtractorValueMissing.Error(),
+ },
+ {
+ name: "nok, no matching value",
+ givenPathValues: examplePathValues,
+ whenName: "nope",
+ expectValues: nil,
+ expectError: errParamExtractorValueMissing.Error(),
+ },
+ {
+ name: "ok, cut values over extractorLimit",
+ givenPathValues: examplePathValues20,
+ whenName: "id",
+ whenLimit: extractorLimit,
+ expectValues: []string{
+ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
+ "11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ if tc.givenPathValues != nil {
+ c.SetPathValues(tc.givenPathValues)
+ }
+
+ extractor := valuesFromParam(tc.whenName, tc.whenLimit)
+
+ values, source, err := extractor(c)
+ assert.Equal(t, tc.expectValues, values)
+ assert.Equal(t, ExtractorSourcePathParam, source)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValuesFromCookie(t *testing.T) {
+ exampleRequest := func(req *http.Request) {
+ req.Header.Set(echo.HeaderCookie, "_csrf=token")
+ }
+
+ var testCases = []struct {
+ name string
+ givenRequest func(req *http.Request)
+ whenName string
+ whenLimit uint
+ expectValues []string
+ expectError string
+ }{
+ {
+ name: "ok, single value",
+ givenRequest: exampleRequest,
+ whenName: "_csrf",
+ expectValues: []string{"token"},
+ },
+ {
+ name: "ok, multiple value",
+ givenRequest: func(req *http.Request) {
+ req.Header.Add(echo.HeaderCookie, "_csrf=token")
+ req.Header.Add(echo.HeaderCookie, "_csrf=token2")
+ },
+ whenName: "_csrf",
+ whenLimit: 2,
+ expectValues: []string{"token", "token2"},
+ },
+ {
+ name: "nok, no matching cookie",
+ givenRequest: exampleRequest,
+ whenName: "xxx",
+ expectValues: nil,
+ expectError: errCookieExtractorValueMissing.Error(),
+ },
+ {
+ name: "nok, no cookies at all",
+ givenRequest: nil,
+ whenName: "xxx",
+ expectValues: nil,
+ expectError: errCookieExtractorValueMissing.Error(),
+ },
+ {
+ name: "ok, cut values over extractorLimit",
+ givenRequest: func(req *http.Request) {
+ for i := 1; i < 25; i++ {
+ req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i))
+ }
+ },
+ whenName: "_csrf",
+ whenLimit: extractorLimit,
+ expectValues: []string{
+ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
+ "11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ if tc.givenRequest != nil {
+ tc.givenRequest(req)
+ }
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ extractor := valuesFromCookie(tc.whenName, tc.whenLimit)
+
+ values, source, err := extractor(c)
+ assert.Equal(t, tc.expectValues, values)
+ assert.Equal(t, ExtractorSourceCookie, source)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValuesFromForm(t *testing.T) {
+ examplePostFormRequest := func(mod func(v *url.Values)) *http.Request {
+ f := make(url.Values)
+ f.Set("name", "Jon Snow")
+ f.Set("emails[]", "jon@labstack.com")
+ if mod != nil {
+ mod(&f)
+ }
+
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
+ req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
+
+ return req
+ }
+ exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request {
+ f := make(url.Values)
+ f.Set("name", "Jon Snow")
+ f.Set("emails[]", "jon@labstack.com")
+ if mod != nil {
+ mod(&f)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil)
+ return req
+ }
+
+ exampleMultiPartFormRequest := func(mod func(w *multipart.Writer)) *http.Request {
+ var b bytes.Buffer
+ w := multipart.NewWriter(&b)
+ w.WriteField("name", "Jon Snow")
+ w.WriteField("emails[]", "jon@labstack.com")
+ if mod != nil {
+ mod(w)
+ }
+
+ fw, _ := w.CreateFormFile("upload", "my.file")
+ fw.Write([]byte(`hi
`))
+ w.Close()
+
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(b.String()))
+ req.Header.Add(echo.HeaderContentType, w.FormDataContentType())
+
+ return req
+ }
+
+ var testCases = []struct {
+ name string
+ givenRequest *http.Request
+ whenName string
+ whenLimit uint
+ expectValues []string
+ expectError string
+ }{
+ {
+ name: "ok, POST form, single value",
+ givenRequest: examplePostFormRequest(nil),
+ whenName: "emails[]",
+ expectValues: []string{"jon@labstack.com"},
+ },
+ {
+ name: "ok, POST form, multiple value",
+ givenRequest: examplePostFormRequest(func(v *url.Values) {
+ v.Add("emails[]", "snow@labstack.com")
+ }),
+ whenName: "emails[]",
+ whenLimit: 2,
+ expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
+ },
+ {
+ name: "ok, POST multipart/form, multiple value",
+ givenRequest: exampleMultiPartFormRequest(func(w *multipart.Writer) {
+ w.WriteField("emails[]", "snow@labstack.com")
+ }),
+ whenName: "emails[]",
+ whenLimit: 2,
+ expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
+ },
+ {
+ name: "ok, GET form, single value",
+ givenRequest: exampleGetFormRequest(nil),
+ whenName: "emails[]",
+ expectValues: []string{"jon@labstack.com"},
+ },
+ {
+ name: "ok, GET form, multiple value",
+ givenRequest: examplePostFormRequest(func(v *url.Values) {
+ v.Add("emails[]", "snow@labstack.com")
+ }),
+ whenName: "emails[]",
+ whenLimit: 2,
+ expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
+ },
+ {
+ name: "nok, POST form, value missing",
+ givenRequest: examplePostFormRequest(nil),
+ whenName: "nope",
+ expectError: errFormExtractorValueMissing.Error(),
+ },
+ {
+ name: "ok, cut values over extractorLimit",
+ givenRequest: examplePostFormRequest(func(v *url.Values) {
+ for i := 1; i < 25; i++ {
+ v.Add("id[]", fmt.Sprintf("%v", i))
+ }
+ }),
+ whenName: "id[]",
+ whenLimit: extractorLimit,
+ expectValues: []string{
+ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
+ "11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := tc.givenRequest
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ extractor := valuesFromForm(tc.whenName, tc.whenLimit)
+
+ values, source, err := extractor(c)
+ assert.Equal(t, tc.expectValues, values)
+ assert.Equal(t, ExtractorSourceForm, source)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
diff --git a/middleware/jwt.go b/middleware/jwt.go
deleted file mode 100644
index 55a986327..000000000
--- a/middleware/jwt.go
+++ /dev/null
@@ -1,267 +0,0 @@
-package middleware
-
-import (
- "fmt"
- "net/http"
- "reflect"
- "strings"
-
- "github.com/dgrijalva/jwt-go"
- "github.com/labstack/echo/v4"
-)
-
-type (
- // JWTConfig defines the config for JWT middleware.
- JWTConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
-
- // BeforeFunc defines a function which is executed just before the middleware.
- BeforeFunc BeforeFunc
-
- // SuccessHandler defines a function which is executed for a valid token.
- SuccessHandler JWTSuccessHandler
-
- // ErrorHandler defines a function which is executed for an invalid token.
- // It may be used to define a custom JWT error.
- ErrorHandler JWTErrorHandler
-
- // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context.
- ErrorHandlerWithContext JWTErrorHandlerWithContext
-
- // Signing key to validate token. Used as fallback if SigningKeys has length 0.
- // Required. This or SigningKeys.
- SigningKey interface{}
-
- // Map of signing keys to validate token with kid field usage.
- // Required. This or SigningKey.
- SigningKeys map[string]interface{}
-
- // Signing method, used to check token signing method.
- // Optional. Default value HS256.
- SigningMethod string
-
- // Context key to store user information from the token into context.
- // Optional. Default value "user".
- ContextKey string
-
- // Claims are extendable claims data defining token content.
- // Optional. Default value jwt.MapClaims
- Claims jwt.Claims
-
- // TokenLookup is a string in the form of ":" that is used
- // to extract token from the request.
- // Optional. Default value "header:Authorization".
- // Possible values:
- // - "header:"
- // - "query:"
- // - "param:"
- // - "cookie:"
- TokenLookup string
-
- // AuthScheme to be used in the Authorization header.
- // Optional. Default value "Bearer".
- AuthScheme string
-
- keyFunc jwt.Keyfunc
- }
-
- // JWTSuccessHandler defines a function which is executed for a valid token.
- JWTSuccessHandler func(echo.Context)
-
- // JWTErrorHandler defines a function which is executed for an invalid token.
- JWTErrorHandler func(error) error
-
- // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
- JWTErrorHandlerWithContext func(error, echo.Context) error
-
- jwtExtractor func(echo.Context) (string, error)
-)
-
-// Algorithms
-const (
- AlgorithmHS256 = "HS256"
-)
-
-// Errors
-var (
- ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt")
-)
-
-var (
- // DefaultJWTConfig is the default JWT auth middleware config.
- DefaultJWTConfig = JWTConfig{
- Skipper: DefaultSkipper,
- SigningMethod: AlgorithmHS256,
- ContextKey: "user",
- TokenLookup: "header:" + echo.HeaderAuthorization,
- AuthScheme: "Bearer",
- Claims: jwt.MapClaims{},
- }
-)
-
-// JWT returns a JSON Web Token (JWT) auth middleware.
-//
-// For valid token, it sets the user in context and calls next handler.
-// For invalid token, it returns "401 - Unauthorized" error.
-// For missing token, it returns "400 - Bad Request" error.
-//
-// See: https://jwt.io/introduction
-// See `JWTConfig.TokenLookup`
-func JWT(key interface{}) echo.MiddlewareFunc {
- c := DefaultJWTConfig
- c.SigningKey = key
- return JWTWithConfig(c)
-}
-
-// JWTWithConfig returns a JWT auth middleware with config.
-// See: `JWT()`.
-func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
- // Defaults
- if config.Skipper == nil {
- config.Skipper = DefaultJWTConfig.Skipper
- }
- if config.SigningKey == nil && len(config.SigningKeys) == 0 {
- panic("echo: jwt middleware requires signing key")
- }
- if config.SigningMethod == "" {
- config.SigningMethod = DefaultJWTConfig.SigningMethod
- }
- if config.ContextKey == "" {
- config.ContextKey = DefaultJWTConfig.ContextKey
- }
- if config.Claims == nil {
- config.Claims = DefaultJWTConfig.Claims
- }
- if config.TokenLookup == "" {
- config.TokenLookup = DefaultJWTConfig.TokenLookup
- }
- if config.AuthScheme == "" {
- config.AuthScheme = DefaultJWTConfig.AuthScheme
- }
- config.keyFunc = func(t *jwt.Token) (interface{}, error) {
- // Check the signing method
- if t.Method.Alg() != config.SigningMethod {
- return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
- }
- if len(config.SigningKeys) > 0 {
- if kid, ok := t.Header["kid"].(string); ok {
- if key, ok := config.SigningKeys[kid]; ok {
- return key, nil
- }
- }
- return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
- }
-
- return config.SigningKey, nil
- }
-
- // Initialize
- parts := strings.Split(config.TokenLookup, ":")
- extractor := jwtFromHeader(parts[1], config.AuthScheme)
- switch parts[0] {
- case "query":
- extractor = jwtFromQuery(parts[1])
- case "param":
- extractor = jwtFromParam(parts[1])
- case "cookie":
- extractor = jwtFromCookie(parts[1])
- }
-
- return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) error {
- if config.Skipper(c) {
- return next(c)
- }
-
- if config.BeforeFunc != nil {
- config.BeforeFunc(c)
- }
-
- auth, err := extractor(c)
- if err != nil {
- if config.ErrorHandler != nil {
- return config.ErrorHandler(err)
- }
-
- if config.ErrorHandlerWithContext != nil {
- return config.ErrorHandlerWithContext(err, c)
- }
- return err
- }
- token := new(jwt.Token)
- // Issue #647, #656
- if _, ok := config.Claims.(jwt.MapClaims); ok {
- token, err = jwt.Parse(auth, config.keyFunc)
- } else {
- t := reflect.ValueOf(config.Claims).Type().Elem()
- claims := reflect.New(t).Interface().(jwt.Claims)
- token, err = jwt.ParseWithClaims(auth, claims, config.keyFunc)
- }
- if err == nil && token.Valid {
- // Store user information from token into context.
- c.Set(config.ContextKey, token)
- if config.SuccessHandler != nil {
- config.SuccessHandler(c)
- }
- return next(c)
- }
- if config.ErrorHandler != nil {
- return config.ErrorHandler(err)
- }
- if config.ErrorHandlerWithContext != nil {
- return config.ErrorHandlerWithContext(err, c)
- }
- return &echo.HTTPError{
- Code: http.StatusUnauthorized,
- Message: "invalid or expired jwt",
- Internal: err,
- }
- }
- }
-}
-
-// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header.
-func jwtFromHeader(header string, authScheme string) jwtExtractor {
- return func(c echo.Context) (string, error) {
- auth := c.Request().Header.Get(header)
- l := len(authScheme)
- if len(auth) > l+1 && auth[:l] == authScheme {
- return auth[l+1:], nil
- }
- return "", ErrJWTMissing
- }
-}
-
-// jwtFromQuery returns a `jwtExtractor` that extracts token from the query string.
-func jwtFromQuery(param string) jwtExtractor {
- return func(c echo.Context) (string, error) {
- token := c.QueryParam(param)
- if token == "" {
- return "", ErrJWTMissing
- }
- return token, nil
- }
-}
-
-// jwtFromParam returns a `jwtExtractor` that extracts token from the url param string.
-func jwtFromParam(param string) jwtExtractor {
- return func(c echo.Context) (string, error) {
- token := c.Param(param)
- if token == "" {
- return "", ErrJWTMissing
- }
- return token, nil
- }
-}
-
-// jwtFromCookie returns a `jwtExtractor` that extracts token from the named cookie.
-func jwtFromCookie(name string) jwtExtractor {
- return func(c echo.Context) (string, error) {
- cookie, err := c.Cookie(name)
- if err != nil {
- return "", ErrJWTMissing
- }
- return cookie.Value, nil
- }
-}
diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go
deleted file mode 100644
index 7f15bd467..000000000
--- a/middleware/jwt_test.go
+++ /dev/null
@@ -1,329 +0,0 @@
-package middleware
-
-import (
- "net/http"
- "net/http/httptest"
- "testing"
-
- "github.com/dgrijalva/jwt-go"
- "github.com/labstack/echo/v4"
- "github.com/stretchr/testify/assert"
-)
-
-// jwtCustomInfo defines some custom types we're going to use within our tokens.
-type jwtCustomInfo struct {
- Name string `json:"name"`
- Admin bool `json:"admin"`
-}
-
-// jwtCustomClaims are custom claims expanding default ones.
-type jwtCustomClaims struct {
- *jwt.StandardClaims
- jwtCustomInfo
-}
-
-func TestJWTRace(t *testing.T) {
- e := echo.New()
- handler := func(c echo.Context) error {
- return c.String(http.StatusOK, "test")
- }
- initialToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
- raceToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IlJhY2UgQ29uZGl0aW9uIiwiYWRtaW4iOmZhbHNlfQ.Xzkx9mcgGqYMTkuxSCbJ67lsDyk5J2aB7hu65cEE-Ss"
- validKey := []byte("secret")
-
- h := JWTWithConfig(JWTConfig{
- Claims: &jwtCustomClaims{},
- SigningKey: validKey,
- })(handler)
-
- makeReq := func(token string) echo.Context {
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- res := httptest.NewRecorder()
- req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" "+token)
- c := e.NewContext(req, res)
- assert.NoError(t, h(c))
- return c
- }
-
- c := makeReq(initialToken)
- user := c.Get("user").(*jwt.Token)
- claims := user.Claims.(*jwtCustomClaims)
- assert.Equal(t, claims.Name, "John Doe")
-
- makeReq(raceToken)
- user = c.Get("user").(*jwt.Token)
- claims = user.Claims.(*jwtCustomClaims)
- // Initial context should still be "John Doe", not "Race Condition"
- assert.Equal(t, claims.Name, "John Doe")
- assert.Equal(t, claims.Admin, true)
-}
-
-func TestJWT(t *testing.T) {
- e := echo.New()
- handler := func(c echo.Context) error {
- return c.String(http.StatusOK, "test")
- }
- token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
- validKey := []byte("secret")
- invalidKey := []byte("invalid-key")
- validAuth := DefaultJWTConfig.AuthScheme + " " + token
-
- for _, tc := range []struct {
- expPanic bool
- expErrCode int // 0 for Success
- config JWTConfig
- reqURL string // "/" if empty
- hdrAuth string
- hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
- info string
- }{
- {
- expPanic: true,
- info: "No signing key provided",
- },
- {
- expErrCode: http.StatusBadRequest,
- config: JWTConfig{
- SigningKey: validKey,
- SigningMethod: "RS256",
- },
- info: "Unexpected signing method",
- },
- {
- expErrCode: http.StatusUnauthorized,
- hdrAuth: validAuth,
- config: JWTConfig{SigningKey: invalidKey},
- info: "Invalid key",
- },
- {
- hdrAuth: validAuth,
- config: JWTConfig{SigningKey: validKey},
- info: "Valid JWT",
- },
- {
- hdrAuth: "Token" + " " + token,
- config: JWTConfig{AuthScheme: "Token", SigningKey: validKey},
- info: "Valid JWT with custom AuthScheme",
- },
- {
- hdrAuth: validAuth,
- config: JWTConfig{
- Claims: &jwtCustomClaims{},
- SigningKey: []byte("secret"),
- },
- info: "Valid JWT with custom claims",
- },
- {
- hdrAuth: "invalid-auth",
- expErrCode: http.StatusBadRequest,
- config: JWTConfig{SigningKey: validKey},
- info: "Invalid Authorization header",
- },
- {
- config: JWTConfig{SigningKey: validKey},
- expErrCode: http.StatusBadRequest,
- info: "Empty header auth field",
- },
- {
- config: JWTConfig{
- SigningKey: validKey,
- TokenLookup: "query:jwt",
- },
- reqURL: "/?a=b&jwt=" + token,
- info: "Valid query method",
- },
- {
- config: JWTConfig{
- SigningKey: validKey,
- TokenLookup: "query:jwt",
- },
- reqURL: "/?a=b&jwtxyz=" + token,
- expErrCode: http.StatusBadRequest,
- info: "Invalid query param name",
- },
- {
- config: JWTConfig{
- SigningKey: validKey,
- TokenLookup: "query:jwt",
- },
- reqURL: "/?a=b&jwt=invalid-token",
- expErrCode: http.StatusUnauthorized,
- info: "Invalid query param value",
- },
- {
- config: JWTConfig{
- SigningKey: validKey,
- TokenLookup: "query:jwt",
- },
- reqURL: "/?a=b",
- expErrCode: http.StatusBadRequest,
- info: "Empty query",
- },
- {
- config: JWTConfig{
- SigningKey: validKey,
- TokenLookup: "param:jwt",
- },
- reqURL: "/" + token,
- info: "Valid param method",
- },
- {
- config: JWTConfig{
- SigningKey: validKey,
- TokenLookup: "cookie:jwt",
- },
- hdrCookie: "jwt=" + token,
- info: "Valid cookie method",
- },
- {
- config: JWTConfig{
- SigningKey: validKey,
- TokenLookup: "cookie:jwt",
- },
- expErrCode: http.StatusUnauthorized,
- hdrCookie: "jwt=invalid",
- info: "Invalid token with cookie method",
- },
- {
- config: JWTConfig{
- SigningKey: validKey,
- TokenLookup: "cookie:jwt",
- },
- expErrCode: http.StatusBadRequest,
- info: "Empty cookie",
- },
- } {
- if tc.reqURL == "" {
- tc.reqURL = "/"
- }
-
- req := httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
- res := httptest.NewRecorder()
- req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
- req.Header.Set(echo.HeaderCookie, tc.hdrCookie)
- c := e.NewContext(req, res)
-
- if tc.reqURL == "/" + token {
- c.SetParamNames("jwt")
- c.SetParamValues(token)
- }
-
- if tc.expPanic {
- assert.Panics(t, func() {
- JWTWithConfig(tc.config)
- }, tc.info)
- continue
- }
-
- if tc.expErrCode != 0 {
- h := JWTWithConfig(tc.config)(handler)
- he := h(c).(*echo.HTTPError)
- assert.Equal(t, tc.expErrCode, he.Code, tc.info)
- continue
- }
-
- h := JWTWithConfig(tc.config)(handler)
- if assert.NoError(t, h(c), tc.info) {
- user := c.Get("user").(*jwt.Token)
- switch claims := user.Claims.(type) {
- case jwt.MapClaims:
- assert.Equal(t, claims["name"], "John Doe", tc.info)
- case *jwtCustomClaims:
- assert.Equal(t, claims.Name, "John Doe", tc.info)
- assert.Equal(t, claims.Admin, true, tc.info)
- default:
- panic("unexpected type of claims")
- }
- }
- }
-}
-
-func TestJWTwithKID(t *testing.T) {
- test := assert.New(t)
-
- e := echo.New()
- handler := func(c echo.Context) error {
- return c.String(http.StatusOK, "test")
- }
- firstToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImZpcnN0T25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.w5VGpHOe0jlNgf7jMVLHzIYH_XULmpUlreJnilwSkWk"
- secondToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.sdghDYQ85jdh0hgQ6bKbMguLI_NSPYWjkhVJkee-yZM"
- wrongToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.RyhLybtVLpoewF6nz9YN79oXo32kAtgUxp8FNwTkb90"
- staticToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.1_-XFYUPpJfgsaGwYhgZEt7hfySMg-a3GN-nfZmbW7o"
- validKeys := map[string]interface{}{"firstOne": []byte("first_secret"), "secondOne": []byte("second_secret")}
- invalidKeys := map[string]interface{}{"thirdOne": []byte("third_secret")}
- staticSecret := []byte("static_secret")
- invalidStaticSecret := []byte("invalid_secret")
-
- for _, tc := range []struct {
- expErrCode int // 0 for Success
- config JWTConfig
- hdrAuth string
- info string
- }{
- {
- hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken,
- config: JWTConfig{SigningKeys: validKeys},
- info: "First token valid",
- },
- {
- hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken,
- config: JWTConfig{SigningKeys: validKeys},
- info: "Second token valid",
- },
- {
- expErrCode: http.StatusUnauthorized,
- hdrAuth: DefaultJWTConfig.AuthScheme + " " + wrongToken,
- config: JWTConfig{SigningKeys: validKeys},
- info: "Wrong key id token",
- },
- {
- hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken,
- config: JWTConfig{SigningKey: staticSecret},
- info: "Valid static secret token",
- },
- {
- expErrCode: http.StatusUnauthorized,
- hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken,
- config: JWTConfig{SigningKey: invalidStaticSecret},
- info: "Invalid static secret",
- },
- {
- expErrCode: http.StatusUnauthorized,
- hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken,
- config: JWTConfig{SigningKeys: invalidKeys},
- info: "Invalid keys first token",
- },
- {
- expErrCode: http.StatusUnauthorized,
- hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken,
- config: JWTConfig{SigningKeys: invalidKeys},
- info: "Invalid keys second token",
- },
- } {
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- res := httptest.NewRecorder()
- req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
- c := e.NewContext(req, res)
-
- if tc.expErrCode != 0 {
- h := JWTWithConfig(tc.config)(handler)
- he := h(c).(*echo.HTTPError)
- test.Equal(tc.expErrCode, he.Code, tc.info)
- continue
- }
-
- h := JWTWithConfig(tc.config)(handler)
- if test.NoError(h(c), tc.info) {
- user := c.Get("user").(*jwt.Token)
- switch claims := user.Claims.(type) {
- case jwt.MapClaims:
- test.Equal(claims["name"], "John Doe", tc.info)
- case *jwtCustomClaims:
- test.Equal(claims.Name, "John Doe", tc.info)
- test.Equal(claims.Admin, true, tc.info)
- default:
- panic("unexpected type of claims")
- }
- }
- }
-}
diff --git a/middleware/key_auth.go b/middleware/key_auth.go
index 94cfd1429..e14bd9e2e 100644
--- a/middleware/key_auth.go
+++ b/middleware/key_auth.go
@@ -1,51 +1,118 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
+ "cmp"
"errors"
+ "fmt"
"net/http"
- "strings"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
)
-type (
- // KeyAuthConfig defines the config for KeyAuth middleware.
- KeyAuthConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
-
- // KeyLookup is a string in the form of ":" that is used
- // to extract key from the request.
- // Optional. Default value "header:Authorization".
- // Possible values:
- // - "header:"
- // - "query:"
- // - "form:"
- KeyLookup string `yaml:"key_lookup"`
-
- // AuthScheme to be used in the Authorization header.
- // Optional. Default value "Bearer".
- AuthScheme string
-
- // Validator is a function to validate key.
- // Required.
- Validator KeyAuthValidator
- }
-
- // KeyAuthValidator defines a function to validate KeyAuth credentials.
- KeyAuthValidator func(string, echo.Context) (bool, error)
-
- keyExtractor func(echo.Context) (string, error)
-)
+// KeyAuthConfig defines the config for KeyAuth middleware.
+//
+// SECURITY: The Validator function is responsible for securely comparing API keys.
+// See KeyAuthValidator documentation for guidance on preventing timing attacks.
+type KeyAuthConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // KeyLookup is a string in the form of ":" or ":,:" that is used
+ // to extract key from the request.
+ // Optional. Default value "header:Authorization".
+ // Possible values:
+ // - "header:" or "header::"
+ // `` is argument value to cut/trim prefix of the extracted value. This is useful if header
+ // value has static prefix like `Authorization: ` where part that we
+ // want to cut is ` ` note the space at the end.
+ // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `.
+ // - "query:"
+ // - "form:"
+ // - "cookie:"
+ // Multiple sources example:
+ // - "header:Authorization,header:X-Api-Key"
+ KeyLookup string
+
+ // AllowedCheckLimit set how many KeyLookup values are allowed to be checked. This is
+ // useful environments like corporate test environments with application proxies restricting
+ // access to environment with their own auth scheme.
+ AllowedCheckLimit uint
+
+ // Validator is a function to validate key.
+ // Required.
+ Validator KeyAuthValidator
+
+ // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator
+ // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key.
+ // It may be used to define a custom error.
+ //
+ // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler.
+ // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
+ // In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain.
+ ErrorHandler KeyAuthErrorHandler
+
+ // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to
+ // ignore the error (by returning `nil`).
+ // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality.
+ // In that case you can use ErrorHandler to set a default public key auth value in the request context
+ // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then.
+ ContinueOnIgnoredError bool
+}
-var (
- // DefaultKeyAuthConfig is the default KeyAuth middleware config.
- DefaultKeyAuthConfig = KeyAuthConfig{
- Skipper: DefaultSkipper,
- KeyLookup: "header:" + echo.HeaderAuthorization,
- AuthScheme: "Bearer",
- }
-)
+// KeyAuthValidator defines a function to validate KeyAuth credentials.
+//
+// SECURITY WARNING: To prevent timing attacks that could allow attackers to enumerate
+// valid API keys, validator implementations MUST use constant-time comparison.
+// Use crypto/subtle.ConstantTimeCompare instead of standard string equality (==)
+// or switch statements.
+//
+// Example of SECURE implementation:
+//
+// import "crypto/subtle"
+//
+// validator := func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
+// // Fetch valid keys from database/config
+// validKeys := []string{"key1", "key2", "key3"}
+//
+// for _, validKey := range validKeys {
+// // Use constant-time comparison to prevent timing attacks
+// if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 {
+// return true, nil
+// }
+// }
+// return false, nil
+// }
+//
+// Example of INSECURE implementation (DO NOT USE):
+//
+// // VULNERABLE TO TIMING ATTACKS - DO NOT USE
+// validator := func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
+// switch key { // Timing leak!
+// case "valid-key":
+// return true, nil
+// default:
+// return false, nil
+// }
+// }
+type KeyAuthValidator func(c *echo.Context, key string, source ExtractorSource) (bool, error)
+
+// KeyAuthErrorHandler defines a function which is executed for an invalid key.
+type KeyAuthErrorHandler func(c *echo.Context, err error) error
+
+// ErrKeyMissing denotes an error raised when key value could not be extracted from request
+var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key")
+
+// ErrInvalidKey denotes an error raised when key value is invalid by validator
+var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key")
+
+// DefaultKeyAuthConfig is the default KeyAuth middleware config.
+var DefaultKeyAuthConfig = KeyAuthConfig{
+ Skipper: DefaultSkipper,
+ KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ",
+}
// KeyAuth returns an KeyAuth middleware.
//
@@ -58,96 +125,81 @@ func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc {
return KeyAuthWithConfig(c)
}
-// KeyAuthWithConfig returns an KeyAuth middleware with config.
-// See `KeyAuth()`.
+// KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid.
+//
+// For first valid key it calls the next handler.
+// For invalid key, it sends "401 - Unauthorized" response.
+// For missing key, it sends "400 - Bad Request" response.
func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
- // Defaults
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration
+func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultKeyAuthConfig.Skipper
}
- // Defaults
- if config.AuthScheme == "" {
- config.AuthScheme = DefaultKeyAuthConfig.AuthScheme
- }
if config.KeyLookup == "" {
config.KeyLookup = DefaultKeyAuthConfig.KeyLookup
}
if config.Validator == nil {
- panic("echo: key-auth middleware requires a validator function")
+ return nil, errors.New("echo key-auth middleware requires a validator function")
}
- // Initialize
- parts := strings.Split(config.KeyLookup, ":")
- extractor := keyFromHeader(parts[1], config.AuthScheme)
- switch parts[0] {
- case "query":
- extractor = keyFromQuery(parts[1])
- case "form":
- extractor = keyFromForm(parts[1])
+ limit := cmp.Or(config.AllowedCheckLimit, 1)
+
+ extractors, cErr := createExtractors(config.KeyLookup, limit)
+ if cErr != nil {
+ return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", cErr)
+ }
+ if len(extractors) == 0 {
+ return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) error {
+ return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
- // Extract and verify key
- key, err := extractor(c)
- if err != nil {
- return echo.NewHTTPError(http.StatusBadRequest, err.Error())
- }
- valid, err := config.Validator(key, c)
- if err != nil {
- return &echo.HTTPError{
- Code: http.StatusUnauthorized,
- Message: "invalid key",
- Internal: err,
+ var lastExtractorErr error
+ var lastValidatorErr error
+ for _, extractor := range extractors {
+ keys, source, extrErr := extractor(c)
+ if extrErr != nil {
+ lastExtractorErr = extrErr
+ continue
+ }
+ for _, key := range keys {
+ valid, err := config.Validator(c, key, source)
+ if err != nil {
+ lastValidatorErr = err
+ continue
+ }
+ if !valid {
+ lastValidatorErr = ErrInvalidKey
+ continue
+ }
+ return next(c)
}
- } else if valid {
- return next(c)
}
- return echo.ErrUnauthorized
- }
- }
-}
-// keyFromHeader returns a `keyExtractor` that extracts key from the request header.
-func keyFromHeader(header string, authScheme string) keyExtractor {
- return func(c echo.Context) (string, error) {
- auth := c.Request().Header.Get(header)
- if auth == "" {
- return "", errors.New("missing key in request header")
- }
- if header == echo.HeaderAuthorization {
- l := len(authScheme)
- if len(auth) > l+1 && auth[:l] == authScheme {
- return auth[l+1:], nil
+ // prioritize validator errors over extracting errors
+ err := lastValidatorErr
+ if err == nil {
+ err = lastExtractorErr
}
- return "", errors.New("invalid key in the request header")
- }
- return auth, nil
- }
-}
-
-// keyFromQuery returns a `keyExtractor` that extracts key from the query string.
-func keyFromQuery(param string) keyExtractor {
- return func(c echo.Context) (string, error) {
- key := c.QueryParam(param)
- if key == "" {
- return "", errors.New("missing key in the query string")
- }
- return key, nil
- }
-}
-
-// keyFromForm returns a `keyExtractor` that extracts key from the form.
-func keyFromForm(param string) keyExtractor {
- return func(c echo.Context) (string, error) {
- key := c.FormValue(param)
- if key == "" {
- return "", errors.New("missing key in the form")
+ if config.ErrorHandler != nil {
+ tmpErr := config.ErrorHandler(c, err)
+ if config.ContinueOnIgnoredError && tmpErr == nil {
+ return next(c)
+ }
+ return tmpErr
+ }
+ if lastValidatorErr == nil {
+ return ErrKeyMissing.Wrap(err)
+ }
+ return echo.ErrUnauthorized.Wrap(err)
}
- return key, nil
- }
+ }, nil
}
diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go
index b874898c8..49a917ed3 100644
--- a/middleware/key_auth_test.go
+++ b/middleware/key_auth_test.go
@@ -1,75 +1,375 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
+ "crypto/subtle"
+ "errors"
"net/http"
"net/http/httptest"
- "net/url"
"strings"
"testing"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
+func testKeyValidator(c *echo.Context, key string, source ExtractorSource) (bool, error) {
+ // Use constant-time comparison to prevent timing attacks
+ if subtle.ConstantTimeCompare([]byte(key), []byte("valid-key")) == 1 {
+ return true, nil
+ }
+
+ // Special case for testing error handling
+ if key == "error-key" { // Error path doesn't need constant-time
+ return false, errors.New("some user defined error")
+ }
+
+ return false, nil
+}
+
func TestKeyAuth(t *testing.T) {
+ handlerCalled := false
+ handler := func(c *echo.Context) error {
+ handlerCalled = true
+ return c.String(http.StatusOK, "test")
+ }
+ middlewareChain := KeyAuth(testKeyValidator)(handler)
+
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- config := KeyAuthConfig{
- Validator: func(key string, c echo.Context) (bool, error) {
- return key == "valid-key", nil
+
+ err := middlewareChain(c)
+
+ assert.NoError(t, err)
+ assert.True(t, handlerCalled)
+}
+
+func TestKeyAuthWithConfig(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenRequestFunc func() *http.Request
+ givenRequest func(req *http.Request)
+ whenConfig func(conf *KeyAuthConfig)
+ expectHandlerCalled bool
+ expectError string
+ }{
+ {
+ name: "ok, defaults, key from header",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key")
+ },
+ expectHandlerCalled: true,
+ },
+ {
+ name: "ok, custom skipper",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.Skipper = func(context *echo.Context) bool {
+ return true
+ }
+ },
+ expectHandlerCalled: true,
+ },
+ {
+ name: "nok, defaults, invalid key from header, Authorization: Bearer",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
+ },
+ expectHandlerCalled: false,
+ expectError: "code=401, message=Unauthorized, err=code=401, message=invalid key",
+ },
+ {
+ name: "nok, defaults, invalid scheme in header",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "Bear valid-key")
+ },
+ expectHandlerCalled: false,
+ expectError: "code=401, message=missing key, err=invalid value in request header",
+ },
+ {
+ name: "nok, defaults, missing header",
+ givenRequest: func(req *http.Request) {},
+ expectHandlerCalled: false,
+ expectError: "code=401, message=missing key, err=missing value in request header",
+ },
+ {
+ name: "ok, custom key lookup, header",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set("API-Key", "valid-key")
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "header:API-Key"
+ },
+ expectHandlerCalled: true,
+ },
+ {
+ name: "nok, custom key lookup, missing header",
+ givenRequest: func(req *http.Request) {
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "header:API-Key"
+ },
+ expectHandlerCalled: false,
+ expectError: "code=401, message=missing key, err=missing value in request header",
+ },
+ {
+ name: "ok, custom key lookup, query",
+ givenRequest: func(req *http.Request) {
+ q := req.URL.Query()
+ q.Add("key", "valid-key")
+ req.URL.RawQuery = q.Encode()
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "query:key"
+ },
+ expectHandlerCalled: true,
+ },
+ {
+ name: "nok, custom key lookup, missing query param",
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "query:key"
+ },
+ expectHandlerCalled: false,
+ expectError: "code=401, message=missing key, err=missing value in the query string",
+ },
+ {
+ name: "ok, custom key lookup, form",
+ givenRequestFunc: func() *http.Request {
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("key=valid-key"))
+ req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
+ return req
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "form:key"
+ },
+ expectHandlerCalled: true,
+ },
+ {
+ name: "nok, custom key lookup, missing key in form",
+ givenRequestFunc: func() *http.Request {
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("xxx=valid-key"))
+ req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
+ return req
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "form:key"
+ },
+ expectHandlerCalled: false,
+ expectError: "code=401, message=missing key, err=missing value in the form",
+ },
+ {
+ name: "ok, custom key lookup, cookie",
+ givenRequest: func(req *http.Request) {
+ req.AddCookie(&http.Cookie{
+ Name: "key",
+ Value: "valid-key",
+ })
+ q := req.URL.Query()
+ q.Add("key", "valid-key")
+ req.URL.RawQuery = q.Encode()
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "cookie:key"
+ },
+ expectHandlerCalled: true,
+ },
+ {
+ name: "nok, custom key lookup, missing cookie param",
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "cookie:key"
+ },
+ expectHandlerCalled: false,
+ expectError: "code=401, message=missing key, err=missing value in cookies",
+ },
+ {
+ name: "nok, custom errorHandler, error from extractor",
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "header:token"
+ conf.ErrorHandler = func(c *echo.Context, err error) error {
+ return echo.NewHTTPError(http.StatusTeapot, "custom").Wrap(err)
+ }
+ },
+ expectHandlerCalled: false,
+ expectError: "code=418, message=custom, err=missing value in request header",
+ },
+ {
+ name: "nok, custom errorHandler, error from validator",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.ErrorHandler = func(c *echo.Context, err error) error {
+ return echo.NewHTTPError(http.StatusTeapot, "custom").Wrap(err)
+ }
+ },
+ expectHandlerCalled: false,
+ expectError: "code=418, message=custom, err=some user defined error",
+ },
+ {
+ name: "nok, defaults, error from validator",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
+ },
+ whenConfig: func(conf *KeyAuthConfig) {},
+ expectHandlerCalled: false,
+ expectError: "code=401, message=Unauthorized, err=some user defined error",
+ },
+ {
+ name: "ok, custom validator checks source",
+ givenRequest: func(req *http.Request) {
+ q := req.URL.Query()
+ q.Add("key", "valid-key")
+ req.URL.RawQuery = q.Encode()
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "query:key"
+ conf.Validator = func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
+ if source == ExtractorSourceQuery {
+ return true, nil
+ }
+ return false, errors.New("invalid source")
+ }
+
+ },
+ expectHandlerCalled: true,
},
}
- h := KeyAuthWithConfig(config)(func(c echo.Context) error {
- return c.String(http.StatusOK, "test")
- })
- assert := assert.New(t)
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ handlerCalled := false
+ handler := func(c *echo.Context) error {
+ handlerCalled = true
+ return c.String(http.StatusOK, "test")
+ }
+ config := KeyAuthConfig{
+ Validator: testKeyValidator,
+ }
+ if tc.whenConfig != nil {
+ tc.whenConfig(&config)
+ }
+ middlewareChain := KeyAuthWithConfig(config)(handler)
- // Valid key
- auth := DefaultKeyAuthConfig.AuthScheme + " " + "valid-key"
- req.Header.Set(echo.HeaderAuthorization, auth)
- assert.NoError(h(c))
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ if tc.givenRequestFunc != nil {
+ req = tc.givenRequestFunc()
+ }
+ if tc.givenRequest != nil {
+ tc.givenRequest(req)
+ }
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
- // Invalid key
- auth = DefaultKeyAuthConfig.AuthScheme + " " + "invalid-key"
- req.Header.Set(echo.HeaderAuthorization, auth)
- he := h(c).(*echo.HTTPError)
- assert.Equal(http.StatusUnauthorized, he.Code)
+ err := middlewareChain(c)
- // Missing Authorization header
- req.Header.Del(echo.HeaderAuthorization)
- he = h(c).(*echo.HTTPError)
- assert.Equal(http.StatusBadRequest, he.Code)
+ assert.Equal(t, tc.expectHandlerCalled, handlerCalled)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
- // Key from custom header
- config.KeyLookup = "header:API-Key"
- h = KeyAuthWithConfig(config)(func(c echo.Context) error {
- return c.String(http.StatusOK, "test")
- })
- req.Header.Set("API-Key", "valid-key")
- assert.NoError(h(c))
+func TestKeyAuthWithConfig_errors(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenConfig KeyAuthConfig
+ expectError string
+ }{
+ {
+ name: "ok, no error",
+ whenConfig: KeyAuthConfig{
+ Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
+ return false, nil
+ },
+ },
+ },
+ {
+ name: "ok, missing validator func",
+ whenConfig: KeyAuthConfig{
+ Validator: nil,
+ },
+ expectError: "echo key-auth middleware requires a validator function",
+ },
+ {
+ name: "ok, extractor source can not be split",
+ whenConfig: KeyAuthConfig{
+ KeyLookup: "nope",
+ Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
+ return false, nil
+ },
+ },
+ expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope",
+ },
+ {
+ name: "ok, no extractors",
+ whenConfig: KeyAuthConfig{
+ KeyLookup: "nope:nope",
+ Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
+ return false, nil
+ },
+ },
+ expectError: "echo key-auth middleware could not create extractors from KeyLookup string",
+ },
+ }
- // Key from query string
- config.KeyLookup = "query:key"
- h = KeyAuthWithConfig(config)(func(c echo.Context) error {
- return c.String(http.StatusOK, "test")
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ mw, err := tc.whenConfig.ToMiddleware()
+ if tc.expectError != "" {
+ assert.Nil(t, mw)
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NotNil(t, mw)
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestMustKeyAuthWithConfig_panic(t *testing.T) {
+ assert.Panics(t, func() {
+ KeyAuthWithConfig(KeyAuthConfig{})
})
- q := req.URL.Query()
- q.Add("key", "valid-key")
- req.URL.RawQuery = q.Encode()
- assert.NoError(h(c))
-
- // Key from form
- config.KeyLookup = "form:key"
- h = KeyAuthWithConfig(config)(func(c echo.Context) error {
+}
+
+func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) {
+ handlerCalled := false
+ var authValue string
+ handler := func(c *echo.Context) error {
+ handlerCalled = true
+ authValue = c.Get("auth").(string)
return c.String(http.StatusOK, "test")
- })
- f := make(url.Values)
- f.Set("key", "valid-key")
- req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
- req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
- c = e.NewContext(req, rec)
- assert.NoError(h(c))
+ }
+ middlewareChain := KeyAuthWithConfig(KeyAuthConfig{
+ Validator: testKeyValidator,
+ ErrorHandler: func(c *echo.Context, err error) error {
+ // could check error to decide if we can swallow the error
+ c.Set("auth", "public")
+ return nil
+ },
+ ContinueOnIgnoredError: true,
+ })(handler)
+
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ // no auth header this time
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := middlewareChain(c)
+
+ assert.NoError(t, err)
+ assert.True(t, handlerCalled)
+ assert.Equal(t, "public", authValue)
}
diff --git a/middleware/logger.go b/middleware/logger.go
deleted file mode 100644
index 9baac4769..000000000
--- a/middleware/logger.go
+++ /dev/null
@@ -1,223 +0,0 @@
-package middleware
-
-import (
- "bytes"
- "encoding/json"
- "io"
- "strconv"
- "strings"
- "sync"
- "time"
-
- "github.com/labstack/echo/v4"
- "github.com/labstack/gommon/color"
- "github.com/valyala/fasttemplate"
-)
-
-type (
- // LoggerConfig defines the config for Logger middleware.
- LoggerConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
-
- // Tags to construct the logger format.
- //
- // - time_unix
- // - time_unix_nano
- // - time_rfc3339
- // - time_rfc3339_nano
- // - time_custom
- // - id (Request ID)
- // - remote_ip
- // - uri
- // - host
- // - method
- // - path
- // - protocol
- // - referer
- // - user_agent
- // - status
- // - error
- // - latency (In nanoseconds)
- // - latency_human (Human readable)
- // - bytes_in (Bytes received)
- // - bytes_out (Bytes sent)
- // - header:
- // - query:
- // - form:
- //
- // Example "${remote_ip} ${status}"
- //
- // Optional. Default value DefaultLoggerConfig.Format.
- Format string `yaml:"format"`
-
- // Optional. Default value DefaultLoggerConfig.CustomTimeFormat.
- CustomTimeFormat string `yaml:"custom_time_format"`
-
- // Output is a writer where logs in JSON format are written.
- // Optional. Default value os.Stdout.
- Output io.Writer
-
- template *fasttemplate.Template
- colorer *color.Color
- pool *sync.Pool
- }
-)
-
-var (
- // DefaultLoggerConfig is the default Logger middleware config.
- DefaultLoggerConfig = LoggerConfig{
- Skipper: DefaultSkipper,
- Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` +
- `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` +
- `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` +
- `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n",
- CustomTimeFormat: "2006-01-02 15:04:05.00000",
- colorer: color.New(),
- }
-)
-
-// Logger returns a middleware that logs HTTP requests.
-func Logger() echo.MiddlewareFunc {
- return LoggerWithConfig(DefaultLoggerConfig)
-}
-
-// LoggerWithConfig returns a Logger middleware with config.
-// See: `Logger()`.
-func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
- // Defaults
- if config.Skipper == nil {
- config.Skipper = DefaultLoggerConfig.Skipper
- }
- if config.Format == "" {
- config.Format = DefaultLoggerConfig.Format
- }
- if config.Output == nil {
- config.Output = DefaultLoggerConfig.Output
- }
-
- config.template = fasttemplate.New(config.Format, "${", "}")
- config.colorer = color.New()
- config.colorer.SetOutput(config.Output)
- config.pool = &sync.Pool{
- New: func() interface{} {
- return bytes.NewBuffer(make([]byte, 256))
- },
- }
-
- return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) (err error) {
- if config.Skipper(c) {
- return next(c)
- }
-
- req := c.Request()
- res := c.Response()
- start := time.Now()
- if err = next(c); err != nil {
- c.Error(err)
- }
- stop := time.Now()
- buf := config.pool.Get().(*bytes.Buffer)
- buf.Reset()
- defer config.pool.Put(buf)
-
- if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) {
- switch tag {
- case "time_unix":
- return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10))
- case "time_unix_nano":
- return buf.WriteString(strconv.FormatInt(time.Now().UnixNano(), 10))
- case "time_rfc3339":
- return buf.WriteString(time.Now().Format(time.RFC3339))
- case "time_rfc3339_nano":
- return buf.WriteString(time.Now().Format(time.RFC3339Nano))
- case "time_custom":
- return buf.WriteString(time.Now().Format(config.CustomTimeFormat))
- case "id":
- id := req.Header.Get(echo.HeaderXRequestID)
- if id == "" {
- id = res.Header().Get(echo.HeaderXRequestID)
- }
- return buf.WriteString(id)
- case "remote_ip":
- return buf.WriteString(c.RealIP())
- case "host":
- return buf.WriteString(req.Host)
- case "uri":
- return buf.WriteString(req.RequestURI)
- case "method":
- return buf.WriteString(req.Method)
- case "path":
- p := req.URL.Path
- if p == "" {
- p = "/"
- }
- return buf.WriteString(p)
- case "protocol":
- return buf.WriteString(req.Proto)
- case "referer":
- return buf.WriteString(req.Referer())
- case "user_agent":
- return buf.WriteString(req.UserAgent())
- case "status":
- n := res.Status
- s := config.colorer.Green(n)
- switch {
- case n >= 500:
- s = config.colorer.Red(n)
- case n >= 400:
- s = config.colorer.Yellow(n)
- case n >= 300:
- s = config.colorer.Cyan(n)
- }
- return buf.WriteString(s)
- case "error":
- if err != nil {
- // Error may contain invalid JSON e.g. `"`
- b, _ := json.Marshal(err.Error())
- b = b[1 : len(b)-1]
- return buf.Write(b)
- }
- case "latency":
- l := stop.Sub(start)
- return buf.WriteString(strconv.FormatInt(int64(l), 10))
- case "latency_human":
- return buf.WriteString(stop.Sub(start).String())
- case "bytes_in":
- cl := req.Header.Get(echo.HeaderContentLength)
- if cl == "" {
- cl = "0"
- }
- return buf.WriteString(cl)
- case "bytes_out":
- return buf.WriteString(strconv.FormatInt(res.Size, 10))
- default:
- switch {
- case strings.HasPrefix(tag, "header:"):
- return buf.Write([]byte(c.Request().Header.Get(tag[7:])))
- case strings.HasPrefix(tag, "query:"):
- return buf.Write([]byte(c.QueryParam(tag[6:])))
- case strings.HasPrefix(tag, "form:"):
- return buf.Write([]byte(c.FormValue(tag[5:])))
- case strings.HasPrefix(tag, "cookie:"):
- cookie, err := c.Cookie(tag[7:])
- if err == nil {
- return buf.Write([]byte(cookie.Value))
- }
- }
- }
- return 0, nil
- }); err != nil {
- return
- }
-
- if config.Output == nil {
- _, err = c.Logger().Output().Write(buf.Bytes())
- return
- }
- _, err = config.Output.Write(buf.Bytes())
- return
- }
- }
-}
diff --git a/middleware/logger_test.go b/middleware/logger_test.go
deleted file mode 100644
index b196bc6c8..000000000
--- a/middleware/logger_test.go
+++ /dev/null
@@ -1,173 +0,0 @@
-package middleware
-
-import (
- "bytes"
- "encoding/json"
- "errors"
- "net/http"
- "net/http/httptest"
- "net/url"
- "strings"
- "testing"
- "time"
- "unsafe"
-
- "github.com/labstack/echo/v4"
- "github.com/stretchr/testify/assert"
-)
-
-func TestLogger(t *testing.T) {
- // Note: Just for the test coverage, not a real test.
- e := echo.New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
- h := Logger()(func(c echo.Context) error {
- return c.String(http.StatusOK, "test")
- })
-
- // Status 2xx
- h(c)
-
- // Status 3xx
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
- h = Logger()(func(c echo.Context) error {
- return c.String(http.StatusTemporaryRedirect, "test")
- })
- h(c)
-
- // Status 4xx
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
- h = Logger()(func(c echo.Context) error {
- return c.String(http.StatusNotFound, "test")
- })
- h(c)
-
- // Status 5xx with empty path
- req = httptest.NewRequest(http.MethodGet, "/", nil)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
- h = Logger()(func(c echo.Context) error {
- return errors.New("error")
- })
- h(c)
-}
-
-func TestLoggerIPAddress(t *testing.T) {
- e := echo.New()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- c := e.NewContext(req, rec)
- buf := new(bytes.Buffer)
- e.Logger.SetOutput(buf)
- ip := "127.0.0.1"
- h := Logger()(func(c echo.Context) error {
- return c.String(http.StatusOK, "test")
- })
-
- // With X-Real-IP
- req.Header.Add(echo.HeaderXRealIP, ip)
- h(c)
- assert.Contains(t, buf.String(), ip)
-
- // With X-Forwarded-For
- buf.Reset()
- req.Header.Del(echo.HeaderXRealIP)
- req.Header.Add(echo.HeaderXForwardedFor, ip)
- h(c)
- assert.Contains(t, buf.String(), ip)
-
- buf.Reset()
- h(c)
- assert.Contains(t, buf.String(), ip)
-}
-
-func TestLoggerTemplate(t *testing.T) {
- buf := new(bytes.Buffer)
-
- e := echo.New()
- e.Use(LoggerWithConfig(LoggerConfig{
- Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` +
- `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` +
- `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` +
- `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}", "protocol":"${protocol}"` +
- `"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n",
- Output: buf,
- }))
-
- e.GET("/", func(c echo.Context) error {
- return c.String(http.StatusOK, "Header Logged")
- })
-
- req := httptest.NewRequest(http.MethodGet, "/?username=apagano-param&password=secret", nil)
- req.RequestURI = "/"
- req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
- req.Header.Add("Referer", "google.com")
- req.Header.Add("User-Agent", "echo-tests-agent")
- req.Header.Add("X-Custom-Header", "AAA-CUSTOM-VALUE")
- req.Header.Add("X-Request-ID", "6ba7b810-9dad-11d1-80b4-00c04fd430c8")
- req.Header.Add("Cookie", "_ga=GA1.2.000000000.0000000000; session=ac08034cd216a647fc2eb62f2bcf7b810")
- req.Form = url.Values{
- "username": []string{"apagano-form"},
- "password": []string{"secret-form"},
- }
-
- rec := httptest.NewRecorder()
- e.ServeHTTP(rec, req)
-
- cases := map[string]bool{
- "apagano-param": true,
- "apagano-form": true,
- "AAA-CUSTOM-VALUE": true,
- "BBB-CUSTOM-VALUE": false,
- "secret-form": false,
- "hexvalue": false,
- "GET": true,
- "127.0.0.1": true,
- "\"path\":\"/\"": true,
- "\"uri\":\"/\"": true,
- "\"status\":200": true,
- "\"bytes_in\":0": true,
- "google.com": true,
- "echo-tests-agent": true,
- "6ba7b810-9dad-11d1-80b4-00c04fd430c8": true,
- "ac08034cd216a647fc2eb62f2bcf7b810": true,
- }
-
- for token, present := range cases {
- assert.True(t, strings.Contains(buf.String(), token) == present, "Case: "+token)
- }
-}
-
-func TestLoggerCustomTimestamp(t *testing.T) {
- buf := new(bytes.Buffer)
- customTimeFormat := "2006-01-02 15:04:05.00000"
- e := echo.New()
- e.Use(LoggerWithConfig(LoggerConfig{
- Format: `{"time":"${time_custom}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` +
- `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` +
- `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` +
- `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}",` +
- `"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n",
- CustomTimeFormat: customTimeFormat,
- Output: buf,
- }))
-
- e.GET("/", func(c echo.Context) error {
- return c.String(http.StatusOK, "custom time stamp test")
- })
-
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- e.ServeHTTP(rec, req)
-
- var objs map[string]*json.RawMessage
- if err := json.Unmarshal([]byte(buf.String()), &objs); err != nil {
- panic(err)
- }
- loggedTime := *(*string)(unsafe.Pointer(objs["time"]))
- _, err := time.Parse(customTimeFormat, loggedTime)
- assert.Error(t, err)
-}
diff --git a/middleware/method_override.go b/middleware/method_override.go
index 92b14d2ed..25ec1f935 100644
--- a/middleware/method_override.go
+++ b/middleware/method_override.go
@@ -1,33 +1,32 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
"net/http"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
)
-type (
- // MethodOverrideConfig defines the config for MethodOverride middleware.
- MethodOverrideConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
+// MethodOverrideConfig defines the config for MethodOverride middleware.
+type MethodOverrideConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
- // Getter is a function that gets overridden method from the request.
- // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride).
- Getter MethodOverrideGetter
- }
+ // Getter is a function that gets overridden method from the request.
+ // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride).
+ Getter MethodOverrideGetter
+}
- // MethodOverrideGetter is a function that gets overridden method from the request
- MethodOverrideGetter func(echo.Context) string
-)
+// MethodOverrideGetter is a function that gets overridden method from the request
+type MethodOverrideGetter func(c *echo.Context) string
-var (
- // DefaultMethodOverrideConfig is the default MethodOverride middleware config.
- DefaultMethodOverrideConfig = MethodOverrideConfig{
- Skipper: DefaultSkipper,
- Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
- }
-)
+// DefaultMethodOverrideConfig is the default MethodOverride middleware config.
+var DefaultMethodOverrideConfig = MethodOverrideConfig{
+ Skipper: DefaultSkipper,
+ Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
+}
// MethodOverride returns a MethodOverride middleware.
// MethodOverride middleware checks for the overridden method from the request and
@@ -38,9 +37,13 @@ func MethodOverride() echo.MiddlewareFunc {
return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
}
-// MethodOverrideWithConfig returns a MethodOverride middleware with config.
-// See: `MethodOverride()`.
+// MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration.
func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration
+func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultMethodOverrideConfig.Skipper
@@ -50,7 +53,7 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) error {
+ return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
@@ -64,13 +67,13 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
}
return next(c)
}
- }
+ }, nil
}
// MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from
// the request header.
func MethodFromHeader(header string) MethodOverrideGetter {
- return func(c echo.Context) string {
+ return func(c *echo.Context) string {
return c.Request().Header.Get(header)
}
}
@@ -78,7 +81,7 @@ func MethodFromHeader(header string) MethodOverrideGetter {
// MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the
// form parameter.
func MethodFromForm(param string) MethodOverrideGetter {
- return func(c echo.Context) string {
+ return func(c *echo.Context) string {
return c.FormValue(param)
}
}
@@ -86,7 +89,7 @@ func MethodFromForm(param string) MethodOverrideGetter {
// MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from
// the query parameter.
func MethodFromQuery(param string) MethodOverrideGetter {
- return func(c echo.Context) string {
+ return func(c *echo.Context) string {
return c.QueryParam(param)
}
}
diff --git a/middleware/method_override_test.go b/middleware/method_override_test.go
index 5760b1581..525ad10ba 100644
--- a/middleware/method_override_test.go
+++ b/middleware/method_override_test.go
@@ -1,3 +1,6 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
@@ -6,14 +9,14 @@ import (
"net/http/httptest"
"testing"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestMethodOverride(t *testing.T) {
e := echo.New()
m := MethodOverride()
- h := func(c echo.Context) error {
+ h := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
@@ -22,28 +25,68 @@ func TestMethodOverride(t *testing.T) {
rec := httptest.NewRecorder()
req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
c := e.NewContext(req, rec)
- m(h)(c)
+
+ err := m(h)(c)
+ assert.NoError(t, err)
+
assert.Equal(t, http.MethodDelete, req.Method)
+}
+
+func TestMethodOverride_formParam(t *testing.T) {
+ e := echo.New()
+ h := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
// Override with form parameter
- m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")})
- req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete)))
- rec = httptest.NewRecorder()
+ m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware()
+ assert.NoError(t, err)
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete)))
+ rec := httptest.NewRecorder()
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
- c = e.NewContext(req, rec)
- m(h)(c)
+ c := e.NewContext(req, rec)
+
+ err = m(h)(c)
+ assert.NoError(t, err)
+
assert.Equal(t, http.MethodDelete, req.Method)
+}
+
+func TestMethodOverride_queryParam(t *testing.T) {
+ e := echo.New()
+ h := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
// Override with query parameter
- m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")})
- req = httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil)
- rec = httptest.NewRecorder()
- c = e.NewContext(req, rec)
- m(h)(c)
+ m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware()
+ assert.NoError(t, err)
+ req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err = m(h)(c)
+ assert.NoError(t, err)
+
assert.Equal(t, http.MethodDelete, req.Method)
+}
+
+func TestMethodOverride_ignoreGet(t *testing.T) {
+ e := echo.New()
+ m := MethodOverride()
+ h := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
// Ignore `GET`
- req = httptest.NewRequest(http.MethodGet, "/", nil)
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := m(h)(c)
+ assert.NoError(t, err)
+
assert.Equal(t, http.MethodGet, req.Method)
}
diff --git a/middleware/middleware.go b/middleware/middleware.go
index d0b7153cb..4562d03b5 100644
--- a/middleware/middleware.go
+++ b/middleware/middleware.go
@@ -1,21 +1,22 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
+ "net/http"
"regexp"
"strconv"
"strings"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
)
-type (
- // Skipper defines a function to skip middleware. Returning true skips processing
- // the middleware.
- Skipper func(echo.Context) bool
+// Skipper defines a function to skip middleware. Returning true skips processing the middleware.
+type Skipper func(c *echo.Context) bool
- // BeforeFunc defines a function which is executed just before the middleware.
- BeforeFunc func(echo.Context)
-)
+// BeforeFunc defines a function which is executed just before the middleware.
+type BeforeFunc func(c *echo.Context)
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
groups := pattern.FindAllStringSubmatch(input, -1)
@@ -32,7 +33,65 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
return strings.NewReplacer(replace...)
}
+func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
+ // Initialize
+ rulesRegex := map[*regexp.Regexp]string{}
+ for k, v := range rewrite {
+ k = regexp.QuoteMeta(k)
+ k = strings.ReplaceAll(k, `\*`, "(.*?)")
+ if strings.HasPrefix(k, `\^`) {
+ k = strings.ReplaceAll(k, `\^`, "^")
+ }
+ k = k + "$"
+ rulesRegex[regexp.MustCompile(k)] = v
+ }
+ return rulesRegex
+}
+
+func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error {
+ if len(rewriteRegex) == 0 {
+ return nil
+ }
+
+ // Depending how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path.
+ // We only want to use path part for rewriting and therefore trim prefix if it exists
+ rawURI := req.RequestURI
+ if rawURI != "" && rawURI[0] != '/' {
+ prefix := ""
+ if req.URL.Scheme != "" {
+ prefix = req.URL.Scheme + "://"
+ }
+ if req.URL.Host != "" {
+ prefix += req.URL.Host // host or host:port
+ }
+ if prefix != "" {
+ rawURI = strings.TrimPrefix(rawURI, prefix)
+ }
+ }
+
+ for k, v := range rewriteRegex {
+ if replacer := captureTokens(k, rawURI); replacer != nil {
+ url, err := req.URL.Parse(replacer.Replace(v))
+ if err != nil {
+ return err
+ }
+ req.URL = url
+
+ return nil // rewrite only once
+ }
+ }
+ return nil
+}
+
// DefaultSkipper returns false which processes the middleware.
-func DefaultSkipper(echo.Context) bool {
+func DefaultSkipper(c *echo.Context) bool {
return false
}
+
+func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc {
+ mw, err := config.ToMiddleware()
+ if err != nil {
+ panic(err)
+ }
+ return mw
+}
diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go
new file mode 100644
index 000000000..28407ed5c
--- /dev/null
+++ b/middleware/middleware_test.go
@@ -0,0 +1,136 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "bufio"
+ "errors"
+ "github.com/stretchr/testify/assert"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "regexp"
+ "testing"
+)
+
+func TestRewriteURL(t *testing.T) {
+ var testCases = []struct {
+ whenURL string
+ expectPath string
+ expectRawPath string
+ expectQuery string
+ expectErr string
+ }{
+ {
+ whenURL: "http://localhost:8080/old",
+ expectPath: "/new",
+ expectRawPath: "",
+ },
+ { // encoded `ol%64` (decoded `old`) should not be rewritten to `/new`
+ whenURL: "/ol%64", // `%64` is decoded `d`
+ expectPath: "/old",
+ expectRawPath: "/ol%64",
+ },
+ {
+ whenURL: "http://localhost:8080/users/+_+/orders/___++++?test=1",
+ expectPath: "/user/+_+/order/___++++",
+ expectRawPath: "",
+ expectQuery: "test=1",
+ },
+ {
+ whenURL: "http://localhost:8080/users/%20a/orders/%20aa",
+ expectPath: "/user/ a/order/ aa",
+ expectRawPath: "",
+ },
+ {
+ whenURL: "http://localhost:8080/%47%6f%2f?test=1",
+ expectPath: "/Go/",
+ expectRawPath: "/%47%6f%2f",
+ expectQuery: "test=1",
+ },
+ {
+ whenURL: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F",
+ expectPath: "/user/jill/order/T/cO4lW/t/Vp/",
+ expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ },
+ { // do nothing, replace nothing
+ whenURL: "http://localhost:8080/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ expectPath: "/user/jill/order/T/cO4lW/t/Vp/",
+ expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ },
+ {
+ whenURL: "http://localhost:8080/static",
+ expectPath: "/static/path",
+ expectRawPath: "",
+ expectQuery: "role=AUTHOR&limit=1000",
+ },
+ {
+ whenURL: "/static",
+ expectPath: "/static/path",
+ expectRawPath: "",
+ expectQuery: "role=AUTHOR&limit=1000",
+ },
+ }
+
+ rules := map[*regexp.Regexp]string{
+ regexp.MustCompile("^/old$"): "/new",
+ regexp.MustCompile("^/users/(.*?)/orders/(.*?)$"): "/user/$1/order/$2",
+ regexp.MustCompile("^/static$"): "/static/path?role=AUTHOR&limit=1000",
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.whenURL, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+
+ err := rewriteURL(rules, req)
+
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expectPath, req.URL.Path) // Path field is stored in decoded form: /%47%6f%2f becomes /Go/.
+ assert.Equal(t, tc.expectRawPath, req.URL.RawPath) // RawPath, an optional field which only gets set if the default encoding is different from Path.
+ assert.Equal(t, tc.expectQuery, req.URL.RawQuery)
+ })
+ }
+}
+
+type testResponseWriterNoFlushHijack struct {
+}
+
+func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) {
+}
+func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) {
+ return 0, nil
+}
+func (w *testResponseWriterNoFlushHijack) Header() http.Header {
+ return nil
+}
+
+type testResponseWriterUnwrapper struct {
+ unwrapCalled int
+ rw http.ResponseWriter
+}
+
+func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) {
+}
+func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) {
+ return 0, nil
+}
+func (w *testResponseWriterUnwrapper) Header() http.Header {
+ return nil
+}
+func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter {
+ w.unwrapCalled++
+ return w.rw
+}
+
+type testResponseWriterUnwrapperHijack struct {
+ testResponseWriterUnwrapper
+}
+
+func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ return nil, nil, errors.New("can hijack")
+}
diff --git a/middleware/proxy.go b/middleware/proxy.go
index ef5602bd6..1996032f7 100644
--- a/middleware/proxy.go
+++ b/middleware/proxy.go
@@ -1,105 +1,157 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
+ "context"
+ "crypto/tls"
+ "errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
+ "net/http/httputil"
"net/url"
"regexp"
"strings"
"sync"
- "sync/atomic"
"time"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
)
// TODO: Handle TLS proxy
-type (
- // ProxyConfig defines the config for Proxy middleware.
- ProxyConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
-
- // Balancer defines a load balancing technique.
- // Required.
- Balancer ProxyBalancer
-
- // Rewrite defines URL path rewrite rules. The values captured in asterisk can be
- // retrieved by index e.g. $1, $2 and so on.
- // Examples:
- // "/old": "/new",
- // "/api/*": "/$1",
- // "/js/*": "/public/javascripts/$1",
- // "/users/*/orders/*": "/user/$1/order/$2",
- Rewrite map[string]string
-
- // Context key to store selected ProxyTarget into context.
- // Optional. Default value "target".
- ContextKey string
-
- // To customize the transport to remote.
- // Examples: If custom TLS certificates are required.
- Transport http.RoundTripper
-
- rewriteRegex map[*regexp.Regexp]string
- }
+// ProxyConfig defines the config for Proxy middleware.
+type ProxyConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // Balancer defines a load balancing technique.
+ // Required.
+ Balancer ProxyBalancer
+
+ // RetryCount defines the number of times a failed proxied request should be retried
+ // using the next available ProxyTarget. Defaults to 0, meaning requests are never retried.
+ RetryCount int
+
+ // RetryFilter defines a function used to determine if a failed request to a
+ // ProxyTarget should be retried. The RetryFilter will only be called when the number
+ // of previous retries is less than RetryCount. If the function returns true, the
+ // request will be retried. The provided error indicates the reason for the request
+ // failure. When the ProxyTarget is unavailable, the error will be an instance of
+ // echo.HTTPError with a code of http.StatusBadGateway. In all other cases, the error
+ // will indicate an internal error in the Proxy middleware. When a RetryFilter is not
+ // specified, all requests that fail with http.StatusBadGateway will be retried. A custom
+ // RetryFilter can be provided to only retry specific requests. Note that RetryFilter is
+ // only called when the request to the target fails, or an internal error in the Proxy
+ // middleware has occurred. Successful requests that return a non-200 response code cannot
+ // be retried.
+ RetryFilter func(c *echo.Context, e error) bool
+
+ // ErrorHandler defines a function which can be used to return custom errors from
+ // the Proxy middleware. ErrorHandler is only invoked when there has been
+ // either an internal error in the Proxy middleware or the ProxyTarget is
+ // unavailable. Due to the way requests are proxied, ErrorHandler is not invoked
+ // when a ProxyTarget returns a non-200 response. In these cases, the response
+ // is already written so errors cannot be modified. ErrorHandler is only
+ // invoked after all retry attempts have been exhausted.
+ ErrorHandler func(c *echo.Context, err error) error
+
+ // Rewrite defines URL path rewrite rules. The values captured in asterisk can be
+ // retrieved by index e.g. $1, $2 and so on.
+ // Examples:
+ // "/old": "/new",
+ // "/api/*": "/$1",
+ // "/js/*": "/public/javascripts/$1",
+ // "/users/*/orders/*": "/user/$1/order/$2",
+ Rewrite map[string]string
+
+ // RegexRewrite defines rewrite rules using regexp.Rexexp with captures
+ // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
+ // Example:
+ // "^/old/[0.9]+/": "/new",
+ // "^/api/.+?/(.*)": "/v2/$1",
+ RegexRewrite map[*regexp.Regexp]string
+
+ // Context key to store selected ProxyTarget into context.
+ // Optional. Default value "target".
+ ContextKey string
+
+ // To customize the transport to remote.
+ // Examples: If custom TLS certificates are required.
+ Transport http.RoundTripper
+
+ // ModifyResponse defines function to modify response from ProxyTarget.
+ ModifyResponse func(*http.Response) error
+}
- // ProxyTarget defines the upstream target.
- ProxyTarget struct {
- Name string
- URL *url.URL
- Meta echo.Map
- }
+// ProxyTarget defines the upstream target.
+type ProxyTarget struct {
+ Name string
+ URL *url.URL
+ Meta map[string]any
+}
- // ProxyBalancer defines an interface to implement a load balancing technique.
- ProxyBalancer interface {
- AddTarget(*ProxyTarget) bool
- RemoveTarget(string) bool
- Next(echo.Context) *ProxyTarget
- }
+// ProxyBalancer defines an interface to implement a load balancing technique.
+type ProxyBalancer interface {
+ AddTarget(target *ProxyTarget) bool
+ RemoveTarget(targetName string) bool
+ Next(c *echo.Context) (*ProxyTarget, error)
+}
- commonBalancer struct {
- targets []*ProxyTarget
- mutex sync.RWMutex
- }
+type commonBalancer struct {
+ targets []*ProxyTarget
+ mutex sync.Mutex
+}
- // RandomBalancer implements a random load balancing technique.
- randomBalancer struct {
- *commonBalancer
- random *rand.Rand
- }
+// RandomBalancer implements a random load balancing technique.
+type randomBalancer struct {
+ commonBalancer
+ random *rand.Rand
+}
- // RoundRobinBalancer implements a round-robin load balancing technique.
- roundRobinBalancer struct {
- *commonBalancer
- i uint32
- }
-)
+// RoundRobinBalancer implements a round-robin load balancing technique.
+type roundRobinBalancer struct {
+ commonBalancer
+ // tracking the index on `targets` slice for the next `*ProxyTarget` to be used
+ i int
+}
-var (
- // DefaultProxyConfig is the default Proxy middleware config.
- DefaultProxyConfig = ProxyConfig{
- Skipper: DefaultSkipper,
- ContextKey: "target",
+// DefaultProxyConfig is the default Proxy middleware config.
+var DefaultProxyConfig = ProxyConfig{
+ Skipper: DefaultSkipper,
+ ContextKey: "target",
+}
+
+func proxyRaw(c *echo.Context, t *ProxyTarget, config ProxyConfig) http.Handler {
+ var dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
+ if transport, ok := config.Transport.(*http.Transport); ok {
+ if transport.TLSClientConfig != nil {
+ d := tls.Dialer{
+ Config: transport.TLSClientConfig,
+ }
+ dialFunc = d.DialContext
+ }
+ }
+ if dialFunc == nil {
+ var d net.Dialer
+ dialFunc = d.DialContext
}
-)
-func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- in, _, err := c.Response().Hijack()
+ in, _, err := http.NewResponseController(w).Hijack()
if err != nil {
- c.Set("_error", fmt.Sprintf("proxy raw, hijack error=%v, url=%s", t.URL, err))
+ c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL))
return
}
defer in.Close()
- out, err := net.Dial("tcp", t.URL.Host)
+ out, err := dialFunc(c.Request().Context(), "tcp", t.URL.Host)
if err != nil {
- c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err)))
+ c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
return
}
defer out.Close()
@@ -107,53 +159,66 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
// Write header
err = r.Write(out)
if err != nil {
- c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err)))
+ c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", err, t.URL)))
return
}
errCh := make(chan error, 2)
cp := func(dst io.Writer, src io.Reader) {
- _, err = io.Copy(dst, src)
- errCh <- err
+ _, copyErr := io.Copy(dst, src)
+ errCh <- copyErr
}
go cp(out, in)
go cp(in, out)
- err = <-errCh
- if err != nil && err != io.EOF {
- c.Set("_error", fmt.Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err))
+
+ // Wait for BOTH goroutines to complete
+ err1 := <-errCh
+ err2 := <-errCh
+
+ if err1 != nil && err1 != io.EOF {
+ c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err1, t.URL))
+ } else if err2 != nil && err2 != io.EOF {
+ c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err2, t.URL))
}
})
}
// NewRandomBalancer returns a random proxy balancer.
func NewRandomBalancer(targets []*ProxyTarget) ProxyBalancer {
- b := &randomBalancer{commonBalancer: new(commonBalancer)}
+ b := randomBalancer{}
b.targets = targets
- return b
+ // G404 (CWE-338): Use of weak random number generator (math/rand or math/rand/v2 instead of crypto/rand)
+ // this random is used to select next target. I can not think of reason this must be cryptographically safe. If you can - please open PR.
+ b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) // #nosec G404
+ return &b
}
// NewRoundRobinBalancer returns a round-robin proxy balancer.
func NewRoundRobinBalancer(targets []*ProxyTarget) ProxyBalancer {
- b := &roundRobinBalancer{commonBalancer: new(commonBalancer)}
+ b := roundRobinBalancer{}
b.targets = targets
- return b
+ return &b
}
-// AddTarget adds an upstream target to the list.
+// AddTarget adds an upstream target to the list and returns `true`.
+//
+// However, if a target with the same name already exists then the operation is aborted returning `false`.
func (b *commonBalancer) AddTarget(target *ProxyTarget) bool {
+ b.mutex.Lock()
+ defer b.mutex.Unlock()
for _, t := range b.targets {
if t.Name == target.Name {
return false
}
}
- b.mutex.Lock()
- defer b.mutex.Unlock()
b.targets = append(b.targets, target)
return true
}
-// RemoveTarget removes an upstream target from the list.
+// RemoveTarget removes an upstream target from the list by name.
+//
+// Returns `true` on success, `false` if no target with the name is found.
func (b *commonBalancer) RemoveTarget(name string) bool {
b.mutex.Lock()
defer b.mutex.Unlock()
@@ -167,21 +232,57 @@ func (b *commonBalancer) RemoveTarget(name string) bool {
}
// Next randomly returns an upstream target.
-func (b *randomBalancer) Next(c echo.Context) *ProxyTarget {
- if b.random == nil {
- b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
+//
+// Note: `nil` is returned in case upstream target list is empty.
+func (b *randomBalancer) Next(c *echo.Context) (*ProxyTarget, error) {
+ b.mutex.Lock()
+ defer b.mutex.Unlock()
+ if len(b.targets) == 0 {
+ return nil, nil
+ } else if len(b.targets) == 1 {
+ return b.targets[0], nil
}
- b.mutex.RLock()
- defer b.mutex.RUnlock()
- return b.targets[b.random.Intn(len(b.targets))]
+ return b.targets[b.random.Intn(len(b.targets))], nil
}
-// Next returns an upstream target using round-robin technique.
-func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget {
- b.i = b.i % uint32(len(b.targets))
- t := b.targets[b.i]
- atomic.AddUint32(&b.i, 1)
- return t
+// Next returns an upstream target using round-robin technique. In the case
+// where a previously failed request is being retried, the round-robin
+// balancer will attempt to use the next target relative to the original
+// request. If the list of targets held by the balancer is modified while a
+// failed request is being retried, it is possible that the balancer will
+// return the original failed target.
+//
+// Note: `nil` is returned in case upstream target list is empty.
+func (b *roundRobinBalancer) Next(c *echo.Context) (*ProxyTarget, error) {
+ b.mutex.Lock()
+ defer b.mutex.Unlock()
+ if len(b.targets) == 0 {
+ return nil, nil
+ } else if len(b.targets) == 1 {
+ return b.targets[0], nil
+ }
+
+ var i int
+ const lastIdxKey = "_round_robin_last_index"
+ // This request is a retry, start from the index of the previous
+ // target to ensure we don't attempt to retry the request with
+ // the same failed target
+ if c.Get(lastIdxKey) != nil {
+ i = c.Get(lastIdxKey).(int)
+ i++
+ if i >= len(b.targets) {
+ i = 0
+ }
+ } else {
+ // This is a first time request, use the global index
+ if b.i >= len(b.targets) {
+ b.i = 0
+ }
+ i = b.i
+ b.i++
+ }
+ c.Set(lastIdxKey, i)
+ return b.targets[i], nil
}
// Proxy returns a Proxy middleware.
@@ -193,45 +294,63 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {
return ProxyWithConfig(c)
}
-// ProxyWithConfig returns a Proxy middleware with config.
-// See: `Proxy()`
+// ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid.
+//
+// Proxy middleware forwards the request to upstream server using a configured load balancing technique.
func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
- // Defaults
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration
+func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
config.Skipper = DefaultProxyConfig.Skipper
}
+ if config.ContextKey == "" {
+ config.ContextKey = DefaultProxyConfig.ContextKey
+ }
if config.Balancer == nil {
- panic("echo: proxy middleware requires balancer")
+ return nil, errors.New("echo proxy middleware requires balancer")
+ }
+ if config.RetryFilter == nil {
+ config.RetryFilter = func(c *echo.Context, e error) bool {
+ if httpErr, ok := e.(*echo.HTTPError); ok {
+ return httpErr.Code == http.StatusBadGateway
+ }
+ return false
+ }
+ }
+ if config.ErrorHandler == nil {
+ config.ErrorHandler = func(c *echo.Context, err error) error {
+ return err
+ }
}
- config.rewriteRegex = map[*regexp.Regexp]string{}
- // Initialize
- for k, v := range config.Rewrite {
- k = strings.Replace(k, "*", "(\\S*)", -1)
- config.rewriteRegex[regexp.MustCompile(k)] = v
+ if config.Rewrite != nil {
+ if config.RegexRewrite == nil {
+ config.RegexRewrite = make(map[*regexp.Regexp]string)
+ }
+ for k, v := range rewriteRulesRegex(config.Rewrite) {
+ config.RegexRewrite[k] = v
+ }
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) (err error) {
+ return func(c *echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()
- tgt := config.Balancer.Next(c)
- c.Set(config.ContextKey, tgt)
-
- // Rewrite
- for k, v := range config.rewriteRegex {
- replacer := captureTokens(k, req.URL.Path)
- if replacer != nil {
- req.URL.Path = replacer.Replace(v)
- }
+ if err := rewriteURL(config.RegexRewrite, req); err != nil {
+ return config.ErrorHandler(c, err)
}
// Fix header
- if req.Header.Get(echo.HeaderXRealIP) == "" {
+ // Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
+ // However, for backward compatibility, legacy behavior is preserved unless you configure Echo#IPExtractor.
+ if req.Header.Get(echo.HeaderXRealIP) == "" || c.Echo().IPExtractor != nil {
req.Header.Set(echo.HeaderXRealIP, c.RealIP())
}
if req.Header.Get(echo.HeaderXForwardedProto) == "" {
@@ -241,19 +360,82 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
}
- // Proxy
- switch {
- case c.IsWebSocket():
- proxyRaw(tgt, c).ServeHTTP(res, req)
- case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
- default:
- proxyHTTP(tgt, c, config).ServeHTTP(res, req)
- }
- if e, ok := c.Get("_error").(error); ok {
- err = e
+ retries := config.RetryCount
+ for {
+ tgt, err := config.Balancer.Next(c)
+ if err != nil {
+ return config.ErrorHandler(c, err)
+ }
+
+ c.Set(config.ContextKey, tgt)
+
+ //If retrying a failed request, clear any previous errors from
+ //context here so that balancers have the option to check for
+ //errors that occurred using previous target
+ if retries < config.RetryCount {
+ c.Set("_error", nil)
+ }
+
+ // This is needed for ProxyConfig.ModifyResponse and/or ProxyConfig.Transport to be able to process the Request
+ // that Balancer may have replaced with c.SetRequest.
+ req = c.Request()
+
+ // Proxy
+ switch {
+ case c.IsWebSocket():
+ proxyRaw(c, tgt, config).ServeHTTP(res, req)
+ default: // even SSE requests
+ proxyHTTP(c, tgt, config).ServeHTTP(res, req)
+ }
+
+ err, hasError := c.Get("_error").(error)
+ if !hasError {
+ return nil
+ }
+
+ retry := retries > 0 && config.RetryFilter(c, err)
+ if !retry {
+ return config.ErrorHandler(c, err)
+ }
+
+ retries--
}
+ }
+ }, nil
+}
- return
+// StatusCodeContextCanceled is a custom HTTP status code for situations
+// where a client unexpectedly closed the connection to the server.
+// As there is no standard error code for "client closed connection", but
+// various well-known HTTP clients and server implement this HTTP code we use
+// 499 too instead of the more problematic 5xx, which does not allow to detect this situation
+const StatusCodeContextCanceled = 499
+
+func proxyHTTP(c *echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler {
+ proxy := httputil.NewSingleHostReverseProxy(tgt.URL)
+ proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) {
+ desc := tgt.URL.String()
+ if tgt.Name != "" {
+ desc = fmt.Sprintf("%s(%s)", tgt.Name, tgt.URL.String())
+ }
+ // If the client canceled the request (usually by closing the connection), we can report a
+ // client error (4xx) instead of a server error (5xx) to correctly identify the situation.
+ // The Go standard library (at of late 2020) wraps the exported, standard
+ // context. Canceled error with unexported garbage value requiring a substring check, see
+ // https://github.com/golang/go/blob/6965b01ea248cabb70c3749fd218b36089a21efb/src/net/net.go#L416-L430
+ // From Caddy https://github.com/caddyserver/caddy/blob/afa778ae05503f563af0d1015cdf7e5e78b1eeec/modules/caddyhttp/reverseproxy/reverseproxy.go#L1352
+ if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "operation was canceled") {
+ httpError := echo.NewHTTPError(StatusCodeContextCanceled, "client closed connection").Wrap(err)
+ c.Set("_error", httpError)
+ } else {
+ httpError := echo.NewHTTPError(
+ http.StatusBadGateway,
+ "remote server unreachable, could not proxy request",
+ ).Wrap(fmt.Errorf("server: %s, err: %w", desc, err))
+ c.Set("_error", httpError)
}
}
+ proxy.Transport = config.Transport
+ proxy.ModifyResponse = config.ModifyResponse
+ return proxy
}
diff --git a/middleware/proxy_1_11.go b/middleware/proxy_1_11.go
deleted file mode 100644
index 12b7568bf..000000000
--- a/middleware/proxy_1_11.go
+++ /dev/null
@@ -1,24 +0,0 @@
-// +build go1.11
-
-package middleware
-
-import (
- "fmt"
- "net/http"
- "net/http/httputil"
-
- "github.com/labstack/echo/v4"
-)
-
-func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
- proxy := httputil.NewSingleHostReverseProxy(tgt.URL)
- proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) {
- desc := tgt.URL.String()
- if tgt.Name != "" {
- desc = fmt.Sprintf("%s(%s)", tgt.Name, tgt.URL.String())
- }
- c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err)))
- }
- proxy.Transport = config.Transport
- return proxy
-}
diff --git a/middleware/proxy_1_11_n.go b/middleware/proxy_1_11_n.go
deleted file mode 100644
index 9a78929fe..000000000
--- a/middleware/proxy_1_11_n.go
+++ /dev/null
@@ -1,14 +0,0 @@
-// +build !go1.11
-
-package middleware
-
-import (
- "net/http"
- "net/http/httputil"
-
- "github.com/labstack/echo/v4"
-)
-
-func proxyHTTP(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
- return httputil.NewSingleHostReverseProxy(t.URL)
-}
diff --git a/middleware/proxy_1_11_test.go b/middleware/proxy_1_11_test.go
deleted file mode 100644
index 26feaabaa..000000000
--- a/middleware/proxy_1_11_test.go
+++ /dev/null
@@ -1,53 +0,0 @@
-// +build go1.11
-
-package middleware
-
-import (
- "net/http"
- "net/http/httptest"
- "net/url"
- "testing"
-
- "github.com/labstack/echo/v4"
- "github.com/stretchr/testify/assert"
-)
-
-func TestProxy_1_11(t *testing.T) {
- // Setup
- url1, _ := url.Parse("http://127.0.0.1:27121")
- url2, _ := url.Parse("http://127.0.0.1:27122")
-
- targets := []*ProxyTarget{
- {
- Name: "target 1",
- URL: url1,
- },
- {
- Name: "target 2",
- URL: url2,
- },
- }
- rb := NewRandomBalancer(nil)
- // must add targets:
- for _, target := range targets {
- assert.True(t, rb.AddTarget(target))
- }
-
- // must ignore duplicates:
- for _, target := range targets {
- assert.False(t, rb.AddTarget(target))
- }
-
- // Random
- e := echo.New()
- e.Use(Proxy(rb))
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
-
- // Remote unreachable
- rec = httptest.NewRecorder()
- req.URL.Path = "/api/users"
- e.ServeHTTP(rec, req)
- assert.Equal(t, "/api/users", req.URL.Path)
- assert.Equal(t, http.StatusBadGateway, rec.Code)
-}
diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go
index 1a375db86..420be3240 100644
--- a/middleware/proxy_test.go
+++ b/middleware/proxy_test.go
@@ -1,16 +1,30 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "errors"
"fmt"
+ "io"
+ "net"
"net/http"
"net/http/httptest"
"net/url"
+ "regexp"
+ "sync"
"testing"
+ "time"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
+ "golang.org/x/net/websocket"
)
+// Assert expected with url.EscapedPath method to obtain the path.
func TestProxy(t *testing.T) {
// Setup
t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -47,7 +61,7 @@ func TestProxy(t *testing.T) {
// Random
e := echo.New()
- e.Use(Proxy(rb))
+ e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
@@ -69,53 +83,973 @@ func TestProxy(t *testing.T) {
// Round-robin
rrb := NewRoundRobinBalancer(targets)
e = echo.New()
- e.Use(Proxy(rrb))
+ e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
+
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
body = rec.Body.String()
assert.Equal(t, "target 1", body)
+
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
body = rec.Body.String()
assert.Equal(t, "target 2", body)
- // Rewrite
+ // ModifyResponse
e = echo.New()
e.Use(ProxyWithConfig(ProxyConfig{
Balancer: rrb,
- Rewrite: map[string]string{
- "/old": "/new",
- "/api/*": "/$1",
- "/js/*": "/public/javascripts/$1",
- "/users/*/orders/*": "/user/$1/order/$2",
+ ModifyResponse: func(res *http.Response) error {
+ res.Body = io.NopCloser(bytes.NewBuffer([]byte("modified")))
+ res.Header.Set("X-Modified", "1")
+ return nil
},
}))
- req.URL.Path = "/api/users"
- e.ServeHTTP(rec, req)
- assert.Equal(t, "/users", req.URL.Path)
- req.URL.Path = "/js/main.js"
- e.ServeHTTP(rec, req)
- assert.Equal(t, "/public/javascripts/main.js", req.URL.Path)
- req.URL.Path = "/old"
- e.ServeHTTP(rec, req)
- assert.Equal(t, "/new", req.URL.Path)
- req.URL.Path = "/users/jack/orders/1"
+
+ rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
- assert.Equal(t, "/user/jack/order/1", req.URL.Path)
- assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "modified", rec.Body.String())
+ assert.Equal(t, "1", rec.Header().Get("X-Modified"))
// ProxyTarget is set in context
contextObserver := func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) (err error) {
+ return func(c *echo.Context) (err error) {
next(c)
assert.Contains(t, targets, c.Get("target"), "target is not set in context")
return nil
}
}
- rrb1 := NewRoundRobinBalancer(targets)
+
e = echo.New()
e.Use(contextObserver)
- e.Use(Proxy(rrb1))
+ e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)}))
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
}
+
+func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) {
+ assert.Panics(t, func() {
+ ProxyWithConfig(ProxyConfig{Balancer: nil})
+ })
+}
+
+func TestProxyRealIPHeader(t *testing.T) {
+ // Setup
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
+ defer upstream.Close()
+ url, _ := url.Parse(upstream.URL)
+ rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}})
+ e := echo.New()
+ e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ remoteAddrIP, _, _ := net.SplitHostPort(req.RemoteAddr)
+ realIPHeaderIP := "203.0.113.1"
+ extractedRealIP := "203.0.113.10"
+ tests := []*struct {
+ hasRealIPheader bool
+ hasIPExtractor bool
+ expectedXRealIP string
+ }{
+ {false, false, remoteAddrIP},
+ {false, true, extractedRealIP},
+ {true, false, realIPHeaderIP},
+ {true, true, extractedRealIP},
+ }
+
+ for _, tt := range tests {
+ if tt.hasRealIPheader {
+ req.Header.Set(echo.HeaderXRealIP, realIPHeaderIP)
+ } else {
+ req.Header.Del(echo.HeaderXRealIP)
+ }
+ if tt.hasIPExtractor {
+ e.IPExtractor = func(*http.Request) string {
+ return extractedRealIP
+ }
+ } else {
+ e.IPExtractor = nil
+ }
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, tt.expectedXRealIP, req.Header.Get(echo.HeaderXRealIP), "hasRealIPheader: %t / hasIPExtractor: %t", tt.hasRealIPheader, tt.hasIPExtractor)
+ }
+}
+
+func TestProxyRewrite(t *testing.T) {
+ var testCases = []struct {
+ whenPath string
+ expectProxiedURI string
+ expectStatus int
+ }{
+ {
+ whenPath: "/api/users",
+ expectProxiedURI: "/users",
+ expectStatus: http.StatusOK,
+ },
+ {
+ whenPath: "/js/main.js",
+ expectProxiedURI: "/public/javascripts/main.js",
+ expectStatus: http.StatusOK,
+ },
+ {
+ whenPath: "/old",
+ expectProxiedURI: "/new",
+ expectStatus: http.StatusOK,
+ },
+ {
+ whenPath: "/users/jack/orders/1",
+ expectProxiedURI: "/user/jack/order/1",
+ expectStatus: http.StatusOK,
+ },
+ {
+ whenPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ expectProxiedURI: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ expectStatus: http.StatusOK,
+ },
+ { // ` ` (space) is encoded by httpClient to `%20` when doing request to Echo. `%20` should not be double escaped when proxying request
+ whenPath: "/api/new users",
+ expectProxiedURI: "/new%20users",
+ expectStatus: http.StatusOK,
+ },
+ { // query params should be proxied and not be modified
+ whenPath: "/api/users?limit=10",
+ expectProxiedURI: "/users?limit=10",
+ expectStatus: http.StatusOK,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.whenPath, func(t *testing.T) {
+ receivedRequestURI := make(chan string, 1)
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server
+ // we need unmodified target to see if we are encoding/decoding the url in addition to rewrite/replace logic
+ // if original request had `%2F` we should not magically decode it to `/` as it would change what was requested
+ receivedRequestURI <- r.RequestURI
+ }))
+ defer upstream.Close()
+ serverURL, _ := url.Parse(upstream.URL)
+ rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: serverURL}})
+
+ // Rewrite
+ e := echo.New()
+ e.Use(ProxyWithConfig(ProxyConfig{
+ Balancer: rrb,
+ Rewrite: map[string]string{
+ "/old": "/new",
+ "/api/*": "/$1",
+ "/js/*": "/public/javascripts/$1",
+ "/users/*/orders/*": "/user/$1/order/$2",
+ },
+ }))
+
+ targetURL, _ := serverURL.Parse(tc.whenPath)
+ req := httptest.NewRequest(http.MethodGet, targetURL.String(), nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ actualRequestURI := <-receivedRequestURI
+ assert.Equal(t, tc.expectProxiedURI, actualRequestURI)
+ })
+ }
+}
+
+func TestProxyRewriteRegex(t *testing.T) {
+ // Setup
+ receivedRequestURI := make(chan string, 1)
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server
+ // we need unmodified target to see if we are encoding/decoding the url in addition to rewrite/replace logic
+ // if original request had `%2F` we should not magically decode it to `/` as it would change what was requested
+ receivedRequestURI <- r.RequestURI
+ }))
+ defer upstream.Close()
+ tmpUrL, _ := url.Parse(upstream.URL)
+ rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: tmpUrL}})
+
+ // Rewrite
+ e := echo.New()
+ e.Use(ProxyWithConfig(ProxyConfig{
+ Balancer: rrb,
+ Rewrite: map[string]string{
+ "^/a/*": "/v1/$1",
+ "^/b/*/c/*": "/v2/$2/$1",
+ "^/c/*/*": "/v3/$2",
+ },
+ RegexRewrite: map[*regexp.Regexp]string{
+ regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1",
+ regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1",
+ },
+ }))
+
+ testCases := []struct {
+ requestPath string
+ statusCode int
+ expectPath string
+ }{
+ {"/unmatched", http.StatusOK, "/unmatched"},
+ {"/a/test", http.StatusOK, "/v1/test"},
+ {"/b/foo/c/bar/baz", http.StatusOK, "/v2/bar/baz/foo"},
+ {"/c/ignore/test", http.StatusOK, "/v3/test"},
+ {"/c/ignore1/test/this", http.StatusOK, "/v3/test/this"},
+ {"/x/ignore/test", http.StatusOK, "/v4/test"},
+ {"/y/foo/bar", http.StatusOK, "/v5/bar/foo"},
+ // NB: fragment is not added by golang httputil.NewSingleHostReverseProxy implementation
+ // $2 = `bar?q=1#frag`, $1 = `foo`. replaced uri = `/v5/bar?q=1#frag/foo` but httputil.NewSingleHostReverseProxy does not send `#frag/foo` (currently)
+ {"/y/foo/bar?q=1#frag", http.StatusOK, "/v5/bar?q=1"},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.requestPath, func(t *testing.T) {
+ targetURL, _ := url.Parse(tc.requestPath)
+ req := httptest.NewRequest(http.MethodGet, targetURL.String(), nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ actualRequestURI := <-receivedRequestURI
+ assert.Equal(t, tc.expectPath, actualRequestURI)
+ assert.Equal(t, tc.statusCode, rec.Code)
+ })
+ }
+}
+
+func TestProxyError(t *testing.T) {
+ // Setup
+ url1, _ := url.Parse("http://127.0.0.1:27121")
+ url2, _ := url.Parse("http://127.0.0.1:27122")
+
+ targets := []*ProxyTarget{
+ {
+ Name: "target 1",
+ URL: url1,
+ },
+ {
+ Name: "target 2",
+ URL: url2,
+ },
+ }
+ rb := NewRandomBalancer(nil)
+ // must add targets:
+ for _, target := range targets {
+ assert.True(t, rb.AddTarget(target))
+ }
+
+ // must ignore duplicates:
+ for _, target := range targets {
+ assert.False(t, rb.AddTarget(target))
+ }
+
+ // Random
+ e := echo.New()
+ e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+
+ // Remote unreachable
+ rec := httptest.NewRecorder()
+ req.URL.Path = "/api/users"
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, "/api/users", req.URL.Path)
+ assert.Equal(t, http.StatusBadGateway, rec.Code)
+}
+
+func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) {
+ var timeoutStop sync.WaitGroup
+ timeoutStop.Add(1)
+ HTTPTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ timeoutStop.Wait() // wait until we have canceled the request
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer HTTPTarget.Close()
+ targetURL, _ := url.Parse(HTTPTarget.URL)
+ target := &ProxyTarget{
+ Name: "target",
+ URL: targetURL,
+ }
+ rb := NewRandomBalancer(nil)
+ assert.True(t, rb.AddTarget(target))
+ e := echo.New()
+ e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ ctx, cancel := context.WithCancel(req.Context())
+ req = req.WithContext(ctx)
+ go func() {
+ time.Sleep(10 * time.Millisecond)
+ cancel()
+ }()
+ e.ServeHTTP(rec, req)
+ timeoutStop.Done()
+ assert.Equal(t, 499, rec.Code)
+}
+
+type testProvider struct {
+ commonBalancer
+ target *ProxyTarget
+ err error
+}
+
+func (p *testProvider) Next(c *echo.Context) (*ProxyTarget, error) {
+ return p.target, p.err
+}
+
+func TestTargetProvider(t *testing.T) {
+ t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, "target 1")
+ }))
+ defer t1.Close()
+ url1, _ := url.Parse(t1.URL)
+
+ e := echo.New()
+ tp := &testProvider{}
+ tp.target = &ProxyTarget{Name: "target 1", URL: url1}
+ e.Use(Proxy(tp))
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ e.ServeHTTP(rec, req)
+ body := rec.Body.String()
+ assert.Equal(t, "target 1", body)
+}
+
+func TestFailNextTarget(t *testing.T) {
+ url1, err := url.Parse("http://dummy:8080")
+ assert.Nil(t, err)
+
+ e := echo.New()
+ tp := &testProvider{}
+ tp.target = &ProxyTarget{Name: "target 1", URL: url1}
+ tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target")
+
+ e.Use(Proxy(tp))
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ e.ServeHTTP(rec, req)
+ body := rec.Body.String()
+ assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body)
+}
+
+func TestRandomBalancerWithNoTargets(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ // Assert balancer with empty targets does return `nil` on `Next()`
+ rb := NewRandomBalancer(nil)
+ target, err := rb.Next(c)
+ assert.Nil(t, target)
+ assert.NoError(t, err)
+}
+
+func TestRoundRobinBalancerWithNoTargets(t *testing.T) {
+ // Assert balancer with empty targets does return `nil` on `Next()`
+ rrb := NewRoundRobinBalancer([]*ProxyTarget{})
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ target, err := rrb.Next(c)
+ assert.Nil(t, target)
+ assert.NoError(t, err)
+}
+
+func TestProxyRetries(t *testing.T) {
+ newServer := func(res int) (*url.URL, *httptest.Server) {
+ server := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(res)
+ }),
+ )
+ targetURL, _ := url.Parse(server.URL)
+ return targetURL, server
+ }
+
+ targetURL, server := newServer(http.StatusOK)
+ defer server.Close()
+ goodTarget := &ProxyTarget{
+ Name: "Good",
+ URL: targetURL,
+ }
+
+ targetURL, server = newServer(http.StatusBadRequest)
+ defer server.Close()
+ goodTargetWith40X := &ProxyTarget{
+ Name: "Good with 40X",
+ URL: targetURL,
+ }
+
+ targetURL, _ = url.Parse("http://127.0.0.1:27121")
+ badTarget := &ProxyTarget{
+ Name: "Bad",
+ URL: targetURL,
+ }
+
+ alwaysRetryFilter := func(c *echo.Context, e error) bool { return true }
+ neverRetryFilter := func(c *echo.Context, e error) bool { return false }
+
+ testCases := []struct {
+ name string
+ retryCount int
+ retryFilters []func(c *echo.Context, e error) bool
+ targets []*ProxyTarget
+ expectedResponse int
+ }{
+ {
+ name: "retry count 0 does not attempt retry on fail",
+ targets: []*ProxyTarget{
+ badTarget,
+ goodTarget,
+ },
+ expectedResponse: http.StatusBadGateway,
+ },
+ {
+ name: "retry count 1 does not attempt retry on success",
+ retryCount: 1,
+ targets: []*ProxyTarget{
+ goodTarget,
+ },
+ expectedResponse: http.StatusOK,
+ },
+ {
+ name: "retry count 1 does retry on handler return true",
+ retryCount: 1,
+ retryFilters: []func(c *echo.Context, e error) bool{
+ alwaysRetryFilter,
+ },
+ targets: []*ProxyTarget{
+ badTarget,
+ goodTarget,
+ },
+ expectedResponse: http.StatusOK,
+ },
+ {
+ name: "retry count 1 does not retry on handler return false",
+ retryCount: 1,
+ retryFilters: []func(c *echo.Context, e error) bool{
+ neverRetryFilter,
+ },
+ targets: []*ProxyTarget{
+ badTarget,
+ goodTarget,
+ },
+ expectedResponse: http.StatusBadGateway,
+ },
+ {
+ name: "retry count 2 returns error when no more retries left",
+ retryCount: 2,
+ retryFilters: []func(c *echo.Context, e error) bool{
+ alwaysRetryFilter,
+ alwaysRetryFilter,
+ },
+ targets: []*ProxyTarget{
+ badTarget,
+ badTarget,
+ badTarget,
+ goodTarget, //Should never be reached as only 2 retries
+ },
+ expectedResponse: http.StatusBadGateway,
+ },
+ {
+ name: "retry count 2 returns error when retries left but handler returns false",
+ retryCount: 3,
+ retryFilters: []func(c *echo.Context, e error) bool{
+ alwaysRetryFilter,
+ alwaysRetryFilter,
+ neverRetryFilter,
+ },
+ targets: []*ProxyTarget{
+ badTarget,
+ badTarget,
+ badTarget,
+ goodTarget, //Should never be reached as retry handler returns false on 2nd check
+ },
+ expectedResponse: http.StatusBadGateway,
+ },
+ {
+ name: "retry count 3 succeeds",
+ retryCount: 3,
+ retryFilters: []func(c *echo.Context, e error) bool{
+ alwaysRetryFilter,
+ alwaysRetryFilter,
+ alwaysRetryFilter,
+ },
+ targets: []*ProxyTarget{
+ badTarget,
+ badTarget,
+ badTarget,
+ goodTarget,
+ },
+ expectedResponse: http.StatusOK,
+ },
+ {
+ name: "40x responses are not retried",
+ retryCount: 1,
+ targets: []*ProxyTarget{
+ goodTargetWith40X,
+ goodTarget,
+ },
+ expectedResponse: http.StatusBadRequest,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+
+ retryFilterCall := 0
+ retryFilter := func(c *echo.Context, e error) bool {
+ if len(tc.retryFilters) == 0 {
+ assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall))
+ }
+
+ retryFilterCall++
+
+ nextRetryFilter := tc.retryFilters[0]
+ tc.retryFilters = tc.retryFilters[1:]
+
+ return nextRetryFilter(c, e)
+ }
+
+ e := echo.New()
+ e.Use(ProxyWithConfig(
+ ProxyConfig{
+ Balancer: NewRoundRobinBalancer(tc.targets),
+ RetryCount: tc.retryCount,
+ RetryFilter: retryFilter,
+ },
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectedResponse, rec.Code)
+ if len(tc.retryFilters) > 0 {
+ assert.FailNow(t, fmt.Sprintf("expected %d more retry handler calls", len(tc.retryFilters)))
+ }
+ })
+ }
+}
+
+func TestProxyRetryWithBackendTimeout(t *testing.T) {
+
+ transport := http.DefaultTransport.(*http.Transport).Clone()
+ transport.ResponseHeaderTimeout = time.Millisecond * 500
+
+ timeoutBackend := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ time.Sleep(1 * time.Second)
+ w.WriteHeader(404)
+ }),
+ )
+ defer timeoutBackend.Close()
+
+ timeoutTargetURL, _ := url.Parse(timeoutBackend.URL)
+ goodBackend := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(200)
+ }),
+ )
+ defer goodBackend.Close()
+
+ goodTargetURL, _ := url.Parse(goodBackend.URL)
+ e := echo.New()
+ e.Use(ProxyWithConfig(
+ ProxyConfig{
+ Transport: transport,
+ Balancer: NewRoundRobinBalancer([]*ProxyTarget{
+ {
+ Name: "Timeout",
+ URL: timeoutTargetURL,
+ },
+ {
+ Name: "Good",
+ URL: goodTargetURL,
+ },
+ }),
+ RetryCount: 1,
+ },
+ ))
+
+ var wg sync.WaitGroup
+ for i := 0; i < 20; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, 200, rec.Code)
+ }()
+ }
+
+ wg.Wait()
+
+}
+
+func TestProxyErrorHandler(t *testing.T) {
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ goodURL, _ := url.Parse(server.URL)
+ defer server.Close()
+ goodTarget := &ProxyTarget{
+ Name: "Good",
+ URL: goodURL,
+ }
+
+ badURL, _ := url.Parse("http://127.0.0.1:27121")
+ badTarget := &ProxyTarget{
+ Name: "Bad",
+ URL: badURL,
+ }
+
+ transformedError := errors.New("a new error")
+
+ testCases := []struct {
+ name string
+ target *ProxyTarget
+ errorHandler func(c *echo.Context, e error) error
+ expectFinalError func(t *testing.T, err error)
+ }{
+ {
+ name: "Error handler not invoked when request success",
+ target: goodTarget,
+ errorHandler: func(c *echo.Context, e error) error {
+ assert.FailNow(t, "error handler should not be invoked")
+ return e
+ },
+ },
+ {
+ name: "Error handler invoked when request fails",
+ target: badTarget,
+ errorHandler: func(c *echo.Context, e error) error {
+ httpErr, ok := e.(*echo.HTTPError)
+ assert.True(t, ok, "expected http error to be passed to handler")
+ assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler")
+ return transformedError
+ },
+ expectFinalError: func(t *testing.T, err error) {
+ assert.Equal(t, transformedError, err, "transformed error not returned from proxy")
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+ e.Use(ProxyWithConfig(
+ ProxyConfig{
+ Balancer: NewRoundRobinBalancer([]*ProxyTarget{tc.target}),
+ ErrorHandler: tc.errorHandler,
+ },
+ ))
+
+ errorHandlerCalled := false
+ dheh := echo.DefaultHTTPErrorHandler(false)
+ e.HTTPErrorHandler = func(c *echo.Context, err error) {
+ errorHandlerCalled = true
+ tc.expectFinalError(t, err)
+ dheh(c, err)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ if !errorHandlerCalled && tc.expectFinalError != nil {
+ t.Fatalf("error handler was not called")
+ }
+
+ })
+ }
+}
+
+type testContextKey string
+type customBalancer struct {
+ target *ProxyTarget
+}
+
+func (b *customBalancer) AddTarget(target *ProxyTarget) bool {
+ return false
+}
+func (b *customBalancer) RemoveTarget(name string) bool {
+ return false
+}
+
+func (b *customBalancer) Next(c *echo.Context) (*ProxyTarget, error) {
+ ctx := context.WithValue(c.Request().Context(), testContextKey("FROM_BALANCER"), "CUSTOM_BALANCER")
+ c.SetRequest(c.Request().WithContext(ctx))
+ return b.target, nil
+}
+
+func TestModifyResponseUseContext(t *testing.T) {
+ server := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte("OK"))
+ }),
+ )
+ defer server.Close()
+ targetURL, _ := url.Parse(server.URL)
+ e := echo.New()
+ e.Use(ProxyWithConfig(
+ ProxyConfig{
+ Balancer: &customBalancer{
+ target: &ProxyTarget{
+ Name: "tst",
+ URL: targetURL,
+ },
+ },
+ RetryCount: 1,
+ ModifyResponse: func(res *http.Response) error {
+ val := res.Request.Context().Value(testContextKey("FROM_BALANCER"))
+ if valStr, ok := val.(string); ok {
+ res.Header.Set("FROM_BALANCER", valStr)
+ }
+ return nil
+ },
+ },
+ ))
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "OK", rec.Body.String())
+ assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER"))
+}
+
+func createSimpleWebSocketServer(serveTLS bool) *httptest.Server {
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ wsHandler := func(conn *websocket.Conn) {
+ defer conn.Close()
+ for {
+ var msg string
+ err := websocket.Message.Receive(conn, &msg)
+ if err != nil {
+ return
+ }
+ // message back to the client
+ websocket.Message.Send(conn, msg)
+ }
+ }
+ websocket.Server{Handler: wsHandler}.ServeHTTP(w, r)
+ })
+ if serveTLS {
+ return httptest.NewTLSServer(handler)
+ }
+ return httptest.NewServer(handler)
+}
+
+func createSimpleProxyServer(t *testing.T, srv *httptest.Server, serveTLS bool, toTLS bool) *httptest.Server {
+ e := echo.New()
+
+ if toTLS {
+ // proxy to tls target
+ tgtURL, _ := url.Parse(srv.URL)
+ tgtURL.Scheme = "wss"
+ balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
+
+ defaultTransport, ok := http.DefaultTransport.(*http.Transport)
+ if !ok {
+ t.Fatal("Default transport is not of type *http.Transport")
+ }
+ transport := defaultTransport.Clone()
+ transport.TLSClientConfig = &tls.Config{
+ InsecureSkipVerify: true,
+ }
+ e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer, Transport: transport}))
+ } else {
+ // proxy to non-TLS target
+ tgtURL, _ := url.Parse(srv.URL)
+ balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
+ e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer}))
+ }
+
+ if serveTLS {
+ // serve proxy server with TLS
+ ts := httptest.NewTLSServer(e)
+ return ts
+ }
+ // serve proxy server without TLS
+ ts := httptest.NewServer(e)
+ return ts
+}
+
+// TestProxyWithConfigWebSocketNonTLS2NonTLS tests the proxy with non-TLS to non-TLS WebSocket connection.
+func TestProxyWithConfigWebSocketNonTLS2NonTLS(t *testing.T) {
+ /*
+ Arrange
+ */
+ // Create a WebSocket test server (non-TLS)
+ srv := createSimpleWebSocketServer(false)
+ defer srv.Close()
+
+ // create proxy server (non-TLS to non-TLS)
+ ts := createSimpleProxyServer(t, srv, false, false)
+ defer ts.Close()
+
+ tsURL, _ := url.Parse(ts.URL)
+ tsURL.Scheme = "ws"
+ tsURL.Path = "/"
+
+ /*
+ Act
+ */
+
+ // Connect to the proxy WebSocket
+ wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
+ assert.NoError(t, err)
+ defer wsConn.Close()
+
+ // Send message
+ sendMsg := "Hello, Non TLS WebSocket!"
+ err = websocket.Message.Send(wsConn, sendMsg)
+ assert.NoError(t, err)
+
+ /*
+ Assert
+ */
+ // Read response
+ var recvMsg string
+ err = websocket.Message.Receive(wsConn, &recvMsg)
+ assert.NoError(t, err)
+ assert.Equal(t, sendMsg, recvMsg)
+}
+
+// TestProxyWithConfigWebSocketTLS2TLS tests the proxy with TLS to TLS WebSocket connection.
+func TestProxyWithConfigWebSocketTLS2TLS(t *testing.T) {
+ /*
+ Arrange
+ */
+ // Create a WebSocket test server (TLS)
+ srv := createSimpleWebSocketServer(true)
+ defer srv.Close()
+
+ // create proxy server (TLS to TLS)
+ ts := createSimpleProxyServer(t, srv, true, true)
+ defer ts.Close()
+
+ tsURL, _ := url.Parse(ts.URL)
+ tsURL.Scheme = "wss"
+ tsURL.Path = "/"
+
+ /*
+ Act
+ */
+ origin, err := url.Parse(ts.URL)
+ assert.NoError(t, err)
+ config := &websocket.Config{
+ Location: tsURL,
+ Origin: origin,
+ TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
+ Version: websocket.ProtocolVersionHybi13,
+ }
+ wsConn, err := websocket.DialConfig(config)
+ assert.NoError(t, err)
+ defer wsConn.Close()
+
+ // Send message
+ sendMsg := "Hello, TLS to TLS WebSocket!"
+ err = websocket.Message.Send(wsConn, sendMsg)
+ assert.NoError(t, err)
+
+ // Read response
+ var recvMsg string
+ err = websocket.Message.Receive(wsConn, &recvMsg)
+ assert.NoError(t, err)
+ assert.Equal(t, sendMsg, recvMsg)
+}
+
+// TestProxyWithConfigWebSocketNonTLS2TLS tests the proxy with non-TLS to TLS WebSocket connection.
+func TestProxyWithConfigWebSocketNonTLS2TLS(t *testing.T) {
+ /*
+ Arrange
+ */
+
+ // Create a WebSocket test server (TLS)
+ srv := createSimpleWebSocketServer(true)
+ defer srv.Close()
+
+ // create proxy server (Non-TLS to TLS)
+ ts := createSimpleProxyServer(t, srv, false, true)
+ defer ts.Close()
+
+ tsURL, _ := url.Parse(ts.URL)
+ tsURL.Scheme = "ws"
+ tsURL.Path = "/"
+
+ /*
+ Act
+ */
+ // Connect to the proxy WebSocket
+ wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
+ assert.NoError(t, err)
+ defer wsConn.Close()
+
+ // Send message
+ sendMsg := "Hello, Non TLS to TLS WebSocket!"
+ err = websocket.Message.Send(wsConn, sendMsg)
+ assert.NoError(t, err)
+
+ /*
+ Assert
+ */
+ // Read response
+ var recvMsg string
+ err = websocket.Message.Receive(wsConn, &recvMsg)
+ assert.NoError(t, err)
+ assert.Equal(t, sendMsg, recvMsg)
+}
+
+// TestProxyWithConfigWebSocketTLSToNoneTLS tests the proxy with TLS to non-TLS WebSocket connection. (TLS termination)
+func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) {
+ /*
+ Arrange
+ */
+
+ // Create a WebSocket test server (non-TLS)
+ srv := createSimpleWebSocketServer(false)
+ defer srv.Close()
+
+ // create proxy server (TLS to non-TLS)
+ ts := createSimpleProxyServer(t, srv, true, false)
+ defer ts.Close()
+
+ tsURL, _ := url.Parse(ts.URL)
+ tsURL.Scheme = "wss"
+ tsURL.Path = "/"
+
+ /*
+ Act
+ */
+ origin, err := url.Parse(ts.URL)
+ assert.NoError(t, err)
+ config := &websocket.Config{
+ Location: tsURL,
+ Origin: origin,
+ TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
+ Version: websocket.ProtocolVersionHybi13,
+ }
+ wsConn, err := websocket.DialConfig(config)
+ assert.NoError(t, err)
+ defer wsConn.Close()
+
+ // Send message
+ sendMsg := "Hello, TLS to NoneTLS WebSocket!"
+ err = websocket.Message.Send(wsConn, sendMsg)
+ assert.NoError(t, err)
+
+ // Read response
+ var recvMsg string
+ err = websocket.Message.Receive(wsConn, &recvMsg)
+ assert.NoError(t, err)
+ assert.Equal(t, sendMsg, recvMsg)
+}
diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go
new file mode 100644
index 000000000..c04ae157d
--- /dev/null
+++ b/middleware/rate_limiter.go
@@ -0,0 +1,263 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "errors"
+ "math"
+ "net/http"
+ "sync"
+ "time"
+
+ "github.com/labstack/echo/v5"
+ "golang.org/x/time/rate"
+)
+
+// RateLimiterStore is the interface to be implemented by custom stores.
+type RateLimiterStore interface {
+ Allow(identifier string) (bool, error)
+}
+
+// RateLimiterConfig defines the configuration for the rate limiter
+type RateLimiterConfig struct {
+ Skipper Skipper
+ BeforeFunc BeforeFunc
+ // IdentifierExtractor uses *echo.Context to extract the identifier for a visitor
+ IdentifierExtractor Extractor
+ // Store defines a store for the rate limiter
+ Store RateLimiterStore
+ // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error
+ ErrorHandler func(c *echo.Context, err error) error
+ // DenyHandler provides a handler to be called when RateLimiter denies access
+ DenyHandler func(c *echo.Context, identifier string, err error) error
+}
+
+// Extractor is used to extract data from *echo.Context
+type Extractor func(c *echo.Context) (string, error)
+
+// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
+var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
+
+// ErrExtractorError denotes an error raised when extractor function is unsuccessful
+var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
+
+// DefaultRateLimiterConfig defines default values for RateLimiterConfig
+var DefaultRateLimiterConfig = RateLimiterConfig{
+ Skipper: DefaultSkipper,
+ IdentifierExtractor: func(ctx *echo.Context) (string, error) {
+ id := ctx.RealIP()
+ return id, nil
+ },
+ ErrorHandler: func(c *echo.Context, err error) error {
+ return ErrExtractorError.Wrap(err)
+ },
+ DenyHandler: func(c *echo.Context, identifier string, err error) error {
+ return ErrRateLimitExceeded.Wrap(err)
+ },
+}
+
+/*
+RateLimiter returns a rate limiting middleware
+
+ e := echo.New()
+
+ limiterStore := middleware.NewRateLimiterMemoryStore(20)
+
+ e.GET("/rate-limited", func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }, RateLimiter(limiterStore))
+*/
+func RateLimiter(store RateLimiterStore) echo.MiddlewareFunc {
+ config := DefaultRateLimiterConfig
+ config.Store = store
+
+ return RateLimiterWithConfig(config)
+}
+
+/*
+RateLimiterWithConfig returns a rate limiting middleware
+
+ e := echo.New()
+
+ config := middleware.RateLimiterConfig{
+ Skipper: DefaultSkipper,
+ Store: middleware.NewRateLimiterMemoryStore(
+ middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute}
+ )
+ IdentifierExtractor: func(ctx *echo.Context) (string, error) {
+ id := ctx.RealIP()
+ return id, nil
+ },
+ ErrorHandler: func(ctx *echo.Context, err error) error {
+ return context.JSON(http.StatusTooManyRequests, nil)
+ },
+ DenyHandler: func(ctx *echo.Context, identifier string, err error) error {
+ return context.JSON(http.StatusForbidden, nil)
+ },
+ }
+
+ e.GET("/rate-limited", func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }, middleware.RateLimiterWithConfig(config))
+*/
+func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration
+func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Skipper == nil {
+ config.Skipper = DefaultRateLimiterConfig.Skipper
+ }
+ if config.IdentifierExtractor == nil {
+ config.IdentifierExtractor = DefaultRateLimiterConfig.IdentifierExtractor
+ }
+ if config.ErrorHandler == nil {
+ config.ErrorHandler = DefaultRateLimiterConfig.ErrorHandler
+ }
+ if config.DenyHandler == nil {
+ config.DenyHandler = DefaultRateLimiterConfig.DenyHandler
+ }
+ if config.Store == nil {
+ return nil, errors.New("echo rate limiter store configuration must be provided")
+ }
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+ if config.BeforeFunc != nil {
+ config.BeforeFunc(c)
+ }
+
+ identifier, err := config.IdentifierExtractor(c)
+ if err != nil {
+ return config.ErrorHandler(c, err)
+ }
+
+ if allow, allowErr := config.Store.Allow(identifier); !allow {
+ return config.DenyHandler(c, identifier, allowErr)
+ }
+ return next(c)
+ }
+ }, nil
+}
+
+// RateLimiterMemoryStore is the built-in store implementation for RateLimiter
+type RateLimiterMemoryStore struct {
+ visitors map[string]*Visitor
+ mutex sync.Mutex
+ rate float64 // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit
+ burst int
+ expiresIn time.Duration
+ lastCleanup time.Time
+
+ timeNow func() time.Time
+}
+
+// Visitor signifies a unique user's limiter details
+type Visitor struct {
+ *rate.Limiter
+ lastSeen time.Time
+}
+
+/*
+NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with
+the provided rate (as req/s).
+for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
+
+Burst and ExpiresIn will be set to default values.
+
+Note that if the provided rate is a float number and Burst is zero, Burst will be treated as the rounded down value of the rate.
+
+Example (with 20 requests/sec):
+
+ limiterStore := middleware.NewRateLimiterMemoryStore(20)
+*/
+func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore) {
+ return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
+ Rate: rateLimit,
+ })
+}
+
+/*
+NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore
+with the provided configuration. Rate must be provided. Burst will be set to the rounded down value of
+the configured rate if not provided or set to 0.
+
+The built-in memory store is usually capable for modest loads. For higher loads other
+store implementations should be considered.
+
+Characteristics:
+* Concurrency above 100 parallel requests may causes measurable lock contention
+* A high number of different IP addresses (above 16000) may be impacted by the internally used Go map
+* A high number of requests from a single IP address may cause lock contention
+
+Example:
+
+ limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig(
+ middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minute},
+ )
+*/
+func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) {
+ store = &RateLimiterMemoryStore{}
+
+ store.rate = config.Rate
+ store.burst = config.Burst
+ store.expiresIn = config.ExpiresIn
+ if config.ExpiresIn == 0 {
+ store.expiresIn = DefaultRateLimiterMemoryStoreConfig.ExpiresIn
+ }
+ if config.Burst == 0 {
+ store.burst = int(math.Max(1, math.Ceil(float64(config.Rate))))
+ }
+ store.visitors = make(map[string]*Visitor)
+ store.timeNow = time.Now
+ store.lastCleanup = store.timeNow()
+ return
+}
+
+// RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore
+type RateLimiterMemoryStoreConfig struct {
+ Rate float64 // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
+ Burst int // Burst is maximum number of requests to pass at the same moment. It additionally allows a number of requests to pass when rate limit is reached.
+ ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up
+}
+
+// DefaultRateLimiterMemoryStoreConfig provides default configuration values for RateLimiterMemoryStore
+var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{
+ ExpiresIn: 3 * time.Minute,
+}
+
+// Allow implements RateLimiterStore.Allow
+func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
+ store.mutex.Lock()
+ limiter, exists := store.visitors[identifier]
+ if !exists {
+ limiter = new(Visitor)
+ limiter.Limiter = rate.NewLimiter(rate.Limit(store.rate), store.burst)
+ store.visitors[identifier] = limiter
+ }
+ now := store.timeNow()
+ limiter.lastSeen = now
+ if now.Sub(store.lastCleanup) > store.expiresIn {
+ store.cleanupStaleVisitors(now)
+ }
+ allowed := limiter.AllowN(now, 1)
+ store.mutex.Unlock()
+ return allowed, nil
+}
+
+/*
+cleanupStaleVisitors helps manage the size of the visitors map by removing stale records
+of users who haven't visited again after the configured expiry time has elapsed
+*/
+func (store *RateLimiterMemoryStore) cleanupStaleVisitors(now time.Time) {
+ for id, visitor := range store.visitors {
+ if now.Sub(visitor.lastSeen) > store.expiresIn {
+ delete(store.visitors, id)
+ }
+ }
+ store.lastCleanup = now
+}
diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go
new file mode 100644
index 000000000..c591d2b19
--- /dev/null
+++ b/middleware/rate_limiter_test.go
@@ -0,0 +1,648 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "errors"
+ "math/rand"
+ "net/http"
+ "net/http/httptest"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+ "golang.org/x/time/rate"
+)
+
+func TestRateLimiter(t *testing.T) {
+ e := echo.New()
+
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
+
+ mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
+
+ testCases := []struct {
+ id string
+ expectErr string
+ }{
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ }
+
+ for _, tc := range testCases {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRealIP, tc.id)
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := mw(handler)(c)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, http.StatusOK, rec.Code)
+ }
+}
+
+func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) {
+ var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
+
+ assert.Panics(t, func() {
+ RateLimiterWithConfig(RateLimiterConfig{})
+ })
+
+ assert.NotPanics(t, func() {
+ RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
+ })
+}
+
+func TestRateLimiterWithConfig(t *testing.T) {
+ var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
+
+ e := echo.New()
+
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ mw, err := RateLimiterConfig{
+ IdentifierExtractor: func(c *echo.Context) (string, error) {
+ id := c.Request().Header.Get(echo.HeaderXRealIP)
+ if id == "" {
+ return "", errors.New("invalid identifier")
+ }
+ return id, nil
+ },
+ DenyHandler: func(ctx *echo.Context, identifier string, err error) error {
+ return ctx.JSON(http.StatusForbidden, nil)
+ },
+ ErrorHandler: func(ctx *echo.Context, err error) error {
+ return ctx.JSON(http.StatusBadRequest, nil)
+ },
+ Store: inMemoryStore,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ testCases := []struct {
+ id string
+ code int
+ }{
+ {"127.0.0.1", http.StatusOK},
+ {"127.0.0.1", http.StatusOK},
+ {"127.0.0.1", http.StatusOK},
+ {"127.0.0.1", http.StatusForbidden},
+ {"", http.StatusBadRequest},
+ {"127.0.0.1", http.StatusForbidden},
+ {"127.0.0.1", http.StatusForbidden},
+ }
+
+ for _, tc := range testCases {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRealIP, tc.id)
+
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec)
+
+ err := mw(handler)(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, tc.code, rec.Code)
+ }
+}
+
+func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
+ var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
+
+ e := echo.New()
+
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ mw, err := RateLimiterConfig{
+ IdentifierExtractor: func(c *echo.Context) (string, error) {
+ id := c.Request().Header.Get(echo.HeaderXRealIP)
+ if id == "" {
+ return "", errors.New("invalid identifier")
+ }
+ return id, nil
+ },
+ Store: inMemoryStore,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ testCases := []struct {
+ id string
+ expectErr string
+ }{
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {expectErr: "code=403, message=error while extracting identifier, err=invalid identifier"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ }
+
+ for _, tc := range testCases {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRealIP, tc.id)
+
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec)
+
+ err := mw(handler)(c)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, http.StatusOK, rec.Code)
+ }
+}
+
+func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
+ {
+ var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
+
+ e := echo.New()
+
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ mw, err := RateLimiterConfig{
+ Store: inMemoryStore,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ testCases := []struct {
+ id string
+ expectErr string
+ }{
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ }
+
+ for _, tc := range testCases {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRealIP, tc.id)
+
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec)
+
+ err := mw(handler)(c)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, http.StatusOK, rec.Code)
+ }
+ }
+}
+
+func TestRateLimiterWithConfig_skipper(t *testing.T) {
+ e := echo.New()
+
+ var beforeFuncRan bool
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+ var inMemoryStore = NewRateLimiterMemoryStore(5)
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
+
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec)
+
+ mw, err := RateLimiterConfig{
+ Skipper: func(c *echo.Context) bool {
+ return true
+ },
+ BeforeFunc: func(c *echo.Context) {
+ beforeFuncRan = true
+ },
+ Store: inMemoryStore,
+ IdentifierExtractor: func(ctx *echo.Context) (string, error) {
+ return "127.0.0.1", nil
+ },
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(handler)(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, false, beforeFuncRan)
+}
+
+func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
+ e := echo.New()
+
+ var beforeFuncRan bool
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+ var inMemoryStore = NewRateLimiterMemoryStore(5)
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
+
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec)
+
+ mw, err := RateLimiterConfig{
+ Skipper: func(c *echo.Context) bool {
+ return false
+ },
+ BeforeFunc: func(c *echo.Context) {
+ beforeFuncRan = true
+ },
+ Store: inMemoryStore,
+ IdentifierExtractor: func(ctx *echo.Context) (string, error) {
+ return "127.0.0.1", nil
+ },
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ _ = mw(handler)(c)
+
+ assert.Equal(t, true, beforeFuncRan)
+}
+
+func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
+ e := echo.New()
+
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ var beforeRan bool
+ var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
+
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec)
+
+ mw, err := RateLimiterConfig{
+ BeforeFunc: func(c *echo.Context) {
+ beforeRan = true
+ },
+ Store: inMemoryStore,
+ IdentifierExtractor: func(ctx *echo.Context) (string, error) {
+ return "127.0.0.1", nil
+ },
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(handler)(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, true, beforeRan)
+}
+
+func TestRateLimiterMemoryStore_Allow(t *testing.T) {
+ var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3, ExpiresIn: 2 * time.Second})
+ testCases := []struct {
+ id string
+ allowed bool
+ }{
+ {"127.0.0.1", true}, // 0 ms
+ {"127.0.0.1", true}, // 220 ms burst #2
+ {"127.0.0.1", true}, // 440 ms burst #3
+ {"127.0.0.1", false}, // 660 ms block
+ {"127.0.0.1", false}, // 880 ms block
+ {"127.0.0.1", true}, // 1100 ms next second #1
+ {"127.0.0.2", true}, // 1320 ms allow other ip
+ {"127.0.0.1", false}, // 1540 ms no burst
+ {"127.0.0.1", false}, // 1760 ms no burst
+ {"127.0.0.1", false}, // 1980 ms no burst
+ {"127.0.0.1", true}, // 2200 ms no burst
+ {"127.0.0.1", false}, // 2420 ms no burst
+ {"127.0.0.1", false}, // 2640 ms no burst
+ {"127.0.0.1", false}, // 2860 ms no burst
+ {"127.0.0.1", true}, // 3080 ms no burst
+ {"127.0.0.1", false}, // 3300 ms no burst
+ {"127.0.0.1", false}, // 3520 ms no burst
+ {"127.0.0.1", false}, // 3740 ms no burst
+ {"127.0.0.1", false}, // 3960 ms no burst
+ {"127.0.0.1", true}, // 4180 ms no burst
+ {"127.0.0.1", false}, // 4400 ms no burst
+ {"127.0.0.1", false}, // 4620 ms no burst
+ {"127.0.0.1", false}, // 4840 ms no burst
+ {"127.0.0.1", true}, // 5060 ms no burst
+ }
+
+ for i, tc := range testCases {
+ t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond)
+ inMemoryStore.timeNow = func() time.Time {
+ return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond)
+ }
+ allowed, _ := inMemoryStore.Allow(tc.id)
+ assert.Equal(t, tc.allowed, allowed)
+ }
+}
+
+func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) {
+ var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
+ inMemoryStore.visitors = map[string]*Visitor{
+ "A": {
+ Limiter: rate.NewLimiter(1, 3),
+ lastSeen: time.Now(),
+ },
+ "B": {
+ Limiter: rate.NewLimiter(1, 3),
+ lastSeen: time.Now().Add(-1 * time.Minute),
+ },
+ "C": {
+ Limiter: rate.NewLimiter(1, 3),
+ lastSeen: time.Now().Add(-5 * time.Minute),
+ },
+ "D": {
+ Limiter: rate.NewLimiter(1, 3),
+ lastSeen: time.Now().Add(-10 * time.Minute),
+ },
+ }
+
+ inMemoryStore.Allow("D")
+ inMemoryStore.cleanupStaleVisitors(time.Now())
+
+ var exists bool
+
+ _, exists = inMemoryStore.visitors["A"]
+ assert.Equal(t, true, exists)
+
+ _, exists = inMemoryStore.visitors["B"]
+ assert.Equal(t, true, exists)
+
+ _, exists = inMemoryStore.visitors["C"]
+ assert.Equal(t, false, exists)
+
+ _, exists = inMemoryStore.visitors["D"]
+ assert.Equal(t, true, exists)
+}
+
+func TestNewRateLimiterMemoryStore(t *testing.T) {
+ testCases := []struct {
+ rate float64
+ burst int
+ expiresIn time.Duration
+ expectedExpiresIn time.Duration
+ }{
+ {1, 3, 5 * time.Second, 5 * time.Second},
+ {2, 4, 0, 3 * time.Minute},
+ {1, 5, 10 * time.Minute, 10 * time.Minute},
+ {3, 7, 0, 3 * time.Minute},
+ }
+
+ for _, tc := range testCases {
+ store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: tc.rate, Burst: tc.burst, ExpiresIn: tc.expiresIn})
+ assert.Equal(t, tc.rate, store.rate)
+ assert.Equal(t, tc.burst, store.burst)
+ assert.Equal(t, tc.expectedExpiresIn, store.expiresIn)
+ }
+}
+
+func TestRateLimiterMemoryStore_FractionalRateDefaultBurst(t *testing.T) {
+ store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
+ Rate: 0.5, // fractional rate should get a burst of at least 1
+ })
+
+ base := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
+ store.timeNow = func() time.Time {
+ return base
+ }
+
+ allowed, err := store.Allow("user")
+ assert.NoError(t, err)
+ assert.True(t, allowed, "first request should not be blocked")
+
+ allowed, err = store.Allow("user")
+ assert.NoError(t, err)
+ assert.False(t, allowed, "burst token should be consumed immediately")
+
+ store.timeNow = func() time.Time {
+ return base.Add(2 * time.Second)
+ }
+
+ allowed, err = store.Allow("user")
+ assert.NoError(t, err)
+ assert.True(t, allowed, "token should refill for fractional rate after time passes")
+}
+
+func generateAddressList(count int) []string {
+ addrs := make([]string, count)
+ for i := 0; i < count; i++ {
+ addrs[i] = randomString(15)
+ }
+ return addrs
+}
+
+func run(wg *sync.WaitGroup, store RateLimiterStore, addrs []string, max int, b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ store.Allow(addrs[rand.Intn(max)])
+ }
+ wg.Done()
+}
+
+func benchmarkStore(store RateLimiterStore, parallel int, max int, b *testing.B) {
+ addrs := generateAddressList(max)
+ wg := &sync.WaitGroup{}
+ for i := 0; i < parallel; i++ {
+ wg.Add(1)
+ go run(wg, store, addrs, max, b)
+ }
+ wg.Wait()
+}
+
+const (
+ testExpiresIn = 1000 * time.Millisecond
+)
+
+func BenchmarkRateLimiterMemoryStore_1000(b *testing.B) {
+ var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
+ benchmarkStore(store, 10, 1000, b)
+}
+
+func BenchmarkRateLimiterMemoryStore_10000(b *testing.B) {
+ var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
+ benchmarkStore(store, 10, 10000, b)
+}
+
+func BenchmarkRateLimiterMemoryStore_100000(b *testing.B) {
+ var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
+ benchmarkStore(store, 10, 100000, b)
+}
+
+func BenchmarkRateLimiterMemoryStore_conc100_10000(b *testing.B) {
+ var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
+ benchmarkStore(store, 100, 10000, b)
+}
+
+// TestRateLimiterMemoryStore_TOCTOUFix verifies that the TOCTOU race condition is fixed
+// by ensuring timeNow() is only called once per Allow() call
+func TestRateLimiterMemoryStore_TOCTOUFix(t *testing.T) {
+ t.Parallel()
+
+ store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
+ Rate: 1,
+ Burst: 1,
+ ExpiresIn: 2 * time.Second,
+ })
+
+ // Track time calls to verify we use the same time value
+ timeCallCount := 0
+ baseTime := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
+
+ store.timeNow = func() time.Time {
+ timeCallCount++
+ return baseTime
+ }
+
+ // First request - should succeed
+ allowed, err := store.Allow("127.0.0.1")
+ assert.NoError(t, err)
+ assert.True(t, allowed, "First request should be allowed")
+
+ // Verify timeNow() was only called once
+ assert.Equal(t, 1, timeCallCount, "timeNow() should only be called once per Allow()")
+}
+
+// TestRateLimiterMemoryStore_ConcurrentAccess verifies rate limiting correctness under concurrent load
+func TestRateLimiterMemoryStore_ConcurrentAccess(t *testing.T) {
+ t.Parallel()
+
+ store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
+ Rate: 10,
+ Burst: 5,
+ ExpiresIn: 5 * time.Second,
+ })
+
+ const goroutines = 50
+ const requestsPerGoroutine = 20
+
+ var wg sync.WaitGroup
+ var allowedCount, deniedCount int32
+
+ for i := 0; i < goroutines; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for j := 0; j < requestsPerGoroutine; j++ {
+ allowed, err := store.Allow("test-user")
+ assert.NoError(t, err)
+ if allowed {
+ atomic.AddInt32(&allowedCount, 1)
+ } else {
+ atomic.AddInt32(&deniedCount, 1)
+ }
+ time.Sleep(time.Millisecond)
+ }
+ }()
+ }
+
+ wg.Wait()
+
+ totalRequests := goroutines * requestsPerGoroutine
+ allowed := int(allowedCount)
+ denied := int(deniedCount)
+
+ assert.Equal(t, totalRequests, allowed+denied, "All requests should be processed")
+ assert.Greater(t, denied, 0, "Some requests should be denied due to rate limiting")
+ assert.Greater(t, allowed, 0, "Some requests should be allowed")
+}
+
+// TestRateLimiterMemoryStore_RaceDetection verifies no data races with high concurrency
+// Run with: go test -race ./middleware -run TestRateLimiterMemoryStore_RaceDetection
+func TestRateLimiterMemoryStore_RaceDetection(t *testing.T) {
+ t.Parallel()
+
+ store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
+ Rate: 100,
+ Burst: 200,
+ ExpiresIn: 1 * time.Second,
+ })
+
+ const goroutines = 100
+ const requestsPerGoroutine = 100
+
+ var wg sync.WaitGroup
+ identifiers := []string{"user1", "user2", "user3", "user4", "user5"}
+
+ for i := 0; i < goroutines; i++ {
+ wg.Add(1)
+ go func(routineID int) {
+ defer wg.Done()
+ for j := 0; j < requestsPerGoroutine; j++ {
+ identifier := identifiers[routineID%len(identifiers)]
+ _, err := store.Allow(identifier)
+ assert.NoError(t, err)
+ }
+ }(i)
+ }
+
+ wg.Wait()
+}
+
+// TestRateLimiterMemoryStore_TimeOrdering verifies time ordering consistency in rate limiting decisions
+func TestRateLimiterMemoryStore_TimeOrdering(t *testing.T) {
+ t.Parallel()
+
+ store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
+ Rate: 1,
+ Burst: 2,
+ ExpiresIn: 5 * time.Second,
+ })
+
+ currentTime := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
+ store.timeNow = func() time.Time {
+ return currentTime
+ }
+
+ // First two requests should succeed (burst=2)
+ allowed1, _ := store.Allow("user1")
+ assert.True(t, allowed1, "Request 1 should be allowed (burst)")
+
+ allowed2, _ := store.Allow("user1")
+ assert.True(t, allowed2, "Request 2 should be allowed (burst)")
+
+ // Third request should be denied
+ allowed3, _ := store.Allow("user1")
+ assert.False(t, allowed3, "Request 3 should be denied (burst exhausted)")
+
+ // Advance time by 1 second
+ currentTime = currentTime.Add(1 * time.Second)
+
+ // Fourth request should succeed
+ allowed4, _ := store.Allow("user1")
+ assert.True(t, allowed4, "Request 4 should be allowed (1 token available)")
+}
diff --git a/middleware/recover.go b/middleware/recover.go
index e87aaf321..01fde5152 100644
--- a/middleware/recover.go
+++ b/middleware/recover.go
@@ -1,42 +1,42 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
"fmt"
+ "net/http"
"runtime"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
)
-type (
- // RecoverConfig defines the config for Recover middleware.
- RecoverConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
+// RecoverConfig defines the config for Recover middleware.
+type RecoverConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
- // Size of the stack to be printed.
- // Optional. Default value 4KB.
- StackSize int `yaml:"stack_size"`
+ // Size of the stack to be printed.
+ // Optional. Default value 4KB.
+ StackSize int
- // DisableStackAll disables formatting stack traces of all other goroutines
- // into buffer after the trace for the current goroutine.
- // Optional. Default value false.
- DisableStackAll bool `yaml:"disable_stack_all"`
+ // DisableStackAll disables formatting stack traces of all other goroutines
+ // into buffer after the trace for the current goroutine.
+ // Optional. Default value false.
+ DisableStackAll bool
- // DisablePrintStack disables printing stack trace.
- // Optional. Default value as false.
- DisablePrintStack bool `yaml:"disable_print_stack"`
- }
-)
+ // DisablePrintStack disables printing stack trace.
+ // Optional. Default value as false.
+ DisablePrintStack bool
+}
-var (
- // DefaultRecoverConfig is the default Recover middleware config.
- DefaultRecoverConfig = RecoverConfig{
- Skipper: DefaultSkipper,
- StackSize: 4 << 10, // 4 KB
- DisableStackAll: false,
- DisablePrintStack: false,
- }
-)
+// DefaultRecoverConfig is the default Recover middleware config.
+var DefaultRecoverConfig = RecoverConfig{
+ Skipper: DefaultSkipper,
+ StackSize: 4 << 10, // 4 KB
+ DisableStackAll: false,
+ DisablePrintStack: false,
+}
// Recover returns a middleware which recovers from panics anywhere in the chain
// and handles the control to the centralized HTTPErrorHandler.
@@ -44,9 +44,13 @@ func Recover() echo.MiddlewareFunc {
return RecoverWithConfig(DefaultRecoverConfig)
}
-// RecoverWithConfig returns a Recover middleware with config.
-// See: `Recover()`.
+// RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration.
func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration
+func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultRecoverConfig.Skipper
@@ -56,26 +60,44 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) error {
+ return func(c *echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
defer func() {
if r := recover(); r != nil {
- err, ok := r.(error)
+ if r == http.ErrAbortHandler {
+ panic(r)
+ }
+ tmpErr, ok := r.(error)
if !ok {
- err = fmt.Errorf("%v", r)
+ tmpErr = fmt.Errorf("%v", r)
}
- stack := make([]byte, config.StackSize)
- length := runtime.Stack(stack, !config.DisableStackAll)
if !config.DisablePrintStack {
- c.Logger().Printf("[PANIC RECOVER] %v %s\n", err, stack[:length])
+ stack := make([]byte, config.StackSize)
+ length := runtime.Stack(stack, !config.DisableStackAll)
+ tmpErr = &PanicStackError{Stack: stack[:length], Err: tmpErr}
}
- c.Error(err)
+ err = tmpErr
}
}()
return next(c)
}
- }
+ }, nil
+}
+
+// PanicStackError is an error type that wraps an error along with its stack trace.
+// It is returned when config.DisablePrintStack is set to false.
+type PanicStackError struct {
+ Stack []byte
+ Err error
+}
+
+func (e *PanicStackError) Error() string {
+ return fmt.Sprintf("[PANIC RECOVER] %s %s", e.Err.Error(), e.Stack)
+}
+
+func (e *PanicStackError) Unwrap() error {
+ return e.Err
}
diff --git a/middleware/recover_test.go b/middleware/recover_test.go
index 37707c5c1..719e0cc3d 100644
--- a/middleware/recover_test.go
+++ b/middleware/recover_test.go
@@ -1,26 +1,150 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
"bytes"
+ "errors"
+ "log/slog"
"net/http"
"net/http/httptest"
"testing"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestRecover(t *testing.T) {
e := echo.New()
buf := new(bytes.Buffer)
- e.Logger.SetOutput(buf)
+ e.Logger = slog.New(&discardHandler{})
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- h := Recover()(echo.HandlerFunc(func(c echo.Context) error {
+ h := Recover()(func(c *echo.Context) error {
panic("test")
- }))
- h(c)
+ })
+ err := h(c)
+ assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine")
+
+ var pse *PanicStackError
+ if errors.As(err, &pse) {
+ assert.Contains(t, string(pse.Stack), "middleware/recover.go")
+ } else {
+ assert.Fail(t, "not of type PanicStackError")
+ }
+
+ assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
+ assert.Contains(t, buf.String(), "") // nothing is logged
+}
+
+func TestRecover_skipper(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ config := RecoverConfig{
+ Skipper: func(c *echo.Context) bool {
+ return true
+ },
+ }
+ h := RecoverWithConfig(config)(func(c *echo.Context) error {
+ panic("testPANIC")
+ })
+
+ var err error
+ assert.Panics(t, func() {
+ err = h(c)
+ })
+
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
+}
+
+func TestRecoverErrAbortHandler(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := Recover()(func(c *echo.Context) error {
+ panic(http.ErrAbortHandler)
+ })
+ defer func() {
+ r := recover()
+ if r == nil {
+ assert.Fail(t, "expecting `http.ErrAbortHandler`, got `nil`")
+ } else {
+ if err, ok := r.(error); ok {
+ assert.ErrorIs(t, err, http.ErrAbortHandler)
+ } else {
+ assert.Fail(t, "not of error type")
+ }
+ }
+ }()
+
+ hErr := h(c)
+
assert.Equal(t, http.StatusInternalServerError, rec.Code)
- assert.Contains(t, buf.String(), "PANIC RECOVER")
+ assert.NotContains(t, hErr.Error(), "PANIC RECOVER")
+}
+
+func TestRecoverWithConfig(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenNoPanic bool
+ whenConfig RecoverConfig
+ expectErrContain string
+ expectErr string
+ }{
+ {
+ name: "ok, default config",
+ whenConfig: DefaultRecoverConfig,
+ expectErrContain: "[PANIC RECOVER] testPANIC goroutine",
+ },
+ {
+ name: "ok, no panic",
+ givenNoPanic: true,
+ whenConfig: DefaultRecoverConfig,
+ expectErrContain: "",
+ },
+ {
+ name: "ok, DisablePrintStack",
+ whenConfig: RecoverConfig{
+ DisablePrintStack: true,
+ },
+ expectErr: "testPANIC",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ config := tc.whenConfig
+ h := RecoverWithConfig(config)(func(c *echo.Context) error {
+ if tc.givenNoPanic {
+ return nil
+ }
+ panic("testPANIC")
+ })
+
+ err := h(c)
+
+ if tc.expectErrContain != "" {
+ assert.Contains(t, err.Error(), tc.expectErrContain)
+ } else if tc.expectErr != "" {
+ assert.Contains(t, err.Error(), tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
+ })
+ }
}
diff --git a/middleware/redirect.go b/middleware/redirect.go
index 813e5b856..bb7045cfe 100644
--- a/middleware/redirect.go
+++ b/middleware/redirect.go
@@ -1,9 +1,14 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
+ "errors"
"net/http"
+ "strings"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
)
// RedirectConfig defines the config for Redirect middleware.
@@ -13,7 +18,9 @@ type RedirectConfig struct {
// Status code to be used when redirecting the request.
// Optional. Default value http.StatusMovedPermanently.
- Code int `yaml:"code"`
+ Code int
+
+ redirect redirectLogic
}
// redirectLogic represents a function that given a scheme, host and uri
@@ -23,29 +30,33 @@ type redirectLogic func(scheme, host, uri string) (ok bool, url string)
const www = "www."
-// DefaultRedirectConfig is the default Redirect middleware config.
-var DefaultRedirectConfig = RedirectConfig{
- Skipper: DefaultSkipper,
- Code: http.StatusMovedPermanently,
-}
+// RedirectHTTPSConfig is the HTTPS Redirect middleware config.
+var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS}
+
+// RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config.
+var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW}
+
+// RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config.
+var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW}
+
+// RedirectWWWConfig is the WWW Redirect middleware config.
+var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW}
+
+// RedirectNonWWWConfig is the non WWW Redirect middleware config.
+var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW}
// HTTPSRedirect redirects http requests to https.
// For example, http://labstack.com will be redirect to https://labstack.com.
//
// Usage `Echo#Pre(HTTPSRedirect())`
func HTTPSRedirect() echo.MiddlewareFunc {
- return HTTPSRedirectWithConfig(DefaultRedirectConfig)
+ return HTTPSRedirectWithConfig(RedirectHTTPSConfig)
}
-// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config.
-// See `HTTPSRedirect()`.
+// HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration.
func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
- return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
- if ok = scheme != "https"; ok {
- url = "https://" + host + uri
- }
- return
- })
+ config.redirect = redirectHTTPS
+ return toMiddlewareOrPanic(config)
}
// HTTPSWWWRedirect redirects http requests to https www.
@@ -53,18 +64,13 @@ func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
//
// Usage `Echo#Pre(HTTPSWWWRedirect())`
func HTTPSWWWRedirect() echo.MiddlewareFunc {
- return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig)
+ return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig)
}
-// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
-// See `HTTPSWWWRedirect()`.
+// HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration.
func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
- return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
- if ok = scheme != "https" && host[:4] != www; ok {
- url = "https://www." + host + uri
- }
- return
- })
+ config.redirect = redirectHTTPSWWW
+ return toMiddlewareOrPanic(config)
}
// HTTPSNonWWWRedirect redirects http requests to https non www.
@@ -72,21 +78,13 @@ func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
//
// Usage `Echo#Pre(HTTPSNonWWWRedirect())`
func HTTPSNonWWWRedirect() echo.MiddlewareFunc {
- return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig)
+ return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig)
}
-// HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
-// See `HTTPSNonWWWRedirect()`.
+// HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration.
func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
- return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
- if ok = scheme != "https"; ok {
- if host[:4] == www {
- host = host[4:]
- }
- url = "https://" + host + uri
- }
- return
- })
+ config.redirect = redirectNonHTTPSWWW
+ return toMiddlewareOrPanic(config)
}
// WWWRedirect redirects non www requests to www.
@@ -94,18 +92,13 @@ func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
//
// Usage `Echo#Pre(WWWRedirect())`
func WWWRedirect() echo.MiddlewareFunc {
- return WWWRedirectWithConfig(DefaultRedirectConfig)
+ return WWWRedirectWithConfig(RedirectWWWConfig)
}
-// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
-// See `WWWRedirect()`.
+// WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration.
func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
- return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
- if ok = host[:4] != www; ok {
- url = scheme + "://www." + host + uri
- }
- return
- })
+ config.redirect = redirectWWW
+ return toMiddlewareOrPanic(config)
}
// NonWWWRedirect redirects www requests to non www.
@@ -113,41 +106,79 @@ func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
//
// Usage `Echo#Pre(NonWWWRedirect())`
func NonWWWRedirect() echo.MiddlewareFunc {
- return NonWWWRedirectWithConfig(DefaultRedirectConfig)
+ return NonWWWRedirectWithConfig(RedirectNonWWWConfig)
}
-// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config.
-// See `NonWWWRedirect()`.
+// NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration.
func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
- return redirect(config, func(scheme, host, uri string) (ok bool, url string) {
- if ok = host[:4] == www; ok {
- url = scheme + "://" + host[4:] + uri
- }
- return
- })
+ config.redirect = redirectNonWWW
+ return toMiddlewareOrPanic(config)
}
-func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc {
+// ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration
+func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
- config.Skipper = DefaultTrailingSlashConfig.Skipper
+ config.Skipper = DefaultSkipper
}
if config.Code == 0 {
- config.Code = DefaultRedirectConfig.Code
+ config.Code = http.StatusMovedPermanently
+ }
+ if config.redirect == nil {
+ return nil, errors.New("redirectConfig is missing redirect function")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) error {
+ return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req, scheme := c.Request(), c.Scheme()
host := req.Host
- if ok, url := cb(scheme, host, req.RequestURI); ok {
+ if ok, url := config.redirect(scheme, host, req.RequestURI); ok {
return c.Redirect(config.Code, url)
}
return next(c)
}
+ }, nil
+}
+
+var redirectHTTPS = func(scheme, host, uri string) (bool, string) {
+ if scheme != "https" {
+ return true, "https://" + host + uri
+ }
+ return false, ""
+}
+
+var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) {
+ // Redirect if not HTTPS OR missing www prefix (needs either fix)
+ if scheme != "https" || !strings.HasPrefix(host, www) {
+ host = strings.TrimPrefix(host, www) // Remove www if present to avoid duplication
+ return true, "https://www." + host + uri
+ }
+ return false, ""
+}
+
+var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) {
+ // Redirect if not HTTPS OR has www prefix (needs either fix)
+ if scheme != "https" || strings.HasPrefix(host, www) {
+ host = strings.TrimPrefix(host, www)
+ return true, "https://" + host + uri
+ }
+ return false, ""
+}
+
+var redirectWWW = func(scheme, host, uri string) (bool, string) {
+ if !strings.HasPrefix(host, www) {
+ return true, scheme + "://www." + host + uri
+ }
+ return false, ""
+}
+
+var redirectNonWWW = func(scheme, host, uri string) (bool, string) {
+ if strings.HasPrefix(host, www) {
+ return true, scheme + "://" + host[4:] + uri
}
+ return false, ""
}
diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go
index 082609574..a127ca40c 100644
--- a/middleware/redirect_test.go
+++ b/middleware/redirect_test.go
@@ -1,3 +1,6 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
@@ -5,74 +8,277 @@ import (
"net/http/httptest"
"testing"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
type middlewareGenerator func() echo.MiddlewareFunc
func TestRedirectHTTPSRedirect(t *testing.T) {
- res := redirectTest(HTTPSRedirect, "labstack.com", nil)
-
- assert.Equal(t, http.StatusMovedPermanently, res.Code)
- assert.Equal(t, "https://labstack.com/", res.Header().Get(echo.HeaderLocation))
-}
+ var testCases = []struct {
+ whenHost string
+ whenHeader http.Header
+ expectLocation string
+ expectStatusCode int
+ }{
+ {
+ whenHost: "labstack.com",
+ expectLocation: "https://labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "labstack.com",
+ whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
+ expectLocation: "",
+ expectStatusCode: http.StatusOK,
+ },
+ }
-func TestHTTPSRedirectBehindTLSTerminationProxy(t *testing.T) {
- header := http.Header{}
- header.Set(echo.HeaderXForwardedProto, "https")
- res := redirectTest(HTTPSRedirect, "labstack.com", header)
+ for _, tc := range testCases {
+ t.Run(tc.whenHost, func(t *testing.T) {
+ res := redirectTest(HTTPSRedirect, tc.whenHost, tc.whenHeader)
- assert.Equal(t, http.StatusOK, res.Code)
+ assert.Equal(t, tc.expectStatusCode, res.Code)
+ assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
+ })
+ }
}
func TestRedirectHTTPSWWWRedirect(t *testing.T) {
- res := redirectTest(HTTPSWWWRedirect, "labstack.com", nil)
-
- assert.Equal(t, http.StatusMovedPermanently, res.Code)
- assert.Equal(t, "https://www.labstack.com/", res.Header().Get(echo.HeaderLocation))
-}
+ var testCases = []struct {
+ whenHost string
+ whenHeader http.Header
+ expectLocation string
+ expectStatusCode int
+ }{
+ {
+ whenHost: "labstack.com",
+ expectLocation: "https://www.labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "www.labstack.com",
+ expectLocation: "https://www.labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "a.com",
+ expectLocation: "https://www.a.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "ip",
+ expectLocation: "https://www.ip/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "labstack.com",
+ whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
+ expectLocation: "https://www.labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "www.labstack.com",
+ whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
+ expectLocation: "",
+ expectStatusCode: http.StatusOK,
+ },
+ }
-func TestRedirectHTTPSWWWRedirectBehindTLSTerminationProxy(t *testing.T) {
- header := http.Header{}
- header.Set(echo.HeaderXForwardedProto, "https")
- res := redirectTest(HTTPSWWWRedirect, "labstack.com", header)
+ for _, tc := range testCases {
+ t.Run(tc.whenHost, func(t *testing.T) {
+ res := redirectTest(HTTPSWWWRedirect, tc.whenHost, tc.whenHeader)
- assert.Equal(t, http.StatusOK, res.Code)
+ assert.Equal(t, tc.expectStatusCode, res.Code)
+ assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
+ })
+ }
}
func TestRedirectHTTPSNonWWWRedirect(t *testing.T) {
- res := redirectTest(HTTPSNonWWWRedirect, "www.labstack.com", nil)
-
- assert.Equal(t, http.StatusMovedPermanently, res.Code)
- assert.Equal(t, "https://labstack.com/", res.Header().Get(echo.HeaderLocation))
-}
+ var testCases = []struct {
+ whenHost string
+ whenHeader http.Header
+ expectLocation string
+ expectStatusCode int
+ }{
+ {
+ whenHost: "www.labstack.com",
+ expectLocation: "https://labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "a.com",
+ expectLocation: "https://a.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "ip",
+ expectLocation: "https://ip/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "www.labstack.com",
+ whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
+ expectLocation: "https://labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "labstack.com",
+ whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
+ expectLocation: "",
+ expectStatusCode: http.StatusOK,
+ },
+ }
-func TestRedirectHTTPSNonWWWRedirectBehindTLSTerminationProxy(t *testing.T) {
- header := http.Header{}
- header.Set(echo.HeaderXForwardedProto, "https")
- res := redirectTest(HTTPSNonWWWRedirect, "www.labstack.com", header)
+ for _, tc := range testCases {
+ t.Run(tc.whenHost, func(t *testing.T) {
+ res := redirectTest(HTTPSNonWWWRedirect, tc.whenHost, tc.whenHeader)
- assert.Equal(t, http.StatusOK, res.Code)
+ assert.Equal(t, tc.expectStatusCode, res.Code)
+ assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
+ })
+ }
}
func TestRedirectWWWRedirect(t *testing.T) {
- res := redirectTest(WWWRedirect, "labstack.com", nil)
+ var testCases = []struct {
+ whenHost string
+ whenHeader http.Header
+ expectLocation string
+ expectStatusCode int
+ }{
+ {
+ whenHost: "labstack.com",
+ expectLocation: "http://www.labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "a.com",
+ expectLocation: "http://www.a.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "ip",
+ expectLocation: "http://www.ip/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "a.com",
+ whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
+ expectLocation: "https://www.a.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "www.ip",
+ expectLocation: "",
+ expectStatusCode: http.StatusOK,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.whenHost, func(t *testing.T) {
+ res := redirectTest(WWWRedirect, tc.whenHost, tc.whenHeader)
- assert.Equal(t, http.StatusMovedPermanently, res.Code)
- assert.Equal(t, "http://www.labstack.com/", res.Header().Get(echo.HeaderLocation))
+ assert.Equal(t, tc.expectStatusCode, res.Code)
+ assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
+ })
+ }
}
func TestRedirectNonWWWRedirect(t *testing.T) {
- res := redirectTest(NonWWWRedirect, "www.labstack.com", nil)
+ var testCases = []struct {
+ whenHost string
+ whenHeader http.Header
+ expectLocation string
+ expectStatusCode int
+ }{
+ {
+ whenHost: "www.labstack.com",
+ expectLocation: "http://labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "www.a.com",
+ expectLocation: "http://a.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "www.a.com",
+ whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
+ expectLocation: "https://a.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "ip",
+ expectLocation: "",
+ expectStatusCode: http.StatusOK,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.whenHost, func(t *testing.T) {
+ res := redirectTest(NonWWWRedirect, tc.whenHost, tc.whenHeader)
+
+ assert.Equal(t, tc.expectStatusCode, res.Code)
+ assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
+ })
+ }
+}
+
+func TestNonWWWRedirectWithConfig(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenCode int
+ givenSkipFunc func(c *echo.Context) bool
+ whenHost string
+ whenHeader http.Header
+ expectLocation string
+ expectStatusCode int
+ }{
+ {
+ name: "usual redirect",
+ whenHost: "www.labstack.com",
+ expectLocation: "http://labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ name: "redirect is skipped",
+ givenSkipFunc: func(c *echo.Context) bool {
+ return true // skip always
+ },
+ whenHost: "www.labstack.com",
+ expectLocation: "",
+ expectStatusCode: http.StatusOK,
+ },
+ {
+ name: "redirect with custom status code",
+ givenCode: http.StatusSeeOther,
+ whenHost: "www.labstack.com",
+ expectLocation: "http://labstack.com/",
+ expectStatusCode: http.StatusSeeOther,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.whenHost, func(t *testing.T) {
+ middleware := func() echo.MiddlewareFunc {
+ return NonWWWRedirectWithConfig(RedirectConfig{
+ Skipper: tc.givenSkipFunc,
+ Code: tc.givenCode,
+ })
+ }
+ res := redirectTest(middleware, tc.whenHost, tc.whenHeader)
- assert.Equal(t, http.StatusMovedPermanently, res.Code)
- assert.Equal(t, "http://labstack.com/", res.Header().Get(echo.HeaderLocation))
+ assert.Equal(t, tc.expectStatusCode, res.Code)
+ assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
+ })
+ }
}
func redirectTest(fn middlewareGenerator, host string, header http.Header) *httptest.ResponseRecorder {
e := echo.New()
- next := func(c echo.Context) (err error) {
+ next := func(c *echo.Context) (err error) {
return c.NoContent(http.StatusOK)
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
diff --git a/middleware/request_id.go b/middleware/request_id.go
index 21f801f3b..b3de40d19 100644
--- a/middleware/request_id.go
+++ b/middleware/request_id.go
@@ -1,64 +1,73 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
- "github.com/labstack/echo/v4"
- "github.com/labstack/gommon/random"
+ "github.com/labstack/echo/v5"
)
-type (
- // RequestIDConfig defines the config for RequestID middleware.
- RequestIDConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
+// RequestIDConfig defines the config for RequestID middleware.
+type RequestIDConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
- // Generator defines a function to generate an ID.
- // Optional. Default value random.String(32).
- Generator func() string
- }
-)
+ // Generator defines a function to generate an ID.
+ // Optional. Default value random.String(32).
+ Generator func() string
-var (
- // DefaultRequestIDConfig is the default RequestID middleware config.
- DefaultRequestIDConfig = RequestIDConfig{
- Skipper: DefaultSkipper,
- Generator: generator,
- }
-)
+ // RequestIDHandler defines a function which is executed for a request id.
+ RequestIDHandler func(c *echo.Context, requestID string)
+
+ // TargetHeader defines what header to look for to populate the id.
+ // Optional. Default value is `X-Request-Id`
+ TargetHeader string
+}
-// RequestID returns a X-Request-ID middleware.
+// RequestID returns a middleware that reads RequestIDConfig.TargetHeader (`X-Request-ID`) header value or when
+// the header value is empty, generates that value and sets request ID to response
+// as RequestIDConfig.TargetHeader (`X-Request-Id`) value.
func RequestID() echo.MiddlewareFunc {
- return RequestIDWithConfig(DefaultRequestIDConfig)
+ return RequestIDWithConfig(RequestIDConfig{})
}
-// RequestIDWithConfig returns a X-Request-ID middleware with config.
+// RequestIDWithConfig returns a middleware with given valid config or panics on invalid configuration.
+// The middleware reads RequestIDConfig.TargetHeader (`X-Request-ID`) header value or when the header value is empty,
+// generates that value and sets request ID to response as RequestIDConfig.TargetHeader (`X-Request-Id`) value.
func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
- // Defaults
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration
+func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
- config.Skipper = DefaultRequestIDConfig.Skipper
+ config.Skipper = DefaultSkipper
}
if config.Generator == nil {
- config.Generator = generator
+ config.Generator = createRandomStringGenerator(32)
+ }
+ if config.TargetHeader == "" {
+ config.TargetHeader = echo.HeaderXRequestID
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) error {
+ return func(c *echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()
- rid := req.Header.Get(echo.HeaderXRequestID)
+ rid := req.Header.Get(config.TargetHeader)
if rid == "" {
rid = config.Generator()
}
- res.Header().Set(echo.HeaderXRequestID, rid)
+ res.Header().Set(config.TargetHeader, rid)
+ if config.RequestIDHandler != nil {
+ config.RequestIDHandler(c, rid)
+ }
return next(c)
}
- }
-}
-
-func generator() string {
- return random.String(32)
+ }, nil
}
diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go
index 30eecdef9..465e6fc42 100644
--- a/middleware/request_id_test.go
+++ b/middleware/request_id_test.go
@@ -1,3 +1,6 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
@@ -5,7 +8,7 @@ import (
"net/http/httptest"
"testing"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
@@ -14,11 +17,97 @@ func TestRequestID(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
- handler := func(c echo.Context) error {
+ handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
- rid := RequestIDWithConfig(RequestIDConfig{})
+ rid := RequestID()
+ h := rid(handler)
+ err := h(c)
+ assert.NoError(t, err)
+ assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
+}
+
+func TestMustRequestIDWithConfig_skipper(t *testing.T) {
+ e := echo.New()
+ e.GET("/", func(c *echo.Context) error {
+ return c.String(http.StatusTeapot, "test")
+ })
+
+ generatorCalled := false
+ e.Use(RequestIDWithConfig(RequestIDConfig{
+ Skipper: func(c *echo.Context) bool {
+ return true
+ },
+ Generator: func() string {
+ generatorCalled = true
+ return "customGenerator"
+ },
+ }))
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ res := httptest.NewRecorder()
+ e.ServeHTTP(res, req)
+
+ assert.Equal(t, http.StatusTeapot, res.Code)
+ assert.Equal(t, "test", res.Body.String())
+
+ assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "")
+ assert.False(t, generatorCalled)
+}
+
+func TestMustRequestIDWithConfig_customGenerator(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ rid := RequestIDWithConfig(RequestIDConfig{
+ Generator: func() string { return "customGenerator" },
+ })
+ h := rid(handler)
+ err := h(c)
+ assert.NoError(t, err)
+ assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
+}
+
+func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ called := false
+ rid := RequestIDWithConfig(RequestIDConfig{
+ Generator: func() string { return "customGenerator" },
+ RequestIDHandler: func(c *echo.Context, s string) {
+ called = true
+ },
+ })
+ h := rid(handler)
+ err := h(c)
+ assert.NoError(t, err)
+ assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
+ assert.True(t, called)
+}
+
+func TestRequestIDWithConfig(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ rid, err := RequestIDConfig{}.ToMiddleware()
+ assert.NoError(t, err)
h := rid(handler)
h(c)
assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
@@ -31,3 +120,51 @@ func TestRequestID(t *testing.T) {
h(c)
assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
}
+
+func TestRequestID_IDNotAltered(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRequestID, "")
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ rid := RequestIDWithConfig(RequestIDConfig{})
+ h := rid(handler)
+ _ = h(c)
+ assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "")
+}
+
+func TestRequestIDConfigDifferentHeader(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ rid := RequestIDWithConfig(RequestIDConfig{TargetHeader: echo.HeaderXCorrelationID})
+ h := rid(handler)
+ h(c)
+ assert.Len(t, rec.Header().Get(echo.HeaderXCorrelationID), 32)
+
+ // Custom generator and handler
+ customID := "customGenerator"
+ calledHandler := false
+ rid = RequestIDWithConfig(RequestIDConfig{
+ Generator: func() string { return customID },
+ TargetHeader: echo.HeaderXCorrelationID,
+ RequestIDHandler: func(_ *echo.Context, id string) {
+ calledHandler = true
+ assert.Equal(t, customID, id)
+ },
+ })
+ h = rid(handler)
+ h(c)
+ assert.Equal(t, rec.Header().Get(echo.HeaderXCorrelationID), "customGenerator")
+ assert.True(t, calledHandler)
+}
diff --git a/middleware/request_logger.go b/middleware/request_logger.go
new file mode 100644
index 000000000..76903c62a
--- /dev/null
+++ b/middleware/request_logger.go
@@ -0,0 +1,462 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "context"
+ "errors"
+ "log/slog"
+ "net/http"
+ "time"
+
+ "github.com/labstack/echo/v5"
+)
+
+// Example for `slog` https://pkg.go.dev/log/slog
+// logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
+// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
+// LogStatus: true,
+// LogURI: true,
+// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
+// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
+// if v.Error == nil {
+// logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST",
+// slog.String("uri", v.URI),
+// slog.Int("status", v.Status),
+// )
+// } else {
+// logger.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR",
+// slog.String("uri", v.URI),
+// slog.Int("status", v.Status),
+// slog.String("err", v.Error.Error()),
+// )
+// }
+// return nil
+// },
+// }))
+//
+// Example for `fmt.Printf`
+// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
+// LogStatus: true,
+// LogURI: true,
+// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
+// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
+// if v.Error == nil {
+// fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status)
+// } else {
+// fmt.Printf("REQUEST_ERROR: uri: %v, status: %v, err: %v\n", v.URI, v.Status, v.Error)
+// }
+// return nil
+// },
+// }))
+//
+// Example for Zerolog (https://github.com/rs/zerolog)
+// logger := zerolog.New(os.Stdout)
+// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
+// LogURI: true,
+// LogStatus: true,
+// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
+// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
+// if v.Error == nil {
+// logger.Info().
+// Str("URI", v.URI).
+// Int("status", v.Status).
+// Msg("request")
+// } else {
+// logger.Error().
+// Err(v.Error).
+// Str("URI", v.URI).
+// Int("status", v.Status).
+// Msg("request error")
+// }
+// return nil
+// },
+// }))
+//
+// Example for Zap (https://github.com/uber-go/zap)
+// logger, _ := zap.NewProduction()
+// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
+// LogURI: true,
+// LogStatus: true,
+// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
+// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
+// if v.Error == nil {
+// logger.Info("request",
+// zap.String("URI", v.URI),
+// zap.Int("status", v.Status),
+// )
+// } else {
+// logger.Error("request error",
+// zap.String("URI", v.URI),
+// zap.Int("status", v.Status),
+// zap.Error(v.Error),
+// )
+// }
+// return nil
+// },
+// }))
+//
+// Example for Logrus (https://github.com/sirupsen/logrus)
+// log := logrus.New()
+// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
+// LogURI: true,
+// LogStatus: true,
+// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
+// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
+// if v.Error == nil {
+// log.WithFields(logrus.Fields{
+// "URI": v.URI,
+// "status": v.Status,
+// }).Info("request")
+// } else {
+// log.WithFields(logrus.Fields{
+// "URI": v.URI,
+// "status": v.Status,
+// "error": v.Error,
+// }).Error("request error")
+// }
+// return nil
+// },
+// }))
+
+// RequestLoggerConfig is configuration for Request Logger middleware.
+type RequestLoggerConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // BeforeNextFunc defines a function that is called before next middleware or handler is called in chain.
+ BeforeNextFunc func(c *echo.Context)
+ // LogValuesFunc defines a function that is called with values extracted by logger from request/response.
+ // Mandatory.
+ LogValuesFunc func(c *echo.Context, v RequestLoggerValues) error
+
+ // HandleError instructs logger to call global error handler when next middleware/handler returns an error.
+ // This is useful when you have custom error handler that can decide to use different status codes.
+ //
+ // A side-effect of calling global error handler is that now Response has been committed and sent to the client
+ // and middlewares up in chain can not change Response status code or response body.
+ HandleError bool
+
+ // LogLatency instructs logger to record duration it took to execute rest of the handler chain (next(c) call).
+ LogLatency bool
+ // LogProtocol instructs logger to extract request protocol (i.e. `HTTP/1.1` or `HTTP/2`)
+ LogProtocol bool
+ // LogRemoteIP instructs logger to extract request remote IP. See `echo.Context.RealIP()` for implementation details.
+ LogRemoteIP bool
+ // LogHost instructs logger to extract request host value (i.e. `example.com`)
+ LogHost bool
+ // LogMethod instructs logger to extract request method value (i.e. `GET` etc)
+ LogMethod bool
+ // LogURI instructs logger to extract request URI (i.e. `/list?lang=en&page=1`)
+ LogURI bool
+ // LogURIPath instructs logger to extract request URI path part (i.e. `/list`)
+ LogURIPath bool
+ // LogRoutePath instructs logger to extract route path part to which request was matched to (i.e. `/user/:id`)
+ LogRoutePath bool
+ // LogRequestID instructs logger to extract request ID from request `X-Request-ID` header or response if request did not have value.
+ LogRequestID bool
+ // LogReferer instructs logger to extract request referer values.
+ LogReferer bool
+ // LogUserAgent instructs logger to extract request user agent values.
+ LogUserAgent bool
+ // LogStatus instructs logger to extract response status code. If handler chain returns an echo.HTTPError,
+ // the status code is extracted from the echo.HTTPError returned
+ LogStatus bool
+ // LogContentLength instructs logger to extract content length header value. Note: this value could be different from
+ // actual request body size as it could be spoofed etc.
+ LogContentLength bool
+ // LogResponseSize instructs logger to extract response content length value. Note: when used with Gzip middleware
+ // this value may not be always correct.
+ LogResponseSize bool
+ // LogHeaders instructs logger to extract given list of headers from request. Note: request can contain more than
+ // one header with same value so slice of values is been logger for each given header.
+ //
+ // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header
+ // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding".
+ LogHeaders []string
+ // LogQueryParams instructs logger to extract given list of query parameters from request URI. Note: request can
+ // contain more than one query parameter with same name so slice of values is been logger for each given query param name.
+ LogQueryParams []string
+ // LogFormValues instructs logger to extract given list of form values from request body+URI. Note: request can
+ // contain more than one form value with same name so slice of values is been logger for each given form value name.
+ LogFormValues []string
+
+ timeNow func() time.Time
+}
+
+// RequestLoggerValues contains extracted values from logger.
+type RequestLoggerValues struct {
+ // StartTime is time recorded before next middleware/handler is executed.
+ StartTime time.Time
+ // Latency is duration it took to execute rest of the handler chain (next(c) call).
+ Latency time.Duration
+ // Protocol is request protocol (i.e. `HTTP/1.1` or `HTTP/2`)
+ Protocol string
+ // RemoteIP is request remote IP. See `echo.Context.RealIP()` for implementation details.
+ RemoteIP string
+ // Host is request host value (i.e. `example.com`)
+ Host string
+ // Method is request method value (i.e. `GET` etc)
+ Method string
+ // URI is request URI (i.e. `/list?lang=en&page=1`)
+ URI string
+ // URIPath is request URI path part (i.e. `/list`)
+ URIPath string
+ // RoutePath is route path part to which request was matched to (i.e. `/user/:id`)
+ RoutePath string
+ // RequestID is request ID from request `X-Request-ID` header or response if request did not have value.
+ RequestID string
+ // Referer is request referer values.
+ Referer string
+ // UserAgent is request user agent values.
+ UserAgent string
+ // Status is response status code. Then handler returns an echo.HTTPError then code from there.
+ Status int
+ // Error is error returned from executed handler chain.
+ Error error
+ // ContentLength is content length header value. Note: this value could be different from actual request body size
+ // as it could be spoofed etc.
+ ContentLength string
+ // ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct.
+ ResponseSize int64
+ // Headers are list of headers from request. Note: request can contain more than one header with same value so slice
+ // of values is what will be returned/logged for each given header.
+ // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header
+ // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding".
+ Headers map[string][]string
+ // QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter
+ // with same name so slice of values is what will be returned/logged for each given query param name.
+ QueryParams map[string][]string
+ // FormValues are list of form values from request body+URI. Note: request can contain more than one form value with
+ // same name so slice of values is what will be returned/logged for each given form value name.
+ FormValues map[string][]string
+}
+
+// RequestLoggerWithConfig returns a RequestLogger middleware with config.
+func RequestLoggerWithConfig(config RequestLoggerConfig) echo.MiddlewareFunc {
+ mw, err := config.ToMiddleware()
+ if err != nil {
+ panic(err)
+ }
+ return mw
+}
+
+// ToMiddleware converts RequestLoggerConfig into middleware or returns an error for invalid configuration.
+func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Skipper == nil {
+ config.Skipper = DefaultSkipper
+ }
+ now := time.Now
+ if config.timeNow != nil {
+ now = config.timeNow
+ }
+
+ if config.LogValuesFunc == nil {
+ return nil, errors.New("missing LogValuesFunc callback function for request logger middleware")
+ }
+
+ logHeaders := len(config.LogHeaders) > 0
+ headers := append([]string(nil), config.LogHeaders...)
+ for i, v := range headers {
+ headers[i] = http.CanonicalHeaderKey(v)
+ }
+
+ logQueryParams := len(config.LogQueryParams) > 0
+ logFormValues := len(config.LogFormValues) > 0
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ req := c.Request()
+ res := c.Response()
+ start := now()
+
+ if config.BeforeNextFunc != nil {
+ config.BeforeNextFunc(c)
+ }
+ err := next(c)
+ if err != nil && config.HandleError {
+ // When global error handler writes the error to the client the Response gets "committed". This state can be
+ // checked with `c.Response().Committed` field.
+ c.Echo().HTTPErrorHandler(c, err)
+ }
+
+ v := RequestLoggerValues{
+ StartTime: start,
+ }
+ if config.LogLatency {
+ v.Latency = now().Sub(start)
+ }
+ if config.LogProtocol {
+ v.Protocol = req.Proto
+ }
+ if config.LogRemoteIP {
+ v.RemoteIP = c.RealIP()
+ }
+ if config.LogHost {
+ v.Host = req.Host
+ }
+ if config.LogMethod {
+ v.Method = req.Method
+ }
+ if config.LogURI {
+ v.URI = req.RequestURI
+ }
+ if config.LogURIPath {
+ p := req.URL.Path
+ if p == "" {
+ p = "/"
+ }
+ v.URIPath = p
+ }
+ if config.LogRoutePath {
+ v.RoutePath = c.Path()
+ }
+ if config.LogRequestID {
+ id := req.Header.Get(echo.HeaderXRequestID)
+ if id == "" {
+ id = res.Header().Get(echo.HeaderXRequestID)
+ }
+ v.RequestID = id
+ }
+ if config.LogReferer {
+ v.Referer = req.Referer()
+ }
+ if config.LogUserAgent {
+ v.UserAgent = req.UserAgent()
+ }
+
+ var resp *echo.Response
+ if config.LogStatus || config.LogResponseSize {
+ if r, err := echo.UnwrapResponse(res); err != nil {
+ c.Logger().Error("can not determine response status and/or size. ResponseWriter in context does not implement unwrapper interface")
+ } else {
+ resp = r
+ }
+ }
+
+ if config.LogStatus {
+ v.Status = -1
+ if resp != nil {
+ v.Status = resp.Status
+ }
+ if err != nil && !config.HandleError {
+ // this block should not be executed in case of HandleError=true as the global error handler will decide
+ // the status code. In that case status code could be different from what err contains.
+ var hsc echo.HTTPStatusCoder
+ if errors.As(err, &hsc) {
+ v.Status = hsc.StatusCode()
+ }
+ }
+ }
+ if err != nil {
+ v.Error = err
+ }
+ if config.LogContentLength {
+ v.ContentLength = req.Header.Get(echo.HeaderContentLength)
+ }
+ if config.LogResponseSize {
+ v.ResponseSize = -1
+ if resp != nil {
+ v.ResponseSize = resp.Size
+ }
+ }
+ if logHeaders {
+ v.Headers = map[string][]string{}
+ for _, header := range headers {
+ if values, ok := req.Header[header]; ok {
+ v.Headers[header] = values
+ }
+ }
+ }
+ if logQueryParams {
+ queryParams := c.QueryParams()
+ v.QueryParams = map[string][]string{}
+ for _, param := range config.LogQueryParams {
+ if values, ok := queryParams[param]; ok {
+ v.QueryParams[param] = values
+ }
+ }
+ }
+ if logFormValues {
+ v.FormValues = map[string][]string{}
+ for _, formValue := range config.LogFormValues {
+ if values, ok := req.Form[formValue]; ok {
+ v.FormValues[formValue] = values
+ }
+ }
+ }
+
+ if errOnLog := config.LogValuesFunc(c, v); errOnLog != nil {
+ return errOnLog
+ }
+ // in case of HandleError=true we are returning the error that we already have handled with global error handler
+ // this is deliberate as this error could be useful for upstream middlewares and default global error handler
+ // will ignore that error when it bubbles up in middleware chain.
+ // Committed response can be checked in custom error handler with following logic
+ //
+ // if r, _ := echo.UnwrapResponse(c.Response()); r != nil && r.Committed {
+ // return
+ // }
+ return err
+ }
+ }, nil
+}
+
+// RequestLogger creates Request Logger middleware with Echo default settings that uses Context.Logger() as logger.
+func RequestLogger() echo.MiddlewareFunc {
+ return RequestLoggerWithConfig(RequestLoggerConfig{
+ LogLatency: true,
+ LogRemoteIP: true,
+ LogHost: true,
+ LogMethod: true,
+ LogURI: true,
+ LogRequestID: true,
+ LogUserAgent: true,
+ LogStatus: true,
+ LogContentLength: true,
+ LogResponseSize: true,
+ // forwards error to the global error handler, so it can decide appropriate status code.
+ // NB: side-effect of that is - request is now "commited" written to the client. Middlewares up in chain can not
+ // change Response status code or response body.
+ HandleError: true,
+ LogValuesFunc: func(c *echo.Context, v RequestLoggerValues) error {
+ logger := c.Logger()
+ if v.Error == nil {
+ logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST",
+ slog.String("method", v.Method),
+ slog.String("uri", v.URI),
+ slog.Int("status", v.Status),
+ slog.Duration("latency", v.Latency),
+ slog.String("host", v.Host),
+ slog.String("bytes_in", v.ContentLength),
+ slog.Int64("bytes_out", v.ResponseSize),
+ slog.String("user_agent", v.UserAgent),
+ slog.String("remote_ip", v.RemoteIP),
+ slog.String("request_id", v.RequestID),
+ )
+ return nil
+ }
+
+ logger.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR",
+ slog.String("method", v.Method),
+ slog.String("uri", v.URI),
+ slog.Int("status", v.Status),
+ slog.Duration("latency", v.Latency),
+ slog.String("host", v.Host),
+ slog.String("bytes_in", v.ContentLength),
+ slog.Int64("bytes_out", v.ResponseSize),
+ slog.String("user_agent", v.UserAgent),
+ slog.String("remote_ip", v.RemoteIP),
+ slog.String("request_id", v.RequestID),
+
+ slog.String("error", v.Error.Error()),
+ )
+ return nil
+ },
+ })
+}
diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go
new file mode 100644
index 000000000..af39eb32a
--- /dev/null
+++ b/middleware/request_logger_test.go
@@ -0,0 +1,630 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "log/slog"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestRequestLoggerOK(t *testing.T) {
+ old := slog.Default()
+ t.Cleanup(func() {
+ slog.SetDefault(old)
+ })
+
+ e := echo.New()
+ buf := new(bytes.Buffer)
+ e.Logger = slog.New(slog.NewJSONHandler(buf, nil))
+ e.Use(RequestLogger())
+
+ e.POST("/test", func(c *echo.Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ reader := strings.NewReader(`{"foo":"bar"}`)
+ req := httptest.NewRequest(http.MethodPost, "/test", reader)
+ req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size())))
+ req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
+ req.Header.Set(echo.HeaderXRealIP, "8.8.8.8")
+ req.Header.Set("User-Agent", "curl/7.68.0")
+
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ logAttrs := map[string]interface{}{}
+ assert.NoError(t, json.Unmarshal(buf.Bytes(), &logAttrs))
+ logAttrs["latency"] = 123
+ logAttrs["time"] = "x"
+
+ expect := map[string]interface{}{
+ "level": "INFO",
+ "msg": "REQUEST",
+ "method": "POST",
+ "uri": "/test",
+ "status": float64(418),
+ "bytes_in": "13",
+ "host": "example.com",
+ "bytes_out": float64(2),
+ "user_agent": "curl/7.68.0",
+ "remote_ip": "8.8.8.8",
+ "request_id": "",
+
+ "time": "x",
+ "latency": 123,
+ }
+ assert.Equal(t, expect, logAttrs)
+}
+
+func TestRequestLoggerError(t *testing.T) {
+ old := slog.Default()
+ t.Cleanup(func() {
+ slog.SetDefault(old)
+ })
+
+ e := echo.New()
+ buf := new(bytes.Buffer)
+ e.Logger = slog.New(slog.NewJSONHandler(buf, nil))
+ e.Use(RequestLogger())
+
+ e.GET("/test", func(c *echo.Context) error {
+ return errors.New("nope")
+ })
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ logAttrs := map[string]interface{}{}
+ assert.NoError(t, json.Unmarshal(buf.Bytes(), &logAttrs))
+ logAttrs["latency"] = 123
+ logAttrs["time"] = "x"
+
+ expect := map[string]interface{}{
+ "level": "ERROR",
+ "msg": "REQUEST_ERROR",
+ "method": "GET",
+ "uri": "/test",
+ "status": float64(500),
+ "bytes_in": "",
+ "host": "example.com",
+ "bytes_out": float64(36.0),
+ "user_agent": "",
+ "remote_ip": "192.0.2.1",
+ "request_id": "",
+ "error": "nope",
+
+ "latency": 123,
+ "time": "x",
+ }
+ assert.Equal(t, expect, logAttrs)
+}
+
+func TestRequestLoggerWithConfig(t *testing.T) {
+ e := echo.New()
+
+ var expect RequestLoggerValues
+ e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
+ LogRoutePath: true,
+ LogURI: true,
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ expect = values
+ return nil
+ },
+ }))
+
+ e.GET("/test", func(c *echo.Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusTeapot, rec.Code)
+ assert.Equal(t, "/test", expect.RoutePath)
+}
+
+func TestRequestLoggerWithConfig_missingOnLogValuesPanics(t *testing.T) {
+ assert.Panics(t, func() {
+ RequestLoggerWithConfig(RequestLoggerConfig{
+ LogValuesFunc: nil,
+ })
+ })
+}
+
+func TestRequestLogger_skipper(t *testing.T) {
+ e := echo.New()
+
+ loggerCalled := false
+ e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
+ Skipper: func(c *echo.Context) bool {
+ return true
+ },
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ loggerCalled = true
+ return nil
+ },
+ }))
+
+ e.GET("/test", func(c *echo.Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusTeapot, rec.Code)
+ assert.False(t, loggerCalled)
+}
+
+func TestRequestLogger_beforeNextFunc(t *testing.T) {
+ e := echo.New()
+
+ var myLoggerInstance int
+ e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
+ BeforeNextFunc: func(c *echo.Context) {
+ c.Set("myLoggerInstance", 42)
+ },
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ myLoggerInstance = c.Get("myLoggerInstance").(int)
+ return nil
+ },
+ }))
+
+ e.GET("/test", func(c *echo.Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusTeapot, rec.Code)
+ assert.Equal(t, 42, myLoggerInstance)
+}
+
+func TestRequestLogger_logError(t *testing.T) {
+ e := echo.New()
+
+ var actual RequestLoggerValues
+ e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
+ LogStatus: true,
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ actual = values
+ return nil
+ },
+ }))
+
+ e.GET("/test", func(c *echo.Context) error {
+ return echo.NewHTTPError(http.StatusNotAcceptable, "nope")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusNotAcceptable, rec.Code)
+ assert.Equal(t, http.StatusNotAcceptable, actual.Status)
+ assert.EqualError(t, actual.Error, "code=406, message=nope")
+}
+
+func TestRequestLogger_HandleError(t *testing.T) {
+ e := echo.New()
+
+ var actual RequestLoggerValues
+ e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
+ timeNow: func() time.Time {
+ return time.Unix(1631045377, 0).UTC()
+ },
+ HandleError: true,
+ LogStatus: true,
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ actual = values
+ return nil
+ },
+ }))
+
+ // to see if "HandleError" works we create custom error handler that uses its own status codes
+ e.HTTPErrorHandler = func(c *echo.Context, err error) {
+ if r, _ := echo.UnwrapResponse(c.Response()); r != nil && r.Committed {
+ return
+ }
+ c.JSON(http.StatusTeapot, "custom error handler")
+ }
+
+ e.GET("/test", func(c *echo.Context) error {
+ return echo.NewHTTPError(http.StatusForbidden, "nope")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusTeapot, rec.Code)
+
+ expect := RequestLoggerValues{
+ StartTime: time.Unix(1631045377, 0).UTC(),
+ Status: http.StatusTeapot,
+ Error: echo.NewHTTPError(http.StatusForbidden, "nope"),
+ }
+ assert.Equal(t, expect, actual)
+}
+
+func TestRequestLogger_LogValuesFuncError(t *testing.T) {
+ e := echo.New()
+
+ var expect RequestLoggerValues
+ e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
+ LogStatus: true,
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ expect = values
+ return echo.NewHTTPError(http.StatusNotAcceptable, "LogValuesFuncError")
+ },
+ }))
+
+ e.GET("/test", func(c *echo.Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ // NOTE: when global error handler received error returned from middleware the status has already
+ // been written to the client and response has been "committed" therefore global error handler does not do anything
+ // and error that bubbled up in middleware chain will not be reflected in response code.
+ assert.Equal(t, http.StatusTeapot, rec.Code)
+ assert.Equal(t, http.StatusTeapot, expect.Status)
+}
+
+func TestRequestLogger_ID(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenFromRequest bool
+ expect string
+ }{
+ {
+ name: "ok, ID is provided from request headers",
+ whenFromRequest: true,
+ expect: "123",
+ },
+ {
+ name: "ok, ID is from response headers",
+ whenFromRequest: false,
+ expect: "321",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ var expect RequestLoggerValues
+ e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
+ LogRequestID: true,
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ expect = values
+ return nil
+ },
+ }))
+
+ e.GET("/test", func(c *echo.Context) error {
+ c.Response().Header().Set(echo.HeaderXRequestID, "321")
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ if tc.whenFromRequest {
+ req.Header.Set(echo.HeaderXRequestID, "123")
+ }
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusTeapot, rec.Code)
+ assert.Equal(t, tc.expect, expect.RequestID)
+ })
+ }
+}
+
+func TestRequestLogger_headerIsCaseInsensitive(t *testing.T) {
+ e := echo.New()
+
+ var expect RequestLoggerValues
+ mw := RequestLoggerWithConfig(RequestLoggerConfig{
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ expect = values
+ return nil
+ },
+ LogHeaders: []string{"referer", "User-Agent"},
+ })(func(c *echo.Context) error {
+ c.Request().Header.Set(echo.HeaderXRequestID, "123")
+ c.FormValue("to force parse form")
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test?lang=en&checked=1&checked=2", nil)
+ req.Header.Set("referer", "https://echo.labstack.com/")
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := mw(c)
+
+ assert.NoError(t, err)
+ assert.Len(t, expect.Headers, 1)
+ assert.Equal(t, []string{"https://echo.labstack.com/"}, expect.Headers["Referer"])
+}
+
+func TestRequestLogger_allFields(t *testing.T) {
+ e := echo.New()
+
+ isFirstNowCall := true
+ var expect RequestLoggerValues
+ mw := RequestLoggerWithConfig(RequestLoggerConfig{
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ expect = values
+ return nil
+ },
+ LogLatency: true,
+ LogProtocol: true,
+ LogRemoteIP: true,
+ LogHost: true,
+ LogMethod: true,
+ LogURI: true,
+ LogURIPath: true,
+ LogRoutePath: true,
+ LogRequestID: true,
+ LogReferer: true,
+ LogUserAgent: true,
+ LogStatus: true,
+ LogContentLength: true,
+ LogResponseSize: true,
+ LogHeaders: []string{"accept-encoding", "User-Agent"},
+ LogQueryParams: []string{"lang", "checked"},
+ LogFormValues: []string{"csrf", "multiple"},
+ timeNow: func() time.Time {
+ if isFirstNowCall {
+ isFirstNowCall = false
+ return time.Unix(1631045377, 0)
+ }
+ return time.Unix(1631045377+10, 0)
+ },
+ })(func(c *echo.Context) error {
+ c.Request().Header.Set(echo.HeaderXRequestID, "123")
+ c.FormValue("to force parse form")
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ f := make(url.Values)
+ f.Set("csrf", "token")
+ f.Set("multiple", "1")
+ f.Add("multiple", "2")
+ reader := strings.NewReader(f.Encode())
+ req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", reader)
+ req.Header.Set("Referer", "https://echo.labstack.com/")
+ req.Header.Set("User-Agent", "curl/7.68.0")
+ req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size())))
+ req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
+ req.Header.Set(echo.HeaderXRealIP, "8.8.8.8")
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ c.SetPath("/test*")
+
+ err := mw(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, time.Unix(1631045377, 0), expect.StartTime)
+ assert.Equal(t, 10*time.Second, expect.Latency)
+ assert.Equal(t, "HTTP/1.1", expect.Protocol)
+ assert.Equal(t, "8.8.8.8", expect.RemoteIP)
+ assert.Equal(t, "example.com", expect.Host)
+ assert.Equal(t, http.MethodPost, expect.Method)
+ assert.Equal(t, "/test?lang=en&checked=1&checked=2", expect.URI)
+ assert.Equal(t, "/test", expect.URIPath)
+ assert.Equal(t, "/test*", expect.RoutePath)
+ assert.Equal(t, "123", expect.RequestID)
+ assert.Equal(t, "https://echo.labstack.com/", expect.Referer)
+ assert.Equal(t, "curl/7.68.0", expect.UserAgent)
+ assert.Equal(t, 418, expect.Status)
+ assert.Equal(t, nil, expect.Error)
+ assert.Equal(t, "32", expect.ContentLength)
+ assert.Equal(t, int64(2), expect.ResponseSize)
+
+ assert.Len(t, expect.Headers, 1)
+ assert.Equal(t, []string{"curl/7.68.0"}, expect.Headers["User-Agent"])
+
+ assert.Len(t, expect.QueryParams, 2)
+ assert.Equal(t, []string{"en"}, expect.QueryParams["lang"])
+ assert.Equal(t, []string{"1", "2"}, expect.QueryParams["checked"])
+
+ assert.Len(t, expect.FormValues, 2)
+ assert.Equal(t, []string{"token"}, expect.FormValues["csrf"])
+ assert.Equal(t, []string{"1", "2"}, expect.FormValues["multiple"])
+}
+
+func TestTestRequestLogger(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenStatus int
+ whenError error
+ expectStatus string
+ expectError string
+ }{
+ {
+ name: "ok",
+ whenStatus: http.StatusTeapot,
+ expectStatus: "418",
+ },
+ {
+ name: "error",
+ whenError: echo.NewHTTPError(http.StatusBadGateway, "bad gw"),
+ expectStatus: "502",
+ expectError: `"error":"code=502, message=bad gw"`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+ buf := new(bytes.Buffer)
+ e.Logger = slog.New(slog.NewJSONHandler(buf, nil))
+
+ e.Use(RequestLogger())
+ e.POST("/test", func(c *echo.Context) error {
+ if tc.whenError != nil {
+ return tc.whenError
+ }
+ return c.String(tc.whenStatus, "OK")
+ })
+
+ f := make(url.Values)
+ f.Set("csrf", "token")
+ f.Set("multiple", "1")
+ f.Add("multiple", "2")
+ reader := strings.NewReader(f.Encode())
+ req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", reader)
+ req.Header.Set("Referer", "https://echo.labstack.com/")
+ req.Header.Set("User-Agent", "curl/7.68.0")
+ req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size())))
+ req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
+ req.Header.Set(echo.HeaderXRealIP, "8.8.8.8")
+ req.Header.Set(echo.HeaderXRequestID, "MY_ID")
+
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ rawlog := buf.Bytes()
+ if tc.expectError != "" {
+ assert.Contains(t, string(rawlog), `"level":"ERROR"`)
+ assert.Contains(t, string(rawlog), `"msg":"REQUEST_ERROR"`)
+ assert.Contains(t, string(rawlog), tc.expectError)
+ } else {
+ assert.Contains(t, string(rawlog), `"level":"INFO"`)
+ assert.Contains(t, string(rawlog), `"msg":"REQUEST"`)
+ }
+ assert.Contains(t, string(rawlog), `"status":`+tc.expectStatus)
+ assert.Contains(t, string(rawlog), `"method":"POST"`)
+ assert.Contains(t, string(rawlog), `"uri":"/test?lang=en&checked=1&checked=2"`)
+ assert.Contains(t, string(rawlog), `"latency":`) // this value varies
+ assert.Contains(t, string(rawlog), `"request_id":"MY_ID"`)
+ assert.Contains(t, string(rawlog), `"remote_ip":"8.8.8.8"`)
+ assert.Contains(t, string(rawlog), `"host":"example.com"`)
+ assert.Contains(t, string(rawlog), `"user_agent":"curl/7.68.0"`)
+ assert.Contains(t, string(rawlog), `"bytes_in":"32"`)
+ assert.Contains(t, string(rawlog), `"bytes_out":2`)
+ })
+ }
+}
+
+func BenchmarkRequestLogger_withoutMapFields(b *testing.B) {
+ e := echo.New()
+
+ mw := RequestLoggerWithConfig(RequestLoggerConfig{
+ Skipper: nil,
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ return nil
+ },
+ LogLatency: true,
+ LogProtocol: true,
+ LogRemoteIP: true,
+ LogHost: true,
+ LogMethod: true,
+ LogURI: true,
+ LogURIPath: true,
+ LogRoutePath: true,
+ LogRequestID: true,
+ LogReferer: true,
+ LogUserAgent: true,
+ LogStatus: true,
+ LogContentLength: true,
+ LogResponseSize: true,
+ })(func(c *echo.Context) error {
+ c.Request().Header.Set(echo.HeaderXRequestID, "123")
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test?lang=en", nil)
+ req.Header.Set("Referer", "https://echo.labstack.com/")
+ req.Header.Set("User-Agent", "curl/7.68.0")
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ mw(c)
+ }
+}
+
+func BenchmarkRequestLogger_withMapFields(b *testing.B) {
+ e := echo.New()
+
+ mw := RequestLoggerWithConfig(RequestLoggerConfig{
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ return nil
+ },
+ LogLatency: true,
+ LogProtocol: true,
+ LogRemoteIP: true,
+ LogHost: true,
+ LogMethod: true,
+ LogURI: true,
+ LogURIPath: true,
+ LogRoutePath: true,
+ LogRequestID: true,
+ LogReferer: true,
+ LogUserAgent: true,
+ LogStatus: true,
+ LogContentLength: true,
+ LogResponseSize: true,
+ LogHeaders: []string{"accept-encoding", "User-Agent"},
+ LogQueryParams: []string{"lang", "checked"},
+ LogFormValues: []string{"csrf", "multiple"},
+ })(func(c *echo.Context) error {
+ c.Request().Header.Set(echo.HeaderXRequestID, "123")
+ c.FormValue("to force parse form")
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ f := make(url.Values)
+ f.Set("csrf", "token")
+ f.Add("multiple", "1")
+ f.Add("multiple", "2")
+ req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", strings.NewReader(f.Encode()))
+ req.Header.Set("Referer", "https://echo.labstack.com/")
+ req.Header.Set("User-Agent", "curl/7.68.0")
+ req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ mw(c)
+ }
+}
diff --git a/middleware/rewrite.go b/middleware/rewrite.go
index a64e10bb3..ea58091b0 100644
--- a/middleware/rewrite.go
+++ b/middleware/rewrite.go
@@ -1,84 +1,80 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
+ "errors"
"regexp"
- "strings"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
)
-type (
- // RewriteConfig defines the config for Rewrite middleware.
- RewriteConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
-
- // Rules defines the URL path rewrite rules. The values captured in asterisk can be
- // retrieved by index e.g. $1, $2 and so on.
- // Example:
- // "/old": "/new",
- // "/api/*": "/$1",
- // "/js/*": "/public/javascripts/$1",
- // "/users/*/orders/*": "/user/$1/order/$2",
- // Required.
- Rules map[string]string `yaml:"rules"`
+// RewriteConfig defines the config for Rewrite middleware.
+type RewriteConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
- rulesRegex map[*regexp.Regexp]string
- }
-)
+ // Rules defines the URL path rewrite rules. The values captured in asterisk can be
+ // retrieved by index e.g. $1, $2 and so on.
+ // Example:
+ // "/old": "/new",
+ // "/api/*": "/$1",
+ // "/js/*": "/public/javascripts/$1",
+ // "/users/*/orders/*": "/user/$1/order/$2",
+ // Required.
+ Rules map[string]string
-var (
- // DefaultRewriteConfig is the default Rewrite middleware config.
- DefaultRewriteConfig = RewriteConfig{
- Skipper: DefaultSkipper,
- }
-)
+ // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures
+ // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
+ // Example:
+ // "^/old/[0.9]+/": "/new",
+ // "^/api/.+?/(.*)": "/v2/$1",
+ RegexRules map[*regexp.Regexp]string
+}
// Rewrite returns a Rewrite middleware.
//
// Rewrite middleware rewrites the URL path based on the provided rules.
func Rewrite(rules map[string]string) echo.MiddlewareFunc {
- c := DefaultRewriteConfig
+ c := RewriteConfig{}
c.Rules = rules
return RewriteWithConfig(c)
}
-// RewriteWithConfig returns a Rewrite middleware with config.
-// See: `Rewrite()`.
+// RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration.
+//
+// Rewrite middleware rewrites the URL path based on the provided rules.
func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
- // Defaults
- if config.Rules == nil {
- panic("echo: rewrite middleware requires url path rewrite rules")
- }
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration
+func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Skipper == nil {
- config.Skipper = DefaultBodyDumpConfig.Skipper
+ config.Skipper = DefaultSkipper
+ }
+ if config.Rules == nil && config.RegexRules == nil {
+ return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules")
}
- config.rulesRegex = map[*regexp.Regexp]string{}
- // Initialize
- for k, v := range config.Rules {
- k = strings.Replace(k, "*", "(.*)", -1)
- k = k + "$"
- config.rulesRegex[regexp.MustCompile(k)] = v
+ if config.RegexRules == nil {
+ config.RegexRules = make(map[*regexp.Regexp]string)
+ }
+ for k, v := range rewriteRulesRegex(config.Rules) {
+ config.RegexRules[k] = v
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
- return func(c echo.Context) (err error) {
+ return func(c *echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
- req := c.Request()
-
- // Rewrite
- for k, v := range config.rulesRegex {
- replacer := captureTokens(k, req.URL.Path)
- if replacer != nil {
- req.URL.Path = replacer.Replace(v)
- break
- }
+ if err := rewriteURL(config.RegexRules, c.Request()); err != nil {
+ return err
}
return next(c)
}
- }
+ }, nil
}
diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go
index eb5a46d89..f45b8d98a 100644
--- a/middleware/rewrite_test.go
+++ b/middleware/rewrite_test.go
@@ -1,17 +1,23 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
- "io/ioutil"
+ "io"
"net/http"
"net/http/httptest"
+ "net/url"
+ "regexp"
"testing"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
-func TestRewrite(t *testing.T) {
+func TestRewriteAfterRouting(t *testing.T) {
e := echo.New()
+ // middlewares added with `Use()` are executed after routing is done and do not affect which route handler is matched
e.Use(RewriteWithConfig(RewriteConfig{
Rules: map[string]string{
"/old": "/new",
@@ -20,54 +26,156 @@ func TestRewrite(t *testing.T) {
"/users/*/orders/*": "/user/$1/order/$2",
},
}))
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- rec := httptest.NewRecorder()
- req.URL.Path = "/api/users"
- e.ServeHTTP(rec, req)
- assert.Equal(t, "/users", req.URL.Path)
- req.URL.Path = "/js/main.js"
- e.ServeHTTP(rec, req)
- assert.Equal(t, "/public/javascripts/main.js", req.URL.Path)
- req.URL.Path = "/old"
- e.ServeHTTP(rec, req)
- assert.Equal(t, "/new", req.URL.Path)
- req.URL.Path = "/users/jack/orders/1"
- e.ServeHTTP(rec, req)
- assert.Equal(t, "/user/jack/order/1", req.URL.Path)
- req.URL.Path = "/api/new users"
- e.ServeHTTP(rec, req)
- assert.Equal(t, "/new users", req.URL.Path)
+ e.GET("/public/*", func(c *echo.Context) error {
+ return c.String(http.StatusOK, c.Param("*"))
+ })
+ e.GET("/*", func(c *echo.Context) error {
+ return c.String(http.StatusOK, c.Param("*"))
+ })
+
+ var testCases = []struct {
+ whenPath string
+ expectRoutePath string
+ expectRequestPath string
+ expectRequestRawPath string
+ }{
+ {
+ whenPath: "/api/users",
+ expectRoutePath: "api/users",
+ expectRequestPath: "/users",
+ expectRequestRawPath: "",
+ },
+ {
+ whenPath: "/js/main.js",
+ expectRoutePath: "js/main.js",
+ expectRequestPath: "/public/javascripts/main.js",
+ expectRequestRawPath: "",
+ },
+ {
+ whenPath: "/users/jack/orders/1",
+ expectRoutePath: "users/jack/orders/1",
+ expectRequestPath: "/user/jack/order/1",
+ expectRequestRawPath: "",
+ },
+ { // no rewrite rule matched. already encoded URL should not be double encoded or changed in any way
+ whenPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ expectRoutePath: "user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ expectRequestPath: "/user/jill/order/T/cO4lW/t/Vp/", // this is equal to `url.Parse(tc.whenPath)` result
+ expectRequestRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ },
+ { // just rewrite but do not touch encoding. already encoded URL should not be double encoded
+ whenPath: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F",
+ expectRoutePath: "users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F",
+ expectRequestPath: "/user/jill/order/T/cO4lW/t/Vp/", // this is equal to `url.Parse(tc.whenPath)` result
+ expectRequestRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ },
+ { // ` ` (space) is encoded by httpClient to `%20` when doing request to Echo. `%20` should not be double escaped or changed in any way when rewriting request
+ whenPath: "/api/new users",
+ expectRoutePath: "api/new users",
+ expectRequestPath: "/new users",
+ expectRequestRawPath: "",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.whenPath, func(t *testing.T) {
+ target, _ := url.Parse(tc.whenPath)
+ req := httptest.NewRequest(http.MethodGet, target.String(), nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, tc.expectRoutePath, rec.Body.String())
+ assert.Equal(t, tc.expectRequestPath, req.URL.Path)
+ assert.Equal(t, tc.expectRequestRawPath, req.URL.RawPath)
+ })
+ }
+}
+
+func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) {
+ assert.Panics(t, func() {
+ RewriteWithConfig(RewriteConfig{})
+ })
+}
+
+func TestMustRewriteWithConfig_skipper(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenSkipper func(c *echo.Context) bool
+ whenURL string
+ expectURL string
+ expectStatus int
+ }{
+ {
+ name: "not skipped",
+ whenURL: "/old",
+ expectURL: "/new",
+ expectStatus: http.StatusOK,
+ },
+ {
+ name: "skipped",
+ givenSkipper: func(c *echo.Context) bool {
+ return true
+ },
+ whenURL: "/old",
+ expectURL: "/old",
+ expectStatus: http.StatusNotFound,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ e.Pre(RewriteWithConfig(
+ RewriteConfig{
+ Skipper: tc.givenSkipper,
+ Rules: map[string]string{"/old": "/new"}},
+ ))
+
+ e.GET("/new", func(c *echo.Context) error {
+ return c.NoContent(http.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectURL, req.URL.EscapedPath())
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ })
+ }
}
// Issue #1086
func TestEchoRewritePreMiddleware(t *testing.T) {
e := echo.New()
- r := e.Router()
// Rewrite old url to new one
+ // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
e.Pre(RewriteWithConfig(RewriteConfig{
- Rules: map[string]string{
- "/old": "/new",
- },
- }))
+ Rules: map[string]string{"/old": "/new"}}),
+ )
// Route
- r.Add(http.MethodGet, "/new", func(c echo.Context) error {
- return c.NoContent(200)
+ e.Add(http.MethodGet, "/new", func(c *echo.Context) error {
+ return c.NoContent(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/old", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
- assert.Equal(t, "/new", req.URL.Path)
- assert.Equal(t, 200, rec.Code)
+ assert.Equal(t, "/new", req.URL.EscapedPath())
+ assert.Equal(t, http.StatusOK, rec.Code)
}
// Issue #1143
func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
e := echo.New()
- r := e.Router()
+ // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
e.Pre(RewriteWithConfig(RewriteConfig{
Rules: map[string]string{
"/api/*/mgmt/proj/*/agt": "/api/$1/hosts/$2",
@@ -75,22 +183,135 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
},
}))
- r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error {
- return c.String(200, "hosts")
+ e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c *echo.Context) error {
+ return c.String(http.StatusOK, "hosts")
})
- r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error {
- return c.String(200, "eng")
+ e.Add(http.MethodGet, "/api/:version/eng", func(c *echo.Context) error {
+ return c.String(http.StatusOK, "eng")
})
for i := 0; i < 100; i++ {
req := httptest.NewRequest(http.MethodGet, "/api/v1/mgmt/proj/test/agt", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
- assert.Equal(t, "/api/v1/hosts/test", req.URL.Path)
- assert.Equal(t, 200, rec.Code)
+ assert.Equal(t, "/api/v1/hosts/test", req.URL.EscapedPath())
+ assert.Equal(t, http.StatusOK, rec.Code)
defer rec.Result().Body.Close()
- bodyBytes, _ := ioutil.ReadAll(rec.Result().Body)
+ bodyBytes, _ := io.ReadAll(rec.Result().Body)
assert.Equal(t, "hosts", string(bodyBytes))
}
}
+
+// Issue #1573
+func TestEchoRewriteWithCaret(t *testing.T) {
+ e := echo.New()
+
+ e.Pre(RewriteWithConfig(RewriteConfig{
+ Rules: map[string]string{
+ "^/abc/*": "/v1/abc/$1",
+ },
+ }))
+
+ rec := httptest.NewRecorder()
+
+ var req *http.Request
+
+ req = httptest.NewRequest(http.MethodGet, "/abc/test", nil)
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, "/v1/abc/test", req.URL.Path)
+
+ req = httptest.NewRequest(http.MethodGet, "/v1/abc/test", nil)
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, "/v1/abc/test", req.URL.Path)
+
+ req = httptest.NewRequest(http.MethodGet, "/v2/abc/test", nil)
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, "/v2/abc/test", req.URL.Path)
+}
+
+// Verify regex used with rewrite
+func TestEchoRewriteWithRegexRules(t *testing.T) {
+ e := echo.New()
+
+ e.Pre(RewriteWithConfig(RewriteConfig{
+ Rules: map[string]string{
+ "^/a/*": "/v1/$1",
+ "^/b/*/c/*": "/v2/$2/$1",
+ "^/c/*/*": "/v3/$2",
+ },
+ RegexRules: map[*regexp.Regexp]string{
+ regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1",
+ regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1",
+ },
+ }))
+
+ var rec *httptest.ResponseRecorder
+ var req *http.Request
+
+ testCases := []struct {
+ requestPath string
+ expectPath string
+ }{
+ {"/unmatched", "/unmatched"},
+ {"/a/test", "/v1/test"},
+ {"/b/foo/c/bar/baz", "/v2/bar/baz/foo"},
+ {"/c/ignore/test", "/v3/test"},
+ {"/c/ignore1/test/this", "/v3/test/this"},
+ {"/x/ignore/test", "/v4/test"},
+ {"/y/foo/bar", "/v5/bar/foo"},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.requestPath, func(t *testing.T) {
+ req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil)
+ rec = httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, tc.expectPath, req.URL.EscapedPath())
+ })
+ }
+}
+
+// Ensure correct escaping as defined in replacement (issue #1798)
+func TestEchoRewriteReplacementEscaping(t *testing.T) {
+ e := echo.New()
+
+ // NOTE: these are incorrect regexps as they do not factor in that URI we are replacing could contain ? (query) and # (fragment) parts
+ // so in reality they append query and fragment part as `$1` matches everything after that prefix
+ e.Pre(RewriteWithConfig(RewriteConfig{
+ Rules: map[string]string{
+ "^/a/*": "/$1?query=param",
+ "^/b/*": "/$1;part#one",
+ },
+ RegexRules: map[*regexp.Regexp]string{
+ regexp.MustCompile("^/x/(.*)"): "/$1?query=param",
+ regexp.MustCompile("^/y/(.*)"): "/$1;part#one",
+ regexp.MustCompile("^/z/(.*)"): "/$1?test=1#escaped%20test",
+ },
+ }))
+
+ var rec *httptest.ResponseRecorder
+ var req *http.Request
+
+ testCases := []struct {
+ requestPath string
+ expect string
+ }{
+ {"/unmatched", "/unmatched"},
+ {"/a/test", "/test?query=param"},
+ {"/b/foo/bar", "/foo/bar;part#one"},
+ {"/x/test", "/test?query=param"},
+ {"/y/foo/bar", "/foo/bar;part#one"},
+ {"/z/foo/b%20ar", "/foo/b%20ar?test=1#escaped%20test"},
+ {"/z/foo/b%20ar?nope=1#yes", "/foo/b%20ar?nope=1#yes?test=1%23escaped%20test"}, // example of appending
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.requestPath, func(t *testing.T) {
+ req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil)
+ rec = httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, tc.expect, req.URL.String())
+ })
+ }
+}
diff --git a/middleware/secure.go b/middleware/secure.go
index 6c4051723..bd389f7ae 100644
--- a/middleware/secure.go
+++ b/middleware/secure.go
@@ -1,89 +1,88 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
"fmt"
- "github.com/labstack/echo/v4"
+ "github.com/labstack/echo/v5"
)
-type (
- // SecureConfig defines the config for Secure middleware.
- SecureConfig struct {
- // Skipper defines a function to skip middleware.
- Skipper Skipper
-
- // XSSProtection provides protection against cross-site scripting attack (XSS)
- // by setting the `X-XSS-Protection` header.
- // Optional. Default value "1; mode=block".
- XSSProtection string `yaml:"xss_protection"`
-
- // ContentTypeNosniff provides protection against overriding Content-Type
- // header by setting the `X-Content-Type-Options` header.
- // Optional. Default value "nosniff".
- ContentTypeNosniff string `yaml:"content_type_nosniff"`
-
- // XFrameOptions can be used to indicate whether or not a browser should
- // be allowed to render a page in a ,