diff --git a/.editorconfig b/.editorconfig
index d7f05924e..17ae50dd0 100644
--- a/.editorconfig
+++ b/.editorconfig
@@ -13,7 +13,7 @@ insert_final_newline = true
trim_trailing_whitespace = true
indent_style = space
-indent_size = 4
+indent_size = 2
[Makefile]
indent_style = tab
diff --git a/.gitattributes b/.gitattributes
index a9609ad01..28981b84a 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -13,9 +13,3 @@
*.js text eol=lf
*.json text eol=lf
LICENSE text eol=lf
-
-# Exclude `website` and `examples` from Github's language statistics
-# https://github.com/github/linguist#using-gitattributes
-examples/* linguist-documentation
-recipes/* linguist-documentation
-website/* linguist-documentation
diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml
new file mode 100644
index 000000000..af410716d
--- /dev/null
+++ b/.github/FUNDING.yml
@@ -0,0 +1,12 @@
+# These are supported funding model platforms
+
+github: [labstack]
+patreon: # Replace with a single Patreon username
+open_collective: # Replace with a single Open Collective username
+ko_fi: # Replace with a single Ko-fi username
+tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
+community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
+liberapay: # Replace with a single Liberapay username
+issuehunt: # Replace with a single IssueHunt username
+otechie: # Replace with a single Otechie username
+custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md
new file mode 100644
index 000000000..1a76adca7
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE.md
@@ -0,0 +1,33 @@
+### Issue Description
+
+### Working code to debug
+
+```go
+package main
+
+import (
+ "github.com/labstack/echo/v5"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestExample(t *testing.T) {
+ e := echo.New()
+
+ e.GET("/", func(c *echo.Context) error {
+ return c.String(http.StatusOK, "Hello, World!")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Errorf("got %d, want %d", rec.Code, http.StatusOK)
+ }
+}
+```
+
+### Version/commit
diff --git a/.github/stale.yml b/.github/stale.yml
new file mode 100644
index 000000000..04dd169cd
--- /dev/null
+++ b/.github/stale.yml
@@ -0,0 +1,19 @@
+# Number of days of inactivity before an issue becomes stale
+daysUntilStale: 60
+# Number of days of inactivity before a stale issue is closed
+daysUntilClose: 30
+# Issues with these labels will never be considered stale
+exemptLabels:
+ - pinned
+ - security
+ - bug
+ - enhancement
+# Label to use when marking an issue as stale
+staleLabel: stale
+# Comment to post when marking an issue as stale. Set to `false` to disable
+markComment: >
+ This issue has been automatically marked as stale because it has not had
+ recent activity. It will be closed within a month if no further activity occurs.
+ Thank you for your contributions.
+# Comment to post when closing a stale issue. Set to `false` to disable
+closeComment: false
diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml
new file mode 100644
index 000000000..8f4eff96e
--- /dev/null
+++ b/.github/workflows/checks.yml
@@ -0,0 +1,47 @@
+name: Run checks
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ branches:
+ - master
+ workflow_dispatch:
+
+permissions:
+ contents: read # to fetch code (actions/checkout)
+
+env:
+ # run static analysis only with the latest Go version
+ LATEST_GO_VERSION: "1.26"
+
+jobs:
+ check:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout Code
+ uses: actions/checkout@v5
+
+ - name: Set up Go ${{ matrix.go }}
+ uses: actions/setup-go@v5
+ with:
+ go-version: ${{ env.LATEST_GO_VERSION }}
+ check-latest: true
+
+ - name: Run golint
+ run: |
+ go install golang.org/x/lint/golint@latest
+ golint -set_exit_status ./...
+
+ - name: Run staticcheck
+ run: |
+ go install honnef.co/go/tools/cmd/staticcheck@latest
+ staticcheck ./...
+
+ - name: Run govulncheck
+ run: |
+ go version
+ go install golang.org/x/vuln/cmd/govulncheck@latest
+ govulncheck ./...
+
diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml
new file mode 100644
index 000000000..b92c70c1b
--- /dev/null
+++ b/.github/workflows/echo.yml
@@ -0,0 +1,86 @@
+name: Run Tests
+
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ branches:
+ - master
+ workflow_dispatch:
+
+permissions:
+ contents: read # to fetch code (actions/checkout)
+
+env:
+ # run coverage and benchmarks only with the latest Go version
+ LATEST_GO_VERSION: "1.26"
+
+jobs:
+ test:
+ strategy:
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-latest]
+ # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy
+ # Echo tests with last four major releases (unless there are pressing vulnerabilities)
+ # As we depend on `golang.org/x/` libraries which only support the last 2 Go releases, we could have situations when
+ # we derive from the last four major releases promise.
+ go: ["1.25", "1.26"]
+ name: ${{ matrix.os }} @ Go ${{ matrix.go }}
+ runs-on: ${{ matrix.os }}
+ steps:
+ - name: Checkout Code
+ uses: actions/checkout@v5
+
+ - name: Set up Go ${{ matrix.go }}
+ uses: actions/setup-go@v5
+ with:
+ go-version: ${{ matrix.go }}
+
+ - name: Run Tests
+ run: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./...
+
+ - name: Upload coverage to Codecov
+ if: success() && matrix.go == env.LATEST_GO_VERSION && matrix.os == 'ubuntu-latest'
+ uses: codecov/codecov-action@v5
+ with:
+ token:
+ fail_ci_if_error: false
+
+ benchmark:
+ needs: test
+ name: Benchmark comparison
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout Code (Previous)
+ uses: actions/checkout@v5
+ with:
+ ref: ${{ github.base_ref }}
+ path: previous
+
+ - name: Checkout Code (New)
+ uses: actions/checkout@v5
+ with:
+ path: new
+
+ - name: Set up Go ${{ matrix.go }}
+ uses: actions/setup-go@v5
+ with:
+ go-version: ${{ env.LATEST_GO_VERSION }}
+
+ - name: Install Dependencies
+ run: go install golang.org/x/perf/cmd/benchstat@latest
+
+ - name: Run Benchmark (Previous)
+ run: |
+ cd previous
+ go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
+
+ - name: Run Benchmark (New)
+ run: |
+ cd new
+ go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt
+
+ - name: Run Benchstat
+ run: |
+ benchstat previous/benchmark.txt new/benchmark.txt
diff --git a/.gitignore b/.gitignore
index 01f796c2e..dbadf3bd0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,10 +1,8 @@
-# Website
-site/
-.publish/
-
-# Node.js
-node_modules/
-
-# IntelliJ
-.idea/
+.DS_Store
+coverage.txt
+_test
+vendor
+.idea
*.iml
+*.out
+.vscode
diff --git a/.travis.yml b/.travis.yml
deleted file mode 100644
index 27f892a5a..000000000
--- a/.travis.yml
+++ /dev/null
@@ -1,12 +0,0 @@
-language: go
-go:
- - tip
-before_install:
- - go get github.com/modocache/gover
- - go get github.com/mattn/goveralls
- - go get golang.org/x/tools/cmd/cover
-script:
- - go test -coverprofile=echo.coverprofile
- - go test -coverprofile=middleware.coverprofile ./middleware
- - $HOME/gopath/bin/gover
- - $HOME/gopath/bin/goveralls -coverprofile=gover.coverprofile -service=travis-ci
diff --git a/API_CHANGES_V5.md b/API_CHANGES_V5.md
new file mode 100644
index 000000000..d3ca81560
--- /dev/null
+++ b/API_CHANGES_V5.md
@@ -0,0 +1,1178 @@
+# Echo v5 Public API Changes
+
+**Comparison between `master` (v4.15.0) and `v5` (v5.0.0-alpha) branches**
+
+Generated: 2026-01-01
+
+---
+
+## Executive Summary (by authors)
+
+Echo `v5` is maintenance release with **major breaking changes**
+- `Context` is now struct instead of interface and we can add method to it in the future in minor versions.
+- Adds new `Router` interface for possible new routing implementations.
+- Drops old logging interface and uses moderm `log/slog` instead.
+- Rearranges alot of methods/function signatures to make them more consistent.
+
+## Executive Summary (by LLMs)
+
+Echo v5 represents a **major breaking release** with significant architectural changes focused on:
+- **Updated generic helpers** to take `*Context` and rename form helpers to `FormValue*`
+- **Simplified API surface** by moving Context from interface to concrete struct
+- **Modern Go patterns** including slog.Logger integration
+- **Enhanced routing** with explicit RouteInfo and Routes types
+- **Better error handling** with simplified HTTPError
+- **New test helpers** via the `echotest` package
+
+### Change Statistics
+
+- **Major Breaking Changes**: 15+
+- **New Functions Added**: 30+
+- **Type Signature Changes**: 20+
+- **Removed APIs**: 10+
+- **New Packages Added**: 1 (`echotest`)
+- **Version Change**: `4.15.0` → `5.0.0-alpha`
+
+---
+
+## Critical Breaking Changes
+
+### 1. **Context: Interface → Concrete Struct**
+
+**v4 (master):**
+```go
+type Context interface {
+ Request() *http.Request
+ // ... many methods
+}
+
+// Handler signature
+func handler(c echo.Context) error
+```
+
+**v5:**
+```go
+type Context struct {
+ // Has unexported fields
+}
+
+// Handler signature - NOW USES POINTER!
+func handler(c *echo.Context) error
+```
+
+**Impact:** 🔴 **CRITICAL BREAKING CHANGE**
+- ALL handlers must change from `echo.Context` to `*echo.Context`
+- Context is now a concrete struct, not an interface
+- This affects every single handler function in user code
+
+**Migration:**
+```go
+// Before (v4)
+func MyHandler(c echo.Context) error {
+ return c.JSON(200, map[string]string{"hello": "world"})
+}
+
+// After (v5)
+func MyHandler(c *echo.Context) error {
+ return c.JSON(200, map[string]string{"hello": "world"})
+}
+```
+
+---
+
+### 2. **Logger: Custom Interface → slog.Logger**
+
+**v4:**
+```go
+type Echo struct {
+ Logger Logger // Custom interface with Print, Debug, Info, etc.
+}
+
+type Logger interface {
+ Output() io.Writer
+ SetOutput(w io.Writer)
+ Prefix() string
+ // ... many custom methods
+}
+
+// Context returns Logger interface
+func (c Context) Logger() Logger
+```
+
+**v5:**
+```go
+type Echo struct {
+ Logger *slog.Logger // Standard library structured logger
+}
+
+// Context returns slog.Logger
+func (c *Context) Logger() *slog.Logger
+func (c *Context) SetLogger(logger *slog.Logger)
+```
+
+**Impact:** 🔴 **BREAKING CHANGE**
+- Must use Go's standard `log/slog` package
+- Logger interface completely removed
+- All logging code needs updating
+
+---
+
+### 3. **Router: From Router to DefaultRouter**
+
+**v4:**
+```go
+type Router struct { ... }
+
+func NewRouter(e *Echo) *Router
+func (e *Echo) Router() *Router
+```
+
+**v5:**
+```go
+type DefaultRouter struct { ... }
+
+func NewRouter(config RouterConfig) *DefaultRouter
+func (e *Echo) Router() Router // Returns interface
+```
+
+**Changes:**
+- New `Router` interface introduced
+- `DefaultRouter` is the concrete implementation
+- `NewRouter()` now takes `RouterConfig` instead of `*Echo`
+- Added `NewConcurrentRouter(r Router) Router` for thread-safe routing
+
+---
+
+### 4. **Route Return Types Changed**
+
+**v4:**
+```go
+func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route
+func (e *Echo) Any(path string, h HandlerFunc, m ...MiddlewareFunc) []*Route
+func (e *Echo) Routes() []*Route
+```
+
+**v5:**
+```go
+func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo
+func (e *Echo) Any(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo
+func (e *Echo) Match(...) Routes // Returns Routes type
+func (e *Echo) Router() Router // Returns interface
+```
+
+**New Types:**
+```go
+type RouteInfo struct {
+ Name string
+ Method string
+ Path string
+ Parameters []string
+}
+
+type Routes []RouteInfo // Collection with helper methods
+```
+
+**Impact:** 🔴 **BREAKING CHANGE**
+- Route registration methods return `RouteInfo` instead of `*Route`
+- New `Routes` collection type with filtering methods
+- `Route` struct still exists but used differently
+
+---
+
+### 5. **Response Type Changed**
+
+**v4:**
+```go
+func (c Context) Response() *Response
+type Response struct {
+ Writer http.ResponseWriter
+ Status int
+ Size int64
+ Committed bool
+}
+func NewResponse(w http.ResponseWriter, e *Echo) *Response
+```
+
+**v5:**
+```go
+func (c *Context) Response() http.ResponseWriter
+type Response struct {
+ http.ResponseWriter // Embedded
+ Status int
+ Size int64
+ Committed bool
+}
+func NewResponse(w http.ResponseWriter, logger *slog.Logger) *Response
+func UnwrapResponse(rw http.ResponseWriter) (*Response, error)
+```
+
+**Changes:**
+- Context.Response() returns `http.ResponseWriter` instead of `*Response`
+- Response now embeds `http.ResponseWriter`
+- NewResponse takes `*slog.Logger` instead of `*Echo`
+- New `UnwrapResponse()` helper function
+
+---
+
+### 6. **HTTPError Simplified**
+
+**v4:**
+```go
+type HTTPError struct {
+ Internal error
+ Message interface{} // Can be any type
+ Code int
+}
+
+func NewHTTPError(code int, message ...interface{}) *HTTPError
+```
+
+**v5:**
+```go
+type HTTPError struct {
+ Code int
+ Message string // Now string only
+ // Has unexported fields (Internal moved)
+}
+
+func NewHTTPError(code int, message string) *HTTPError
+func (he HTTPError) Wrap(err error) error // New method
+func (he *HTTPError) StatusCode() int // Implements HTTPStatusCoder
+```
+
+**Changes:**
+- `Message` field changed from `interface{}` to `string`
+- `NewHTTPError()` now takes `string` instead of `...interface{}`
+- Added `HTTPStatusCoder` interface and `StatusCode()` method
+- Added `Wrap(err error)` method for error wrapping
+
+---
+
+### 7. **HTTPErrorHandler Signature Changed**
+
+**v4:**
+```go
+type HTTPErrorHandler func(err error, c Context)
+
+func (e *Echo) DefaultHTTPErrorHandler(err error, c Context)
+```
+
+**v5:**
+```go
+type HTTPErrorHandler func(c *Context, err error) // Parameters swapped!
+
+func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler // Now a factory
+```
+
+**Impact:** 🔴 **BREAKING CHANGE**
+- Parameter order reversed: `(c *Context, err error)` instead of `(err error, c Context)`
+- DefaultHTTPErrorHandler is now a factory function that returns HTTPErrorHandler
+- Takes `exposeError` bool to control error message exposure
+
+---
+
+## Notable API Changes in v5
+
+### 1. **Generic Parameter Extraction Functions (Updated Signatures)**
+
+These helpers keep the same generic API but now accept `*Context`, and the
+form helpers are renamed from `FormParam*` to `FormValue*`:
+
+```go
+// Query Parameters
+func QueryParam[T any](c *Context, key string, opts ...any) (T, error)
+func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error)
+func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error)
+func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error)
+
+// Path Parameters
+func PathParam[T any](c *Context, paramName string, opts ...any) (T, error)
+func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...any) (T, error)
+
+// Form Values
+func FormValue[T any](c *Context, key string, opts ...any) (T, error)
+func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error)
+func FormValues[T any](c *Context, key string, opts ...any) ([]T, error)
+func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error)
+
+// Generic Parsing
+func ParseValue[T any](value string, opts ...any) (T, error)
+func ParseValueOr[T any](value string, defaultValue T, opts ...any) (T, error)
+func ParseValues[T any](values []string, opts ...any) ([]T, error)
+func ParseValuesOr[T any](values []string, defaultValue []T, opts ...any) ([]T, error)
+```
+
+`FormParam*` was renamed to `FormValue*`; the rest keep names but now take `*Context`.
+
+**Supported Types:**
+- bool, string
+- int, int8, int16, int32, int64
+- uint, uint8, uint16, uint32, uint64
+- float32, float64
+- time.Time, time.Duration
+- BindUnmarshaler, encoding.TextUnmarshaler, json.Unmarshaler
+
+**Example Usage:**
+```go
+// v5 - Type-safe parameter binding
+id, err := echo.PathParam[int](c, "id")
+page, err := echo.QueryParamOr[int](c, "page", 1)
+tags, err := echo.QueryParams[string](c, "tags")
+```
+
+---
+
+### 2. **Context Store Helpers Now Use `*Context`**
+
+```go
+// Type-safe context value retrieval
+func ContextGet[T any](c *Context, key string) (T, error)
+func ContextGetOr[T any](c *Context, key string, defaultValue T) (T, error)
+
+// Error types
+var ErrNonExistentKey = errors.New("non existent key")
+var ErrInvalidKeyType = errors.New("invalid key type")
+```
+
+These helpers existed in v4 with `Context` and now accept `*Context`.
+
+**Example:**
+```go
+// v5
+user, err := echo.ContextGet[*User](c, "user")
+count, err := echo.ContextGetOr[int](c, "count", 0)
+```
+
+---
+
+### 3. **PathValues Type**
+
+New structured path parameter handling:
+
+```go
+type PathValue struct {
+ Name string
+ Value string
+}
+
+type PathValues []PathValue
+
+func (p PathValues) Get(name string) (string, bool)
+func (p PathValues) GetOr(name string, defaultValue string) string
+
+// Context methods
+func (c *Context) PathValues() PathValues
+func (c *Context) SetPathValues(pathValues PathValues)
+```
+
+---
+
+### 4. **Time Parsing Options**
+
+```go
+type TimeLayout string
+
+const (
+ TimeLayoutUnixTime = TimeLayout("UnixTime")
+ TimeLayoutUnixTimeMilli = TimeLayout("UnixTimeMilli")
+ TimeLayoutUnixTimeNano = TimeLayout("UnixTimeNano")
+)
+
+type TimeOpts struct {
+ Layout TimeLayout
+ ParseInLocation *time.Location
+ ToInLocation *time.Location
+}
+```
+
+---
+
+### 5. **StartConfig for Server Configuration**
+
+```go
+type StartConfig struct {
+ Address string
+ HideBanner bool
+ HidePort bool
+ CertFilesystem fs.FS
+ TLSConfig *tls.Config
+ ListenerNetwork string
+ ListenerAddrFunc func(addr net.Addr)
+ GracefulTimeout time.Duration
+ OnShutdownError func(err error)
+ BeforeServeFunc func(s *http.Server) error
+}
+
+func (sc StartConfig) Start(ctx context.Context, h http.Handler) error
+func (sc StartConfig) StartTLS(ctx context.Context, h http.Handler, certFile, keyFile any) error
+```
+
+**Example:**
+```go
+// v5 - More control over server startup
+ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
+defer cancel()
+
+sc := echo.StartConfig{
+ Address: ":8080",
+ GracefulTimeout: 10 * time.Second,
+}
+if err := sc.Start(ctx, e); err != nil {
+ log.Fatal(err)
+}
+```
+
+---
+
+### 6. **Echo Config and Constructors**
+
+```go
+type Config struct {
+ // Configuration for Echo (logger, binder, renderer, etc.)
+}
+
+func NewWithConfig(config Config) *Echo
+```
+
+This adds a configuration struct for creating an `Echo` instance without
+mutating fields after `New()`.
+
+---
+
+### 7. **Enhanced Routing Features**
+
+```go
+// New route methods
+func (e *Echo) AddRoute(route Route) (RouteInfo, error)
+func (e *Echo) Middlewares() []MiddlewareFunc
+func (e *Echo) PreMiddlewares() []MiddlewareFunc
+type AddRouteError struct{ ... }
+
+// Routes collection with filters
+type Routes []RouteInfo
+
+func (r Routes) Clone() Routes
+func (r Routes) FilterByMethod(method string) (Routes, error)
+func (r Routes) FilterByName(name string) (Routes, error)
+func (r Routes) FilterByPath(path string) (Routes, error)
+func (r Routes) FindByMethodPath(method string, path string) (RouteInfo, error)
+func (r Routes) Reverse(routeName string, pathValues ...any) (string, error)
+
+// RouteInfo operations
+func (r RouteInfo) Clone() RouteInfo
+func (r RouteInfo) Reverse(pathValues ...any) string
+```
+
+---
+
+### 8. **Middleware Configuration Interface**
+
+```go
+type MiddlewareConfigurator interface {
+ ToMiddleware() (MiddlewareFunc, error)
+}
+```
+
+Allows middleware configs to be converted to middleware without panicking.
+
+---
+
+### 9. **New Context Methods**
+
+```go
+// v5 additions
+func (c *Context) FileFS(file string, filesystem fs.FS) error
+func (c *Context) FormValueOr(name, defaultValue string) string
+func (c *Context) InitializeRoute(ri *RouteInfo, pathValues *PathValues)
+func (c *Context) ParamOr(name, defaultValue string) string
+func (c *Context) QueryParamOr(name, defaultValue string) string
+func (c *Context) RouteInfo() RouteInfo
+```
+
+---
+
+### 10. **Virtual Host Support**
+
+```go
+func NewVirtualHostHandler(vhosts map[string]*Echo) *Echo
+```
+
+Creates an Echo instance that routes requests to different Echo instances based on host.
+
+---
+
+### 11. **New Binder Functions**
+
+```go
+func BindBody(c *Context, target any) error
+func BindHeaders(c *Context, target any) error
+func BindPathValues(c *Context, target any) error // Renamed from BindPathParams
+func BindQueryParams(c *Context, target any) error
+```
+
+Top-level binding functions that work with `*Context`.
+
+---
+
+### 12. **New echotest Package**
+
+```go
+package echotest // import "github.com/labstack/echo/v5/echotest"
+
+func LoadBytes(t *testing.T, name string, opts ...loadBytesOpts) []byte
+func TrimNewlineEnd(bytes []byte) []byte
+type ContextConfig struct{ ... }
+type MultipartForm struct{ ... }
+type MultipartFormFile struct{ ... }
+```
+
+Helpers for loading fixtures and constructing test contexts.
+
+---
+
+## Removed APIs in v5
+
+### Constants
+
+```go
+// v4 - Removed in v5
+const CONNECT = http.MethodConnect // Use http.MethodConnect directly
+```
+
+**Reason:** Deprecated in v4, use stdlib `http.Method*` constants instead.
+
+---
+
+### Constants Added in v5
+
+```go
+// v5 additions
+const (
+ NotFoundRouteName = "echo_route_not_found_name"
+)
+```
+
+---
+
+### Error Variable Changes
+
+**v4 exports:**
+```go
+ErrBadRequest
+ErrInvalidKeyType
+ErrNonExistentKey
+```
+
+**v5 exports:**
+```go
+ErrBadRequest // Now backed by unexported httpError type
+ErrValidatorNotRegistered // New
+ErrInvalidKeyType
+ErrNonExistentKey
+```
+
+**Reason:** v5 centralizes on `NewHTTPError(code, message)` rather than a broad set
+of predefined HTTP error variables.
+
+---
+
+### Functions Removed
+
+```go
+// v4 - Removed in v5
+func GetPath(r *http.Request) string // Use r.URL.Path or r.URL.RawPath
+```
+
+### Variables Removed
+
+```go
+// v4 - Removed in v5
+var MethodNotAllowedHandler = func(c Context) error { ... }
+var NotFoundHandler = func(c Context) error { ... }
+```
+
+### Functions Renamed
+
+```go
+// v4
+func FormParam[T any](c Context, key string, opts ...any) (T, error)
+func FormParamOr[T any](c Context, key string, defaultValue T, opts ...any) (T, error)
+func FormParams[T any](c Context, key string, opts ...any) ([]T, error)
+func FormParamsOr[T any](c Context, key string, defaultValue []T, opts ...any) ([]T, error)
+
+// v5
+func FormValue[T any](c *Context, key string, opts ...any) (T, error)
+func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error)
+func FormValues[T any](c *Context, key string, opts ...any) ([]T, error)
+func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error)
+```
+
+---
+
+### Type Methods Removed/Changed
+
+**Echo struct changes:**
+```go
+// v4 fields removed in v5
+type Echo struct {
+ StdLogger *stdLog.Logger // Removed
+ Server *http.Server // Removed (use StartConfig)
+ TLSServer *http.Server // Removed (use StartConfig)
+ Listener net.Listener // Removed (use StartConfig)
+ TLSListener net.Listener // Removed (use StartConfig)
+ AutoTLSManager autocert.Manager // Removed
+ ListenerNetwork string // Removed
+ OnAddRouteHandler func(...) // Changed to OnAddRoute
+ DisableHTTP2 bool // Removed (use StartConfig)
+ Debug bool // Removed
+ HideBanner bool // Removed (use StartConfig)
+ HidePort bool // Removed (use StartConfig)
+}
+
+// v5 Echo struct (simplified)
+type Echo struct {
+ Binder Binder
+ Filesystem fs.FS // NEW
+ Renderer Renderer
+ Validator Validator
+ JSONSerializer JSONSerializer
+ IPExtractor IPExtractor
+ OnAddRoute func(route Route) error // Simplified
+ HTTPErrorHandler HTTPErrorHandler
+ Logger *slog.Logger // Changed from Logger interface
+}
+```
+
+---
+
+**Context interface → struct:**
+```go
+// v4
+type Context interface {
+ // Had: SetResponse(*Response)
+ Response() *Response
+
+ // Had: ParamNames(), SetParamNames(), ParamValues(), SetParamValues()
+ // These are removed in v5 (use PathValues() instead)
+}
+
+// v5
+type Context struct {
+ // Concrete struct with unexported fields
+}
+
+func (c *Context) Response() http.ResponseWriter // Changed return type
+func (c *Context) PathValues() PathValues // Replaces ParamNames/Values
+```
+
+---
+
+**Types removed:**
+```go
+// v4
+type Map map[string]interface{}
+```
+
+**Group changes:**
+```go
+// v4
+func (g *Group) File(path, file string) // No return value
+func (g *Group) Static(pathPrefix, fsRoot string) // No return value
+func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) // No return value
+
+// v5
+func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo
+func (g *Group) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo
+func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo
+```
+
+Now return `RouteInfo` and accept middleware.
+
+---
+
+### Value Binder Factory Name Changes
+
+```go
+// v4
+func PathParamsBinder(c Context) *ValueBinder
+func QueryParamsBinder(c Context) *ValueBinder
+func FormFieldBinder(c Context) *ValueBinder
+
+// v5
+func PathValuesBinder(c *Context) *ValueBinder // Renamed
+func QueryParamsBinder(c *Context) *ValueBinder
+func FormFieldBinder(c *Context) *ValueBinder
+```
+
+---
+
+## Type Signature Changes
+
+### Binder Interface
+
+```go
+// v4
+type Binder interface {
+ Bind(i interface{}, c Context) error
+}
+
+// v5
+type Binder interface {
+ Bind(c *Context, target any) error // Parameters swapped!
+}
+```
+
+---
+
+### DefaultBinder Methods
+
+```go
+// v4
+func (b *DefaultBinder) Bind(i interface{}, c Context) error
+func (b *DefaultBinder) BindBody(c Context, i interface{}) error
+func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error
+
+// v5
+func (b *DefaultBinder) Bind(c *Context, target any) error // Swapped params
+// BindBody, BindPathParams, etc. are now top-level functions
+```
+
+---
+
+### JSONSerializer Interface
+
+```go
+// v4
+type JSONSerializer interface {
+ Serialize(c Context, i interface{}, indent string) error
+ Deserialize(c Context, i interface{}) error
+}
+
+// v5
+type JSONSerializer interface {
+ Serialize(c *Context, target any, indent string) error
+ Deserialize(c *Context, target any) error
+}
+```
+
+---
+
+### Renderer Interface
+
+```go
+// v4
+type Renderer interface {
+ Render(io.Writer, string, interface{}, Context) error
+}
+
+// v5
+type Renderer interface {
+ Render(c *Context, w io.Writer, templateName string, data any) error
+}
+```
+
+Parameters reordered with Context first.
+
+---
+
+### NewBindingError
+
+```go
+// v4
+func NewBindingError(sourceParam string, values []string, message interface{}, internalError error) error
+
+// v5
+func NewBindingError(sourceParam string, values []string, message string, err error) error
+```
+
+Message parameter changed from `interface{}` to `string`.
+
+---
+
+### HandlerName
+
+```go
+// v5 only
+func HandlerName(h HandlerFunc) string
+```
+
+New utility function to get handler function name.
+
+---
+
+## Middleware Package Changes
+
+### Signature and Type Updates
+
+```go
+// CORS now accepts optional allow-origins
+func CORS(allowOrigins ...string) echo.MiddlewareFunc
+
+// BodyLimit now accepts bytes
+func BodyLimit(limitBytes int64) echo.MiddlewareFunc
+
+// DefaultSkipper now uses *echo.Context
+func DefaultSkipper(c *echo.Context) bool
+
+// Trailing slash configs renamed/split
+func AddTrailingSlashWithConfig(config AddTrailingSlashConfig) echo.MiddlewareFunc
+func RemoveTrailingSlashWithConfig(config RemoveTrailingSlashConfig) echo.MiddlewareFunc
+type AddTrailingSlashConfig struct{ ... }
+type RemoveTrailingSlashConfig struct{ ... }
+
+// Auth + extractor signatures now use *echo.Context and add ExtractorSource
+type BasicAuthValidator func(c *echo.Context, user string, password string) (bool, error)
+type Extractor func(c *echo.Context) (string, error)
+type ExtractorSource string
+type KeyAuthValidator func(c *echo.Context, key string, source ExtractorSource) (bool, error)
+type KeyAuthErrorHandler func(c *echo.Context, err error) error
+
+// BodyDump handler now includes err
+type BodyDumpHandler func(c *echo.Context, reqBody []byte, resBody []byte, err error)
+
+// ValuesExtractor now returns extractor source and CreateExtractors takes a limit
+type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error)
+func CreateExtractors(lookups string, limit uint) ([]ValuesExtractor, error)
+type ValueExtractorError struct{ ... }
+
+// New constants
+const KB = 1024
+
+// Rate limiter store now takes a float64 limit
+func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore)
+```
+
+### Added Middleware Exports
+
+```go
+var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key")
+var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key")
+var RedirectHTTPSConfig = RedirectConfig{ ... }
+var RedirectHTTPSWWWConfig = RedirectConfig{ ... }
+var RedirectNonHTTPSWWWConfig = RedirectConfig{ ... }
+var RedirectNonWWWConfig = RedirectConfig{ ... }
+var RedirectWWWConfig = RedirectConfig{ ... }
+```
+
+### Removed/Consolidated Middleware Exports
+
+```go
+// Removed in v5
+func Logger() echo.MiddlewareFunc
+func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc
+func Timeout() echo.MiddlewareFunc
+func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc
+type ErrKeyAuthMissing struct{ ... }
+type CSRFErrorHandler func(err error, c echo.Context) error
+type LoggerConfig struct{ ... }
+type LogErrorFunc func(c echo.Context, err error, stack []byte) error
+type TargetProvider interface{ ... }
+type TrailingSlashConfig struct{ ... }
+type TimeoutConfig struct{ ... }
+```
+
+Also removed defaults: `DefaultBasicAuthConfig`, `DefaultBodyDumpConfig`, `DefaultBodyLimitConfig`,
+`DefaultCORSConfig`, `DefaultDecompressConfig`, `DefaultGzipConfig`, `DefaultLoggerConfig`,
+`DefaultRedirectConfig`, `DefaultRequestIDConfig`, `DefaultRewriteConfig`, `DefaultTimeoutConfig`,
+`DefaultTrailingSlashConfig`.
+
+---
+
+## Router Interface Changes
+
+### v4 Router (Concrete Struct)
+
+```go
+type Router struct { ... }
+
+func NewRouter(e *Echo) *Router
+func (r *Router) Add(method, path string, h HandlerFunc)
+func (r *Router) Find(method, path string, c Context)
+func (r *Router) Reverse(name string, params ...interface{}) string
+func (r *Router) Routes() []*Route
+```
+
+### v5 Router (Interface + DefaultRouter)
+
+```go
+type Router interface {
+ Add(routable Route) (RouteInfo, error)
+ Remove(method string, path string) error
+ Routes() Routes
+ Route(c *Context) HandlerFunc
+}
+
+type DefaultRouter struct { ... }
+
+func NewRouter(config RouterConfig) *DefaultRouter
+func NewConcurrentRouter(r Router) Router // NEW
+
+type RouterConfig struct {
+ NotFoundHandler HandlerFunc
+ MethodNotAllowedHandler HandlerFunc
+ OptionsMethodHandler HandlerFunc
+ AllowOverwritingRoute bool
+ UnescapePathParamValues bool
+ UseEscapedPathForMatching bool
+}
+```
+
+**Key Changes:**
+- Router is now an interface
+- DefaultRouter is the concrete implementation
+- Add() returns `(RouteInfo, error)` instead of being void
+- New `Remove()` method
+- New `Route()` method replaces `Find()`
+- Configuration through `RouterConfig`
+
+---
+
+## Echo Instance Method Changes
+
+### Route Registration
+
+```go
+// v4
+func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route
+
+// v5
+func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo
+func (e *Echo) AddRoute(route Route) (RouteInfo, error) // NEW
+```
+
+### Static File Serving
+
+```go
+// v4
+func (e *Echo) Static(pathPrefix, fsRoot string) *Route
+func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route
+func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route
+func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route
+
+// v5
+func (e *Echo) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo
+func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo
+func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo
+func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo
+```
+
+Return type changed from `*Route` to `RouteInfo`.
+
+### Server Management
+
+```go
+// v4
+func (e *Echo) Start(address string) error
+func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) error
+func (e *Echo) StartAutoTLS(address string) error
+func (e *Echo) StartH2CServer(address string, h2s *http2.Server) error
+func (e *Echo) StartServer(s *http.Server) error
+func (e *Echo) Shutdown(ctx context.Context) error
+func (e *Echo) Close() error
+func (e *Echo) ListenerAddr() net.Addr
+func (e *Echo) TLSListenerAddr() net.Addr
+func (e *Echo) DefaultHTTPErrorHandler(err error, c Context)
+
+// v5
+func (e *Echo) Start(address string) error // Simplified
+func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request)
+
+// Removed: StartTLS, StartAutoTLS, StartH2CServer, StartServer
+// Use StartConfig instead for advanced server configuration
+// Removed: Shutdown, Close, ListenerAddr, TLSListenerAddr
+// Removed: DefaultHTTPErrorHandler (now a top-level factory function)
+```
+
+**v5 provides** `StartConfig` type for all advanced server configuration.
+
+### Router Access
+
+```go
+// v4
+func (e *Echo) Router() *Router
+func (e *Echo) Routers() map[string]*Router // For multi-host
+func (e *Echo) Routes() []*Route
+func (e *Echo) Reverse(name string, params ...interface{}) string
+func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string
+func (e *Echo) URL(h HandlerFunc, params ...interface{}) string
+func (e *Echo) Host(name string, m ...MiddlewareFunc) *Group
+
+// v5
+func (e *Echo) Router() Router // Returns interface
+// Removed: Routers(), Reverse(), URI(), URL(), Host()
+// Use router.Routes() and Routes.Reverse() instead
+```
+
+---
+
+## NewContext Changes
+
+```go
+// v4
+func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context
+func NewResponse(w http.ResponseWriter, e *Echo) *Response
+
+// v5
+func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) *Context
+func NewContext(r *http.Request, w http.ResponseWriter, opts ...any) *Context // Standalone
+func NewResponse(w http.ResponseWriter, logger *slog.Logger) *Response
+```
+
+---
+
+## Migration Guide Summary
+
+If you are using Linux you can migrate easier parts like that:
+```bash
+find . -type f -name "*.go" -exec sed -i 's/ echo.Context/ *echo.Context/g' {} +
+find . -type f -name "*.go" -exec sed -i 's/echo\/v4/echo\/v5/g' {} +
+```
+or in your favorite IDE
+
+Replace all:
+1. ` echo.Context` -> ` *echo.Context`
+2. `echo/v4` -> `echo/v5`
+
+
+### 1. Update All Handler Signatures
+
+```go
+// Before
+func MyHandler(c echo.Context) error { ... }
+
+// After
+func MyHandler(c *echo.Context) error { ... }
+```
+
+### 2. Update Logger Usage
+
+```go
+// Before
+e.Logger.Info("Server started")
+c.Logger().Error("Something went wrong")
+
+// After
+e.Logger.Info("Server started")
+c.Logger().Error("Something went wrong") // Same API, different logger
+```
+
+### 3. Use Type-Safe Parameter Extraction
+
+```go
+// Before
+idStr := c.Param("id")
+id, err := strconv.Atoi(idStr)
+
+// After
+id, err := echo.PathParam[int](c, "id")
+```
+
+### 4. Update Error Handler
+
+```go
+// Before
+e.HTTPErrorHandler = func(err error, c echo.Context) {
+ // handle error
+}
+
+// After
+e.HTTPErrorHandler = func(c *echo.Context, err error) { // Swapped!
+ // handle error
+}
+
+// Or use factory
+e.HTTPErrorHandler = echo.DefaultHTTPErrorHandler(true) // exposeError=true
+```
+
+### 5. Update Server Startup
+
+```go
+// Before
+e.Start(":8080")
+e.StartTLS(":443", "cert.pem", "key.pem")
+
+// After
+// Simple
+e.Start(":8080")
+
+// Advanced with graceful shutdown
+ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
+defer cancel()
+sc := echo.StartConfig{Address: ":8080"}
+sc.Start(ctx, e)
+```
+
+### 6. Update Route Info Access
+
+```go
+// Before
+routes := e.Routes()
+for _, r := range routes {
+ fmt.Println(r.Method, r.Path)
+}
+
+// After
+routes := e.Router().Routes()
+for _, r := range routes {
+ fmt.Println(r.Method, r.Path)
+}
+```
+
+### 7. Update HTTPError Creation
+
+```go
+// Before
+return echo.NewHTTPError(400, "invalid request", someDetail)
+
+// After
+return echo.NewHTTPError(400, "invalid request")
+```
+
+### 8. Update Custom Binder
+
+```go
+// Before
+type MyBinder struct{}
+func (b *MyBinder) Bind(i interface{}, c echo.Context) error { ... }
+
+// After
+type MyBinder struct{}
+func (b *MyBinder) Bind(c *echo.Context, target any) error { ... } // Swapped!
+```
+
+### 9. Path Parameters
+
+```go
+// Before
+names := c.ParamNames()
+values := c.ParamValues()
+
+// After
+pathValues := c.PathValues()
+for _, pv := range pathValues {
+ fmt.Println(pv.Name, pv.Value)
+}
+```
+
+### 10. Response Access
+
+```go
+// Before
+resp := c.Response()
+resp.Header().Set("X-Custom", "value")
+
+// After
+c.Response().Header().Set("X-Custom", "value") // Returns http.ResponseWriter
+
+// To get *echo.Response
+resp, err := echo.UnwrapResponse(c.Response())
+```
+
+### Go Version Requirements
+
+- **v4**: Go 1.24.0 (per `go.mod`)
+- **v5**: Go 1.25.0 (per `go.mod`)
+
+---
+
+**Generated by comparing `go doc` output from master (v4.15.0) and v5 (v5.0.0-alpha) branches**
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 000000000..37d1adb66
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,838 @@
+# Changelog
+
+## v5.0.3 - 2026-02-06
+
+**Security**
+
+* Fix directory traversal vulnerability under Windows in Static middleware when default Echo filesystem is used. Reported by @shblue21.
+
+This applies to cases when:
+- Windows is used as OS
+- `middleware.StaticConfig.Filesystem` is `nil` (default)
+- `echo.Filesystem` is has not been set explicitly (default)
+
+Exposure is restricted to the active process working directory and its subfolders.
+
+
+## v5.0.2 - 2026-02-02
+
+**Security**
+
+* Fix Static middleware with `config.Browse=true` lists all files/subfolders from `config.Filesystem` root and not starting from `config.Root` in https://github.com/labstack/echo/pull/2887
+
+
+## v5.0.1 - 2026-01-28
+
+* Panic MW: will now return a custom PanicStackError with stack trace by @aldas in https://github.com/labstack/echo/pull/2871
+* Docs: add missing err parameter to DenyHandler example by @cgalibern in https://github.com/labstack/echo/pull/2878
+* improve: improve websocket checks in IsWebSocket() [per RFC 6455] by @raju-mechatronics in https://github.com/labstack/echo/pull/2875
+* fix: Context.Json() should not send status code before serialization is complete by @aldas in https://github.com/labstack/echo/pull/2877
+
+
+## v5.0.0 - 2026-01-18
+
+Echo `v5` is maintenance release with **major breaking changes**
+- `Context` is now struct instead of interface and we can add method to it in the future in minor versions.
+- Adds new `Router` interface for possible new routing implementations.
+- Drops old logging interface and uses moderm `log/slog` instead.
+- Rearranges alot of methods/function signatures to make them more consistent.
+
+Upgrade notes and `v4` support:
+- Echo `v4` is supported with **security*** updates and **bug** fixes until **2026-12-31**
+- If you are using Echo in a production environment, it is recommended to wait until after 2026-03-31 before upgrading.
+- Until 2026-03-31, any critical issues requiring breaking `v5` API changes will be addressed, even if this violates semantic versioning.
+
+See [API_CHANGES_V5.md](./API_CHANGES_V5.md) for public API changes between `v4` and `v5`, notes on **upgrading**.
+
+Upgrading TLDR:
+
+If you are using Linux you can migrate easier parts like that:
+```bash
+find . -type f -name "*.go" -exec sed -i 's/ echo.Context/ *echo.Context/g' {} +
+find . -type f -name "*.go" -exec sed -i 's/echo\/v4/echo\/v5/g' {} +
+```
+macOS
+```bash
+find . -type f -name "*.go" -exec sed -i '' 's/ echo.Context/ *echo.Context/g' {} +
+find . -type f -name "*.go" -exec sed -i '' 's/echo\/v4/echo\/v5/g' {} +
+```
+
+or in your favorite IDE
+
+Replace all:
+1. ` echo.Context` -> ` *echo.Context`
+2. `echo/v4` -> `echo/v5`
+
+This should solve most of the issues. Probably the hardest part is updating all the tests.
+
+
+## v4.15.0 - 2026-01-01
+
+
+**Security**
+
+NB: **If your application relies on cross-origin or same-site (same subdomain) requests do not blindly push this version to production**
+
+
+The CSRF middleware now supports the [**Sec-Fetch-Site**](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Sec-Fetch-Site) header as a modern, defense-in-depth approach to [CSRF
+protection](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers), implementing the OWASP-recommended Fetch Metadata API alongside the traditional token-based mechanism.
+
+**How it works:**
+
+Modern browsers automatically send the `Sec-Fetch-Site` header with all requests, indicating the relationship
+between the request origin and the target. The middleware uses this to make security decisions:
+
+- **`same-origin`** or **`none`**: Requests are allowed (exact origin match or direct user navigation)
+- **`same-site`**: Falls back to token validation (e.g., subdomain to main domain)
+- **`cross-site`**: Blocked by default with 403 error for unsafe methods (POST, PUT, DELETE, PATCH)
+
+For browsers that don't send this header (older browsers), the middleware seamlessly falls back to
+traditional token-based CSRF protection.
+
+**New Configuration Options:**
+- `TrustedOrigins []string`: Allowlist specific origins for cross-site requests (useful for OAuth callbacks, webhooks)
+- `AllowSecFetchSiteFunc func(echo.Context) (bool, error)`: Custom logic for same-site/cross-site request validation
+
+**Example:**
+ ```go
+ e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{
+ // Allow OAuth callbacks from trusted provider
+ TrustedOrigins: []string{"https://oauth-provider.com"},
+
+ // Custom validation for same-site requests
+ AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) {
+ // Your custom authorization logic here
+ return validateCustomAuth(c), nil
+ // return true, err // blocks request with error
+ // return true, nil // allows CSRF request through
+ // return false, nil // falls back to legacy token logic
+ },
+ }))
+ ```
+PR: https://github.com/labstack/echo/pull/2858
+
+**Type-Safe Generic Parameter Binding**
+
+* Added generic functions for type-safe parameter extraction and context access by @aldas in https://github.com/labstack/echo/pull/2856
+
+ Echo now provides generic functions for extracting path, query, and form parameters with automatic type conversion,
+ eliminating manual string parsing and type assertions.
+
+ **New Functions:**
+ - Path parameters: `PathParam[T]`, `PathParamOr[T]`
+ - Query parameters: `QueryParam[T]`, `QueryParamOr[T]`, `QueryParams[T]`, `QueryParamsOr[T]`
+ - Form values: `FormParam[T]`, `FormParamOr[T]`, `FormParams[T]`, `FormParamsOr[T]`
+ - Context store: `ContextGet[T]`, `ContextGetOr[T]`
+
+ **Supported Types:**
+ Primitives (`bool`, `string`, `int`/`uint` variants, `float32`/`float64`), `time.Duration`, `time.Time`
+ (with custom layouts and Unix timestamp support), and custom types implementing `BindUnmarshaler`,
+ `TextUnmarshaler`, or `JSONUnmarshaler`.
+
+ **Example:**
+ ```go
+ // Before: Manual parsing
+ idStr := c.Param("id")
+ id, err := strconv.Atoi(idStr)
+
+ // After: Type-safe with automatic parsing
+ id, err := echo.PathParam[int](c, "id")
+
+ // With default values
+ page, err := echo.QueryParamOr[int](c, "page", 1)
+ limit, err := echo.QueryParamOr[int](c, "limit", 20)
+
+ // Type-safe context access (no more panics from type assertions)
+ user, err := echo.ContextGet[*User](c, "user")
+ ```
+
+PR: https://github.com/labstack/echo/pull/2856
+
+
+
+**DEPRECATION NOTICE** Timeout Middleware Deprecated - Use ContextTimeout Instead
+
+The `middleware.Timeout` middleware has been **deprecated** due to fundamental architectural issues that cause
+data races. Use `middleware.ContextTimeout` or `middleware.ContextTimeoutWithConfig` instead.
+
+**Why is this being deprecated?**
+
+The Timeout middleware manipulates response writers across goroutine boundaries, which causes data races that
+cannot be reliably fixed without a complete architectural redesign. The middleware:
+
+- Swaps the response writer using `http.TimeoutHandler`
+- Must be the first middleware in the chain (fragile constraint)
+- Can cause races with other middleware (Logger, metrics, custom middleware)
+- Has been the source of multiple race condition fixes over the years
+
+**What should you use instead?**
+
+The `ContextTimeout` middleware (available since v4.12.0) provides timeout functionality using Go's standard
+context mechanism. It is:
+
+- Race-free by design
+- Can be placed anywhere in the middleware chain
+- Simpler and more maintainable
+- Compatible with all other middleware
+
+**Migration Guide:**
+
+```go
+// Before (deprecated):
+e.Use(middleware.Timeout())
+
+// After (recommended):
+e.Use(middleware.ContextTimeout(30 * time.Second))
+```
+
+**Important Behavioral Differences:**
+
+1. **Handler cooperation required**: With ContextTimeout, your handlers must check `context.Done()` for cooperative
+ cancellation. The old Timeout middleware would send a 503 response regardless of handler cooperation, but had
+ data race issues.
+
+2. **Error handling**: ContextTimeout returns errors through the standard error handling flow. Handlers that receive
+ `context.DeadlineExceeded` should handle it appropriately:
+
+```go
+e.GET("/long-task", func(c echo.Context) error {
+ ctx := c.Request().Context()
+
+ // Example: database query with context
+ result, err := db.QueryContext(ctx, "SELECT * FROM large_table")
+ if err != nil {
+ if errors.Is(err, context.DeadlineExceeded) {
+ // Handle timeout
+ return echo.NewHTTPError(http.StatusServiceUnavailable, "Request timeout")
+ }
+ return err
+ }
+
+ return c.JSON(http.StatusOK, result)
+})
+```
+
+3. **Background tasks**: For long-running background tasks, use goroutines with context:
+
+```go
+e.GET("/async-task", func(c echo.Context) error {
+ ctx := c.Request().Context()
+
+ resultCh := make(chan Result, 1)
+ errCh := make(chan error, 1)
+
+ go func() {
+ result, err := performLongTask(ctx)
+ if err != nil {
+ errCh <- err
+ return
+ }
+ resultCh <- result
+ }()
+
+ select {
+ case result := <-resultCh:
+ return c.JSON(http.StatusOK, result)
+ case err := <-errCh:
+ return err
+ case <-ctx.Done():
+ return echo.NewHTTPError(http.StatusServiceUnavailable, "Request timeout")
+ }
+})
+```
+
+**Enhancements**
+
+* Fixes by @aldas in https://github.com/labstack/echo/pull/2852
+* Generic functions by @aldas in https://github.com/labstack/echo/pull/2856
+* CRSF with Sec-Fetch-Site checks by @aldas in https://github.com/labstack/echo/pull/2858
+
+
+## v4.14.0 - 2025-12-11
+
+`middleware.Logger` has been deprecated. For request logging, use `middleware.RequestLogger` or
+`middleware.RequestLoggerWithConfig`.
+
+`middleware.RequestLogger` replaces `middleware.Logger`, offering comparable configuration while relying on the
+Go standard library’s new `slog` logger.
+
+The previous default output format was JSON. The new default follows the standard `slog` logger settings.
+To continue emitting request logs in JSON, configure `slog` accordingly:
+```go
+slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, nil)))
+e.Use(middleware.RequestLogger())
+```
+
+
+**Security**
+
+* Logger middleware json string escaping and deprecation by @aldas in https://github.com/labstack/echo/pull/2849
+
+
+
+**Enhancements**
+
+* Update deps by @aldas in https://github.com/labstack/echo/pull/2807
+* refactor to use reflect.TypeFor by @cuiweixie in https://github.com/labstack/echo/pull/2812
+* Use Go 1.25 in CI by @aldas in https://github.com/labstack/echo/pull/2810
+* Modernize context.go by replacing interface{} with any by @vishr in https://github.com/labstack/echo/pull/2822
+* Fix typo in SetParamValues comment by @vishr in https://github.com/labstack/echo/pull/2828
+* Fix typo in ContextTimeout middleware comment by @vishr in https://github.com/labstack/echo/pull/2827
+* Improve BasicAuth middleware: use strings.Cut and RFC compliance by @vishr in https://github.com/labstack/echo/pull/2825
+* Fix duplicate plus operator in router backtracking logic by @yuya-morimoto in https://github.com/labstack/echo/pull/2832
+* Replace custom private IP range check with built-in net.IP.IsPrivate by @kumapower17 in https://github.com/labstack/echo/pull/2835
+* Ensure proxy connection is closed in proxyRaw function(#2837) by @kumapower17 in https://github.com/labstack/echo/pull/2838
+* Update deps by @aldas in https://github.com/labstack/echo/pull/2843
+* Update golang.org/x/* deps by @aldas in https://github.com/labstack/echo/pull/2850
+
+
+
+## v4.13.4 - 2025-05-22
+
+**Enhancements**
+
+* chore: fix some typos in comment by @zhuhaicity in https://github.com/labstack/echo/pull/2735
+* CI: test with Go 1.24 by @aldas in https://github.com/labstack/echo/pull/2748
+* Add support for TLS WebSocket proxy by @t-ibayashi-safie in https://github.com/labstack/echo/pull/2762
+
+**Security**
+
+* Update dependencies for [GO-2025-3487](https://pkg.go.dev/vuln/GO-2025-3487), [GO-2025-3503](https://pkg.go.dev/vuln/GO-2025-3503) and [GO-2025-3595](https://pkg.go.dev/vuln/GO-2025-3595) in https://github.com/labstack/echo/pull/2780
+
+
+## v4.13.3 - 2024-12-19
+
+**Security**
+
+* Update golang.org/x/net dependency [GO-2024-3333](https://pkg.go.dev/vuln/GO-2024-3333) in https://github.com/labstack/echo/pull/2722
+
+
+## v4.13.2 - 2024-12-12
+
+**Security**
+
+* Update dependencies (dependabot reports [GO-2024-3321](https://pkg.go.dev/vuln/GO-2024-3321)) in https://github.com/labstack/echo/pull/2721
+
+
+## v4.13.1 - 2024-12-11
+
+**Fixes**
+
+* Fix BindBody ignoring `Transfer-Encoding: chunked` requests by @178inaba in https://github.com/labstack/echo/pull/2717
+
+
+
+## v4.13.0 - 2024-12-04
+
+**BREAKING CHANGE** JWT Middleware Removed from Core use [labstack/echo-jwt](https://github.com/labstack/echo-jwt) instead
+
+The JWT middleware has been **removed from Echo core** due to another security vulnerability, [CVE-2024-51744](https://nvd.nist.gov/vuln/detail/CVE-2024-51744). For more details, refer to issue [#2699](https://github.com/labstack/echo/issues/2699). A drop-in replacement is available in the [labstack/echo-jwt](https://github.com/labstack/echo-jwt) repository.
+
+**Important**: Direct assignments like `token := c.Get("user").(*jwt.Token)` will now cause a panic due to an invalid cast. Update your code accordingly. Replace the current imports from `"github.com/golang-jwt/jwt"` in your handlers to the new middleware version using `"github.com/golang-jwt/jwt/v5"`.
+
+
+Background:
+
+The version of `golang-jwt/jwt` (v3.2.2) previously used in Echo core has been in an unmaintained state for some time. This is not the first vulnerability affecting this library; earlier issues were addressed in [PR #1946](https://github.com/labstack/echo/pull/1946).
+JWT middleware was marked as deprecated in Echo core as of [v4.10.0](https://github.com/labstack/echo/releases/tag/v4.10.0) on 2022-12-27. If you did not notice that, consider leveraging tools like [Staticcheck](https://staticcheck.dev/) to catch such deprecations earlier in you dev/CI flow. For bonus points - check out [gosec](https://github.com/securego/gosec).
+
+We sincerely apologize for any inconvenience caused by this change. While we strive to maintain backward compatibility within Echo core, recurring security issues with third-party dependencies have forced this decision.
+
+**Enhancements**
+
+* remove jwt middleware by @stevenwhitehead in https://github.com/labstack/echo/pull/2701
+* optimization: struct alignment by @behnambm in https://github.com/labstack/echo/pull/2636
+* bind: Maintain backwards compatibility for map[string]interface{} binding by @thesaltree in https://github.com/labstack/echo/pull/2656
+* Add Go 1.23 to CI by @aldas in https://github.com/labstack/echo/pull/2675
+* improve `MultipartForm` test by @martinyonatann in https://github.com/labstack/echo/pull/2682
+* `bind` : add support of multipart multi files by @martinyonatann in https://github.com/labstack/echo/pull/2684
+* Add TemplateRenderer struct to ease creating renderers for `html/template` and `text/template` packages. by @aldas in https://github.com/labstack/echo/pull/2690
+* Refactor TestBasicAuth to utilize table-driven test format by @ErikOlson in https://github.com/labstack/echo/pull/2688
+* Remove broken header by @aldas in https://github.com/labstack/echo/pull/2705
+* fix(bind body): content-length can be -1 by @phamvinhdat in https://github.com/labstack/echo/pull/2710
+* CORS middleware should compile allowOrigin regexp at creation by @aldas in https://github.com/labstack/echo/pull/2709
+* Shorten Github issue template and add test example by @aldas in https://github.com/labstack/echo/pull/2711
+
+
+## v4.12.0 - 2024-04-15
+
+**Security**
+
+* Update golang.org/x/net dep because of [GO-2024-2687](https://pkg.go.dev/vuln/GO-2024-2687) by @aldas in https://github.com/labstack/echo/pull/2625
+
+
+**Enhancements**
+
+* binder: make binding to Map work better with string destinations by @aldas in https://github.com/labstack/echo/pull/2554
+* README.md: add Encore as sponsor by @marcuskohlberg in https://github.com/labstack/echo/pull/2579
+* Reorder paragraphs in README.md by @aldas in https://github.com/labstack/echo/pull/2581
+* CI: upgrade actions/checkout to v4 by @aldas in https://github.com/labstack/echo/pull/2584
+* Remove default charset from 'application/json' Content-Type header by @doortts in https://github.com/labstack/echo/pull/2568
+* CI: Use Go 1.22 by @aldas in https://github.com/labstack/echo/pull/2588
+* binder: allow binding to a nil map by @georgmu in https://github.com/labstack/echo/pull/2574
+* Add Skipper Unit Test In BasicBasicAuthConfig and Add More Detail Explanation regarding BasicAuthValidator by @RyoKusnadi in https://github.com/labstack/echo/pull/2461
+* fix some typos by @teslaedison in https://github.com/labstack/echo/pull/2603
+* fix: some typos by @pomadev in https://github.com/labstack/echo/pull/2596
+* Allow ResponseWriters to unwrap writers when flushing/hijacking by @aldas in https://github.com/labstack/echo/pull/2595
+* Add SPDX licence comments to files. by @aldas in https://github.com/labstack/echo/pull/2604
+* Upgrade deps by @aldas in https://github.com/labstack/echo/pull/2605
+* Change type definition blocks to single declarations. This helps copy… by @aldas in https://github.com/labstack/echo/pull/2606
+* Fix Real IP logic by @cl-bvl in https://github.com/labstack/echo/pull/2550
+* Default binder can use `UnmarshalParams(params []string) error` inter… by @aldas in https://github.com/labstack/echo/pull/2607
+* Default binder can bind pointer to slice as struct field. For example `*[]string` by @aldas in https://github.com/labstack/echo/pull/2608
+* Remove maxparam dependence from Context by @aldas in https://github.com/labstack/echo/pull/2611
+* When route is registered with empty path it is normalized to `/`. by @aldas in https://github.com/labstack/echo/pull/2616
+* proxy middleware should use httputil.ReverseProxy for SSE requests by @aldas in https://github.com/labstack/echo/pull/2624
+
+
+## v4.11.4 - 2023-12-20
+
+**Security**
+
+* Upgrade golang.org/x/crypto to v0.17.0 to fix vulnerability [issue](https://pkg.go.dev/vuln/GO-2023-2402) [#2562](https://github.com/labstack/echo/pull/2562)
+
+**Enhancements**
+
+* Update deps and mark Go version to 1.18 as this is what golang.org/x/* use [#2563](https://github.com/labstack/echo/pull/2563)
+* Request logger: add example for Slog https://pkg.go.dev/log/slog [#2543](https://github.com/labstack/echo/pull/2543)
+
+
+## v4.11.3 - 2023-11-07
+
+**Security**
+
+* 'c.Attachment' and 'c.Inline' should escape filename in 'Content-Disposition' header to avoid 'Reflect File Download' vulnerability. [#2541](https://github.com/labstack/echo/pull/2541)
+
+**Enhancements**
+
+* Tests: refactor context tests to be separate functions [#2540](https://github.com/labstack/echo/pull/2540)
+* Proxy middleware: reuse echo request context [#2537](https://github.com/labstack/echo/pull/2537)
+* Mark unmarshallable yaml struct tags as ignored [#2536](https://github.com/labstack/echo/pull/2536)
+
+
+## v4.11.2 - 2023-10-11
+
+**Security**
+
+* Bump golang.org/x/net to prevent CVE-2023-39325 / CVE-2023-44487 HTTP/2 Rapid Reset Attack [#2527](https://github.com/labstack/echo/pull/2527)
+* fix(sec): randomString bias introduced by #2490 [#2492](https://github.com/labstack/echo/pull/2492)
+* CSRF/RequestID mw: switch math/random usage to crypto/random [#2490](https://github.com/labstack/echo/pull/2490)
+
+**Enhancements**
+
+* Delete unused context in body_limit.go [#2483](https://github.com/labstack/echo/pull/2483)
+* Use Go 1.21 in CI [#2505](https://github.com/labstack/echo/pull/2505)
+* Fix some typos [#2511](https://github.com/labstack/echo/pull/2511)
+* Allow CORS middleware to send Access-Control-Max-Age: 0 [#2518](https://github.com/labstack/echo/pull/2518)
+* Bump dependancies [#2522](https://github.com/labstack/echo/pull/2522)
+
+## v4.11.1 - 2023-07-16
+
+**Fixes**
+
+* Fix `Gzip` middleware not sending response code for no content responses (404, 301/302 redirects etc) [#2481](https://github.com/labstack/echo/pull/2481)
+
+
+## v4.11.0 - 2023-07-14
+
+
+**Fixes**
+
+* Fixes the proxy middleware concurrency issue of calling the Next() proxy target on Round Robin Balancer [#2409](https://github.com/labstack/echo/pull/2409)
+* Fix `group.RouteNotFound` not working when group has attached middlewares [#2411](https://github.com/labstack/echo/pull/2411)
+* Fix global error handler return error message when message is an error [#2456](https://github.com/labstack/echo/pull/2456)
+* Do not use global timeNow variables [#2477](https://github.com/labstack/echo/pull/2477)
+
+
+**Enhancements**
+
+* Added a optional config variable to disable centralized error handler in recovery middleware [#2410](https://github.com/labstack/echo/pull/2410)
+* refactor: use `strings.ReplaceAll` directly [#2424](https://github.com/labstack/echo/pull/2424)
+* Add support for Go1.20 `http.rwUnwrapper` to Response struct [#2425](https://github.com/labstack/echo/pull/2425)
+* Check whether is nil before invoking centralized error handling [#2429](https://github.com/labstack/echo/pull/2429)
+* Proper colon support in `echo.Reverse` method [#2416](https://github.com/labstack/echo/pull/2416)
+* Fix misuses of a vs an in documentation comments [#2436](https://github.com/labstack/echo/pull/2436)
+* Add link to slog.Handler library for Echo logging into README.md [#2444](https://github.com/labstack/echo/pull/2444)
+* In proxy middleware Support retries of failed proxy requests [#2414](https://github.com/labstack/echo/pull/2414)
+* gofmt fixes to comments [#2452](https://github.com/labstack/echo/pull/2452)
+* gzip response only if it exceeds a minimal length [#2267](https://github.com/labstack/echo/pull/2267)
+* Upgrade packages [#2475](https://github.com/labstack/echo/pull/2475)
+
+
+## v4.10.2 - 2023-02-22
+
+**Security**
+
+* `filepath.Clean` behaviour has changed in Go 1.20 - adapt to it [#2406](https://github.com/labstack/echo/pull/2406)
+* Add `middleware.CORSConfig.UnsafeWildcardOriginWithAllowCredentials` to make UNSAFE usages of wildcard origin + allow cretentials less likely [#2405](https://github.com/labstack/echo/pull/2405)
+
+**Enhancements**
+
+* Add more HTTP error values [#2277](https://github.com/labstack/echo/pull/2277)
+
+
+## v4.10.1 - 2023-02-19
+
+**Security**
+
+* Upgrade deps due to the latest golang.org/x/net vulnerability [#2402](https://github.com/labstack/echo/pull/2402)
+
+
+**Enhancements**
+
+* Add new JWT repository to the README [#2377](https://github.com/labstack/echo/pull/2377)
+* Return an empty string for ctx.path if there is no registered path [#2385](https://github.com/labstack/echo/pull/2385)
+* Add context timeout middleware [#2380](https://github.com/labstack/echo/pull/2380)
+* Update link to jaegertracing [#2394](https://github.com/labstack/echo/pull/2394)
+
+
+## v4.10.0 - 2022-12-27
+
+**Security**
+
+* We are deprecating JWT middleware in this repository. Please use https://github.com/labstack/echo-jwt instead.
+
+ JWT middleware is moved to separate repository to allow us to bump/upgrade version of JWT implementation (`github.com/golang-jwt/jwt`) we are using
+which we can not do in Echo core because this would break backwards compatibility guarantees we try to maintain.
+
+* This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are
+ several vulnerabilities fixed in these libraries.
+
+ Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise.
+
+
+**Enhancements**
+
+* Bump x/text to 0.3.8 [#2305](https://github.com/labstack/echo/pull/2305)
+* Bump dependencies and add notes about Go releases we support [#2336](https://github.com/labstack/echo/pull/2336)
+* Add helper interface for ProxyBalancer interface [#2316](https://github.com/labstack/echo/pull/2316)
+* Expose `middleware.CreateExtractors` function so we can use it from echo-contrib repository [#2338](https://github.com/labstack/echo/pull/2338)
+* Refactor func(Context) error to HandlerFunc [#2315](https://github.com/labstack/echo/pull/2315)
+* Improve function comments [#2329](https://github.com/labstack/echo/pull/2329)
+* Add new method HTTPError.WithInternal [#2340](https://github.com/labstack/echo/pull/2340)
+* Replace io/ioutil package usages [#2342](https://github.com/labstack/echo/pull/2342)
+* Add staticcheck to CI flow [#2343](https://github.com/labstack/echo/pull/2343)
+* Replace relative path determination from proprietary to std [#2345](https://github.com/labstack/echo/pull/2345)
+* Remove square brackets from ipv6 addresses in XFF (X-Forwarded-For header) [#2182](https://github.com/labstack/echo/pull/2182)
+* Add testcases for some BodyLimit middleware configuration options [#2350](https://github.com/labstack/echo/pull/2350)
+* Additional configuration options for RequestLogger and Logger middleware [#2341](https://github.com/labstack/echo/pull/2341)
+* Add route to request log [#2162](https://github.com/labstack/echo/pull/2162)
+* GitHub Workflows security hardening [#2358](https://github.com/labstack/echo/pull/2358)
+* Add govulncheck to CI and bump dependencies [#2362](https://github.com/labstack/echo/pull/2362)
+* Fix rate limiter docs [#2366](https://github.com/labstack/echo/pull/2366)
+* Refactor how `e.Routes()` work and introduce `e.OnAddRouteHandler` callback [#2337](https://github.com/labstack/echo/pull/2337)
+
+
+## v4.9.1 - 2022-10-12
+
+**Fixes**
+
+* Fix logger panicing (when template is set to empty) by bumping dependency version [#2295](https://github.com/labstack/echo/issues/2295)
+
+**Enhancements**
+
+* Improve CORS documentation [#2272](https://github.com/labstack/echo/pull/2272)
+* Update readme about supported Go versions [#2291](https://github.com/labstack/echo/pull/2291)
+* Tests: improve error handling on closing body [#2254](https://github.com/labstack/echo/pull/2254)
+* Tests: refactor some of the assertions in tests [#2275](https://github.com/labstack/echo/pull/2275)
+* Tests: refactor assertions [#2301](https://github.com/labstack/echo/pull/2301)
+
+## v4.9.0 - 2022-09-04
+
+**Security**
+
+* Fix open redirect vulnerability in handlers serving static directories (e.Static, e.StaticFs, echo.StaticDirectoryHandler) [#2260](https://github.com/labstack/echo/pull/2260)
+
+**Enhancements**
+
+* Allow configuring ErrorHandler in CSRF middleware [#2257](https://github.com/labstack/echo/pull/2257)
+* Replace HTTP method constants in tests with stdlib constants [#2247](https://github.com/labstack/echo/pull/2247)
+
+
+## v4.8.0 - 2022-08-10
+
+**Most notable things**
+
+You can now add any arbitrary HTTP method type as a route [#2237](https://github.com/labstack/echo/pull/2237)
+```go
+e.Add("COPY", "/*", func(c echo.Context) error
+ return c.String(http.StatusOK, "OK COPY")
+})
+```
+
+You can add custom 404 handler for specific paths [#2217](https://github.com/labstack/echo/pull/2217)
+```go
+e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })
+
+g := e.Group("/images")
+g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })
+```
+
+**Enhancements**
+
+* Add new value binding methods (UnixTimeMilli,TextUnmarshaler,JSONUnmarshaler) to Valuebinder [#2127](https://github.com/labstack/echo/pull/2127)
+* Refactor: body_limit middleware unit test [#2145](https://github.com/labstack/echo/pull/2145)
+* Refactor: Timeout mw: rework how test waits for timeout. [#2187](https://github.com/labstack/echo/pull/2187)
+* BasicAuth middleware returns 500 InternalServerError on invalid base64 strings but should return 400 [#2191](https://github.com/labstack/echo/pull/2191)
+* Refactor: duplicated findStaticChild process at findChildWithLabel [#2176](https://github.com/labstack/echo/pull/2176)
+* Allow different param names in different methods with same path scheme [#2209](https://github.com/labstack/echo/pull/2209)
+* Add support for registering handlers for different 404 routes [#2217](https://github.com/labstack/echo/pull/2217)
+* Middlewares should use errors.As() instead of type assertion on HTTPError [#2227](https://github.com/labstack/echo/pull/2227)
+* Allow arbitrary HTTP method types to be added as routes [#2237](https://github.com/labstack/echo/pull/2237)
+
+## v4.7.2 - 2022-03-16
+
+**Fixes**
+
+* Fix nil pointer exception when calling Start again after address binding error [#2131](https://github.com/labstack/echo/pull/2131)
+* Fix CSRF middleware not being able to extract token from multipart/form-data form [#2136](https://github.com/labstack/echo/pull/2136)
+* Fix Timeout middleware write race [#2126](https://github.com/labstack/echo/pull/2126)
+
+**Enhancements**
+
+* Recover middleware should not log panic for aborted handler [#2134](https://github.com/labstack/echo/pull/2134)
+
+
+## v4.7.1 - 2022-03-13
+
+**Fixes**
+
+* Fix `e.Static`, `.File()`, `c.Attachment()` being picky with paths starting with `./`, `../` and `/` after 4.7.0 introduced echo.Filesystem support (Go1.16+) [#2123](https://github.com/labstack/echo/pull/2123)
+
+**Enhancements**
+
+* Remove some unused code [#2116](https://github.com/labstack/echo/pull/2116)
+
+
+## v4.7.0 - 2022-03-01
+
+**Enhancements**
+
+* Add JWT, KeyAuth, CSRF multivalue extractors [#2060](https://github.com/labstack/echo/pull/2060)
+* Add LogErrorFunc to recover middleware [#2072](https://github.com/labstack/echo/pull/2072)
+* Add support for HEAD method query params binding [#2027](https://github.com/labstack/echo/pull/2027)
+* Improve filesystem support with echo.FileFS, echo.StaticFS, group.FileFS, group.StaticFS [#2064](https://github.com/labstack/echo/pull/2064)
+
+**Fixes**
+
+* Fix X-Real-IP bug, improve tests [#2007](https://github.com/labstack/echo/pull/2007)
+* Minor syntax fixes [#1994](https://github.com/labstack/echo/pull/1994), [#2102](https://github.com/labstack/echo/pull/2102), [#2102](https://github.com/labstack/echo/pull/2102)
+
+**General**
+
+* Add cache-control and connection headers [#2103](https://github.com/labstack/echo/pull/2103)
+* Add Retry-After header constant [#2078](https://github.com/labstack/echo/pull/2078)
+* Upgrade `go` directive in `go.mod` to 1.17 [#2049](https://github.com/labstack/echo/pull/2049)
+* Add Pagoda [#2077](https://github.com/labstack/echo/pull/2077) and Souin [#2069](https://github.com/labstack/echo/pull/2069) to 3rd-party middlewares in README
+
+## v4.6.3 - 2022-01-10
+
+**Fixes**
+
+* Fixed Echo version number in greeting message which was not incremented to `4.6.2` [#2066](https://github.com/labstack/echo/issues/2066)
+
+
+## v4.6.2 - 2022-01-08
+
+**Fixes**
+
+* Fixed route containing escaped colon should be matchable but is not matched to request path [#2047](https://github.com/labstack/echo/pull/2047)
+* Fixed a problem that returned wrong content-encoding when the gzip compressed content was empty. [#1921](https://github.com/labstack/echo/pull/1921)
+* Update (test) dependencies [#2021](https://github.com/labstack/echo/pull/2021)
+
+
+**Enhancements**
+
+* Add support for configurable target header for the request_id middleware [#2040](https://github.com/labstack/echo/pull/2040)
+* Change decompress middleware to use stream decompression instead of buffering [#2018](https://github.com/labstack/echo/pull/2018)
+* Documentation updates
+
+
+## v4.6.1 - 2021-09-26
+
+**Enhancements**
+
+* Add start time to request logger middleware values [#1991](https://github.com/labstack/echo/pull/1991)
+
+## v4.6.0 - 2021-09-20
+
+Introduced a new [request logger](https://github.com/labstack/echo/blob/master/middleware/request_logger.go) middleware
+to help with cases when you want to use some other logging library in your application.
+
+**Fixes**
+
+* fix timeout middleware warning: superfluous response.WriteHeader [#1905](https://github.com/labstack/echo/issues/1905)
+
+**Enhancements**
+
+* Add Cookie to KeyAuth middleware's KeyLookup [#1929](https://github.com/labstack/echo/pull/1929)
+* JWT middleware should ignore case of auth scheme in request header [#1951](https://github.com/labstack/echo/pull/1951)
+* Refactor default error handler to return first if response is already committed [#1956](https://github.com/labstack/echo/pull/1956)
+* Added request logger middleware which helps to use custom logger library for logging requests. [#1980](https://github.com/labstack/echo/pull/1980)
+* Allow escaping of colon in route path so Google Cloud API "custom methods" could be implemented [#1988](https://github.com/labstack/echo/pull/1988)
+
+## v4.5.0 - 2021-08-01
+
+**Important notes**
+
+A **BREAKING CHANGE** is introduced for JWT middleware users.
+The JWT library used for the JWT middleware had to be changed from [github.com/dgrijalva/jwt-go](https://github.com/dgrijalva/jwt-go) to
+[github.com/golang-jwt/jwt](https://github.com/golang-jwt/jwt) due former library being unmaintained and affected by security
+issues.
+The [github.com/golang-jwt/jwt](https://github.com/golang-jwt/jwt) project is a drop-in replacement, but supports only the latest 2 Go versions.
+So for JWT middleware users Go 1.15+ is required. For detailed information please read [#1940](https://github.com/labstack/echo/discussions/)
+
+To change the library imports in all .go files in your project replace all occurrences of `dgrijalva/jwt-go` with `golang-jwt/jwt`.
+
+For Linux CLI you can use:
+```bash
+find -type f -name "*.go" -exec sed -i "s/dgrijalva\/jwt-go/golang-jwt\/jwt/g" {} \;
+go mod tidy
+```
+
+**Fixes**
+
+* Change JWT library to `github.com/golang-jwt/jwt` [#1946](https://github.com/labstack/echo/pull/1946)
+
+## v4.4.0 - 2021-07-12
+
+**Fixes**
+
+* Split HeaderXForwardedFor header only by comma [#1878](https://github.com/labstack/echo/pull/1878)
+* Fix Timeout middleware Context propagation [#1910](https://github.com/labstack/echo/pull/1910)
+
+**Enhancements**
+
+* Bind data using headers as source [#1866](https://github.com/labstack/echo/pull/1866)
+* Adds JWTConfig.ParseTokenFunc to JWT middleware to allow different libraries implementing JWT parsing. [#1887](https://github.com/labstack/echo/pull/1887)
+* Adding tests for Echo#Host [#1895](https://github.com/labstack/echo/pull/1895)
+* Adds RequestIDHandler function to RequestID middleware [#1898](https://github.com/labstack/echo/pull/1898)
+* Allow for custom JSON encoding implementations [#1880](https://github.com/labstack/echo/pull/1880)
+
+## v4.3.0 - 2021-05-08
+
+**Important notes**
+
+* Route matching has improvements for following cases:
+ 1. Correctly match routes with parameter part as last part of route (with trailing backslash)
+ 2. Considering handlers when resolving routes and search for matching http method handler
+* Echo minimal Go version is now 1.13.
+
+**Fixes**
+
+* When url ends with slash first param route is the match [#1804](https://github.com/labstack/echo/pull/1812)
+* Router should check if node is suitable as matching route by path+method and if not then continue search in tree [#1808](https://github.com/labstack/echo/issues/1808)
+* Fix timeout middleware not writing response correctly when handler panics [#1864](https://github.com/labstack/echo/pull/1864)
+* Fix binder not working with embedded pointer structs [#1861](https://github.com/labstack/echo/pull/1861)
+* Add Go 1.16 to CI and drop 1.12 specific code [#1850](https://github.com/labstack/echo/pull/1850)
+
+**Enhancements**
+
+* Make KeyFunc public in JWT middleware [#1756](https://github.com/labstack/echo/pull/1756)
+* Add support for optional filesystem to the static middleware [#1797](https://github.com/labstack/echo/pull/1797)
+* Add a custom error handler to key-auth middleware [#1847](https://github.com/labstack/echo/pull/1847)
+* Allow JWT token to be looked up from multiple sources [#1845](https://github.com/labstack/echo/pull/1845)
+
+## v4.2.2 - 2021-04-07
+
+**Fixes**
+
+* Allow proxy middleware to use query part in rewrite (#1802)
+* Fix timeout middleware not sending status code when handler returns an error (#1805)
+* Fix Bind() when target is array/slice and path/query params complains bind target not being struct (#1835)
+* Fix panic in redirect middleware on short host name (#1813)
+* Fix timeout middleware docs (#1836)
+
+## v4.2.1 - 2021-03-08
+
+**Important notes**
+
+Due to a datarace the config parameters for the newly added timeout middleware required a change.
+See the [docs](https://echo.labstack.com/middleware/timeout).
+A performance regression has been fixed, even bringing better performance than before for some routing scenarios.
+
+**Fixes**
+
+* Fix performance regression caused by path escaping (#1777, #1798, #1799, aldas)
+* Avoid context canceled errors (#1789, clwluvw)
+* Improve router to use on stack backtracking (#1791, aldas, stffabi)
+* Fix panic in timeout middleware not being not recovered and cause application crash (#1794, aldas)
+* Fix Echo.Serve() not serving on HTTP port correctly when TLSListener is used (#1785, #1793, aldas)
+* Apply go fmt (#1788, Le0tk0k)
+* Uses strings.Equalfold (#1790, rkilingr)
+* Improve code quality (#1792, withshubh)
+
+This release was made possible by our **contributors**:
+aldas, clwluvw, lammel, Le0tk0k, maciej-jezierski, rkilingr, stffabi, withshubh
+
+## v4.2.0 - 2021-02-11
+
+**Important notes**
+
+The behaviour for binding data has been reworked for compatibility with echo before v4.1.11 by
+enforcing `explicit tagging` for processing parameters. This **may break** your code if you
+expect combined handling of query/path/form params.
+Please see the updated documentation for [request](https://echo.labstack.com/guide/request) and [binding](https://echo.labstack.com/guide/request)
+
+The handling for rewrite rules has been slightly adjusted to expand `*` to a non-greedy `(.*?)` capture group. This is only relevant if multiple asterisks are used in your rules.
+Please see [rewrite](https://echo.labstack.com/middleware/rewrite) and [proxy](https://echo.labstack.com/middleware/proxy) for details.
+
+**Security**
+
+* Fix directory traversal vulnerability for Windows (#1718, little-cui)
+* Fix open redirect vulnerability with trailing slash (#1771,#1775 aldas,GeoffreyFrogeye)
+
+**Enhancements**
+
+* Add Echo#ListenerNetwork as configuration (#1667, pafuent)
+* Add ability to change the status code using response beforeFuncs (#1706, RashadAnsari)
+* Echo server startup to allow data race free access to listener address
+* Binder: Restore pre v4.1.11 behaviour for c.Bind() to use query params only for GET or DELETE methods (#1727, aldas)
+* Binder: Add separate methods to bind only query params, path params or request body (#1681, aldas)
+* Binder: New fluent binder for query/path/form parameter binding (#1717, #1736, aldas)
+* Router: Performance improvements for missed routes (#1689, pafuent)
+* Router: Improve performance for Real-IP detection using IndexByte instead of Split (#1640, imxyb)
+* Middleware: Support real regex rules for rewrite and proxy middleware (#1767)
+* Middleware: New rate limiting middleware (#1724, iambenkay)
+* Middleware: New timeout middleware implementation for go1.13+ (#1743, )
+* Middleware: Allow regex pattern for CORS middleware (#1623, KlotzAndrew)
+* Middleware: Add IgnoreBase parameter to static middleware (#1701, lnenad, iambenkay)
+* Middleware: Add an optional custom function to CORS middleware to validate origin (#1651, curvegrid)
+* Middleware: Support form fields in JWT middleware (#1704, rkfg)
+* Middleware: Use sync.Pool for (de)compress middleware to improve performance (#1699, #1672, pafuent)
+* Middleware: Add decompress middleware to support gzip compressed requests (#1687, arun0009)
+* Middleware: Add ErrJWTInvalid for JWT middleware (#1627, juanbelieni)
+* Middleware: Add SameSite mode for CSRF cookies to support iframes (#1524, pr0head)
+
+**Fixes**
+
+* Fix handling of special trailing slash case for partial prefix (#1741, stffabi)
+* Fix handling of static routes with trailing slash (#1747)
+* Fix Static files route not working (#1671, pwli0755, lammel)
+* Fix use of caret(^) in regex for rewrite middleware (#1588, chotow)
+* Fix Echo#Reverse for Any type routes (#1695, pafuent)
+* Fix Router#Find panic with infinite loop (#1661, pafuent)
+* Fix Router#Find panic fails on Param paths (#1659, pafuent)
+* Fix DefaultHTTPErrorHandler with Debug=true (#1477, lammel)
+* Fix incorrect CORS headers (#1669, ulasakdeniz)
+* Fix proxy middleware rewritePath to use url with updated tests (#1630, arun0009)
+* Fix rewritePath for proxy middleware to use escaped path in (#1628, arun0009)
+* Remove unless defer (#1656, imxyb)
+
+**General**
+
+* New maintainers for Echo: Roland Lammel (@lammel) and Pablo Andres Fuente (@pafuent)
+* Add GitHub action to compare benchmarks (#1702, pafuent)
+* Binding query/path params and form fields to struct only works for explicit tags (#1729,#1734, aldas)
+* Add support for Go 1.15 in CI (#1683, asahasrabuddhe)
+* Add test for request id to remain unchanged if provided (#1719, iambenkay)
+* Refactor echo instance listener access and startup to speed up testing (#1735, aldas)
+* Refactor and improve various tests for binding and routing
+* Run test workflow only for relevant changes (#1637, #1636, pofl)
+* Update .travis.yml (#1662, santosh653)
+* Update README.md with an recents framework benchmark (#1679, pafuent)
+
+This release was made possible by **over 100 commits** from more than **20 contributors**:
+asahasrabuddhe, aldas, AndrewKlotz, arun0009, chotow, curvegrid, iambenkay, imxyb,
+juanbelieni, lammel, little-cui, lnenad, pafuent, pofl, pr0head, pwli, RashadAnsari,
+rkfg, santosh653, segfiner, stffabi, ulasakdeniz
diff --git a/CLAUDE.md b/CLAUDE.md
new file mode 100644
index 000000000..decbf0792
--- /dev/null
+++ b/CLAUDE.md
@@ -0,0 +1,99 @@
+# CLAUDE.md
+
+This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
+
+## About This Project
+
+Echo is a high performance, minimalist Go web framework. This is the main repository for Echo v4, which is available as a Go module at `github.com/labstack/echo/v4`.
+
+## Development Commands
+
+The project uses a Makefile for common development tasks:
+
+- `make check` - Run linting, vetting, and race condition tests (default target)
+- `make init` - Install required linting tools (golint, staticcheck)
+- `make lint` - Run staticcheck and golint
+- `make vet` - Run go vet
+- `make test` - Run short tests
+- `make race` - Run tests with race detector
+- `make benchmark` - Run benchmarks
+
+Example commands for development:
+```bash
+# Setup development environment
+make init
+
+# Run all checks (lint, vet, race)
+make check
+
+# Run specific tests
+go test ./middleware/...
+go test -race ./...
+
+# Run benchmarks
+make benchmark
+```
+
+## Code Architecture
+
+### Core Components
+
+**Echo Instance (`echo.go`)**
+- The `Echo` struct is the top-level framework instance
+- Contains router, middleware stacks, and server configuration
+- Not goroutine-safe for mutations after server start
+
+**Context (`context.go`)**
+- The `Context` interface represents HTTP request/response context
+- Provides methods for request/response handling, path parameters, data binding
+- Core abstraction for request processing
+
+**Router (`router.go`)**
+- Radix tree-based HTTP router with smart route prioritization
+- Supports static routes, parameterized routes (`/users/:id`), and wildcard routes (`/static/*`)
+- Each HTTP method has its own routing tree
+
+**Middleware (`middleware/`)**
+- Extensive middleware system with 50+ built-in middlewares
+- Middleware can be applied at Echo, Group, or individual route level
+- Common middleware: Logger, Recover, CORS, JWT, Rate Limiting, etc.
+
+### Key Patterns
+
+**Middleware Chain**
+- Pre-middleware runs before routing
+- Regular middleware runs after routing but before handlers
+- Middleware functions have signature `func(next echo.HandlerFunc) echo.HandlerFunc`
+
+**Route Groups**
+- Routes can be grouped with common prefixes and middleware
+- Groups support nested sub-groups
+- Defined in `group.go`
+
+**Data Binding**
+- Automatic binding of request data (JSON, XML, form) to Go structs
+- Implemented in `binder.go` with support for custom binders
+
+**Error Handling**
+- Centralized error handling via `HTTPErrorHandler`
+- Automatic panic recovery with stack traces
+
+## File Organization
+
+- Root directory: Core Echo functionality (echo.go, context.go, router.go, etc.)
+- `middleware/`: All built-in middleware implementations
+- `_test/`: Test fixtures and utilities
+- `_fixture/`: Test data files
+
+## Code Style
+
+- Go code uses tabs for indentation (per .editorconfig)
+- Follows standard Go conventions and formatting
+- Uses gofmt, golint, and staticcheck for code quality
+
+## Testing
+
+- Standard Go testing with `testing` package
+- Tests include unit tests, integration tests, and benchmarks
+- Race condition testing is required (`make race`)
+- Test files follow `*_test.go` naming convention
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
index a14f926e5..2f18411bd 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,6 +1,6 @@
The MIT License (MIT)
-Copyright (c) 2015 LabStack
+Copyright (c) 2022 LabStack
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
@@ -19,4 +19,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
-
diff --git a/Makefile b/Makefile
new file mode 100644
index 000000000..bd075bbae
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,32 @@
+PKG := "github.com/labstack/echo"
+PKG_LIST := $(shell go list ${PKG}/...)
+
+.DEFAULT_GOAL := check
+check: lint vet race ## Check project
+
+init:
+ @go install golang.org/x/lint/golint@latest
+ @go install honnef.co/go/tools/cmd/staticcheck@latest
+
+lint: ## Lint the files
+ @staticcheck ${PKG_LIST}
+ @golint -set_exit_status ${PKG_LIST}
+
+vet: ## Vet the files
+ @go vet ${PKG_LIST}
+
+test: ## Run tests
+ @go test -short ${PKG_LIST}
+
+race: ## Run tests with data race detector
+ @go test -race ${PKG_LIST}
+
+benchmark: ## Run benchmarks
+ @go test -run="-" -benchmem -bench=".*" ${PKG_LIST}
+
+help: ## Display this help screen
+ @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
+
+goversion ?= "1.25"
+test_version: ## Run tests inside Docker with given version (defaults to 1.25 oldest supported). Example: make test_version goversion=1.25
+ @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check"
diff --git a/README.md b/README.md
index 7f8c21744..ca6dfbf5d 100644
--- a/README.md
+++ b/README.md
@@ -1,124 +1,155 @@
-# [Echo](http://echo.labstack.com) [](http://godoc.org/github.com/labstack/echo) [](https://travis-ci.org/labstack/echo) [](https://coveralls.io/r/labstack/echo) [](https://gitter.im/labstack/echo)
-
-A fast and unfancy micro web framework for Golang.
-
-## Features
-
-- Fast HTTP router which smartly prioritize routes.
-- Extensible middleware, supports:
- - `echo.MiddlewareFunc`
- - `func(echo.HandlerFunc) echo.HandlerFunc`
- - `echo.HandlerFunc`
- - `func(*echo.Context) error`
- - `func(http.Handler) http.Handler`
- - `http.Handler`
- - `http.HandlerFunc`
- - `func(http.ResponseWriter, *http.Request)`
-- Extensible handler, supports:
- - `echo.HandlerFunc`
- - `func(*echo.Context) error`
- - `http.Handler`
- - `http.HandlerFunc`
- - `func(http.ResponseWriter, *http.Request)`
-- Sub-router/Groups
-- Handy functions to send variety of HTTP response:
- - HTML
- - HTML via templates
- - String
- - JSON
- - XML
- - NoContent
- - Redirect
- - Error
-- Build-in support for:
- - Favicon
- - Index file
- - Static files
- - WebSocket
-- Centralized HTTP error handling.
-- Customizable HTTP request binding function.
-- Customizable HTTP response rendering function, allowing you to use any HTML template engine.
-
-## Performance
-
-Based on [vishr/go-http-routing-benchmark] (https://github.com/vishr/go-http-routing-benchmark), June 5, 2015.
-
-##### [GitHub API](http://developer.github.com/v3)
-
-> Echo: 38662 ns/op, 0 B/op, 0 allocs/op
-
-
+[](https://sourcegraph.com/github.com/labstack/echo?badge)
+[](https://pkg.go.dev/github.com/labstack/echo/v4)
+[](https://goreportcard.com/report/github.com/labstack/echo)
+[](https://github.com/labstack/echo/actions)
+[](https://codecov.io/gh/labstack/echo)
+[](https://github.com/labstack/echo/discussions)
+[](https://twitter.com/labstack)
+[](https://raw.githubusercontent.com/labstack/echo/master/LICENSE)
-```
-BenchmarkAce_GithubAll 20000 93675 ns/op 13792 B/op 167 allocs/op
-BenchmarkBear_GithubAll 10000 264194 ns/op 79952 B/op 943 allocs/op
-BenchmarkBeego_GithubAll 2000 1109160 ns/op 146272 B/op 2092 allocs/op
-BenchmarkBone_GithubAll 1000 2063973 ns/op 648016 B/op 8119 allocs/op
-BenchmarkDenco_GithubAll 20000 83114 ns/op 20224 B/op 167 allocs/op
-BenchmarkEcho_GithubAll 30000 38662 ns/op 0 B/op 0 allocs/op
-BenchmarkGin_GithubAll 30000 43467 ns/op 0 B/op 0 allocs/op
-BenchmarkGocraftWeb_GithubAll 5000 386829 ns/op 133280 B/op 1889 allocs/op
-BenchmarkGoji_GithubAll 3000 561131 ns/op 56113 B/op 334 allocs/op
-BenchmarkGoJsonRest_GithubAll 3000 490789 ns/op 135995 B/op 2940 allocs/op
-BenchmarkGoRestful_GithubAll 100 15569513 ns/op 797239 B/op 7725 allocs/op
-BenchmarkGorillaMux_GithubAll 200 7431130 ns/op 153137 B/op 1791 allocs/op
-BenchmarkHttpRouter_GithubAll 30000 51192 ns/op 13792 B/op 167 allocs/op
-BenchmarkHttpTreeMux_GithubAll 10000 138164 ns/op 56112 B/op 334 allocs/op
-BenchmarkKocha_GithubAll 10000 139625 ns/op 23304 B/op 843 allocs/op
-BenchmarkMacaron_GithubAll 2000 709932 ns/op 224960 B/op 2315 allocs/op
-BenchmarkMartini_GithubAll 100 10261331 ns/op 237953 B/op 2686 allocs/op
-BenchmarkPat_GithubAll 500 3989686 ns/op 1504104 B/op 32222 allocs/op
-BenchmarkPossum_GithubAll 5000 259165 ns/op 97441 B/op 812 allocs/op
-BenchmarkR2router_GithubAll 10000 240345 ns/op 77328 B/op 1182 allocs/op
-BenchmarkRevel_GithubAll 2000 1203336 ns/op 345554 B/op 5918 allocs/op
-BenchmarkRivet_GithubAll 10000 247213 ns/op 84272 B/op 1079 allocs/op
-BenchmarkTango_GithubAll 5000 379960 ns/op 87081 B/op 2470 allocs/op
-BenchmarkTigerTonic_GithubAll 2000 931401 ns/op 241089 B/op 6052 allocs/op
-BenchmarkTraffic_GithubAll 200 7292170 ns/op 2664770 B/op 22390 allocs/op
-BenchmarkVulcan_GithubAll 5000 271682 ns/op 19894 B/op 609 allocs/op
-BenchmarkZeus_GithubAll 2000 748827 ns/op 300688 B/op 2648 allocs/op
-```
+## Echo
+
+High performance, extensible, minimalist Go web framework.
+
+* [Official website](https://echo.labstack.com)
+* [Quick start](https://echo.labstack.com/docs/quick-start)
+* [Middlewares](https://echo.labstack.com/docs/category/middleware)
+
+Help and questions: [Github Discussions](https://github.com/labstack/echo/discussions)
+
+
+### Feature Overview
+
+- Optimized HTTP router which smartly prioritize routes
+- Build robust and scalable RESTful APIs
+- Group APIs
+- Extensible middleware framework
+- Define middleware at root, group or route level
+- Data binding for JSON, XML and form payload
+- Handy functions to send variety of HTTP responses
+- Centralized HTTP error handling
+- Template rendering with any template engine
+- Define your format for the logger
+- Highly customizable
+- Automatic TLS via Let’s Encrypt
+- HTTP/2 support
+
+## Sponsors
+
+
+
+
+Click [here](https://github.com/sponsors/labstack) for more information on sponsorship.
+
+## [Guide](https://echo.labstack.com/guide)
+
+### Supported Echo versions
+
+- Latest major version of Echo is `v5` as of 2026-01-18.
+ - Until 2026-03-31, any critical issues requiring breaking API changes will be addressed, even if this violates semantic versioning.
+ - See [API_CHANGES_V5.md](./API_CHANGES_V5.md) for public API changes between `v4` and `v5`, notes on upgrading.
+ - If you are using Echo in a production environment, it is recommended to wait until after 2026-03-31 before upgrading.
+- Echo `v4` is supported with **security*** updates and **bug** fixes until **2026-12-31**
-## Installation
+
+### Installation
```sh
-$ go get github.com/labstack/echo
+// go get github.com/labstack/echo/{version}
+go get github.com/labstack/echo/v5
```
+Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions.
+
+
+### Example
+
+```go
+package main
-## [Recipes](https://github.com/labstack/echo/tree/master/recipes)
+import (
+ "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v5/middleware"
+ "log/slog"
+ "net/http"
+)
-- [File Upload](http://echo.labstack.com/recipes/file-upload)
-- [Streaming File Upload](http://echo.labstack.com/recipes/streaming-file-upload)
-- [Streaming Response](http://echo.labstack.com/recipes/streaming-response)
-- [WebSocket](http://echo.labstack.com/recipes/websocket)
-- [Subdomains](http://echo.labstack.com/recipes/subdomains)
-- [JWT Authentication](http://echo.labstack.com/recipes/jwt-authentication)
-- [Graceful Shutdown](http://echo.labstack.com/recipes/graceful-shutdown)
+func main() {
+ // Echo instance
+ e := echo.New()
-##[Guide](http://echo.labstack.com/guide)
+ // Middleware
+ e.Use(middleware.RequestLogger()) // use the RequestLogger middleware with slog logger
+ e.Use(middleware.Recover()) // recover panics as errors for proper error handling
-## Echo System
+ // Routes
+ e.GET("/", hello)
+
+ // Start server
+ if err := e.Start(":8080"); err != nil {
+ slog.Error("failed to start server", "error", err)
+ }
+}
+
+// Handler
+func hello(c *echo.Context) error {
+ return c.String(http.StatusOK, "Hello, World!")
+}
+```
-Community created packages for Echo
+# Official middleware repositories
-- [echo-logrus](https://github.com/deoxxa/echo-logrus)
-- [go_middleware](https://github.com/rightscale/go_middleware)
-- [permissions2](https://github.com/xyproto/permissions2)
-- [permissionbolt](https://github.com/xyproto/permissionbolt)
-- [echo-middleware](https://github.com/syntaqx/echo-middleware)
+Following list of middleware is maintained by Echo team.
+
+| Repository | Description |
+|------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware |
+| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](https://github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares |
+
+# Third-party middleware repositories
+
+Be careful when adding 3rd party middleware. Echo teams does not have time or manpower to guarantee safety and quality
+of middlewares in this list.
+
+| Repository | Description |
+|------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| [oapi-codegen/oapi-codegen](https://github.com/oapi-codegen/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator |
+| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. |
+| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. |
+| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. |
+| [github.com/samber/slog-echo](https://github.com/samber/slog-echo) | Go [slog](https://pkg.go.dev/golang.org/x/exp/slog) logging library wrapper for Echo logger interface. |
+| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. |
+| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. |
+| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code |
+
+Please send a PR to add your own library here.
## Contribute
**Use issues for everything**
-- Report problems
-- Discuss before sending pull request
-- Suggest new features
-- Improve/fix documentation
+- For a small change, just send a PR.
+- For bigger changes open an issue for discussion before sending a PR.
+- PR should have:
+ - Test case
+ - Documentation
+ - Example (If it makes sense)
+- You can also contribute by:
+ - Reporting issues
+ - Suggesting new features or enhancements
+ - Improve/fix documentation
## Credits
-- [Vishal Rana](https://github.com/vishr) - Author
-- [Nitin Rana](https://github.com/nr17) - Consultant
+
+- [Vishal Rana](https://github.com/vishr) (Author)
+- [Nitin Rana](https://github.com/nr17) (Consultant)
+- [Roland Lammel](https://github.com/lammel) (Maintainer)
+- [Martti T.](https://github.com/aldas) (Maintainer)
+- [Pablo Andres Fuente](https://github.com/pafuent) (Maintainer)
- [Contributors](https://github.com/labstack/echo/graphs/contributors)
## License
diff --git a/SECURITY.md b/SECURITY.md
new file mode 100644
index 000000000..efb618697
--- /dev/null
+++ b/SECURITY.md
@@ -0,0 +1,15 @@
+# Security Policy
+
+## Supported Versions
+
+| Version | Supported |
+|-----------|-------------------------------------|
+| 5.x.x | :white_check_mark: |
+| >= 4.15.x | :white_check_mark: until 2026.12.31 |
+| < 4.15 | :x: |
+
+## Reporting a Vulnerability
+
+https://github.com/labstack/echo/security/advisories/new
+
+or look for maintainers email(s) in commits and email them.
diff --git a/_fixture/_fixture/README.md b/_fixture/_fixture/README.md
new file mode 100644
index 000000000..21a785851
--- /dev/null
+++ b/_fixture/_fixture/README.md
@@ -0,0 +1 @@
+This directory is used for the static middleware test
\ No newline at end of file
diff --git a/_fixture/certs/README.md b/_fixture/certs/README.md
new file mode 100644
index 000000000..e27d4b139
--- /dev/null
+++ b/_fixture/certs/README.md
@@ -0,0 +1,13 @@
+To generate a valid certificate and private key use the following command:
+
+```bash
+# In OpenSSL ≥ 1.1.1
+openssl req -x509 -newkey rsa:4096 -sha256 -days 9999 -nodes \
+ -keyout key.pem -out cert.pem -subj "/CN=localhost" \
+ -addext "subjectAltName=DNS:localhost,IP:127.0.0.1,IP:::1"
+```
+
+To check a certificate use the following command:
+```bash
+openssl x509 -in cert.pem -text
+```
diff --git a/_fixture/certs/cert.pem b/_fixture/certs/cert.pem
new file mode 100644
index 000000000..d88cf3fec
--- /dev/null
+++ b/_fixture/certs/cert.pem
@@ -0,0 +1,30 @@
+-----BEGIN CERTIFICATE-----
+MIIFODCCAyCgAwIBAgIUaTvDluaMf+VJgYHQ0HFTS3yuCHYwDQYJKoZIhvcNAQEL
+BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIxMDIyNzIxMzQ0MVoXDTQ4MDcx
+NDIxMzQ0MVowFDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEF
+AAOCAg8AMIICCgKCAgEAnqyyAAnWFH2TH7Epj5yfZxYrBvizydZe1Wo/1WpGR2IK
+QT+qIul5sEKX/ERqEOXsawSrL3fw9cuSM8Z2vD/57ZZdoSR7XIdVaMDEQenJ968a
+HObu4D27uBQwIwrM5ELgnd+fC4gis64nIu+2GSfHumZXi7lLW7DbNm8oWkMqI6tY
+2s2wx2hwGYNVJrwSn4WGnkzhQ5U5mkcsLELMx7GR0Qnv6P7sNGZVeqMU7awkcSpR
+crKR1OUP7XCJkEq83WLHSx50+QZv7LiyDmGnujHevRbdSHlcFfHZtaufYat+qICe
+S3XADwRQe/0VSsmja6u3DAHy7VmL8PNisAdkopQZrhiI9OvGrpGZffs9zn+s/jeX
+N1bqVDihCMiEjqXMlHx2oj3AXrZTFxb7y7Ap9C07nf70lpxQWW9SjMYRF98JBiHF
+eJbQkNVkmz6T8ielQbX0l46F2SGK98oyFCGNIAZBUdj5CcS1E6w/lk4t58/em0k7
+3wFC5qg0g0wfIbNSmxljBNxnaBYUqyaaAJJhpaEoOebm4RYV58hQ0FbMfpnLnSh4
+dYStsk6i1PumWoa7D45DTtxF3kH7TB3YOB5aWaNGAPQC1m4Qcd23YB5Rd/ABirSp
+ux6/cFGosjSfJ/G+G0RhNUpmcbDJvFSOhD2WCuieVhCTAzp+VPIA9bSqD+InlT0C
+AwEAAaOBgTB/MB0GA1UdDgQWBBQZyM//SvzYKokQZI/0MVGb6PkH+zAfBgNVHSME
+GDAWgBQZyM//SvzYKokQZI/0MVGb6PkH+zAPBgNVHRMBAf8EBTADAQH/MCwGA1Ud
+EQQlMCOCCWxvY2FsaG9zdIcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG
+9w0BAQsFAAOCAgEAKGAJQmQ/KLw8iMb5QsyxxAonVjJ1eDAhNM3GWdHpM0/GFamO
+vVtATLQQldwDiZJvrsCQPEc8ctZ2Utvg/StLQ3+rZpsvt0+gcUlLJK61qguwYqb2
++T7VK5s7V/OyI/tsuboOW50Pka9vQHV+Z0aM06Yu+HNDAq/UTpEOb/3MQvZd6Ooy
+PTpZtFb/+5jIQa1dIsfFWmpBxF0+wUd9GEkX3j7nekwoZfJ8Ze4GWYERZbOFpDAQ
+rIHdthH5VJztnpQJmaKqzgIOF+Rurwlp5ecSC33xNNjDaYtuf/fiWnoKGhHVSBhT
+61+0yxn3rTgh/Dsm95xY00rSX6lmcvI+kRNTUc8GGPz0ajBH6xyY7bNhfMjmnSW/
+C/XTEDbTAhT7ndWC5vvzp7ZU0TvN+WY6A0f2kxSnnrEk6QRUvRtKkjAkmAFz8exi
+ttBBW0I3E5HNIC5CYRimq/9z+3clM/P1KbNblwuC65bL+PZ+nzFnn5hFaK9eLPol
+OwZQXv7IvAw8GfgLTrEUT7eBCQwe1IqesA7NTxF1BVwmNUb2XamvQZ7ly67QybRw
+0uJq80XjpVjBWYTTQy1dsnC2OTKdqGsV9TVIDR+UGfIG9cxL70pEbiSH2AX+IDCy
+i3kNIvpXgBliAyOjW6Hj1fv6dNfAat/hqEfnquWkfvcs3HNrG/InwpwNAUs=
+-----END CERTIFICATE-----
diff --git a/_fixture/certs/key.pem b/_fixture/certs/key.pem
new file mode 100644
index 000000000..0276c224e
--- /dev/null
+++ b/_fixture/certs/key.pem
@@ -0,0 +1,52 @@
+-----BEGIN PRIVATE KEY-----
+MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQCerLIACdYUfZMf
+sSmPnJ9nFisG+LPJ1l7Vaj/VakZHYgpBP6oi6XmwQpf8RGoQ5exrBKsvd/D1y5Iz
+xna8P/ntll2hJHtch1VowMRB6cn3rxoc5u7gPbu4FDAjCszkQuCd358LiCKzrici
+77YZJ8e6ZleLuUtbsNs2byhaQyojq1jazbDHaHAZg1UmvBKfhYaeTOFDlTmaRyws
+QszHsZHRCe/o/uw0ZlV6oxTtrCRxKlFyspHU5Q/tcImQSrzdYsdLHnT5Bm/suLIO
+Yae6Md69Ft1IeVwV8dm1q59hq36ogJ5LdcAPBFB7/RVKyaNrq7cMAfLtWYvw82Kw
+B2SilBmuGIj068aukZl9+z3Of6z+N5c3VupUOKEIyISOpcyUfHaiPcBetlMXFvvL
+sCn0LTud/vSWnFBZb1KMxhEX3wkGIcV4ltCQ1WSbPpPyJ6VBtfSXjoXZIYr3yjIU
+IY0gBkFR2PkJxLUTrD+WTi3nz96bSTvfAULmqDSDTB8hs1KbGWME3GdoFhSrJpoA
+kmGloSg55ubhFhXnyFDQVsx+mcudKHh1hK2yTqLU+6ZahrsPjkNO3EXeQftMHdg4
+HlpZo0YA9ALWbhBx3bdgHlF38AGKtKm7Hr9wUaiyNJ8n8b4bRGE1SmZxsMm8VI6E
+PZYK6J5WEJMDOn5U8gD1tKoP4ieVPQIDAQABAoICAEHF2CsH6MOpofi7GT08cR7s
+I33KTcxWngzc9ATk/qjMTO/rEf1Sxmx3zkR1n3nNtQhPcR5GG43nin0HwWQbKOCB
+OeJ4GuKp/o9jiHbCEEQpQyvD1jUBofSV+bYs3e2ogy8t6OGA1tGgWPy0XMlkoff0
+QEnczw3864FO5m0z9h2/Ax//r02ZTw5kUEG0KAwT709jEuVO0AfRhM/8CKKmSola
+EyaDtSmrWbdyLlSuzJRUNFrVBno3UTjdM0iqkks6jN3ojBhFwNNhY/1uIXafAXNk
+LOnD1JYMIHCb6X809VWnqvYgozIWWb5rlA3iM2mITmId1LLqMYX5fWj2R5LUzSek
+H+XG+F9FIouTaL1ACoXr0zyeY5N5YJdyXYa1tThdW+axX9ZrnPgeiQrmxzKPIyb7
+LLlVtNBQUg/t5tX80KyYjkNUu4j3oq/uBYPi0m//ovwMyi9bSbbyPT+cDXuXX5Bc
+oY7wyn3evXX0c1R7vdJLZLkLu+ctVex/9hvMjeW/mMasDjLnqY7pF3Skct1SX5N2
+U8YVU9bGvFpLEwM9lmi/T7bcv+zbmGPlfTsZiFrCsixPLn7sX7y5M4L8au8O0jh0
+nHm/8rWVg1Qw0Hobg3tA8FjeMa8Sr2fYmkNLVKFzhuJLxknTJLaUbX5CymNqWP4H
+OctvfSY0nSZ1eQpBkQaJAoIBAQDTb/NhYCfaJBLXHVMy/VYd7kWGZ+I87artcE/l
+8u0pJ8XOP4kp0otFIumpHUFodysAeP6HrI79MuJB40fy91HzWZC+NrPufFFFuZ0z
+Ld1o3Y5nAeoZmMlf1F12Oe3OQZy7nm9eNNkfeoVtKqDv4FhAqk+aoMor86HscKsR
+C6HlZFdGc7kX0ylrQAXPq9KLhcvUU9oAUpbqTbhYK83IebRJgFDG45HkVo9SUHpF
+dmCFSb91eZpRGpdfNLCuLiSu52TebayaUCnceeAt8SyeiChJ/TwWmRRDJS0QUv6h
+s3Wdp+cx9ANoujA4XzAs8Fld5IZ4bcG5jjwD62/tJyWrCC5DAoIBAQDAHfHjrYCK
+GHBrMj+MA7cK7fCJUn/iJLSLGgo2ANYF5oq9gaCwHCtKIyB9DN/KiY0JpJ6PWg+Q
+9Difq23YXiJjNEBS5EFTu9UwWAr1RhSAegrfHxm0sDbcAx31NtDYvBsADCWQYmzc
+KPfBshf5K4g/VCIj2VzC2CE6kNtdhqLU6AV2Pi1Tl1S82xWoAjHy91tDmlFQNWCj
+B2ZnZ7tY9zuwDfeBBOVCPHICgl5Q4PrY1KEWEXiNxgbtkNmOPAsY9WSqgOsP9pWK
+J924gdCCvovINzZtgRisxKth6Fkhra+VCsheg9SWvgR09Deo6CCoSwYxOSb0cjh2
+oyX5Rb1kJ7Z/AoIBAQCX2iNVoBV/GcFeNXV3fXLH9ESCj0FwuNC1zp/TanDhyerK
+gd8k5k2Xzcc66gP73vpHUJ6dGlVni4/r+ivGV9HHkF/f/LGlaiuEhBZel2YY1mZb
+nIhg8dZOuNqW+mvMYlsKdHNPmW0GqpwBF0iWfu1jI+4gA7Kvdj6o7RIvH8eaVEJK
+GvqoHcP1fvmteJ2yDtmhGMfMy4QPqtnmmS8l+CJ/V2SsMuyorXIpkBsAoFAZ6ilT
+WY53CT4F5nWt4v39j7pl9SatfT1TV0SmOjvtb6Rf3zu0jyR6RMzkmHa/839ZRylI
+OxPntzDCi7qxy7yjLmlVPJ6RgZGgzwqHrEHlX+65AoIBAQCEzu6d3x5B2N02LZli
+eFr8MjqbI64GLiulEY5HgNJzZ8k3cjocJI0Ehj36VIEMaYRXSzbVkIO8SCgwsPiR
+n5mUDNX+t441jV62Odbxcc3Qdw226rABieOSupDmKEu92GOt57e8FV5939BOVYhf
+FunsJYQoViXbCEAIVYVgJSfBmNfVwuvgonfQyn8xErtm4/pyRGa71PqGGSKAj2Qi
+/16CuVUFGtZFsLV76JW8wZqHdI4bTF6TW3cEmaLbwcRGL7W0bMSS13rO8/pBh3QW
+PhUxhoGYt6rQHHEBkPa04nXDyZ10QRwgTSGVnBIyMK4KyTpxorm8OI2x7dzdcomX
+iCCPAoIBAETwfr2JKPb/AzrKhhbZgU+sLVn3WH/nb68VheNEmGOzsqXaSHCR2NOq
+/ow7bawjc8yUIhBRzokR4F/7jGolOmfdq0MYFb6/YokssKfv1ugxBhmvOxpZ6F6E
+cERJ8Ex/ffQU053gLR/0ammddVuS1GR5I/jEdP0lJVh0xapoZNUlT5dWYCgo20hY
+ZAmKpU+veyUn+5Li0pmm959vnLK5LJzEA5mpz3w1QPPtVwQs05dwmEV3CRAcCeeh
+8sXp49WNCSW4I3BxuTZzRV845SGIFhZwgVV42PTp2LPKl2p6E7Bk8xpUCCvBpALp
+QmA5yIMx+u2Jpr7fUsXEXEPTEhvjff0=
+-----END PRIVATE KEY-----
diff --git a/_fixture/dist/private.txt b/_fixture/dist/private.txt
new file mode 100644
index 000000000..0f9d2435b
--- /dev/null
+++ b/_fixture/dist/private.txt
@@ -0,0 +1 @@
+private file
diff --git a/_fixture/dist/public/assets/readme.md b/_fixture/dist/public/assets/readme.md
new file mode 100644
index 000000000..50590f554
--- /dev/null
+++ b/_fixture/dist/public/assets/readme.md
@@ -0,0 +1 @@
+readme in assets
diff --git a/_fixture/dist/public/assets/subfolder/subfolder.md b/_fixture/dist/public/assets/subfolder/subfolder.md
new file mode 100644
index 000000000..74c928b2f
--- /dev/null
+++ b/_fixture/dist/public/assets/subfolder/subfolder.md
@@ -0,0 +1 @@
+file inside subfolder
diff --git a/_fixture/dist/public/index.html b/_fixture/dist/public/index.html
new file mode 100644
index 000000000..df6d9015a
--- /dev/null
+++ b/_fixture/dist/public/index.html
@@ -0,0 +1 @@
+Hello from index
diff --git a/_fixture/dist/public/test.txt b/_fixture/dist/public/test.txt
new file mode 100644
index 000000000..dd937160d
--- /dev/null
+++ b/_fixture/dist/public/test.txt
@@ -0,0 +1 @@
+test.txt contents
diff --git a/examples/website/public/favicon.ico b/_fixture/favicon.ico
similarity index 100%
rename from examples/website/public/favicon.ico
rename to _fixture/favicon.ico
diff --git a/_fixture/folder/index.html b/_fixture/folder/index.html
new file mode 100644
index 000000000..9b07a7588
--- /dev/null
+++ b/_fixture/folder/index.html
@@ -0,0 +1,9 @@
+
+
+
+
+ Echo
+
+
+
+
diff --git a/_fixture/images/walle.png b/_fixture/images/walle.png
new file mode 100644
index 000000000..493985d4a
Binary files /dev/null and b/_fixture/images/walle.png differ
diff --git a/_fixture/index.html b/_fixture/index.html
new file mode 100644
index 000000000..9b07a7588
--- /dev/null
+++ b/_fixture/index.html
@@ -0,0 +1,9 @@
+
+
+
+
+ Echo
+
+
+
+
diff --git a/bind.go b/bind.go
new file mode 100644
index 000000000..050e8973b
--- /dev/null
+++ b/bind.go
@@ -0,0 +1,472 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "encoding"
+ "encoding/xml"
+ "errors"
+ "mime/multipart"
+ "net/http"
+ "reflect"
+ "strconv"
+ "strings"
+ "time"
+)
+
+// Binder is the interface that wraps the Bind method.
+type Binder interface {
+ Bind(c *Context, target any) error
+}
+
+// DefaultBinder is the default implementation of the Binder interface.
+type DefaultBinder struct{}
+
+// BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
+// Types that don't implement this, but do implement encoding.TextUnmarshaler
+// will use that interface instead.
+type BindUnmarshaler interface {
+ // UnmarshalParam decodes and assigns a value from an form or query param.
+ UnmarshalParam(param string) error
+}
+
+// bindMultipleUnmarshaler is used by binder to unmarshal multiple values from request at once to
+// type implementing this interface. For example request could have multiple query fields `?a=1&a=2&b=test` in that case
+// for `a` following slice `["1", "2"] will be passed to unmarshaller.
+type bindMultipleUnmarshaler interface {
+ UnmarshalParams(params []string) error
+}
+
+// BindPathValues binds path parameter values to bindable object
+func BindPathValues(c *Context, target any) error {
+ params := map[string][]string{}
+ for _, param := range c.PathValues() {
+ params[param.Name] = []string{param.Value}
+ }
+ if err := bindData(target, params, "param", nil); err != nil {
+ return ErrBadRequest.Wrap(err)
+ }
+ return nil
+}
+
+// BindQueryParams binds query params to bindable object
+func BindQueryParams(c *Context, target any) error {
+ if err := bindData(target, c.QueryParams(), "query", nil); err != nil {
+ return ErrBadRequest.Wrap(err)
+ }
+ return nil
+}
+
+// BindBody binds request body contents to bindable object
+// NB: then binding forms take note that this implementation uses standard library form parsing
+// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm
+// See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm
+// See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm
+func BindBody(c *Context, target any) (err error) {
+ req := c.Request()
+ if req.ContentLength == 0 {
+ return
+ }
+
+ // mediatype is found like `mime.ParseMediaType()` does it
+ base, _, _ := strings.Cut(req.Header.Get(HeaderContentType), ";")
+ mediatype := strings.TrimSpace(base)
+
+ switch mediatype {
+ case MIMEApplicationJSON:
+ if err = c.Echo().JSONSerializer.Deserialize(c, target); err != nil {
+ var hErr *HTTPError
+ if errors.As(err, &hErr) {
+ return err
+ }
+ return ErrBadRequest.Wrap(err)
+ }
+ case MIMEApplicationXML, MIMETextXML:
+ if err = xml.NewDecoder(req.Body).Decode(target); err != nil {
+ return ErrBadRequest.Wrap(err)
+ }
+ case MIMEApplicationForm:
+ params, err := c.FormValues()
+ if err != nil {
+ return ErrBadRequest.Wrap(err)
+ }
+ if err = bindData(target, params, "form", nil); err != nil {
+ return ErrBadRequest.Wrap(err)
+ }
+ case MIMEMultipartForm:
+ params, err := c.MultipartForm()
+ if err != nil {
+ return ErrBadRequest.Wrap(err)
+ }
+ if err = bindData(target, params.Value, "form", params.File); err != nil {
+ return ErrBadRequest.Wrap(err)
+ }
+ default:
+ return &HTTPError{Code: http.StatusUnsupportedMediaType}
+ }
+ return nil
+}
+
+// BindHeaders binds HTTP headers to a bindable object
+func BindHeaders(c *Context, target any) error {
+ if err := bindData(target, c.Request().Header, "header", nil); err != nil {
+ return ErrBadRequest.Wrap(err)
+ }
+ return nil
+}
+
+// Bind implements the `Binder#Bind` function.
+// Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous
+// step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathValues.
+func (b *DefaultBinder) Bind(c *Context, target any) error {
+ if err := BindPathValues(c, target); err != nil {
+ return err
+ }
+ // Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body.
+ // For example a request URL `&id=1&lang=en` with body `{"id":100,"lang":"de"}` would lead to precedence issues.
+ // The HTTP method check restores pre-v4.1.11 behavior to avoid these problems (see issue #1670)
+ method := c.Request().Method
+ if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead {
+ if err := BindQueryParams(c, target); err != nil {
+ return err
+ }
+ }
+ return BindBody(c, target)
+}
+
+// bindData will bind data ONLY fields in destination struct that have EXPLICIT tag
+func bindData(destination any, data map[string][]string, tag string, dataFiles map[string][]*multipart.FileHeader) error {
+ if destination == nil || (len(data) == 0 && len(dataFiles) == 0) {
+ return nil
+ }
+ hasFiles := len(dataFiles) > 0
+ typ := reflect.TypeOf(destination).Elem()
+ val := reflect.ValueOf(destination).Elem()
+
+ // Support binding to limited Map destinations:
+ // - map[string][]string,
+ // - map[string]string <-- (binds first value from data slice)
+ // - map[string]any
+ // You are better off binding to struct but there are user who want this map feature. Source of data for these cases are:
+ // params,query,header,form as these sources produce string values, most of the time slice of strings, actually.
+ if typ.Kind() == reflect.Map && typ.Key().Kind() == reflect.String {
+ k := typ.Elem().Kind()
+ isElemInterface := k == reflect.Interface
+ isElemString := k == reflect.String
+ isElemSliceOfStrings := k == reflect.Slice && typ.Elem().Elem().Kind() == reflect.String
+ if !(isElemSliceOfStrings || isElemString || isElemInterface) {
+ return nil
+ }
+ if val.IsNil() {
+ val.Set(reflect.MakeMap(typ))
+ }
+ for k, v := range data {
+ if isElemString {
+ val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0]))
+ } else if isElemInterface {
+ // To maintain backward compatibility, we always bind to the first string value
+ // and not the slice of strings when dealing with map[string]any{}
+ val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0]))
+ } else {
+ val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v))
+ }
+ }
+ return nil
+ }
+
+ // !struct
+ if typ.Kind() != reflect.Struct {
+ if tag == "param" || tag == "query" || tag == "header" {
+ // incompatible type, data is probably to be found in the body
+ return nil
+ }
+ return errors.New("binding element must be a struct")
+ }
+
+ for i := 0; i < typ.NumField(); i++ { // iterate over all destination fields
+ typeField := typ.Field(i)
+ structField := val.Field(i)
+ if typeField.Anonymous {
+ if structField.Kind() == reflect.Ptr {
+ structField = structField.Elem()
+ }
+ }
+ if !structField.CanSet() {
+ continue
+ }
+ structFieldKind := structField.Kind()
+ inputFieldName := typeField.Tag.Get(tag)
+ if typeField.Anonymous && structFieldKind == reflect.Struct && inputFieldName != "" {
+ // if anonymous struct with query/param/form tags, report an error
+ return errors.New("query/param/form tags are not allowed with anonymous struct field")
+ }
+
+ if inputFieldName == "" {
+ // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contain fields with tags).
+ // structs that implement BindUnmarshaler are bound only when they have explicit tag
+ if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct {
+ if err := bindData(structField.Addr().Interface(), data, tag, dataFiles); err != nil {
+ return err
+ }
+ }
+ // does not have explicit tag and is not an ordinary struct - so move to next field
+ continue
+ }
+
+ if hasFiles {
+ if ok, err := isFieldMultipartFile(structField.Type()); err != nil {
+ return err
+ } else if ok {
+ if ok := setMultipartFileHeaderTypes(structField, inputFieldName, dataFiles); ok {
+ continue
+ }
+ }
+ }
+
+ inputValue, exists := data[inputFieldName]
+ if !exists {
+ // Go json.Unmarshal supports case-insensitive binding. However the
+ // url params are bound case-sensitive which is inconsistent. To
+ // fix this we must check all of the map values in a
+ // case-insensitive search.
+ for k, v := range data {
+ if strings.EqualFold(k, inputFieldName) {
+ inputValue = v
+ exists = true
+ break
+ }
+ }
+ }
+
+ if !exists {
+ continue
+ }
+
+ // NOTE: algorithm here is not particularly sophisticated. It probably does not work with absurd types like `**[]*int`
+ // but it is smart enough to handle niche cases like `*int`,`*[]string`,`[]*int` .
+
+ // try unmarshalling first, in case we're dealing with an alias to an array type
+ if ok, err := unmarshalInputsToField(typeField.Type.Kind(), inputValue, structField); ok {
+ if err != nil {
+ return err
+ }
+ continue
+ }
+
+ formatTag := typeField.Tag.Get("format")
+ if ok, err := unmarshalInputToField(typeField.Type.Kind(), inputValue[0], structField, formatTag); ok {
+ if err != nil {
+ return err
+ }
+ continue
+ }
+
+ // we could be dealing with pointer to slice `*[]string` so dereference it. There are weird OpenAPI generators
+ // that could create struct fields like that.
+ if structFieldKind == reflect.Pointer {
+ structFieldKind = structField.Elem().Kind()
+ structField = structField.Elem()
+ }
+
+ if structFieldKind == reflect.Slice {
+ sliceOf := structField.Type().Elem().Kind()
+ numElems := len(inputValue)
+ slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
+ for j := 0; j < numElems; j++ {
+ if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil {
+ return err
+ }
+ }
+ structField.Set(slice)
+ continue
+ }
+
+ if err := setWithProperType(structFieldKind, inputValue[0], structField); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
+ // But also call it here, in case we're dealing with an array of BindUnmarshalers
+ // Note: format tag not available in this context, so empty string is passed
+ if ok, err := unmarshalInputToField(valueKind, val, structField, ""); ok {
+ return err
+ }
+
+ switch valueKind {
+ case reflect.Ptr:
+ return setWithProperType(structField.Elem().Kind(), val, structField.Elem())
+ case reflect.Int:
+ return setIntField(val, 0, structField)
+ case reflect.Int8:
+ return setIntField(val, 8, structField)
+ case reflect.Int16:
+ return setIntField(val, 16, structField)
+ case reflect.Int32:
+ return setIntField(val, 32, structField)
+ case reflect.Int64:
+ return setIntField(val, 64, structField)
+ case reflect.Uint:
+ return setUintField(val, 0, structField)
+ case reflect.Uint8:
+ return setUintField(val, 8, structField)
+ case reflect.Uint16:
+ return setUintField(val, 16, structField)
+ case reflect.Uint32:
+ return setUintField(val, 32, structField)
+ case reflect.Uint64:
+ return setUintField(val, 64, structField)
+ case reflect.Bool:
+ return setBoolField(val, structField)
+ case reflect.Float32:
+ return setFloatField(val, 32, structField)
+ case reflect.Float64:
+ return setFloatField(val, 64, structField)
+ case reflect.String:
+ structField.SetString(val)
+ default:
+ return errors.New("unknown type")
+ }
+ return nil
+}
+
+func unmarshalInputsToField(valueKind reflect.Kind, values []string, field reflect.Value) (bool, error) {
+ if valueKind == reflect.Ptr {
+ if field.IsNil() {
+ field.Set(reflect.New(field.Type().Elem()))
+ }
+ field = field.Elem()
+ }
+
+ fieldIValue := field.Addr().Interface()
+ unmarshaler, ok := fieldIValue.(bindMultipleUnmarshaler)
+ if !ok {
+ return false, nil
+ }
+ return true, unmarshaler.UnmarshalParams(values)
+}
+
+func unmarshalInputToField(valueKind reflect.Kind, val string, field reflect.Value, formatTag string) (bool, error) {
+ if valueKind == reflect.Ptr {
+ if field.IsNil() {
+ field.Set(reflect.New(field.Type().Elem()))
+ }
+ field = field.Elem()
+ }
+
+ fieldIValue := field.Addr().Interface()
+ // Handle time.Time with custom format tag
+ if formatTag != "" {
+ if _, isTime := fieldIValue.(*time.Time); isTime {
+ t, err := time.Parse(formatTag, val)
+ if err != nil {
+ return true, err
+ }
+ field.Set(reflect.ValueOf(t))
+ return true, nil
+ }
+ }
+
+ switch unmarshaler := fieldIValue.(type) {
+ case BindUnmarshaler:
+ return true, unmarshaler.UnmarshalParam(val)
+ case encoding.TextUnmarshaler:
+ return true, unmarshaler.UnmarshalText([]byte(val))
+ }
+
+ return false, nil
+}
+
+func setIntField(value string, bitSize int, field reflect.Value) error {
+ if value == "" {
+ value = "0"
+ }
+ intVal, err := strconv.ParseInt(value, 10, bitSize)
+ if err == nil {
+ field.SetInt(intVal)
+ }
+ return err
+}
+
+func setUintField(value string, bitSize int, field reflect.Value) error {
+ if value == "" {
+ value = "0"
+ }
+ uintVal, err := strconv.ParseUint(value, 10, bitSize)
+ if err == nil {
+ field.SetUint(uintVal)
+ }
+ return err
+}
+
+func setBoolField(value string, field reflect.Value) error {
+ if value == "" {
+ value = "false"
+ }
+ boolVal, err := strconv.ParseBool(value)
+ if err == nil {
+ field.SetBool(boolVal)
+ }
+ return err
+}
+
+func setFloatField(value string, bitSize int, field reflect.Value) error {
+ if value == "" {
+ value = "0.0"
+ }
+ floatVal, err := strconv.ParseFloat(value, bitSize)
+ if err == nil {
+ field.SetFloat(floatVal)
+ }
+ return err
+}
+
+var (
+ // NOT supported by bind as you can NOT check easily empty struct being actual file or not
+ multipartFileHeaderType = reflect.TypeFor[multipart.FileHeader]()
+ // supported by bind as you can check by nil value if file existed or not
+ multipartFileHeaderPointerType = reflect.TypeFor[*multipart.FileHeader]()
+ multipartFileHeaderSliceType = reflect.TypeFor[[]multipart.FileHeader]()
+ multipartFileHeaderPointerSliceType = reflect.TypeFor[[]*multipart.FileHeader]()
+)
+
+func isFieldMultipartFile(field reflect.Type) (bool, error) {
+ switch field {
+ case multipartFileHeaderPointerType,
+ multipartFileHeaderSliceType,
+ multipartFileHeaderPointerSliceType:
+ return true, nil
+ case multipartFileHeaderType:
+ return true, errors.New("binding to multipart.FileHeader struct is not supported, use pointer to struct")
+ default:
+ return false, nil
+ }
+}
+
+func setMultipartFileHeaderTypes(structField reflect.Value, inputFieldName string, files map[string][]*multipart.FileHeader) bool {
+ fileHeaders := files[inputFieldName]
+ if len(fileHeaders) == 0 {
+ return false
+ }
+
+ result := true
+ switch structField.Type() {
+ case multipartFileHeaderPointerSliceType:
+ structField.Set(reflect.ValueOf(fileHeaders))
+ case multipartFileHeaderSliceType:
+ headers := make([]multipart.FileHeader, len(fileHeaders))
+ for i, fileHeader := range fileHeaders {
+ headers[i] = *fileHeader
+ }
+ structField.Set(reflect.ValueOf(headers))
+ case multipartFileHeaderPointerType:
+ structField.Set(reflect.ValueOf(fileHeaders[0]))
+ default:
+ result = false
+ }
+
+ return result
+}
diff --git a/bind_test.go b/bind_test.go
new file mode 100644
index 000000000..1d5f8ca41
--- /dev/null
+++ b/bind_test.go
@@ -0,0 +1,1693 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "bytes"
+ "encoding/json"
+ "encoding/xml"
+ "errors"
+ "fmt"
+ "io"
+ "mime/multipart"
+ "net/http"
+ "net/http/httptest"
+ "net/http/httputil"
+ "net/url"
+ "reflect"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+type bindTestStruct struct {
+ T Timestamp
+ GoT time.Time
+ PtrI16 *int16
+ PtrUI *uint
+ Tptr *Timestamp
+ PtrF32 *float32
+ PtrB *bool
+ PtrI32 *int32
+ GoTptr *time.Time
+ PtrI64 *int64
+ PtrI *int
+ PtrI8 *int8
+ PtrF64 *float64
+ PtrUI8 *uint8
+ PtrUI64 *uint64
+ PtrUI16 *uint16
+ PtrS *string
+ PtrUI32 *uint32
+ S string
+ cantSet string
+ DoesntExist string
+ SA StringArray
+ F64 float64
+ I int
+ UI64 uint64
+ UI uint
+ I64 int64
+ F32 float32
+ UI32 uint32
+ I32 int32
+ UI16 uint16
+ I16 int16
+ B bool
+ UI8 uint8
+ I8 int8
+}
+
+type bindTestStructWithTags struct {
+ T Timestamp `json:"T" form:"T"`
+ GoT time.Time `json:"GoT" form:"GoT"`
+ PtrI16 *int16 `json:"PtrI16" form:"PtrI16"`
+ PtrUI *uint `json:"PtrUI" form:"PtrUI"`
+ Tptr *Timestamp `json:"Tptr" form:"Tptr"`
+ PtrF32 *float32 `json:"PtrF32" form:"PtrF32"`
+ PtrB *bool `json:"PtrB" form:"PtrB"`
+ PtrI32 *int32 `json:"PtrI32" form:"PtrI32"`
+ GoTptr *time.Time `json:"GoTptr" form:"GoTptr"`
+ PtrI64 *int64 `json:"PtrI64" form:"PtrI64"`
+ PtrI *int `json:"PtrI" form:"PtrI"`
+ PtrI8 *int8 `json:"PtrI8" form:"PtrI8"`
+ PtrF64 *float64 `json:"PtrF64" form:"PtrF64"`
+ PtrUI8 *uint8 `json:"PtrUI8" form:"PtrUI8"`
+ PtrUI64 *uint64 `json:"PtrUI64" form:"PtrUI64"`
+ PtrUI16 *uint16 `json:"PtrUI16" form:"PtrUI16"`
+ PtrS *string `json:"PtrS" form:"PtrS"`
+ PtrUI32 *uint32 `json:"PtrUI32" form:"PtrUI32"`
+ S string `json:"S" form:"S"`
+ cantSet string
+ DoesntExist string `json:"DoesntExist" form:"DoesntExist"`
+ SA StringArray `json:"SA" form:"SA"`
+ F64 float64 `json:"F64" form:"F64"`
+ I int `json:"I" form:"I"`
+ UI64 uint64 `json:"UI64" form:"UI64"`
+ UI uint `json:"UI" form:"UI"`
+ I64 int64 `json:"I64" form:"I64"`
+ F32 float32 `json:"F32" form:"F32"`
+ UI32 uint32 `json:"UI32" form:"UI32"`
+ I32 int32 `json:"I32" form:"I32"`
+ UI16 uint16 `json:"UI16" form:"UI16"`
+ I16 int16 `json:"I16" form:"I16"`
+ B bool `json:"B" form:"B"`
+ UI8 uint8 `json:"UI8" form:"UI8"`
+ I8 int8 `json:"I8" form:"I8"`
+}
+
+type Timestamp time.Time
+type TA []Timestamp
+type StringArray []string
+type Struct struct {
+ Foo string
+}
+type Bar struct {
+ Baz int `json:"baz" query:"baz"`
+}
+
+func (t *Timestamp) UnmarshalParam(src string) error {
+ ts, err := time.Parse(time.RFC3339, src)
+ *t = Timestamp(ts)
+ return err
+}
+
+func (a *StringArray) UnmarshalParam(src string) error {
+ *a = StringArray(strings.Split(src, ","))
+ return nil
+}
+
+func (s *Struct) UnmarshalParam(src string) error {
+ *s = Struct{
+ Foo: src,
+ }
+ return nil
+}
+
+func (t bindTestStruct) GetCantSet() string {
+ return t.cantSet
+}
+
+var values = map[string][]string{
+ "I": {"0"},
+ "PtrI": {"0"},
+ "I8": {"8"},
+ "PtrI8": {"8"},
+ "I16": {"16"},
+ "PtrI16": {"16"},
+ "I32": {"32"},
+ "PtrI32": {"32"},
+ "I64": {"64"},
+ "PtrI64": {"64"},
+ "UI": {"0"},
+ "PtrUI": {"0"},
+ "UI8": {"8"},
+ "PtrUI8": {"8"},
+ "UI16": {"16"},
+ "PtrUI16": {"16"},
+ "UI32": {"32"},
+ "PtrUI32": {"32"},
+ "UI64": {"64"},
+ "PtrUI64": {"64"},
+ "B": {"true"},
+ "PtrB": {"true"},
+ "F32": {"32.5"},
+ "PtrF32": {"32.5"},
+ "F64": {"64.5"},
+ "PtrF64": {"64.5"},
+ "S": {"test"},
+ "PtrS": {"test"},
+ "cantSet": {"test"},
+ "T": {"2016-12-06T19:09:05+01:00"},
+ "Tptr": {"2016-12-06T19:09:05+01:00"},
+ "GoT": {"2016-12-06T19:09:05+01:00"},
+ "GoTptr": {"2016-12-06T19:09:05+01:00"},
+ "ST": {"bar"},
+}
+
+// ptr return pointer to value. This is useful as `v := []*int8{&int8(1)}` will not compile
+func ptr[T any](value T) *T {
+ return &value
+}
+
+func TestToMultipleFields(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/?id=1&ID=2", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ type Root struct {
+ ID int64 `query:"id"`
+ Child2 struct {
+ ID int64
+ }
+ Child1 struct {
+ ID int64 `query:"id"`
+ }
+ }
+
+ u := new(Root)
+ err := c.Bind(u)
+ if assert.NoError(t, err) {
+ assert.Equal(t, int64(1), u.ID) // perfectly reasonable
+ assert.Equal(t, int64(1), u.Child1.ID) // untagged struct containing tagged field gets filled (by tag)
+ assert.Equal(t, int64(0), u.Child2.ID) // untagged struct containing untagged field should not be bind
+ }
+}
+
+func TestBindJSON(t *testing.T) {
+ testBindOkay(t, strings.NewReader(userJSON), nil, MIMEApplicationJSON)
+ testBindOkay(t, strings.NewReader(userJSON), dummyQuery, MIMEApplicationJSON)
+ testBindArrayOkay(t, strings.NewReader(usersJSON), nil, MIMEApplicationJSON)
+ testBindArrayOkay(t, strings.NewReader(usersJSON), dummyQuery, MIMEApplicationJSON)
+ testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{})
+ testBindError(t, strings.NewReader(userJSONInvalidType), MIMEApplicationJSON, &json.UnmarshalTypeError{})
+}
+
+func TestBindXML(t *testing.T) {
+ testBindOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML)
+ testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML)
+ testBindArrayOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML)
+ testBindArrayOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML)
+ testBindError(t, strings.NewReader(invalidContent), MIMEApplicationXML, errors.New(""))
+ testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMEApplicationXML, &strconv.NumError{})
+ testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMEApplicationXML, &xml.SyntaxError{})
+ testBindOkay(t, strings.NewReader(userXML), nil, MIMETextXML)
+ testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMETextXML)
+ testBindError(t, strings.NewReader(invalidContent), MIMETextXML, errors.New(""))
+ testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMETextXML, &strconv.NumError{})
+ testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMETextXML, &xml.SyntaxError{})
+}
+
+func TestBindForm(t *testing.T) {
+
+ testBindOkay(t, strings.NewReader(userForm), nil, MIMEApplicationForm)
+ testBindOkay(t, strings.NewReader(userForm), dummyQuery, MIMEApplicationForm)
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userForm))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ req.Header.Set(HeaderContentType, MIMEApplicationForm)
+ err := c.Bind(&[]struct{ Field string }{})
+ assert.Error(t, err)
+}
+
+func TestBindQueryParams(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ u := new(user)
+ err := c.Bind(u)
+ if assert.NoError(t, err) {
+ assert.Equal(t, 1, u.ID)
+ assert.Equal(t, "Jon Snow", u.Name)
+ }
+}
+
+func TestBindQueryParamsCaseInsensitive(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/?ID=1&NAME=Jon+Snow", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ u := new(user)
+ err := c.Bind(u)
+ if assert.NoError(t, err) {
+ assert.Equal(t, 1, u.ID)
+ assert.Equal(t, "Jon Snow", u.Name)
+ }
+}
+
+func TestBindQueryParamsCaseSensitivePrioritized(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/?id=1&ID=2&NAME=Jon+Snow&name=Jon+Doe", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ u := new(user)
+ err := c.Bind(u)
+ if assert.NoError(t, err) {
+ assert.Equal(t, 1, u.ID)
+ assert.Equal(t, "Jon Doe", u.Name)
+ }
+}
+
+func TestBindHeaderParam(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set("Name", "Jon Doe")
+ req.Header.Set("Id", "2")
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ u := new(user)
+ err := BindHeaders(c, u)
+ if assert.NoError(t, err) {
+ assert.Equal(t, 2, u.ID)
+ assert.Equal(t, "Jon Doe", u.Name)
+ }
+}
+
+func TestBindHeaderParamBadType(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set("Id", "salamander")
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ u := new(user)
+ err := BindHeaders(c, u)
+ assert.Error(t, err)
+
+ httpErr, ok := err.(*HTTPError)
+ if assert.True(t, ok) {
+ assert.Equal(t, http.StatusBadRequest, httpErr.Code)
+ }
+}
+
+func TestBindUnmarshalParam(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ result := struct {
+ T Timestamp `query:"ts"`
+ ST Struct
+ StWithTag struct {
+ Foo string `query:"st"`
+ }
+ TA []Timestamp `query:"ta"`
+ SA StringArray `query:"sa"`
+ }{}
+ err := c.Bind(&result)
+ ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC))
+
+ if assert.NoError(t, err) {
+ // assert.Equal( Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T)
+ assert.Equal(t, ts, result.T)
+ assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA)
+ assert.Equal(t, []Timestamp{ts, ts}, result.TA)
+ assert.Equal(t, Struct{""}, result.ST) // child struct does not have a field with matching tag
+ assert.Equal(t, "baz", result.StWithTag.Foo) // child struct has field with matching tag
+ }
+}
+
+func TestBindUnmarshalText(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ result := struct {
+ T time.Time `query:"ts"`
+ ST Struct
+ TA []time.Time `query:"ta"`
+ SA StringArray `query:"sa"`
+ }{}
+ err := c.Bind(&result)
+ ts := time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)
+ if assert.NoError(t, err) {
+ // assert.Equal(t, Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T)
+ assert.Equal(t, ts, result.T)
+ assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA)
+ assert.Equal(t, []time.Time{ts, ts}, result.TA)
+ assert.Equal(t, Struct{""}, result.ST) // field in child struct does not have tag
+ }
+}
+
+func TestBindUnmarshalParamPtr(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ result := struct {
+ Tptr *Timestamp `query:"ts"`
+ }{}
+ err := c.Bind(&result)
+ if assert.NoError(t, err) {
+ assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), *result.Tptr)
+ }
+}
+
+func TestBindUnmarshalParamAnonymousFieldPtr(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/?baz=1", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ result := struct {
+ *Bar
+ }{&Bar{}}
+ err := c.Bind(&result)
+ if assert.NoError(t, err) {
+ assert.Equal(t, 1, result.Baz)
+ }
+}
+
+func TestBindUnmarshalParamAnonymousFieldPtrNil(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/?baz=1", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ result := struct {
+ *Bar
+ }{}
+ err := c.Bind(&result)
+ if assert.NoError(t, err) {
+ assert.Nil(t, result.Bar)
+ }
+}
+
+func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, `/?bar={"baz":100}&baz=1`, nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ result := struct {
+ *Bar `json:"bar" query:"bar"`
+ }{&Bar{}}
+ err := c.Bind(&result)
+ assert.Contains(t, err.Error(), "query/param/form tags are not allowed with anonymous struct field")
+}
+
+func TestBindUnmarshalTextPtr(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ result := struct {
+ Tptr *time.Time `query:"ts"`
+ }{}
+ err := c.Bind(&result)
+ if assert.NoError(t, err) {
+ assert.Equal(t, time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC), *result.Tptr)
+ }
+}
+
+func TestBindMultipartForm(t *testing.T) {
+ bodyBuffer := new(bytes.Buffer)
+ mw := multipart.NewWriter(bodyBuffer)
+ mw.WriteField("id", "1")
+ mw.WriteField("name", "Jon Snow")
+ mw.Close()
+ body := bodyBuffer.Bytes()
+
+ testBindOkay(t, bytes.NewReader(body), nil, mw.FormDataContentType())
+ testBindOkay(t, bytes.NewReader(body), dummyQuery, mw.FormDataContentType())
+}
+
+func TestBindUnsupportedMediaType(t *testing.T) {
+ testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{})
+}
+
+func TestDefaultBinder_bindDataToMap(t *testing.T) {
+ exampleData := map[string][]string{
+ "multiple": {"1", "2"},
+ "single": {"3"},
+ }
+
+ t.Run("ok, bind to map[string]string", func(t *testing.T) {
+ dest := map[string]string{}
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t,
+ map[string]string{
+ "multiple": "1",
+ "single": "3",
+ },
+ dest,
+ )
+ })
+
+ t.Run("ok, bind to map[string]string with nil map", func(t *testing.T) {
+ var dest map[string]string
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t,
+ map[string]string{
+ "multiple": "1",
+ "single": "3",
+ },
+ dest,
+ )
+ })
+
+ t.Run("ok, bind to map[string][]string", func(t *testing.T) {
+ dest := map[string][]string{}
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t,
+ map[string][]string{
+ "multiple": {"1", "2"},
+ "single": {"3"},
+ },
+ dest,
+ )
+ })
+
+ t.Run("ok, bind to map[string][]string with nil map", func(t *testing.T) {
+ var dest map[string][]string
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t,
+ map[string][]string{
+ "multiple": {"1", "2"},
+ "single": {"3"},
+ },
+ dest,
+ )
+ })
+
+ t.Run("ok, bind to map[string]interface", func(t *testing.T) {
+ dest := map[string]any{}
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t,
+ map[string]any{
+ "multiple": "1",
+ "single": "3",
+ },
+ dest,
+ )
+ })
+
+ t.Run("ok, bind to map[string]interface with nil map", func(t *testing.T) {
+ var dest map[string]any
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t,
+ map[string]any{
+ "multiple": "1",
+ "single": "3",
+ },
+ dest,
+ )
+ })
+
+ t.Run("ok, bind to map[string]int skips", func(t *testing.T) {
+ dest := map[string]int{}
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t, map[string]int{}, dest)
+ })
+
+ t.Run("ok, bind to map[string]int skips with nil map", func(t *testing.T) {
+ var dest map[string]int
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t, map[string]int(nil), dest)
+ })
+
+ t.Run("ok, bind to map[string][]int skips", func(t *testing.T) {
+ dest := map[string][]int{}
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t, map[string][]int{}, dest)
+ })
+
+ t.Run("ok, bind to map[string][]int skips with nil map", func(t *testing.T) {
+ var dest map[string][]int
+ assert.NoError(t, bindData(&dest, exampleData, "param", nil))
+ assert.Equal(t, map[string][]int(nil), dest)
+ })
+}
+
+func TestBindbindData(t *testing.T) {
+ ts := new(bindTestStruct)
+ err := bindData(ts, values, "form", nil)
+ assert.NoError(t, err)
+
+ assert.Equal(t, 0, ts.I)
+ assert.Equal(t, int8(0), ts.I8)
+ assert.Equal(t, int16(0), ts.I16)
+ assert.Equal(t, int32(0), ts.I32)
+ assert.Equal(t, int64(0), ts.I64)
+ assert.Equal(t, uint(0), ts.UI)
+ assert.Equal(t, uint8(0), ts.UI8)
+ assert.Equal(t, uint16(0), ts.UI16)
+ assert.Equal(t, uint32(0), ts.UI32)
+ assert.Equal(t, uint64(0), ts.UI64)
+ assert.Equal(t, false, ts.B)
+ assert.Equal(t, float32(0), ts.F32)
+ assert.Equal(t, float64(0), ts.F64)
+ assert.Equal(t, "", ts.S)
+ assert.Equal(t, "", ts.cantSet)
+}
+
+func TestBindParam(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ c.InitializeRoute(
+ &RouteInfo{Path: "/users/:id/:name"},
+ &PathValues{
+ {Name: "id", Value: "1"},
+ {Name: "name", Value: "Jon Snow"},
+ },
+ )
+
+ u := new(user)
+ err := c.Bind(u)
+ if assert.NoError(t, err) {
+ assert.Equal(t, 1, u.ID)
+ assert.Equal(t, "Jon Snow", u.Name)
+ }
+
+ // Second test for the absence of a param
+ c2 := e.NewContext(req, rec)
+ c2.InitializeRoute(
+ &RouteInfo{Path: "/users/:id"},
+ &PathValues{
+ {Name: "id", Value: "1"},
+ },
+ )
+
+ u = new(user)
+ err = c2.Bind(u)
+ if assert.NoError(t, err) {
+ assert.Equal(t, 1, u.ID)
+ assert.Equal(t, "", u.Name)
+ }
+
+ // Bind something with param and post data payload
+ body := bytes.NewBufferString(`{ "name": "Jon Snow" }`)
+ e2 := New()
+ req2 := httptest.NewRequest(http.MethodPost, "/", body)
+ req2.Header.Set(HeaderContentType, MIMEApplicationJSON)
+
+ rec2 := httptest.NewRecorder()
+
+ c3 := e2.NewContext(req2, rec2)
+ c3.InitializeRoute(
+ &RouteInfo{Path: "/users/:id"},
+ &PathValues{
+ {Name: "id", Value: "1"},
+ },
+ )
+
+ u = new(user)
+ err = c3.Bind(u)
+ if assert.NoError(t, err) {
+ assert.Equal(t, 1, u.ID)
+ assert.Equal(t, "Jon Snow", u.Name)
+ }
+}
+
+func TestBindUnmarshalTypeError(t *testing.T) {
+ body := bytes.NewBufferString(`{ "id": "text" }`)
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", body)
+ req.Header.Set(HeaderContentType, MIMEApplicationJSON)
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ u := new(user)
+
+ err := c.Bind(u)
+
+ assert.EqualError(t, err, `code=400, message=Bad Request, err=json: cannot unmarshal string into Go struct field user.id of type int`)
+}
+
+func TestBindSetWithProperType(t *testing.T) {
+ ts := new(bindTestStruct)
+ typ := reflect.TypeOf(ts).Elem()
+ val := reflect.ValueOf(ts).Elem()
+ for i := 0; i < typ.NumField(); i++ {
+ typeField := typ.Field(i)
+ structField := val.Field(i)
+ if !structField.CanSet() {
+ continue
+ }
+ if len(values[typeField.Name]) == 0 {
+ continue
+ }
+ val := values[typeField.Name][0]
+ err := setWithProperType(typeField.Type.Kind(), val, structField)
+ assert.NoError(t, err)
+ }
+ assertBindTestStruct(t, ts)
+
+ type foo struct {
+ Bar bytes.Buffer
+ }
+ v := &foo{}
+ typ = reflect.TypeOf(v).Elem()
+ val = reflect.ValueOf(v).Elem()
+ assert.Error(t, setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0)))
+}
+
+func BenchmarkBindbindDataWithTags(b *testing.B) {
+ b.ReportAllocs()
+ ts := new(bindTestStructWithTags)
+ var err error
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ err = bindData(ts, values, "form", nil)
+ }
+ assert.NoError(b, err)
+ assertBindTestStruct(b, (*bindTestStruct)(ts))
+}
+
+func assertBindTestStruct(tb testing.TB, ts *bindTestStruct) {
+ assert.Equal(tb, 0, ts.I)
+ assert.Equal(tb, int8(8), ts.I8)
+ assert.Equal(tb, int16(16), ts.I16)
+ assert.Equal(tb, int32(32), ts.I32)
+ assert.Equal(tb, int64(64), ts.I64)
+ assert.Equal(tb, uint(0), ts.UI)
+ assert.Equal(tb, uint8(8), ts.UI8)
+ assert.Equal(tb, uint16(16), ts.UI16)
+ assert.Equal(tb, uint32(32), ts.UI32)
+ assert.Equal(tb, uint64(64), ts.UI64)
+ assert.Equal(tb, true, ts.B)
+ assert.Equal(tb, float32(32.5), ts.F32)
+ assert.Equal(tb, float64(64.5), ts.F64)
+ assert.Equal(tb, "test", ts.S)
+ assert.Equal(tb, "", ts.GetCantSet())
+}
+
+func testBindOkay(t *testing.T, r io.Reader, query url.Values, ctype string) {
+ e := New()
+ path := "/"
+ if len(query) > 0 {
+ path += "?" + query.Encode()
+ }
+ req := httptest.NewRequest(http.MethodPost, path, r)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ req.Header.Set(HeaderContentType, ctype)
+ u := new(user)
+ err := c.Bind(u)
+ if assert.Equal(t, nil, err) {
+ assert.Equal(t, 1, u.ID)
+ assert.Equal(t, "Jon Snow", u.Name)
+ }
+}
+
+func testBindArrayOkay(t *testing.T, r io.Reader, query url.Values, ctype string) {
+ e := New()
+ path := "/"
+ if len(query) > 0 {
+ path += "?" + query.Encode()
+ }
+ req := httptest.NewRequest(http.MethodPost, path, r)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ req.Header.Set(HeaderContentType, ctype)
+ u := []user{}
+ err := c.Bind(&u)
+ if assert.NoError(t, err) {
+ assert.Equal(t, 1, len(u))
+ assert.Equal(t, 1, u[0].ID)
+ assert.Equal(t, "Jon Snow", u[0].Name)
+ }
+}
+
+func testBindError(t *testing.T, r io.Reader, ctype string, expectedInternal error) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", r)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ req.Header.Set(HeaderContentType, ctype)
+ u := new(user)
+ err := c.Bind(u)
+
+ switch {
+ case strings.HasPrefix(ctype, MIMEApplicationJSON), strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML),
+ strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
+ if assert.IsType(t, new(HTTPError), err) {
+ assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).Code)
+ assert.IsType(t, expectedInternal, err.(*HTTPError).Unwrap())
+ }
+ default:
+ if assert.IsType(t, new(HTTPError), err) {
+ assert.Equal(t, ErrUnsupportedMediaType, err)
+ assert.IsType(t, expectedInternal, err.(*HTTPError).Unwrap())
+ }
+ }
+}
+
+func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) {
+ // tests to check binding behaviour when multiple sources (path params, query params and request body) are in use
+ // binding is done in steps and one source could overwrite previous source bound data
+ // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed
+
+ type Opts struct {
+ Node string `json:"node" form:"node" query:"node" param:"node"`
+ Lang string
+ ID int `json:"id" form:"id" query:"id"`
+ }
+
+ var testCases = []struct {
+ givenContent io.Reader
+ whenBindTarget any
+ expect any
+ name string
+ givenURL string
+ givenMethod string
+ expectError string
+ whenNoPathValues bool
+ }{
+ {
+ name: "ok, POST bind to struct with: path param + query param + body",
+ givenMethod: http.MethodPost,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`{"id": 1}`),
+ expect: &Opts{ID: 1, Node: "node_from_path"}, // query params are not used, node is filled from path
+ },
+ {
+ name: "ok, PUT bind to struct with: path param + query param + body",
+ givenMethod: http.MethodPut,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`{"id": 1}`),
+ expect: &Opts{ID: 1, Node: "node_from_path"}, // query params are not used
+ },
+ {
+ name: "ok, GET bind to struct with: path param + query param + body",
+ givenMethod: http.MethodGet,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`{"id": 1}`),
+ expect: &Opts{ID: 1, Node: "xxx"}, // query overwrites previous path value
+ },
+ {
+ name: "ok, GET bind to struct with: path param + query param + body",
+ givenMethod: http.MethodGet,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
+ expect: &Opts{ID: 1, Node: "zzz"}, // body is bound last and overwrites previous (path,query) values
+ },
+ {
+ name: "ok, DELETE bind to struct with: path param + query param + body",
+ givenMethod: http.MethodDelete,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
+ expect: &Opts{ID: 1, Node: "zzz"}, // for DELETE body is bound after query params
+ },
+ {
+ name: "ok, POST bind to struct with: path param + body",
+ givenMethod: http.MethodPost,
+ givenURL: "/api/real_node/endpoint",
+ givenContent: strings.NewReader(`{"id": 1}`),
+ expect: &Opts{ID: 1, Node: "node_from_path"},
+ },
+ {
+ name: "ok, POST bind to struct with path + query + body = body has priority",
+ givenMethod: http.MethodPost,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
+ expect: &Opts{ID: 1, Node: "zzz"}, // field value from content has higher priority
+ },
+ {
+ name: "nok, POST body bind failure",
+ givenMethod: http.MethodPost,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`{`),
+ expect: &Opts{ID: 0, Node: "node_from_path"}, // query binding has already modified bind target
+ expectError: "code=400, message=Bad Request, err=unexpected EOF",
+ },
+ {
+ name: "nok, GET with body bind failure when types are not convertible",
+ givenMethod: http.MethodGet,
+ givenURL: "/api/real_node/endpoint?id=nope",
+ givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
+ expect: &Opts{ID: 0, Node: "node_from_path"}, // path params binding has already modified bind target
+ expectError: `code=400, message=Bad Request, err=strconv.ParseInt: parsing "nope": invalid syntax`,
+ },
+ {
+ name: "nok, GET body bind failure - trying to bind json array to struct",
+ givenMethod: http.MethodGet,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`[{"id": 1}]`),
+ expect: &Opts{ID: 0, Node: "xxx"}, // query binding has already modified bind target
+ expectError: `code=400, message=Bad Request, err=json: cannot unmarshal array into Go value of type echo.Opts`,
+ },
+ { // query param is ignored as we do not know where exactly to bind it in slice
+ name: "ok, GET bind to struct slice, ignore query param",
+ givenMethod: http.MethodGet,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`[{"id": 1}]`),
+ whenNoPathValues: true,
+ whenBindTarget: &[]Opts{},
+ expect: &[]Opts{
+ {ID: 1, Node: ""},
+ },
+ },
+ { // binding query params interferes with body. b.BindBody() should be used to bind only body to slice
+ name: "ok, POST binding to slice should not be affected query params types",
+ givenMethod: http.MethodPost,
+ givenURL: "/api/real_node/endpoint?id=nope&node=xxx",
+ givenContent: strings.NewReader(`[{"id": 1}]`),
+ whenNoPathValues: true,
+ whenBindTarget: &[]Opts{},
+ expect: &[]Opts{{ID: 1}},
+ expectError: "",
+ },
+ { // path param is ignored as we do not know where exactly to bind it in slice
+ name: "ok, GET bind to struct slice, ignore path param",
+ givenMethod: http.MethodGet,
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenContent: strings.NewReader(`[{"id": 1}]`),
+ whenBindTarget: &[]Opts{},
+ expect: &[]Opts{
+ {ID: 1, Node: ""},
+ },
+ },
+ {
+ name: "ok, GET body bind json array to slice",
+ givenMethod: http.MethodGet,
+ givenURL: "/api/real_node/endpoint",
+ givenContent: strings.NewReader(`[{"id": 1}]`),
+ whenNoPathValues: true,
+ whenBindTarget: &[]Opts{},
+ expect: &[]Opts{{ID: 1, Node: ""}},
+ expectError: "",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ // assume route we are testing is "/api/:node/endpoint?some_query_params=here"
+ req := httptest.NewRequest(tc.givenMethod, tc.givenURL, tc.givenContent)
+ req.Header.Set(HeaderContentType, MIMEApplicationJSON)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ if !tc.whenNoPathValues {
+ c.SetPathValues(PathValues{
+ {Name: "node", Value: "node_from_path"},
+ })
+ }
+
+ var bindTarget any
+ if tc.whenBindTarget != nil {
+ bindTarget = tc.whenBindTarget
+ } else {
+ bindTarget = &Opts{}
+ }
+ b := new(DefaultBinder)
+
+ err := b.Bind(c, bindTarget)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, bindTarget)
+ })
+ }
+}
+
+func TestDefaultBinder_BindBody(t *testing.T) {
+ // tests to check binding behaviour when multiple sources (path params, query params and request body) are in use
+ // generally when binding from request body - URL and path params are ignored - unless form is being bound.
+ // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed
+
+ type Node struct {
+ Node string `json:"node" xml:"node" form:"node" query:"node" param:"node"`
+ ID int `json:"id" xml:"id" form:"id" query:"id"`
+ }
+ type Nodes struct {
+ Nodes []Node `xml:"node" form:"node"`
+ }
+
+ var testCases = []struct {
+ givenContent io.Reader
+ whenBindTarget any
+ expect any
+ name string
+ givenURL string
+ givenMethod string
+ givenContentType string
+ expectError string
+ whenNoPathValues bool
+ whenChunkedBody bool
+ }{
+ {
+ name: "ok, JSON POST bind to struct with: path + query + empty field in body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationJSON,
+ givenContent: strings.NewReader(`{"id": 1}`),
+ expect: &Node{ID: 1, Node: ""}, // path params or query params should not interfere with body
+ },
+ {
+ name: "ok, JSON POST bind to struct with: path + query + body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationJSON,
+ givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
+ expect: &Node{ID: 1, Node: "zzz"}, // field value from content has higher priority
+ },
+ {
+ name: "ok, JSON POST body bind json array to slice (has matching path/query params)",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationJSON,
+ givenContent: strings.NewReader(`[{"id": 1}]`),
+ whenNoPathValues: true,
+ whenBindTarget: &[]Node{},
+ expect: &[]Node{{ID: 1, Node: ""}},
+ expectError: "",
+ },
+ { // rare case as GET is not usually used to send request body
+ name: "ok, JSON GET bind to struct with: path + query + empty field in body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodGet,
+ givenContentType: MIMEApplicationJSON,
+ givenContent: strings.NewReader(`{"id": 1}`),
+ expect: &Node{ID: 1, Node: ""}, // path params or query params should not interfere with body
+ },
+ { // rare case as GET is not usually used to send request body
+ name: "ok, JSON GET bind to struct with: path + query + body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodGet,
+ givenContentType: MIMEApplicationJSON,
+ givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`),
+ expect: &Node{ID: 1, Node: "zzz"}, // field value from content has higher priority
+ },
+ {
+ name: "nok, JSON POST body bind failure",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationJSON,
+ givenContent: strings.NewReader(`{`),
+ expect: &Node{ID: 0, Node: ""},
+ expectError: "code=400, message=Bad Request, err=unexpected EOF",
+ },
+ {
+ name: "ok, XML POST bind to struct with: path + query + empty body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationXML,
+ givenContent: strings.NewReader(`1yyy`),
+ expect: &Node{ID: 1, Node: "yyy"},
+ },
+ {
+ name: "ok, XML POST bind array to slice with: path + query + body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationXML,
+ givenContent: strings.NewReader(`1yyy`),
+ whenBindTarget: &Nodes{},
+ expect: &Nodes{Nodes: []Node{{ID: 1, Node: "yyy"}}},
+ },
+ {
+ name: "nok, XML POST bind failure",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationXML,
+ givenContent: strings.NewReader(`<`),
+ expect: &Node{ID: 0, Node: ""},
+ expectError: "code=400, message=Bad Request, err=XML syntax error on line 1: unexpected EOF",
+ },
+ {
+ name: "ok, FORM POST bind to struct with: path + query + body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationForm,
+ givenContent: strings.NewReader(`id=1&node=yyy`),
+ expect: &Node{ID: 1, Node: "yyy"},
+ },
+ {
+ // NB: form values are taken from BOTH body and query for POST/PUT/PATCH by standard library implementation
+ // See: https://golang.org/pkg/net/http/#Request.ParseForm
+ name: "ok, FORM POST bind to struct with: path + query + empty field in body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationForm,
+ givenContent: strings.NewReader(`id=1`),
+ expect: &Node{ID: 1, Node: "xxx"},
+ },
+ {
+ // NB: form values are taken from query by standard library implementation
+ // See: https://golang.org/pkg/net/http/#Request.ParseForm
+ name: "ok, FORM GET bind to struct with: path + query + empty field in body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodGet,
+ givenContentType: MIMEApplicationForm,
+ givenContent: strings.NewReader(`id=1`),
+ expect: &Node{ID: 0, Node: "xxx"}, // 'xxx' is taken from URL and body is not used with GET by implementation
+ },
+ {
+ name: "nok, unsupported content type",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMETextPlain,
+ givenContent: strings.NewReader(``),
+ expect: &Node{ID: 0, Node: ""},
+ expectError: "code=415, message=Unsupported Media Type",
+ },
+ // FIXME: REASON in Go 1.24 and earlier http.NoBody would result ContentLength=-1
+ // but as of Go 1.25 http.NoBody would result ContentLength=0
+ // I am too lazy to bother documenting this as 2 version specific tests.
+ //{
+ // name: "nok, JSON POST with http.NoBody",
+ // givenURL: "/api/real_node/endpoint?node=xxx",
+ // givenMethod: http.MethodPost,
+ // givenContentType: MIMEApplicationJSON,
+ // givenContent: http.NoBody,
+ // expect: &Node{ID: 0, Node: ""},
+ // expectError: "code=400, message=EOF, internal=EOF",
+ //},
+ {
+ name: "ok, JSON POST with empty body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationJSON,
+ givenContent: strings.NewReader(""),
+ expect: &Node{ID: 0, Node: ""},
+ },
+ {
+ name: "ok, JSON POST bind to struct with: path + query + chunked body",
+ givenURL: "/api/real_node/endpoint?node=xxx",
+ givenMethod: http.MethodPost,
+ givenContentType: MIMEApplicationJSON,
+ givenContent: httputil.NewChunkedReader(strings.NewReader("18\r\n" + `{"id": 1, "node": "zzz"}` + "\r\n0\r\n")),
+ whenChunkedBody: true,
+ expect: &Node{ID: 1, Node: "zzz"},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ // assume route we are testing is "/api/:node/endpoint?some_query_params=here"
+ req := httptest.NewRequest(tc.givenMethod, tc.givenURL, tc.givenContent)
+ switch tc.givenContentType {
+ case MIMEApplicationXML:
+ req.Header.Set(HeaderContentType, MIMEApplicationXML)
+ case MIMEApplicationForm:
+ req.Header.Set(HeaderContentType, MIMEApplicationForm)
+ case MIMEApplicationJSON:
+ req.Header.Set(HeaderContentType, MIMEApplicationJSON)
+ }
+ if tc.whenChunkedBody {
+ req.ContentLength = -1
+ req.TransferEncoding = append(req.TransferEncoding, "chunked")
+ }
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ if !tc.whenNoPathValues {
+ c.SetPathValues(PathValues{
+ {Name: "node", Value: "real_node"},
+ })
+ }
+
+ var bindTarget any
+ if tc.whenBindTarget != nil {
+ bindTarget = tc.whenBindTarget
+ } else {
+ bindTarget = &Node{}
+ }
+
+ err := BindBody(c, bindTarget)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, bindTarget)
+ })
+ }
+}
+
+func testBindURL(queryString string, target any) error {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, queryString, nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ return c.Bind(target)
+}
+
+type unixTimestamp struct {
+ Time time.Time
+}
+
+func (t *unixTimestamp) UnmarshalParam(param string) error {
+ n, err := strconv.ParseInt(param, 10, 64)
+ if err != nil {
+ return fmt.Errorf("'%s' is not an integer", param)
+ }
+ *t = unixTimestamp{Time: time.Unix(n, 0)}
+ return err
+}
+
+type IntArrayA []int
+
+// UnmarshalParam converts value to *Int64Slice. This allows the API to accept
+// a comma-separated list of integers as a query parameter.
+func (i *IntArrayA) UnmarshalParam(value string) error {
+ var values = strings.Split(value, ",")
+ var numbers = make([]int, 0, len(values))
+
+ for _, v := range values {
+ n, err := strconv.ParseInt(v, 10, 64)
+ if err != nil {
+ return fmt.Errorf("'%s' is not an integer", v)
+ }
+
+ numbers = append(numbers, int(n))
+ }
+
+ *i = append(*i, numbers...)
+ return nil
+}
+
+func TestBindUnmarshalParamExtras(t *testing.T) {
+ // this test documents how bind handles `BindUnmarshaler` interface:
+ // NOTE: BindUnmarshaler chooses first input value to be bound.
+
+ t.Run("nok, unmarshalling fails", func(t *testing.T) {
+ result := struct {
+ V unixTimestamp `query:"t"`
+ }{}
+ err := testBindURL("/?t=xxxx", &result)
+
+ assert.EqualError(t, err, `code=400, message=Bad Request, err='xxxx' is not an integer`)
+ })
+
+ t.Run("ok, target is struct", func(t *testing.T) {
+ result := struct {
+ V unixTimestamp `query:"t"`
+ }{}
+ err := testBindURL("/?t=1710095540&t=1710095541", &result)
+
+ assert.NoError(t, err)
+ expect := unixTimestamp{
+ Time: time.Unix(1710095540, 0),
+ }
+ assert.Equal(t, expect, result.V)
+ })
+
+ t.Run("ok, target is an alias to slice and is nil, append only values from first", func(t *testing.T) {
+ result := struct {
+ V IntArrayA `query:"a"`
+ }{}
+ err := testBindURL("/?a=1,2,3&a=4,5,6", &result)
+
+ assert.NoError(t, err)
+ assert.Equal(t, IntArrayA([]int{1, 2, 3}), result.V)
+ })
+
+ t.Run("ok, target is an alias to slice and is nil, single input", func(t *testing.T) {
+ result := struct {
+ V IntArrayA `query:"a"`
+ }{}
+ err := testBindURL("/?a=1,2", &result)
+
+ assert.NoError(t, err)
+ assert.Equal(t, IntArrayA([]int{1, 2}), result.V)
+ })
+
+ t.Run("ok, target is pointer an alias to slice and is nil", func(t *testing.T) {
+ result := struct {
+ V *IntArrayA `query:"a"`
+ }{}
+ err := testBindURL("/?a=1&a=4,5,6", &result)
+
+ assert.NoError(t, err)
+ var expected = IntArrayA([]int{1})
+ assert.Equal(t, &expected, result.V)
+ })
+
+ t.Run("ok, target is pointer an alias to slice and is NOT nil", func(t *testing.T) {
+ result := struct {
+ V *IntArrayA `query:"a"`
+ }{}
+ result.V = new(IntArrayA) // NOT nil
+
+ err := testBindURL("/?a=1&a=4,5,6", &result)
+
+ assert.NoError(t, err)
+ var expected = IntArrayA([]int{1})
+ assert.Equal(t, &expected, result.V)
+ })
+}
+
+type unixTimestampLast struct {
+ Time time.Time
+}
+
+// this is silly example for `bindMultipleUnmarshaler` for type that uses last input value for unmarshalling
+func (t *unixTimestampLast) UnmarshalParams(params []string) error {
+ lastInput := params[len(params)-1]
+ n, err := strconv.ParseInt(lastInput, 10, 64)
+ if err != nil {
+ return fmt.Errorf("'%s' is not an integer", lastInput)
+ }
+ *t = unixTimestampLast{Time: time.Unix(n, 0)}
+ return err
+}
+
+type IntArrayB []int
+
+func (i *IntArrayB) UnmarshalParams(params []string) error {
+ var numbers = make([]int, 0, len(params))
+
+ for _, param := range params {
+ var values = strings.Split(param, ",")
+ for _, v := range values {
+ n, err := strconv.ParseInt(v, 10, 64)
+ if err != nil {
+ return fmt.Errorf("'%s' is not an integer", v)
+ }
+ numbers = append(numbers, int(n))
+ }
+ }
+
+ *i = append(*i, numbers...)
+ return nil
+}
+
+func TestBindUnmarshalParams(t *testing.T) {
+ // this test documents how bind handles `bindMultipleUnmarshaler` interface:
+
+ t.Run("nok, unmarshalling fails", func(t *testing.T) {
+ result := struct {
+ V unixTimestampLast `query:"t"`
+ }{}
+ err := testBindURL("/?t=xxxx", &result)
+
+ assert.EqualError(t, err, "code=400, message=Bad Request, err='xxxx' is not an integer")
+ })
+
+ t.Run("ok, target is struct", func(t *testing.T) {
+ result := struct {
+ V unixTimestampLast `query:"t"`
+ }{}
+ err := testBindURL("/?t=1710095540&t=1710095541", &result)
+
+ assert.NoError(t, err)
+ expect := unixTimestampLast{
+ Time: time.Unix(1710095541, 0),
+ }
+ assert.Equal(t, expect, result.V)
+ })
+
+ t.Run("ok, target is an alias to slice and is nil, append multiple inputs", func(t *testing.T) {
+ result := struct {
+ V IntArrayB `query:"a"`
+ }{}
+ err := testBindURL("/?a=1,2,3&a=4,5,6", &result)
+
+ assert.NoError(t, err)
+ assert.Equal(t, IntArrayB([]int{1, 2, 3, 4, 5, 6}), result.V)
+ })
+
+ t.Run("ok, target is an alias to slice and is nil, single input", func(t *testing.T) {
+ result := struct {
+ V IntArrayB `query:"a"`
+ }{}
+ err := testBindURL("/?a=1,2", &result)
+
+ assert.NoError(t, err)
+ assert.Equal(t, IntArrayB([]int{1, 2}), result.V)
+ })
+
+ t.Run("ok, target is pointer an alias to slice and is nil", func(t *testing.T) {
+ result := struct {
+ V *IntArrayB `query:"a"`
+ }{}
+ err := testBindURL("/?a=1&a=4,5,6", &result)
+
+ assert.NoError(t, err)
+ var expected = IntArrayB([]int{1, 4, 5, 6})
+ assert.Equal(t, &expected, result.V)
+ })
+
+ t.Run("ok, target is pointer an alias to slice and is NOT nil", func(t *testing.T) {
+ result := struct {
+ V *IntArrayB `query:"a"`
+ }{}
+ result.V = new(IntArrayB) // NOT nil
+
+ err := testBindURL("/?a=1&a=4,5,6", &result)
+ assert.NoError(t, err)
+ var expected = IntArrayB([]int{1, 4, 5, 6})
+ assert.Equal(t, &expected, result.V)
+ })
+}
+
+func TestBindInt8(t *testing.T) {
+ t.Run("nok, binding fails", func(t *testing.T) {
+ type target struct {
+ V int8 `query:"v"`
+ }
+ p := target{}
+ err := testBindURL("/?v=x&v=2", &p)
+ assert.EqualError(t, err, `code=400, message=Bad Request, err=strconv.ParseInt: parsing "x": invalid syntax`)
+ })
+
+ t.Run("nok, int8 embedded in struct", func(t *testing.T) {
+ type target struct {
+ int8 `query:"v"` // embedded field is `Anonymous`. We can only set public fields
+ }
+ p := target{}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+ assert.Equal(t, target{0}, p)
+ })
+
+ t.Run("nok, pointer to int8 embedded in struct", func(t *testing.T) {
+ type target struct {
+ *int8 `query:"v"` // embedded field is `Anonymous`. We can only set public fields
+ }
+ p := target{}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+
+ assert.Equal(t, target{int8: nil}, p)
+ })
+
+ t.Run("ok, bind int8 as struct field", func(t *testing.T) {
+ type target struct {
+ V int8 `query:"v"`
+ }
+ p := target{V: 127}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+ assert.Equal(t, target{V: 1}, p)
+ })
+
+ t.Run("ok, bind pointer to int8 as struct field, value is nil", func(t *testing.T) {
+ type target struct {
+ V *int8 `query:"v"`
+ }
+ p := target{}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+ assert.Equal(t, target{V: ptr(int8(1))}, p)
+ })
+
+ t.Run("ok, bind pointer to int8 as struct field, value is set", func(t *testing.T) {
+ type target struct {
+ V *int8 `query:"v"`
+ }
+ p := target{V: ptr(int8(127))}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+ assert.Equal(t, target{V: ptr(int8(1))}, p)
+ })
+
+ t.Run("ok, bind int8 slice as struct field, value is nil", func(t *testing.T) {
+ type target struct {
+ V []int8 `query:"v"`
+ }
+ p := target{}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+ assert.Equal(t, target{V: []int8{1, 2}}, p)
+ })
+
+ t.Run("ok, bind slice of int8 as struct field, value is set", func(t *testing.T) {
+ type target struct {
+ V []int8 `query:"v"`
+ }
+ p := target{V: []int8{111}}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+ assert.Equal(t, target{V: []int8{1, 2}}, p)
+ })
+
+ t.Run("ok, bind slice of pointer to int8 as struct field, value is set", func(t *testing.T) {
+ type target struct {
+ V []*int8 `query:"v"`
+ }
+ p := target{V: []*int8{ptr(int8(127))}}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+ assert.Equal(t, target{V: []*int8{ptr(int8(1)), ptr(int8(2))}}, p)
+ })
+
+ t.Run("ok, bind pointer to slice of int8 as struct field, value is set", func(t *testing.T) {
+ type target struct {
+ V *[]int8 `query:"v"`
+ }
+ p := target{V: &[]int8{111}}
+ err := testBindURL("/?v=1&v=2", &p)
+ assert.NoError(t, err)
+ assert.Equal(t, target{V: &[]int8{1, 2}}, p)
+ })
+}
+
+func TestBindMultipartFormFiles(t *testing.T) {
+ file1 := createTestFormFile("file", "file1.txt")
+ file11 := createTestFormFile("file", "file11.txt")
+ file2 := createTestFormFile("file2", "file2.txt")
+ filesA := createTestFormFile("files", "filesA.txt")
+ filesB := createTestFormFile("files", "filesB.txt")
+
+ t.Run("nok, can not bind to multipart file struct", func(t *testing.T) {
+ var target struct {
+ File multipart.FileHeader `form:"file"`
+ }
+ err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored
+
+ assert.EqualError(t, err, `code=400, message=Bad Request, err=binding to multipart.FileHeader struct is not supported, use pointer to struct`)
+ })
+
+ t.Run("ok, bind single multipart file to pointer to multipart file", func(t *testing.T) {
+ var target struct {
+ File *multipart.FileHeader `form:"file"`
+ }
+ err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored
+
+ assert.NoError(t, err)
+ assertMultipartFileHeader(t, target.File, file1)
+ })
+
+ t.Run("ok, bind multiple multipart files to pointer to multipart file", func(t *testing.T) {
+ var target struct {
+ File *multipart.FileHeader `form:"file"`
+ }
+ err := bindMultipartFiles(t, &target, file1, file11)
+
+ assert.NoError(t, err)
+ assertMultipartFileHeader(t, target.File, file1) // should choose first one
+ })
+
+ t.Run("ok, bind multiple multipart files to slice of multipart file", func(t *testing.T) {
+ var target struct {
+ Files []multipart.FileHeader `form:"files"`
+ }
+ err := bindMultipartFiles(t, &target, filesA, filesB, file1)
+
+ assert.NoError(t, err)
+
+ assert.Len(t, target.Files, 2)
+ assertMultipartFileHeader(t, &target.Files[0], filesA)
+ assertMultipartFileHeader(t, &target.Files[1], filesB)
+ })
+
+ t.Run("ok, bind multiple multipart files to slice of pointer to multipart file", func(t *testing.T) {
+ var target struct {
+ Files []*multipart.FileHeader `form:"files"`
+ }
+ err := bindMultipartFiles(t, &target, filesA, filesB, file1)
+
+ assert.NoError(t, err)
+
+ assert.Len(t, target.Files, 2)
+ assertMultipartFileHeader(t, target.Files[0], filesA)
+ assertMultipartFileHeader(t, target.Files[1], filesB)
+ })
+}
+
+type testFormFile struct {
+ Fieldname string
+ Filename string
+ Content []byte
+}
+
+func createTestFormFile(formFieldName string, filename string) testFormFile {
+ return testFormFile{
+ Fieldname: formFieldName,
+ Filename: filename,
+ Content: []byte(strings.Repeat(filename, 10)),
+ }
+}
+
+func bindMultipartFiles(t *testing.T, target any, files ...testFormFile) error {
+ var body bytes.Buffer
+ mw := multipart.NewWriter(&body)
+
+ for _, file := range files {
+ fw, err := mw.CreateFormFile(file.Fieldname, file.Filename)
+ assert.NoError(t, err)
+
+ n, err := fw.Write(file.Content)
+ assert.NoError(t, err)
+ assert.Equal(t, len(file.Content), n)
+ }
+
+ err := mw.Close()
+ assert.NoError(t, err)
+
+ req, err := http.NewRequest(http.MethodPost, "/", &body)
+ assert.NoError(t, err)
+ req.Header.Set("Content-Type", mw.FormDataContentType())
+
+ rec := httptest.NewRecorder()
+
+ e := New()
+ c := e.NewContext(req, rec)
+ return c.Bind(target)
+}
+
+func assertMultipartFileHeader(t *testing.T, fh *multipart.FileHeader, file testFormFile) {
+ assert.Equal(t, file.Filename, fh.Filename)
+ assert.Equal(t, int64(len(file.Content)), fh.Size)
+ fl, err := fh.Open()
+ assert.NoError(t, err)
+ body, err := io.ReadAll(fl)
+ assert.NoError(t, err)
+ assert.Equal(t, string(file.Content), string(body))
+ err = fl.Close()
+ assert.NoError(t, err)
+}
+
+func TestTimeFormatBinding(t *testing.T) {
+ type TestStruct struct {
+ DateTimeLocal time.Time `form:"datetime_local" format:"2006-01-02T15:04"`
+ Date time.Time `query:"date" format:"2006-01-02"`
+ CustomFormat time.Time `form:"custom" format:"01/02/2006 15:04:05"`
+ DefaultTime time.Time `form:"default_time"` // No format tag - should use default parsing
+ PtrTime *time.Time `query:"ptr_time" format:"2006-01-02"`
+ }
+
+ testCases := []struct {
+ name string
+ contentType string
+ data string
+ queryParams string
+ expect TestStruct
+ expectError bool
+ }{
+ {
+ name: "ok, datetime-local format binding",
+ contentType: MIMEApplicationForm,
+ data: "datetime_local=2023-12-25T14:30&default_time=2023-12-25T14:30:45Z",
+ expect: TestStruct{
+ DateTimeLocal: time.Date(2023, 12, 25, 14, 30, 0, 0, time.UTC),
+ DefaultTime: time.Date(2023, 12, 25, 14, 30, 45, 0, time.UTC),
+ },
+ },
+ {
+ name: "ok, date format binding via query params",
+ queryParams: "?date=2023-01-15&ptr_time=2023-02-20",
+ expect: TestStruct{
+ Date: time.Date(2023, 1, 15, 0, 0, 0, 0, time.UTC),
+ PtrTime: &time.Time{},
+ },
+ },
+ {
+ name: "ok, custom format via form data",
+ contentType: MIMEApplicationForm,
+ data: "custom=12/25/2023 14:30:45",
+ expect: TestStruct{
+ CustomFormat: time.Date(2023, 12, 25, 14, 30, 45, 0, time.UTC),
+ },
+ },
+ {
+ name: "nok, invalid format should fail",
+ contentType: MIMEApplicationForm,
+ data: "datetime_local=invalid-date",
+ expectError: true,
+ },
+ {
+ name: "nok, wrong format should fail",
+ contentType: MIMEApplicationForm,
+ data: "datetime_local=2023-12-25", // Missing time part
+ expectError: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ var req *http.Request
+
+ if tc.contentType == MIMEApplicationJSON {
+ req = httptest.NewRequest(http.MethodPost, "/"+tc.queryParams, strings.NewReader(tc.data))
+ req.Header.Set(HeaderContentType, tc.contentType)
+ } else if tc.contentType == MIMEApplicationForm {
+ req = httptest.NewRequest(http.MethodPost, "/"+tc.queryParams, strings.NewReader(tc.data))
+ req.Header.Set(HeaderContentType, tc.contentType)
+ } else {
+ req = httptest.NewRequest(http.MethodGet, "/"+tc.queryParams, nil)
+ }
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ var result TestStruct
+ err := c.Bind(&result)
+
+ if tc.expectError {
+ assert.Error(t, err)
+ return
+ }
+
+ assert.NoError(t, err)
+
+ // Check individual fields since time comparison can be tricky
+ if !tc.expect.DateTimeLocal.IsZero() {
+ assert.True(t, tc.expect.DateTimeLocal.Equal(result.DateTimeLocal),
+ "DateTimeLocal: expected %v, got %v", tc.expect.DateTimeLocal, result.DateTimeLocal)
+ }
+ if !tc.expect.Date.IsZero() {
+ assert.True(t, tc.expect.Date.Equal(result.Date),
+ "Date: expected %v, got %v", tc.expect.Date, result.Date)
+ }
+ if !tc.expect.CustomFormat.IsZero() {
+ assert.True(t, tc.expect.CustomFormat.Equal(result.CustomFormat),
+ "CustomFormat: expected %v, got %v", tc.expect.CustomFormat, result.CustomFormat)
+ }
+ if !tc.expect.DefaultTime.IsZero() {
+ assert.True(t, tc.expect.DefaultTime.Equal(result.DefaultTime),
+ "DefaultTime: expected %v, got %v", tc.expect.DefaultTime, result.DefaultTime)
+ }
+ if tc.expect.PtrTime != nil {
+ assert.NotNil(t, result.PtrTime)
+ if result.PtrTime != nil {
+ expectedPtr := time.Date(2023, 2, 20, 0, 0, 0, 0, time.UTC)
+ assert.True(t, expectedPtr.Equal(*result.PtrTime),
+ "PtrTime: expected %v, got %v", expectedPtr, *result.PtrTime)
+ }
+ }
+ })
+ }
+}
diff --git a/binder.go b/binder.go
new file mode 100644
index 000000000..32029ec0f
--- /dev/null
+++ b/binder.go
@@ -0,0 +1,1329 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "encoding"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+)
+
+/**
+ Following functions provide handful of methods for binding to Go native types from request query or path parameters.
+ * QueryParamsBinder(c) - binds query parameters (source URL)
+ * PathValuesBinder(c) - binds path parameters (source URL)
+ * FormFieldBinder(c) - binds form fields (source URL + body)
+
+ Example:
+ ```go
+ var length int64
+ err := echo.QueryParamsBinder(c).Int64("length", &length).BindError()
+ ```
+
+ For every supported type there are following methods:
+ * ("param", &destination) - if parameter value exists then binds it to given destination of that type i.e Int64(...).
+ * Must("param", &destination) - parameter value is required to exist, binds it to given destination of that type i.e MustInt64(...).
+ * s("param", &destination) - (for slices) if parameter values exists then binds it to given destination of that type i.e Int64s(...).
+ * Musts("param", &destination) - (for slices) parameter value is required to exist, binds it to given destination of that type i.e MustInt64s(...).
+
+ for some slice types `BindWithDelimiter("param", &dest, ",")` supports splitting parameter values before type conversion is done
+ i.e. URL `/api/search?id=1,2,3&id=1` can be bind to `[]int64{1,2,3,1}`
+
+ `FailFast` flags binder to stop binding after first bind error during binder call chain. Enabled by default.
+ `BindError()` returns first bind error from binder and resets errors in binder. Useful along with `FailFast()` method
+ to do binding and returns on first problem
+ `BindErrors()` returns all bind errors from binder and resets errors in binder.
+
+ Types that are supported:
+ * bool
+ * float32
+ * float64
+ * int
+ * int8
+ * int16
+ * int32
+ * int64
+ * uint
+ * uint8/byte (does not support `bytes()`. Use BindUnmarshaler/CustomFunc to convert value from base64 etc to []byte{})
+ * uint16
+ * uint32
+ * uint64
+ * string
+ * time
+ * duration
+ * BindUnmarshaler() interface
+ * TextUnmarshaler() interface
+ * JSONUnmarshaler() interface
+ * UnixTime() - converts unix time (integer) to time.Time
+ * UnixTimeMilli() - converts unix time with millisecond precision (integer) to time.Time
+ * UnixTimeNano() - converts unix time with nanosecond precision (integer) to time.Time
+ * CustomFunc() - callback function for your custom conversion logic. Signature `func(values []string) []error`
+*/
+
+// BindingError represents an error that occurred while binding request data.
+type BindingError struct {
+ // Field is the field name where value binding failed
+ Field string `json:"field"`
+ *HTTPError
+ // Values of parameter that failed to bind.
+ Values []string `json:"-"`
+}
+
+// NewBindingError creates new instance of binding error
+func NewBindingError(sourceParam string, values []string, message string, err error) error {
+ return &BindingError{
+ Field: sourceParam,
+ Values: values,
+ HTTPError: &HTTPError{Code: http.StatusBadRequest, Message: message, err: err},
+ }
+}
+
+// Error returns error message
+func (be *BindingError) Error() string {
+ return fmt.Sprintf("%s, field=%s", be.HTTPError.Error(), be.Field)
+}
+
+// ValueBinder provides utility methods for binding query or path parameter to various Go built-in types
+type ValueBinder struct {
+ // ValueFunc is used to get single parameter (first) value from request
+ ValueFunc func(sourceParam string) string
+ // ValuesFunc is used to get all values for parameter from request. i.e. `/api/search?ids=1&ids=2`
+ ValuesFunc func(sourceParam string) []string
+ // ErrorFunc is used to create errors. Allows you to use your own error type, that for example marshals to your specific json response
+ ErrorFunc func(sourceParam string, values []string, message string, internalError error) error
+ errors []error
+ // failFast is flag for binding methods to return without attempting to bind when previous binding already failed
+ failFast bool
+}
+
+// QueryParamsBinder creates query parameter value binder
+func QueryParamsBinder(c *Context) *ValueBinder {
+ return &ValueBinder{
+ failFast: true,
+ ValueFunc: c.QueryParam,
+ ValuesFunc: func(sourceParam string) []string {
+ values, ok := c.QueryParams()[sourceParam]
+ if !ok {
+ return nil
+ }
+ return values
+ },
+ ErrorFunc: NewBindingError,
+ }
+}
+
+// PathValuesBinder creates path parameter value binder
+func PathValuesBinder(c *Context) *ValueBinder {
+ return &ValueBinder{
+ failFast: true,
+ ValueFunc: c.Param,
+ ValuesFunc: func(sourceParam string) []string {
+ // path parameter should not have multiple values so getting values does not make sense but lets not error out here
+ value := c.Param(sourceParam)
+ if value == "" {
+ return nil
+ }
+ return []string{value}
+ },
+ ErrorFunc: NewBindingError,
+ }
+}
+
+// FormFieldBinder creates form field value binder
+// For all requests, FormFieldBinder parses the raw query from the URL and uses query params as form fields
+//
+// For POST, PUT, and PATCH requests, it also reads the request body, parses it
+// as a form and uses query params as form fields. Request body parameters take precedence over URL query
+// string values in r.Form.
+//
+// NB: when binding forms take note that this implementation uses standard library form parsing
+// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm
+// See https://golang.org/pkg/net/http/#Request.ParseForm
+func FormFieldBinder(c *Context) *ValueBinder {
+ vb := &ValueBinder{
+ failFast: true,
+ ValueFunc: func(sourceParam string) string {
+ return c.Request().FormValue(sourceParam)
+ },
+ ErrorFunc: NewBindingError,
+ }
+ vb.ValuesFunc = func(sourceParam string) []string {
+ if c.Request().Form == nil {
+ // this is same as `Request().FormValue()` does internally
+ _, _ = c.MultipartForm() // we want to trigger c.request.ParseMultipartForm(c.formParseMaxMemory)
+ }
+ values, ok := c.Request().Form[sourceParam]
+ if !ok {
+ return nil
+ }
+ return values
+ }
+
+ return vb
+}
+
+// FailFast set internal flag to indicate if binding methods will return early (without binding) when previous bind failed
+// NB: call this method before any other binding methods as it modifies binding methods behaviour
+func (b *ValueBinder) FailFast(value bool) *ValueBinder {
+ b.failFast = value
+ return b
+}
+
+func (b *ValueBinder) setError(err error) {
+ if b.errors == nil {
+ b.errors = []error{err}
+ return
+ }
+ b.errors = append(b.errors, err)
+}
+
+// BindError returns first seen bind error and resets/empties binder errors for further calls
+func (b *ValueBinder) BindError() error {
+ if b.errors == nil {
+ return nil
+ }
+ err := b.errors[0]
+ b.errors = nil // reset errors so next chain will start from zero
+ return err
+}
+
+// BindErrors returns all bind errors and resets/empties binder errors for further calls
+func (b *ValueBinder) BindErrors() []error {
+ if b.errors == nil {
+ return nil
+ }
+ errors := b.errors
+ b.errors = nil // reset errors so next chain will start from zero
+ return errors
+}
+
+// CustomFunc binds parameter values with Func. Func is called only when parameter values exist.
+func (b *ValueBinder) CustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder {
+ return b.customFunc(sourceParam, customFunc, false)
+}
+
+// MustCustomFunc requires parameter values to exist to bind with Func. Returns error when value does not exist.
+func (b *ValueBinder) MustCustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder {
+ return b.customFunc(sourceParam, customFunc, true)
+}
+
+func (b *ValueBinder) customFunc(sourceParam string, customFunc func(values []string) []error, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ values := b.ValuesFunc(sourceParam)
+ if len(values) == 0 {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+ if errs := customFunc(values); errs != nil {
+ b.errors = append(b.errors, errs...)
+ }
+ return b
+}
+
+// String binds parameter to string variable
+func (b *ValueBinder) String(sourceParam string, dest *string) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ return b
+ }
+ *dest = value
+ return b
+}
+
+// MustString requires parameter value to exist to bind to string variable. Returns error when value does not exist
+func (b *ValueBinder) MustString(sourceParam string, dest *string) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil))
+ return b
+ }
+ *dest = value
+ return b
+}
+
+// Strings binds parameter values to slice of string
+func (b *ValueBinder) Strings(sourceParam string, dest *[]string) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValuesFunc(sourceParam)
+ if value == nil {
+ return b
+ }
+ *dest = value
+ return b
+}
+
+// MustStrings requires parameter values to exist to bind to slice of string variables. Returns error when value does not exist
+func (b *ValueBinder) MustStrings(sourceParam string, dest *[]string) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValuesFunc(sourceParam)
+ if value == nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ return b
+ }
+ *dest = value
+ return b
+}
+
+// BindUnmarshaler binds parameter to destination implementing BindUnmarshaler interface
+func (b *ValueBinder) BindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ tmp := b.ValueFunc(sourceParam)
+ if tmp == "" {
+ return b
+ }
+
+ if err := dest.UnmarshalParam(tmp); err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to BindUnmarshaler interface", err))
+ }
+ return b
+}
+
+// MustBindUnmarshaler requires parameter value to exist to bind to destination implementing BindUnmarshaler interface.
+// Returns error when value does not exist
+func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil))
+ return b
+ }
+
+ if err := dest.UnmarshalParam(value); err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to BindUnmarshaler interface", err))
+ }
+ return b
+}
+
+// JSONUnmarshaler binds parameter to destination implementing json.Unmarshaler interface
+func (b *ValueBinder) JSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ tmp := b.ValueFunc(sourceParam)
+ if tmp == "" {
+ return b
+ }
+
+ if err := dest.UnmarshalJSON([]byte(tmp)); err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err))
+ }
+ return b
+}
+
+// MustJSONUnmarshaler requires parameter value to exist to bind to destination implementing json.Unmarshaler interface.
+// Returns error when value does not exist
+func (b *ValueBinder) MustJSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ tmp := b.ValueFunc(sourceParam)
+ if tmp == "" {
+ b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil))
+ return b
+ }
+
+ if err := dest.UnmarshalJSON([]byte(tmp)); err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err))
+ }
+ return b
+}
+
+// TextUnmarshaler binds parameter to destination implementing encoding.TextUnmarshaler interface
+func (b *ValueBinder) TextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ tmp := b.ValueFunc(sourceParam)
+ if tmp == "" {
+ return b
+ }
+
+ if err := dest.UnmarshalText([]byte(tmp)); err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err))
+ }
+ return b
+}
+
+// MustTextUnmarshaler requires parameter value to exist to bind to destination implementing encoding.TextUnmarshaler interface.
+// Returns error when value does not exist
+func (b *ValueBinder) MustTextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ tmp := b.ValueFunc(sourceParam)
+ if tmp == "" {
+ b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil))
+ return b
+ }
+
+ if err := dest.UnmarshalText([]byte(tmp)); err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err))
+ }
+ return b
+}
+
+// BindWithDelimiter binds parameter to destination by suitable conversion function.
+// Delimiter is used before conversion to split parameter value to separate values
+func (b *ValueBinder) BindWithDelimiter(sourceParam string, dest any, delimiter string) *ValueBinder {
+ return b.bindWithDelimiter(sourceParam, dest, delimiter, false)
+}
+
+// MustBindWithDelimiter requires parameter value to exist to bind destination by suitable conversion function.
+// Delimiter is used before conversion to split parameter value to separate values
+func (b *ValueBinder) MustBindWithDelimiter(sourceParam string, dest any, delimiter string) *ValueBinder {
+ return b.bindWithDelimiter(sourceParam, dest, delimiter, true)
+}
+
+func (b *ValueBinder) bindWithDelimiter(sourceParam string, dest any, delimiter string, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ values := b.ValuesFunc(sourceParam)
+ if len(values) == 0 {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+ tmpValues := make([]string, 0, len(values))
+ for _, v := range values {
+ tmpValues = append(tmpValues, strings.Split(v, delimiter)...)
+ }
+
+ switch d := dest.(type) {
+ case *[]string:
+ *d = tmpValues
+ return b
+ case *[]bool:
+ return b.bools(sourceParam, tmpValues, d)
+ case *[]int64, *[]int32, *[]int16, *[]int8, *[]int:
+ return b.ints(sourceParam, tmpValues, d)
+ case *[]uint64, *[]uint32, *[]uint16, *[]uint8, *[]uint: // *[]byte is same as *[]uint8
+ return b.uints(sourceParam, tmpValues, d)
+ case *[]float64, *[]float32:
+ return b.floats(sourceParam, tmpValues, d)
+ case *[]time.Duration:
+ return b.durations(sourceParam, tmpValues, d)
+ default:
+ // support only cases when destination is slice
+ // does not support time.Time as it needs argument (layout) for parsing or BindUnmarshaler
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "unsupported bind type", nil))
+ return b
+ }
+}
+
+// Int64 binds parameter to int64 variable
+func (b *ValueBinder) Int64(sourceParam string, dest *int64) *ValueBinder {
+ return b.intValue(sourceParam, dest, 64, false)
+}
+
+// MustInt64 requires parameter value to exist to bind to int64 variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt64(sourceParam string, dest *int64) *ValueBinder {
+ return b.intValue(sourceParam, dest, 64, true)
+}
+
+// Int32 binds parameter to int32 variable
+func (b *ValueBinder) Int32(sourceParam string, dest *int32) *ValueBinder {
+ return b.intValue(sourceParam, dest, 32, false)
+}
+
+// MustInt32 requires parameter value to exist to bind to int32 variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt32(sourceParam string, dest *int32) *ValueBinder {
+ return b.intValue(sourceParam, dest, 32, true)
+}
+
+// Int16 binds parameter to int16 variable
+func (b *ValueBinder) Int16(sourceParam string, dest *int16) *ValueBinder {
+ return b.intValue(sourceParam, dest, 16, false)
+}
+
+// MustInt16 requires parameter value to exist to bind to int16 variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt16(sourceParam string, dest *int16) *ValueBinder {
+ return b.intValue(sourceParam, dest, 16, true)
+}
+
+// Int8 binds parameter to int8 variable
+func (b *ValueBinder) Int8(sourceParam string, dest *int8) *ValueBinder {
+ return b.intValue(sourceParam, dest, 8, false)
+}
+
+// MustInt8 requires parameter value to exist to bind to int8 variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt8(sourceParam string, dest *int8) *ValueBinder {
+ return b.intValue(sourceParam, dest, 8, true)
+}
+
+// Int binds parameter to int variable
+func (b *ValueBinder) Int(sourceParam string, dest *int) *ValueBinder {
+ return b.intValue(sourceParam, dest, 0, false)
+}
+
+// MustInt requires parameter value to exist to bind to int variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt(sourceParam string, dest *int) *ValueBinder {
+ return b.intValue(sourceParam, dest, 0, true)
+}
+
+func (b *ValueBinder) intValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+
+ return b.int(sourceParam, value, dest, bitSize)
+}
+
+func (b *ValueBinder) int(sourceParam string, value string, dest any, bitSize int) *ValueBinder {
+ n, err := strconv.ParseInt(value, 10, bitSize)
+ if err != nil {
+ if bitSize == 0 {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to int", err))
+ } else {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to int%v", bitSize), err))
+ }
+ return b
+ }
+
+ switch d := dest.(type) {
+ case *int64:
+ *d = n
+ case *int32:
+ *d = int32(n) // #nosec G115
+ case *int16:
+ *d = int16(n) // #nosec G115
+ case *int8:
+ *d = int8(n) // #nosec G115
+ case *int:
+ *d = int(n)
+ }
+ return b
+}
+
+func (b *ValueBinder) intsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ values := b.ValuesFunc(sourceParam)
+ if len(values) == 0 {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, values, "required field value is empty", nil))
+ }
+ return b
+ }
+ return b.ints(sourceParam, values, dest)
+}
+
+func (b *ValueBinder) ints(sourceParam string, values []string, dest any) *ValueBinder {
+ switch d := dest.(type) {
+ case *[]int64:
+ tmp := make([]int64, len(values))
+ for i, v := range values {
+ b.int(sourceParam, v, &tmp[i], 64)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]int32:
+ tmp := make([]int32, len(values))
+ for i, v := range values {
+ b.int(sourceParam, v, &tmp[i], 32)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]int16:
+ tmp := make([]int16, len(values))
+ for i, v := range values {
+ b.int(sourceParam, v, &tmp[i], 16)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]int8:
+ tmp := make([]int8, len(values))
+ for i, v := range values {
+ b.int(sourceParam, v, &tmp[i], 8)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]int:
+ tmp := make([]int, len(values))
+ for i, v := range values {
+ b.int(sourceParam, v, &tmp[i], 0)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ }
+ return b
+}
+
+// Int64s binds parameter to slice of int64
+func (b *ValueBinder) Int64s(sourceParam string, dest *[]int64) *ValueBinder {
+ return b.intsValue(sourceParam, dest, false)
+}
+
+// MustInt64s requires parameter value to exist to bind to int64 slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt64s(sourceParam string, dest *[]int64) *ValueBinder {
+ return b.intsValue(sourceParam, dest, true)
+}
+
+// Int32s binds parameter to slice of int32
+func (b *ValueBinder) Int32s(sourceParam string, dest *[]int32) *ValueBinder {
+ return b.intsValue(sourceParam, dest, false)
+}
+
+// MustInt32s requires parameter value to exist to bind to int32 slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt32s(sourceParam string, dest *[]int32) *ValueBinder {
+ return b.intsValue(sourceParam, dest, true)
+}
+
+// Int16s binds parameter to slice of int16
+func (b *ValueBinder) Int16s(sourceParam string, dest *[]int16) *ValueBinder {
+ return b.intsValue(sourceParam, dest, false)
+}
+
+// MustInt16s requires parameter value to exist to bind to int16 slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt16s(sourceParam string, dest *[]int16) *ValueBinder {
+ return b.intsValue(sourceParam, dest, true)
+}
+
+// Int8s binds parameter to slice of int8
+func (b *ValueBinder) Int8s(sourceParam string, dest *[]int8) *ValueBinder {
+ return b.intsValue(sourceParam, dest, false)
+}
+
+// MustInt8s requires parameter value to exist to bind to int8 slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustInt8s(sourceParam string, dest *[]int8) *ValueBinder {
+ return b.intsValue(sourceParam, dest, true)
+}
+
+// Ints binds parameter to slice of int
+func (b *ValueBinder) Ints(sourceParam string, dest *[]int) *ValueBinder {
+ return b.intsValue(sourceParam, dest, false)
+}
+
+// MustInts requires parameter value to exist to bind to int slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustInts(sourceParam string, dest *[]int) *ValueBinder {
+ return b.intsValue(sourceParam, dest, true)
+}
+
+// Uint64 binds parameter to uint64 variable
+func (b *ValueBinder) Uint64(sourceParam string, dest *uint64) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 64, false)
+}
+
+// MustUint64 requires parameter value to exist to bind to uint64 variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint64(sourceParam string, dest *uint64) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 64, true)
+}
+
+// Uint32 binds parameter to uint32 variable
+func (b *ValueBinder) Uint32(sourceParam string, dest *uint32) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 32, false)
+}
+
+// MustUint32 requires parameter value to exist to bind to uint32 variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint32(sourceParam string, dest *uint32) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 32, true)
+}
+
+// Uint16 binds parameter to uint16 variable
+func (b *ValueBinder) Uint16(sourceParam string, dest *uint16) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 16, false)
+}
+
+// MustUint16 requires parameter value to exist to bind to uint16 variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint16(sourceParam string, dest *uint16) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 16, true)
+}
+
+// Uint8 binds parameter to uint8 variable
+func (b *ValueBinder) Uint8(sourceParam string, dest *uint8) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 8, false)
+}
+
+// MustUint8 requires parameter value to exist to bind to uint8 variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint8(sourceParam string, dest *uint8) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 8, true)
+}
+
+// Byte binds parameter to byte variable
+func (b *ValueBinder) Byte(sourceParam string, dest *byte) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 8, false)
+}
+
+// MustByte requires parameter value to exist to bind to byte variable. Returns error when value does not exist
+func (b *ValueBinder) MustByte(sourceParam string, dest *byte) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 8, true)
+}
+
+// Uint binds parameter to uint variable
+func (b *ValueBinder) Uint(sourceParam string, dest *uint) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 0, false)
+}
+
+// MustUint requires parameter value to exist to bind to uint variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint(sourceParam string, dest *uint) *ValueBinder {
+ return b.uintValue(sourceParam, dest, 0, true)
+}
+
+func (b *ValueBinder) uintValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+
+ return b.uint(sourceParam, value, dest, bitSize)
+}
+
+func (b *ValueBinder) uint(sourceParam string, value string, dest any, bitSize int) *ValueBinder {
+ n, err := strconv.ParseUint(value, 10, bitSize)
+ if err != nil {
+ if bitSize == 0 {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to uint", err))
+ } else {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to uint%v", bitSize), err))
+ }
+ return b
+ }
+
+ switch d := dest.(type) {
+ case *uint64:
+ *d = n
+ case *uint32:
+ *d = uint32(n) // #nosec G115
+ case *uint16:
+ *d = uint16(n) // #nosec G115
+ case *uint8: // byte is alias to uint8
+ *d = uint8(n) // #nosec G115
+ case *uint:
+ *d = uint(n) // #nosec G115
+ }
+ return b
+}
+
+func (b *ValueBinder) uintsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ values := b.ValuesFunc(sourceParam)
+ if len(values) == 0 {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, values, "required field value is empty", nil))
+ }
+ return b
+ }
+ return b.uints(sourceParam, values, dest)
+}
+
+func (b *ValueBinder) uints(sourceParam string, values []string, dest any) *ValueBinder {
+ switch d := dest.(type) {
+ case *[]uint64:
+ tmp := make([]uint64, len(values))
+ for i, v := range values {
+ b.uint(sourceParam, v, &tmp[i], 64)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]uint32:
+ tmp := make([]uint32, len(values))
+ for i, v := range values {
+ b.uint(sourceParam, v, &tmp[i], 32)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]uint16:
+ tmp := make([]uint16, len(values))
+ for i, v := range values {
+ b.uint(sourceParam, v, &tmp[i], 16)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]uint8: // byte is alias to uint8
+ tmp := make([]uint8, len(values))
+ for i, v := range values {
+ b.uint(sourceParam, v, &tmp[i], 8)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]uint:
+ tmp := make([]uint, len(values))
+ for i, v := range values {
+ b.uint(sourceParam, v, &tmp[i], 0)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ }
+ return b
+}
+
+// Uint64s binds parameter to slice of uint64
+func (b *ValueBinder) Uint64s(sourceParam string, dest *[]uint64) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, false)
+}
+
+// MustUint64s requires parameter value to exist to bind to uint64 slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint64s(sourceParam string, dest *[]uint64) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, true)
+}
+
+// Uint32s binds parameter to slice of uint32
+func (b *ValueBinder) Uint32s(sourceParam string, dest *[]uint32) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, false)
+}
+
+// MustUint32s requires parameter value to exist to bind to uint32 slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint32s(sourceParam string, dest *[]uint32) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, true)
+}
+
+// Uint16s binds parameter to slice of uint16
+func (b *ValueBinder) Uint16s(sourceParam string, dest *[]uint16) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, false)
+}
+
+// MustUint16s requires parameter value to exist to bind to uint16 slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint16s(sourceParam string, dest *[]uint16) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, true)
+}
+
+// Uint8s binds parameter to slice of uint8
+func (b *ValueBinder) Uint8s(sourceParam string, dest *[]uint8) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, false)
+}
+
+// MustUint8s requires parameter value to exist to bind to uint8 slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustUint8s(sourceParam string, dest *[]uint8) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, true)
+}
+
+// Uints binds parameter to slice of uint
+func (b *ValueBinder) Uints(sourceParam string, dest *[]uint) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, false)
+}
+
+// MustUints requires parameter value to exist to bind to uint slice variable. Returns error when value does not exist
+func (b *ValueBinder) MustUints(sourceParam string, dest *[]uint) *ValueBinder {
+ return b.uintsValue(sourceParam, dest, true)
+}
+
+// Bool binds parameter to bool variable
+func (b *ValueBinder) Bool(sourceParam string, dest *bool) *ValueBinder {
+ return b.boolValue(sourceParam, dest, false)
+}
+
+// MustBool requires parameter value to exist to bind to bool variable. Returns error when value does not exist
+func (b *ValueBinder) MustBool(sourceParam string, dest *bool) *ValueBinder {
+ return b.boolValue(sourceParam, dest, true)
+}
+
+func (b *ValueBinder) boolValue(sourceParam string, dest *bool, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+ return b.bool(sourceParam, value, dest)
+}
+
+func (b *ValueBinder) bool(sourceParam string, value string, dest *bool) *ValueBinder {
+ n, err := strconv.ParseBool(value)
+ if err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to bool", err))
+ return b
+ }
+
+ *dest = n
+ return b
+}
+
+func (b *ValueBinder) boolsValue(sourceParam string, dest *[]bool, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ values := b.ValuesFunc(sourceParam)
+ if len(values) == 0 {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+ return b.bools(sourceParam, values, dest)
+}
+
+func (b *ValueBinder) bools(sourceParam string, values []string, dest *[]bool) *ValueBinder {
+ tmp := make([]bool, len(values))
+ for i, v := range values {
+ b.bool(sourceParam, v, &tmp[i])
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *dest = tmp
+ }
+ return b
+}
+
+// Bools binds parameter values to slice of bool variables
+func (b *ValueBinder) Bools(sourceParam string, dest *[]bool) *ValueBinder {
+ return b.boolsValue(sourceParam, dest, false)
+}
+
+// MustBools requires parameter values to exist to bind to slice of bool variables. Returns error when values does not exist
+func (b *ValueBinder) MustBools(sourceParam string, dest *[]bool) *ValueBinder {
+ return b.boolsValue(sourceParam, dest, true)
+}
+
+// Float64 binds parameter to float64 variable
+func (b *ValueBinder) Float64(sourceParam string, dest *float64) *ValueBinder {
+ return b.floatValue(sourceParam, dest, 64, false)
+}
+
+// MustFloat64 requires parameter value to exist to bind to float64 variable. Returns error when value does not exist
+func (b *ValueBinder) MustFloat64(sourceParam string, dest *float64) *ValueBinder {
+ return b.floatValue(sourceParam, dest, 64, true)
+}
+
+// Float32 binds parameter to float32 variable
+func (b *ValueBinder) Float32(sourceParam string, dest *float32) *ValueBinder {
+ return b.floatValue(sourceParam, dest, 32, false)
+}
+
+// MustFloat32 requires parameter value to exist to bind to float32 variable. Returns error when value does not exist
+func (b *ValueBinder) MustFloat32(sourceParam string, dest *float32) *ValueBinder {
+ return b.floatValue(sourceParam, dest, 32, true)
+}
+
+func (b *ValueBinder) floatValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+
+ return b.float(sourceParam, value, dest, bitSize)
+}
+
+func (b *ValueBinder) float(sourceParam string, value string, dest any, bitSize int) *ValueBinder {
+ n, err := strconv.ParseFloat(value, bitSize)
+ if err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to float%v", bitSize), err))
+ return b
+ }
+
+ switch d := dest.(type) {
+ case *float64:
+ *d = n
+ case *float32:
+ *d = float32(n)
+ }
+ return b
+}
+
+func (b *ValueBinder) floatsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ values := b.ValuesFunc(sourceParam)
+ if len(values) == 0 {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+ return b.floats(sourceParam, values, dest)
+}
+
+func (b *ValueBinder) floats(sourceParam string, values []string, dest any) *ValueBinder {
+ switch d := dest.(type) {
+ case *[]float64:
+ tmp := make([]float64, len(values))
+ for i, v := range values {
+ b.float(sourceParam, v, &tmp[i], 64)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ case *[]float32:
+ tmp := make([]float32, len(values))
+ for i, v := range values {
+ b.float(sourceParam, v, &tmp[i], 32)
+ if b.failFast && b.errors != nil {
+ return b
+ }
+ }
+ if b.errors == nil {
+ *d = tmp
+ }
+ }
+ return b
+}
+
+// Float64s binds parameter values to slice of float64 variables
+func (b *ValueBinder) Float64s(sourceParam string, dest *[]float64) *ValueBinder {
+ return b.floatsValue(sourceParam, dest, false)
+}
+
+// MustFloat64s requires parameter values to exist to bind to slice of float64 variables. Returns error when values does not exist
+func (b *ValueBinder) MustFloat64s(sourceParam string, dest *[]float64) *ValueBinder {
+ return b.floatsValue(sourceParam, dest, true)
+}
+
+// Float32s binds parameter values to slice of float32 variables
+func (b *ValueBinder) Float32s(sourceParam string, dest *[]float32) *ValueBinder {
+ return b.floatsValue(sourceParam, dest, false)
+}
+
+// MustFloat32s requires parameter values to exist to bind to slice of float32 variables. Returns error when values does not exist
+func (b *ValueBinder) MustFloat32s(sourceParam string, dest *[]float32) *ValueBinder {
+ return b.floatsValue(sourceParam, dest, true)
+}
+
+// Time binds parameter to time.Time variable
+func (b *ValueBinder) Time(sourceParam string, dest *time.Time, layout string) *ValueBinder {
+ return b.time(sourceParam, dest, layout, false)
+}
+
+// MustTime requires parameter value to exist to bind to time.Time variable. Returns error when value does not exist
+func (b *ValueBinder) MustTime(sourceParam string, dest *time.Time, layout string) *ValueBinder {
+ return b.time(sourceParam, dest, layout, true)
+}
+
+func (b *ValueBinder) time(sourceParam string, dest *time.Time, layout string, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil))
+ }
+ return b
+ }
+ t, err := time.Parse(layout, value)
+ if err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Time", err))
+ return b
+ }
+ *dest = t
+ return b
+}
+
+// Times binds parameter values to slice of time.Time variables
+func (b *ValueBinder) Times(sourceParam string, dest *[]time.Time, layout string) *ValueBinder {
+ return b.times(sourceParam, dest, layout, false)
+}
+
+// MustTimes requires parameter values to exist to bind to slice of time.Time variables. Returns error when values does not exist
+func (b *ValueBinder) MustTimes(sourceParam string, dest *[]time.Time, layout string) *ValueBinder {
+ return b.times(sourceParam, dest, layout, true)
+}
+
+func (b *ValueBinder) times(sourceParam string, dest *[]time.Time, layout string, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ values := b.ValuesFunc(sourceParam)
+ if len(values) == 0 {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+
+ tmp := make([]time.Time, len(values))
+ for i, v := range values {
+ t, err := time.Parse(layout, v)
+ if err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{v}, "failed to bind field value to Time", err))
+ if b.failFast {
+ return b
+ }
+ continue
+ }
+ tmp[i] = t
+ }
+ if b.errors == nil {
+ *dest = tmp
+ }
+ return b
+}
+
+// Duration binds parameter to time.Duration variable
+func (b *ValueBinder) Duration(sourceParam string, dest *time.Duration) *ValueBinder {
+ return b.duration(sourceParam, dest, false)
+}
+
+// MustDuration requires parameter value to exist to bind to time.Duration variable. Returns error when value does not exist
+func (b *ValueBinder) MustDuration(sourceParam string, dest *time.Duration) *ValueBinder {
+ return b.duration(sourceParam, dest, true)
+}
+
+func (b *ValueBinder) duration(sourceParam string, dest *time.Duration, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil))
+ }
+ return b
+ }
+ t, err := time.ParseDuration(value)
+ if err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Duration", err))
+ return b
+ }
+ *dest = t
+ return b
+}
+
+// Durations binds parameter values to slice of time.Duration variables
+func (b *ValueBinder) Durations(sourceParam string, dest *[]time.Duration) *ValueBinder {
+ return b.durationsValue(sourceParam, dest, false)
+}
+
+// MustDurations requires parameter values to exist to bind to slice of time.Duration variables. Returns error when values does not exist
+func (b *ValueBinder) MustDurations(sourceParam string, dest *[]time.Duration) *ValueBinder {
+ return b.durationsValue(sourceParam, dest, true)
+}
+
+func (b *ValueBinder) durationsValue(sourceParam string, dest *[]time.Duration, valueMustExist bool) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ values := b.ValuesFunc(sourceParam)
+ if len(values) == 0 {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil))
+ }
+ return b
+ }
+ return b.durations(sourceParam, values, dest)
+}
+
+func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]time.Duration) *ValueBinder {
+ tmp := make([]time.Duration, len(values))
+ for i, v := range values {
+ t, err := time.ParseDuration(v)
+ if err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{v}, "failed to bind field value to Duration", err))
+ if b.failFast {
+ return b
+ }
+ continue
+ }
+ tmp[i] = t
+ }
+ if b.errors == nil {
+ *dest = tmp
+ }
+ return b
+}
+
+// UnixTime binds parameter to time.Time variable (in local Time corresponding to the given Unix time).
+//
+// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
+//
+// Note:
+// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
+func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder {
+ return b.unixTime(sourceParam, dest, false, time.Second)
+}
+
+// MustUnixTime requires parameter value to exist to bind to time.Duration variable (in local time corresponding
+// to the given Unix time). Returns error when value does not exist.
+//
+// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00
+//
+// Note:
+// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
+func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder {
+ return b.unixTime(sourceParam, dest, true, time.Second)
+}
+
+// UnixTimeMilli binds parameter to time.Time variable (in local time corresponding to the given Unix time in millisecond precision).
+//
+// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
+//
+// Note:
+// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
+func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
+ return b.unixTime(sourceParam, dest, false, time.Millisecond)
+}
+
+// MustUnixTimeMilli requires parameter value to exist to bind to time.Duration variable (in local time corresponding
+// to the given Unix time in millisecond precision). Returns error when value does not exist.
+//
+// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00
+//
+// Note:
+// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
+func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder {
+ return b.unixTime(sourceParam, dest, true, time.Millisecond)
+}
+
+// UnixTimeNano binds parameter to time.Time variable (in local time corresponding to the given Unix time in nanosecond precision).
+//
+// Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00
+// Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00
+// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00
+//
+// Note:
+// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
+// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
+func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
+ return b.unixTime(sourceParam, dest, false, time.Nanosecond)
+}
+
+// MustUnixTimeNano requires parameter value to exist to bind to time.Duration variable (in local Time corresponding
+// to the given Unix time value in nano second precision). Returns error when value does not exist.
+//
+// Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00
+// Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00
+// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00
+//
+// Note:
+// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal
+// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example.
+func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder {
+ return b.unixTime(sourceParam, dest, true, time.Nanosecond)
+}
+
+func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, precision time.Duration) *ValueBinder {
+ if b.failFast && b.errors != nil {
+ return b
+ }
+
+ value := b.ValueFunc(sourceParam)
+ if value == "" {
+ if valueMustExist {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil))
+ }
+ return b
+ }
+
+ n, err := strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Time", err))
+ return b
+ }
+
+ switch precision {
+ case time.Second:
+ *dest = time.Unix(n, 0)
+ case time.Millisecond:
+ *dest = time.UnixMilli(n)
+ case time.Nanosecond:
+ *dest = time.Unix(0, n)
+ }
+ return b
+}
diff --git a/binder_external_test.go b/binder_external_test.go
new file mode 100644
index 000000000..d83c891b3
--- /dev/null
+++ b/binder_external_test.go
@@ -0,0 +1,134 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+// run tests as external package to get real feel for API
+package echo_test
+
+import (
+ "encoding/base64"
+ "fmt"
+ "log"
+ "net/http"
+ "net/http/httptest"
+
+ "github.com/labstack/echo/v5"
+)
+
+func ExampleValueBinder_BindErrors() {
+ // example route function that binds query params to different destinations and returns all bind errors in one go
+ routeFunc := func(c *echo.Context) error {
+ var opts struct {
+ IDs []int64
+ Active bool
+ }
+ length := int64(50) // default length is 50
+
+ b := echo.QueryParamsBinder(c)
+
+ errs := b.Int64("length", &length).
+ Int64s("ids", &opts.IDs).
+ Bool("active", &opts.Active).
+ BindErrors() // returns all errors
+ if errs != nil {
+ for _, err := range errs {
+ bErr := err.(*echo.BindingError)
+ log.Printf("in case you want to access what field: %s values: %v failed", bErr.Field, bErr.Values)
+ }
+ return fmt.Errorf("%v fields failed to bind", len(errs))
+ }
+ fmt.Printf("active = %v, length = %v, ids = %v", opts.Active, length, opts.IDs)
+
+ return c.JSON(http.StatusOK, opts)
+ }
+
+ e := echo.New()
+ c := e.NewContext(
+ httptest.NewRequest(http.MethodGet, "/api/endpoint?active=true&length=25&ids=1&ids=2&ids=3", nil),
+ httptest.NewRecorder(),
+ )
+
+ _ = routeFunc(c)
+
+ // Output: active = true, length = 25, ids = [1 2 3]
+}
+
+func ExampleValueBinder_BindError() {
+ // example route function that binds query params to different destinations and stops binding on first bind error
+ failFastRouteFunc := func(c *echo.Context) error {
+ var opts struct {
+ IDs []int64
+ Active bool
+ }
+ length := int64(50) // default length is 50
+
+ // create binder that stops binding at first error
+ b := echo.QueryParamsBinder(c)
+
+ err := b.Int64("length", &length).
+ Int64s("ids", &opts.IDs).
+ Bool("active", &opts.Active).
+ BindError() // returns first binding error
+ if err != nil {
+ bErr := err.(*echo.BindingError)
+ return fmt.Errorf("my own custom error for field: %s values: %v", bErr.Field, bErr.Values)
+ }
+ fmt.Printf("active = %v, length = %v, ids = %v\n", opts.Active, length, opts.IDs)
+
+ return c.JSON(http.StatusOK, opts)
+ }
+
+ e := echo.New()
+ c := e.NewContext(
+ httptest.NewRequest(http.MethodGet, "/api/endpoint?active=true&length=25&ids=1&ids=2&ids=3", nil),
+ httptest.NewRecorder(),
+ )
+
+ _ = failFastRouteFunc(c)
+
+ // Output: active = true, length = 25, ids = [1 2 3]
+}
+
+func ExampleValueBinder_CustomFunc() {
+ // example route function that binds query params using custom function closure
+ routeFunc := func(c *echo.Context) error {
+ length := int64(50) // default length is 50
+ var binary []byte
+
+ b := echo.QueryParamsBinder(c)
+ errs := b.Int64("length", &length).
+ CustomFunc("base64", func(values []string) []error {
+ if len(values) == 0 {
+ return nil
+ }
+ decoded, err := base64.URLEncoding.DecodeString(values[0])
+ if err != nil {
+ // in this example we use only first param value but url could contain multiple params in reality and
+ // therefore in theory produce multiple binding errors
+ return []error{echo.NewBindingError("base64", values[0:1], "failed to decode base64", err)}
+ }
+ binary = decoded
+ return nil
+ }).
+ BindErrors() // returns all errors
+
+ if errs != nil {
+ for _, err := range errs {
+ bErr := err.(*echo.BindingError)
+ log.Printf("in case you want to access what field: %s values: %v failed", bErr.Field, bErr.Values)
+ }
+ return fmt.Errorf("%v fields failed to bind", len(errs))
+ }
+ fmt.Printf("length = %v, base64 = %s", length, binary)
+
+ return c.JSON(http.StatusOK, "ok")
+ }
+
+ e := echo.New()
+ c := e.NewContext(
+ httptest.NewRequest(http.MethodGet, "/api/endpoint?length=25&base64=SGVsbG8gV29ybGQ%3D", nil),
+ httptest.NewRecorder(),
+ )
+ _ = routeFunc(c)
+
+ // Output: length = 25, base64 = Hello World
+}
diff --git a/binder_generic.go b/binder_generic.go
new file mode 100644
index 000000000..0c0eb9089
--- /dev/null
+++ b/binder_generic.go
@@ -0,0 +1,563 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "encoding"
+ "encoding/json"
+ "fmt"
+ "strconv"
+ "time"
+)
+
+// TimeLayout specifies the format for parsing time values in request parameters.
+// It can be a standard Go time layout string or one of the special Unix time layouts.
+type TimeLayout string
+
+// TimeOpts is options for parsing time.Time values
+type TimeOpts struct {
+ // Layout specifies the format for parsing time values in request parameters.
+ // It can be a standard Go time layout string or one of the special Unix time layouts.
+ //
+ // Parsing layout defaults to: echo.TimeLayout(time.RFC3339Nano)
+ // - To convert to custom layout use `echo.TimeLayout("2006-01-02")`
+ // - To convert unix timestamp (integer) to time.Time use `echo.TimeLayoutUnixTime`
+ // - To convert unix timestamp in milliseconds to time.Time use `echo.TimeLayoutUnixTimeMilli`
+ // - To convert unix timestamp in nanoseconds to time.Time use `echo.TimeLayoutUnixTimeNano`
+ Layout TimeLayout
+
+ // ParseInLocation is location used with time.ParseInLocation for layout that do not contain
+ // timezone information to set output time in given location.
+ // Defaults to time.UTC
+ ParseInLocation *time.Location
+
+ // ToInLocation is location to which parsed time is converted to after parsing.
+ // The parsed time will be converted using time.In(ToInLocation).
+ // Defaults to time.UTC
+ ToInLocation *time.Location
+}
+
+// TimeLayout constants for parsing Unix timestamps in different precisions.
+const (
+ TimeLayoutUnixTime = TimeLayout("UnixTime") // Unix timestamp in seconds
+ TimeLayoutUnixTimeMilli = TimeLayout("UnixTimeMilli") // Unix timestamp in milliseconds
+ TimeLayoutUnixTimeNano = TimeLayout("UnixTimeNano") // Unix timestamp in nanoseconds
+)
+
+// PathParam extracts and parses a path parameter from the context by name.
+// It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found.
+//
+// Empty String Handling:
+// If the parameter exists but has an empty value, the zero value of type T is returned
+// with no error. For example, a path parameter with value "" returns (0, nil) for int types.
+// This differs from standard library behavior where parsing empty strings returns errors.
+// To treat empty values as errors, validate the result separately or check the raw value.
+//
+// See ParseValue for supported types and options
+func PathParam[T any](c *Context, paramName string, opts ...any) (T, error) {
+ for _, pv := range c.PathValues() {
+ if pv.Name == paramName {
+ v, err := ParseValue[T](pv.Value, opts...)
+ if err != nil {
+ return v, NewBindingError(paramName, []string{pv.Value}, "path value", err)
+ }
+ return v, nil
+ }
+ }
+ var zero T
+ return zero, ErrNonExistentKey
+}
+
+// PathParamOr extracts and parses a path parameter from the context by name.
+// Returns defaultValue if the parameter is not found or has an empty value.
+// Returns an error only if parsing fails (e.g., "abc" for int type).
+//
+// Example:
+// id, err := echo.PathParamOr[int](c, "id", 0)
+// // If "id" is missing: returns (0, nil)
+// // If "id" is "123": returns (123, nil)
+// // If "id" is "abc": returns (0, BindingError)
+//
+// See ParseValue for supported types and options
+func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...any) (T, error) {
+ for _, pv := range c.PathValues() {
+ if pv.Name == paramName {
+ v, err := ParseValueOr[T](pv.Value, defaultValue, opts...)
+ if err != nil {
+ return v, NewBindingError(paramName, []string{pv.Value}, "path value", err)
+ }
+ return v, nil
+ }
+ }
+ return defaultValue, nil
+}
+
+// QueryParam extracts and parses a single query parameter from the request by key.
+// It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found.
+//
+// Empty String Handling:
+// If the parameter exists but has an empty value (?key=), the zero value of type T is returned
+// with no error. For example, "?count=" returns (0, nil) for int types.
+// This differs from standard library behavior where parsing empty strings returns errors.
+// To treat empty values as errors, validate the result separately or check the raw value.
+//
+// Behavior Summary:
+// - Missing key (?other=value): returns (zero, ErrNonExistentKey)
+// - Empty value (?key=): returns (zero, nil)
+// - Invalid value (?key=abc for int): returns (zero, BindingError)
+//
+// See ParseValue for supported types and options
+func QueryParam[T any](c *Context, key string, opts ...any) (T, error) {
+ values, ok := c.QueryParams()[key]
+ if !ok {
+ var zero T
+ return zero, ErrNonExistentKey
+ }
+ if len(values) == 0 {
+ var zero T
+ return zero, nil
+ }
+ value := values[0]
+ v, err := ParseValue[T](value, opts...)
+ if err != nil {
+ return v, NewBindingError(key, []string{value}, "query param", err)
+ }
+ return v, nil
+}
+
+// QueryParamOr extracts and parses a single query parameter from the request by key.
+// Returns defaultValue if the parameter is not found or has an empty value.
+// Returns an error only if parsing fails (e.g., "abc" for int type).
+//
+// Example:
+// page, err := echo.QueryParamOr[int](c, "page", 1)
+// // If "page" is missing: returns (1, nil)
+// // If "page" is "5": returns (5, nil)
+// // If "page" is "abc": returns (1, BindingError)
+//
+// See ParseValue for supported types and options
+func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) {
+ values, ok := c.QueryParams()[key]
+ if !ok {
+ return defaultValue, nil
+ }
+ if len(values) == 0 {
+ return defaultValue, nil
+ }
+ value := values[0]
+ v, err := ParseValueOr[T](value, defaultValue, opts...)
+ if err != nil {
+ return v, NewBindingError(key, []string{value}, "query param", err)
+ }
+ return v, nil
+}
+
+// QueryParams extracts and parses all values for a query parameter key as a slice.
+// It returns the typed slice and an error if binding any value fails. Returns ErrNonExistentKey if parameter not found.
+//
+// See ParseValues for supported types and options
+func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error) {
+ values, ok := c.QueryParams()[key]
+ if !ok {
+ return nil, ErrNonExistentKey
+ }
+
+ result, err := ParseValues[T](values, opts...)
+ if err != nil {
+ return nil, NewBindingError(key, values, "query params", err)
+ }
+ return result, nil
+}
+
+// QueryParamsOr extracts and parses all values for a query parameter key as a slice.
+// Returns defaultValue if the parameter is not found.
+// Returns an error only if parsing any value fails.
+//
+// Example:
+// ids, err := echo.QueryParamsOr[int](c, "ids", []int{})
+// // If "ids" is missing: returns ([], nil)
+// // If "ids" is "1&ids=2": returns ([1, 2], nil)
+// // If "ids" contains "abc": returns ([], BindingError)
+//
+// See ParseValues for supported types and options
+func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) {
+ values, ok := c.QueryParams()[key]
+ if !ok {
+ return defaultValue, nil
+ }
+
+ result, err := ParseValuesOr[T](values, defaultValue, opts...)
+ if err != nil {
+ return nil, NewBindingError(key, values, "query params", err)
+ }
+ return result, nil
+}
+
+// FormValue extracts and parses a single form value from the request by key.
+// It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found.
+//
+// Empty String Handling:
+// If the form field exists but has an empty value, the zero value of type T is returned
+// with no error. For example, an empty form field returns (0, nil) for int types.
+// This differs from standard library behavior where parsing empty strings returns errors.
+// To treat empty values as errors, validate the result separately or check the raw value.
+//
+// See ParseValue for supported types and options
+func FormValue[T any](c *Context, key string, opts ...any) (T, error) {
+ formValues, err := c.FormValues()
+ if err != nil {
+ var zero T
+ return zero, fmt.Errorf("failed to parse form value, key: %s, err: %w", key, err)
+ }
+ values, ok := formValues[key]
+ if !ok {
+ var zero T
+ return zero, ErrNonExistentKey
+ }
+ if len(values) == 0 {
+ var zero T
+ return zero, nil
+ }
+ value := values[0]
+ v, err := ParseValue[T](value, opts...)
+ if err != nil {
+ return v, NewBindingError(key, []string{value}, "form value", err)
+ }
+ return v, nil
+}
+
+// FormValueOr extracts and parses a single form value from the request by key.
+// Returns defaultValue if the parameter is not found or has an empty value.
+// Returns an error only if parsing fails or form parsing errors occur.
+//
+// Example:
+// limit, err := echo.FormValueOr[int](c, "limit", 100)
+// // If "limit" is missing: returns (100, nil)
+// // If "limit" is "50": returns (50, nil)
+// // If "limit" is "abc": returns (100, BindingError)
+//
+// See ParseValue for supported types and options
+func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) {
+ formValues, err := c.FormValues()
+ if err != nil {
+ var zero T
+ return zero, fmt.Errorf("failed to parse form value, key: %s, err: %w", key, err)
+ }
+ values, ok := formValues[key]
+ if !ok {
+ return defaultValue, nil
+ }
+ if len(values) == 0 {
+ return defaultValue, nil
+ }
+ value := values[0]
+ v, err := ParseValueOr[T](value, defaultValue, opts...)
+ if err != nil {
+ return v, NewBindingError(key, []string{value}, "form value", err)
+ }
+ return v, nil
+}
+
+// FormValues extracts and parses all values for a form values key as a slice.
+// It returns the typed slice and an error if binding any value fails. Returns ErrNonExistentKey if parameter not found.
+//
+// See ParseValues for supported types and options
+func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) {
+ formValues, err := c.FormValues()
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse form values, key: %s, err: %w", key, err)
+ }
+ values, ok := formValues[key]
+ if !ok {
+ return nil, ErrNonExistentKey
+ }
+ result, err := ParseValues[T](values, opts...)
+ if err != nil {
+ return nil, NewBindingError(key, values, "form values", err)
+ }
+ return result, nil
+}
+
+// FormValuesOr extracts and parses all values for a form values key as a slice.
+// Returns defaultValue if the parameter is not found.
+// Returns an error only if parsing any value fails or form parsing errors occur.
+//
+// Example:
+// tags, err := echo.FormValuesOr[string](c, "tags", []string{})
+// // If "tags" is missing: returns ([], nil)
+// // If form parsing fails: returns (nil, error)
+//
+// See ParseValues for supported types and options
+func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) {
+ formValues, err := c.FormValues()
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse form values, key: %s, err: %w", key, err)
+ }
+ values, ok := formValues[key]
+ if !ok {
+ return defaultValue, nil
+ }
+ result, err := ParseValuesOr[T](values, defaultValue, opts...)
+ if err != nil {
+ return nil, NewBindingError(key, values, "form values", err)
+ }
+ return result, nil
+}
+
+// ParseValues parses value to generic type slice. Same types are supported as ParseValue
+// function but the result type is slice instead of scalar value.
+//
+// See ParseValue for supported types and options
+func ParseValues[T any](values []string, opts ...any) ([]T, error) {
+ var zero []T
+ return ParseValuesOr(values, zero, opts...)
+}
+
+// ParseValuesOr parses value to generic type slice, when value is empty defaultValue is returned.
+// Same types are supported as ParseValue function but the result type is slice instead of scalar value.
+//
+// See ParseValue for supported types and options
+func ParseValuesOr[T any](values []string, defaultValue []T, opts ...any) ([]T, error) {
+ if len(values) == 0 {
+ return defaultValue, nil
+ }
+ result := make([]T, 0, len(values))
+ for _, v := range values {
+ tmp, err := ParseValue[T](v, opts...)
+ if err != nil {
+ return nil, err
+ }
+ result = append(result, tmp)
+ }
+ return result, nil
+}
+
+// ParseValue parses value to generic type
+//
+// Types that are supported:
+// - bool
+// - float32
+// - float64
+// - int
+// - int8
+// - int16
+// - int32
+// - int64
+// - uint
+// - uint8/byte
+// - uint16
+// - uint32
+// - uint64
+// - string
+// - echo.BindUnmarshaler interface
+// - encoding.TextUnmarshaler interface
+// - json.Unmarshaler interface
+// - time.Duration
+// - time.Time use echo.TimeOpts or echo.TimeLayout to set time parsing configuration
+func ParseValue[T any](value string, opts ...any) (T, error) {
+ var zero T
+ return ParseValueOr(value, zero, opts...)
+}
+
+// ParseValueOr parses value to generic type, when value is empty defaultValue is returned.
+//
+// Types that are supported:
+// - bool
+// - float32
+// - float64
+// - int
+// - int8
+// - int16
+// - int32
+// - int64
+// - uint
+// - uint8/byte
+// - uint16
+// - uint32
+// - uint64
+// - string
+// - echo.BindUnmarshaler interface
+// - encoding.TextUnmarshaler interface
+// - json.Unmarshaler interface
+// - time.Duration
+// - time.Time use echo.TimeOpts or echo.TimeLayout to set time parsing configuration
+func ParseValueOr[T any](value string, defaultValue T, opts ...any) (T, error) {
+ if len(value) == 0 {
+ return defaultValue, nil
+ }
+ var tmp T
+ if err := bindValue(value, &tmp, opts...); err != nil {
+ var zero T
+ return zero, fmt.Errorf("failed to parse value, err: %w", err)
+ }
+ return tmp, nil
+}
+
+func bindValue(value string, dest any, opts ...any) error {
+ // NOTE: if this function is ever made public the dest should be checked for nil
+ // values when dealing with interfaces
+ if len(opts) > 0 {
+ if _, isTime := dest.(*time.Time); !isTime {
+ return fmt.Errorf("options are only supported for time.Time, got %T", dest)
+ }
+ }
+
+ switch d := dest.(type) {
+ case *bool:
+ n, err := strconv.ParseBool(value)
+ if err != nil {
+ return err
+ }
+ *d = n
+ case *float32:
+ n, err := strconv.ParseFloat(value, 32)
+ if err != nil {
+ return err
+ }
+ *d = float32(n)
+ case *float64:
+ n, err := strconv.ParseFloat(value, 64)
+ if err != nil {
+ return err
+ }
+ *d = n
+ case *int:
+ n, err := strconv.ParseInt(value, 10, 0)
+ if err != nil {
+ return err
+ }
+ *d = int(n)
+ case *int8:
+ n, err := strconv.ParseInt(value, 10, 8)
+ if err != nil {
+ return err
+ }
+ *d = int8(n)
+ case *int16:
+ n, err := strconv.ParseInt(value, 10, 16)
+ if err != nil {
+ return err
+ }
+ *d = int16(n)
+ case *int32:
+ n, err := strconv.ParseInt(value, 10, 32)
+ if err != nil {
+ return err
+ }
+ *d = int32(n)
+ case *int64:
+ n, err := strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ return err
+ }
+ *d = n
+ case *uint:
+ n, err := strconv.ParseUint(value, 10, 0)
+ if err != nil {
+ return err
+ }
+ *d = uint(n)
+ case *uint8:
+ n, err := strconv.ParseUint(value, 10, 8)
+ if err != nil {
+ return err
+ }
+ *d = uint8(n)
+ case *uint16:
+ n, err := strconv.ParseUint(value, 10, 16)
+ if err != nil {
+ return err
+ }
+ *d = uint16(n)
+ case *uint32:
+ n, err := strconv.ParseUint(value, 10, 32)
+ if err != nil {
+ return err
+ }
+ *d = uint32(n)
+ case *uint64:
+ n, err := strconv.ParseUint(value, 10, 64)
+ if err != nil {
+ return err
+ }
+ *d = n
+ case *string:
+ *d = value
+ case *time.Duration:
+ t, err := time.ParseDuration(value)
+ if err != nil {
+ return err
+ }
+ *d = t
+ case *time.Time:
+ to := TimeOpts{
+ Layout: TimeLayout(time.RFC3339Nano),
+ ParseInLocation: time.UTC,
+ ToInLocation: time.UTC,
+ }
+ for _, o := range opts {
+ switch v := o.(type) {
+ case TimeOpts:
+ if v.Layout != "" {
+ to.Layout = v.Layout
+ }
+ if v.ParseInLocation != nil {
+ to.ParseInLocation = v.ParseInLocation
+ }
+ if v.ToInLocation != nil {
+ to.ToInLocation = v.ToInLocation
+ }
+ case TimeLayout:
+ to.Layout = v
+ }
+ }
+ var t time.Time
+ var err error
+ switch to.Layout {
+ case TimeLayoutUnixTime:
+ n, err := strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ return err
+ }
+ t = time.Unix(n, 0)
+ case TimeLayoutUnixTimeMilli:
+ n, err := strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ return err
+ }
+ t = time.UnixMilli(n)
+ case TimeLayoutUnixTimeNano:
+ n, err := strconv.ParseInt(value, 10, 64)
+ if err != nil {
+ return err
+ }
+ t = time.Unix(0, n)
+ default:
+ if to.ParseInLocation != nil {
+ t, err = time.ParseInLocation(string(to.Layout), value, to.ParseInLocation)
+ } else {
+ t, err = time.Parse(string(to.Layout), value)
+ }
+ if err != nil {
+ return err
+ }
+ }
+ *d = t.In(to.ToInLocation)
+ case BindUnmarshaler:
+ if err := d.UnmarshalParam(value); err != nil {
+ return err
+ }
+ case encoding.TextUnmarshaler:
+ if err := d.UnmarshalText([]byte(value)); err != nil {
+ return err
+ }
+ case json.Unmarshaler:
+ if err := d.UnmarshalJSON([]byte(value)); err != nil {
+ return err
+ }
+ default:
+ return fmt.Errorf("unsupported value type: %T", dest)
+ }
+ return nil
+}
diff --git a/binder_generic_test.go b/binder_generic_test.go
new file mode 100644
index 000000000..849d75962
--- /dev/null
+++ b/binder_generic_test.go
@@ -0,0 +1,1616 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "cmp"
+ "encoding/json"
+ "fmt"
+ "math"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+// TextUnmarshalerType implements encoding.TextUnmarshaler but NOT BindUnmarshaler
+type TextUnmarshalerType struct {
+ Value string
+}
+
+func (t *TextUnmarshalerType) UnmarshalText(data []byte) error {
+ s := string(data)
+ if s == "invalid" {
+ return fmt.Errorf("invalid value: %s", s)
+ }
+ t.Value = strings.ToUpper(s)
+ return nil
+}
+
+// JSONUnmarshalerType implements json.Unmarshaler but NOT BindUnmarshaler or TextUnmarshaler
+type JSONUnmarshalerType struct {
+ Value string
+}
+
+func (j *JSONUnmarshalerType) UnmarshalJSON(data []byte) error {
+ return json.Unmarshal(data, &j.Value)
+}
+
+func TestPathParam(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenKey string
+ givenValue string
+ expect bool
+ expectErr string
+ }{
+ {
+ name: "ok",
+ givenValue: "true",
+ expect: true,
+ },
+ {
+ name: "nok, non existent key",
+ givenKey: "missing",
+ givenValue: "true",
+ expect: false,
+ expectErr: ErrNonExistentKey.Error(),
+ },
+ {
+ name: "nok, invalid value",
+ givenValue: "can_parse_me",
+ expect: false,
+ expectErr: `code=400, message=path value, err=failed to parse value, err: strconv.ParseBool: parsing "can_parse_me": invalid syntax, field=key`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := NewContext(nil, nil)
+ c.SetPathValues(PathValues{{
+ Name: cmp.Or(tc.givenKey, "key"),
+ Value: tc.givenValue,
+ }})
+
+ v, err := PathParam[bool](c, "key")
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestPathParam_UnsupportedType(t *testing.T) {
+ c := NewContext(nil, nil)
+ c.SetPathValues(PathValues{{Name: "key", Value: "true"}})
+
+ v, err := PathParam[[]bool](c, "key")
+
+ expectErr := "code=400, message=path value, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
+ assert.EqualError(t, err, expectErr)
+ assert.Equal(t, []bool(nil), v)
+}
+
+func TestQueryParam(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ expect bool
+ expectErr string
+ }{
+ {
+ name: "ok",
+ givenURL: "/?key=true",
+ expect: true,
+ },
+ {
+ name: "nok, non existent key",
+ givenURL: "/?different=true",
+ expect: false,
+ expectErr: ErrNonExistentKey.Error(),
+ },
+ {
+ name: "nok, invalid value",
+ givenURL: "/?key=invalidbool",
+ expect: false,
+ expectErr: `code=400, message=query param, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
+ c := NewContext(req, nil)
+
+ v, err := QueryParam[bool](c, "key")
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestQueryParam_UnsupportedType(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil)
+ c := NewContext(req, nil)
+
+ v, err := QueryParam[[]bool](c, "key")
+
+ expectErr := "code=400, message=query param, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
+ assert.EqualError(t, err, expectErr)
+ assert.Equal(t, []bool(nil), v)
+}
+
+func TestQueryParams(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ expect []bool
+ expectErr string
+ }{
+ {
+ name: "ok",
+ givenURL: "/?key=true&key=false",
+ expect: []bool{true, false},
+ },
+ {
+ name: "nok, non existent key",
+ givenURL: "/?different=true",
+ expect: []bool(nil),
+ expectErr: ErrNonExistentKey.Error(),
+ },
+ {
+ name: "nok, invalid value",
+ givenURL: "/?key=true&key=invalidbool",
+ expect: []bool(nil),
+ expectErr: `code=400, message=query params, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
+ c := NewContext(req, nil)
+
+ v, err := QueryParams[bool](c, "key")
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestQueryParams_UnsupportedType(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil)
+ c := NewContext(req, nil)
+
+ v, err := QueryParams[[]bool](c, "key")
+
+ expectErr := "code=400, message=query params, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
+ assert.EqualError(t, err, expectErr)
+ assert.Equal(t, [][]bool(nil), v)
+}
+
+func TestFormValue(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ expect bool
+ expectErr string
+ }{
+ {
+ name: "ok",
+ givenURL: "/?key=true",
+ expect: true,
+ },
+ {
+ name: "nok, non existent key",
+ givenURL: "/?different=true",
+ expect: false,
+ expectErr: ErrNonExistentKey.Error(),
+ },
+ {
+ name: "nok, invalid value",
+ givenURL: "/?key=invalidbool",
+ expect: false,
+ expectErr: `code=400, message=form value, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
+ c := NewContext(req, nil)
+
+ v, err := FormValue[bool](c, "key")
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestFormValue_UnsupportedType(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil)
+ c := NewContext(req, nil)
+
+ v, err := FormValue[[]bool](c, "key")
+
+ expectErr := "code=400, message=form value, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
+ assert.EqualError(t, err, expectErr)
+ assert.Equal(t, []bool(nil), v)
+}
+
+func TestFormValues(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ expect []bool
+ expectErr string
+ }{
+ {
+ name: "ok",
+ givenURL: "/?key=true&key=false",
+ expect: []bool{true, false},
+ },
+ {
+ name: "nok, non existent key",
+ givenURL: "/?different=true",
+ expect: []bool(nil),
+ expectErr: ErrNonExistentKey.Error(),
+ },
+ {
+ name: "nok, invalid value",
+ givenURL: "/?key=true&key=invalidbool",
+ expect: []bool(nil),
+ expectErr: `code=400, message=form values, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
+ c := NewContext(req, nil)
+
+ v, err := FormValues[bool](c, "key")
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestFormValues_UnsupportedType(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil)
+ c := NewContext(req, nil)
+
+ v, err := FormValues[[]bool](c, "key")
+
+ expectErr := "code=400, message=form values, err=failed to parse value, err: unsupported value type: *[]bool, field=key"
+ assert.EqualError(t, err, expectErr)
+ assert.Equal(t, [][]bool(nil), v)
+}
+
+func TestParseValue_bool(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect bool
+ expectErr error
+ }{
+ {
+ name: "ok, true",
+ when: "true",
+ expect: true,
+ },
+ {
+ name: "ok, false",
+ when: "false",
+ expect: false,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: true,
+ },
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: false,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[bool](tc.when)
+ if tc.expectErr != nil {
+ assert.ErrorIs(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_float32(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect float32
+ expectErr string
+ }{
+ {
+ name: "ok, 123.345",
+ when: "123.345",
+ expect: 123.345,
+ },
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, Inf",
+ when: "+Inf",
+ expect: float32(math.Inf(1)),
+ },
+ {
+ name: "ok, Inf",
+ when: "-Inf",
+ expect: float32(math.Inf(-1)),
+ },
+ {
+ name: "ok, NaN",
+ when: "NaN",
+ expect: float32(math.NaN()),
+ },
+ {
+ name: "ok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseFloat: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[float32](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ if math.IsNaN(float64(tc.expect)) {
+ if !math.IsNaN(float64(v)) {
+ t.Fatal("expected NaN but got non NaN")
+ }
+ } else {
+ assert.Equal(t, tc.expect, v)
+ }
+ })
+ }
+}
+
+func TestParseValue_float64(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect float64
+ expectErr string
+ }{
+ {
+ name: "ok, 123.345",
+ when: "123.345",
+ expect: 123.345,
+ },
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, Inf",
+ when: "+Inf",
+ expect: math.Inf(1),
+ },
+ {
+ name: "ok, Inf",
+ when: "-Inf",
+ expect: math.Inf(-1),
+ },
+ {
+ name: "ok, NaN",
+ when: "NaN",
+ expect: math.NaN(),
+ },
+ {
+ name: "ok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseFloat: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[float64](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ if math.IsNaN(tc.expect) {
+ if !math.IsNaN(v) {
+ t.Fatal("expected NaN but got non NaN")
+ }
+ } else {
+ assert.Equal(t, tc.expect, v)
+ }
+ })
+ }
+}
+
+func TestParseValue_int(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect int
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, -1",
+ when: "-1",
+ expect: -1,
+ },
+ {
+ name: "ok, max int (64bit)",
+ when: "9223372036854775807",
+ expect: 9223372036854775807,
+ },
+ {
+ name: "ok, min int (64bit)",
+ when: "-9223372036854775808",
+ expect: -9223372036854775808,
+ },
+ {
+ name: "ok, overflow max int (64bit)",
+ when: "9223372036854775808",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "9223372036854775808": value out of range`,
+ },
+ {
+ name: "ok, underflow min int (64bit)",
+ when: "-9223372036854775809",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-9223372036854775809": value out of range`,
+ },
+ {
+ name: "ok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[int](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_uint(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect uint
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, max uint (64bit)",
+ when: "18446744073709551615",
+ expect: 18446744073709551615,
+ },
+ {
+ name: "nok, overflow max uint (64bit)",
+ when: "18446744073709551616",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "18446744073709551616": value out of range`,
+ },
+ {
+ name: "nok, negative value",
+ when: "-1",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[uint](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_int8(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect int8
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, -1",
+ when: "-1",
+ expect: -1,
+ },
+ {
+ name: "ok, max int8",
+ when: "127",
+ expect: 127,
+ },
+ {
+ name: "ok, min int8",
+ when: "-128",
+ expect: -128,
+ },
+ {
+ name: "nok, overflow max int8",
+ when: "128",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "128": value out of range`,
+ },
+ {
+ name: "nok, underflow min int8",
+ when: "-129",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-129": value out of range`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[int8](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_int16(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect int16
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, -1",
+ when: "-1",
+ expect: -1,
+ },
+ {
+ name: "ok, max int16",
+ when: "32767",
+ expect: 32767,
+ },
+ {
+ name: "ok, min int16",
+ when: "-32768",
+ expect: -32768,
+ },
+ {
+ name: "nok, overflow max int16",
+ when: "32768",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "32768": value out of range`,
+ },
+ {
+ name: "nok, underflow min int16",
+ when: "-32769",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-32769": value out of range`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[int16](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_int32(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect int32
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, -1",
+ when: "-1",
+ expect: -1,
+ },
+ {
+ name: "ok, max int32",
+ when: "2147483647",
+ expect: 2147483647,
+ },
+ {
+ name: "ok, min int32",
+ when: "-2147483648",
+ expect: -2147483648,
+ },
+ {
+ name: "nok, overflow max int32",
+ when: "2147483648",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "2147483648": value out of range`,
+ },
+ {
+ name: "nok, underflow min int32",
+ when: "-2147483649",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-2147483649": value out of range`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[int32](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_int64(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect int64
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, -1",
+ when: "-1",
+ expect: -1,
+ },
+ {
+ name: "ok, max int64",
+ when: "9223372036854775807",
+ expect: 9223372036854775807,
+ },
+ {
+ name: "ok, min int64",
+ when: "-9223372036854775808",
+ expect: -9223372036854775808,
+ },
+ {
+ name: "nok, overflow max int64",
+ when: "9223372036854775808",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "9223372036854775808": value out of range`,
+ },
+ {
+ name: "nok, underflow min int64",
+ when: "-9223372036854775809",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-9223372036854775809": value out of range`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[int64](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_uint8(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect uint8
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, max uint8",
+ when: "255",
+ expect: 255,
+ },
+ {
+ name: "nok, overflow max uint8",
+ when: "256",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "256": value out of range`,
+ },
+ {
+ name: "nok, negative value",
+ when: "-1",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[uint8](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_uint16(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect uint16
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, max uint16",
+ when: "65535",
+ expect: 65535,
+ },
+ {
+ name: "nok, overflow max uint16",
+ when: "65536",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "65536": value out of range`,
+ },
+ {
+ name: "nok, negative value",
+ when: "-1",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[uint16](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_uint32(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect uint32
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, max uint32",
+ when: "4294967295",
+ expect: 4294967295,
+ },
+ {
+ name: "nok, overflow max uint32",
+ when: "4294967296",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "4294967296": value out of range`,
+ },
+ {
+ name: "nok, negative value",
+ when: "-1",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[uint32](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_uint64(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect uint64
+ expectErr string
+ }{
+ {
+ name: "ok, 0",
+ when: "0",
+ expect: 0,
+ },
+ {
+ name: "ok, 1",
+ when: "1",
+ expect: 1,
+ },
+ {
+ name: "ok, max uint64",
+ when: "18446744073709551615",
+ expect: 18446744073709551615,
+ },
+ {
+ name: "nok, overflow max uint64",
+ when: "18446744073709551616",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "18446744073709551616": value out of range`,
+ },
+ {
+ name: "nok, negative value",
+ when: "-1",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`,
+ },
+ {
+ name: "nok, invalid value",
+ when: "X",
+ expect: 0,
+ expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[uint64](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_string(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect string
+ expectErr string
+ }{
+ {
+ name: "ok, my",
+ when: "my",
+ expect: "my",
+ },
+ {
+ name: "ok, empty",
+ when: "",
+ expect: "",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[string](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_Duration(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect time.Duration
+ expectErr string
+ }{
+ {
+ name: "ok, 10h11m01s",
+ when: "10h11m01s",
+ expect: 10*time.Hour + 11*time.Minute + 1*time.Second,
+ },
+ {
+ name: "ok, empty",
+ when: "",
+ expect: 0,
+ },
+ {
+ name: "ok, invalid",
+ when: "0x0",
+ expect: 0,
+ expectErr: `failed to parse value, err: time: unknown unit "x" in duration "0x0"`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[time.Duration](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_Time(t *testing.T) {
+ tallinn, err := time.LoadLocation("Europe/Tallinn")
+ if err != nil {
+ t.Fatal(err)
+ }
+ berlin, err := time.LoadLocation("Europe/Berlin")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ parse := func(t *testing.T, layout string, s string) time.Time {
+ result, err := time.Parse(layout, s)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return result
+ }
+
+ parseInLoc := func(t *testing.T, layout string, s string, loc *time.Location) time.Time {
+ result, err := time.ParseInLocation(layout, s, loc)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return result
+ }
+
+ var testCases = []struct {
+ name string
+ when string
+ whenLayout TimeLayout
+ whenTimeOpts *TimeOpts
+ expect time.Time
+ expectErr string
+ }{
+ {
+ name: "ok, defaults to RFC3339Nano",
+ when: "2006-01-02T15:04:05.999999999Z",
+ expect: parse(t, time.RFC3339Nano, "2006-01-02T15:04:05.999999999Z"),
+ },
+ {
+ name: "ok, custom TimeOpt",
+ when: "2006-01-02",
+ whenTimeOpts: &TimeOpts{
+ Layout: time.DateOnly,
+ ParseInLocation: tallinn,
+ ToInLocation: berlin,
+ },
+ expect: parseInLoc(t, time.DateTime, "2006-01-01 23:00:00", berlin),
+ },
+ {
+ name: "ok, custom layout",
+ when: "2006-01-02",
+ whenLayout: TimeLayout(time.DateOnly),
+ expect: parse(t, time.DateOnly, "2006-01-02"),
+ },
+ {
+ name: "ok, TimeLayoutUnixTime",
+ when: "1766604665",
+ whenLayout: TimeLayoutUnixTime,
+ expect: parse(t, time.RFC3339Nano, "2025-12-24T19:31:05Z"),
+ },
+ {
+ name: "nok, TimeLayoutUnixTime, invalid value",
+ when: "176x6604665",
+ whenLayout: TimeLayoutUnixTime,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "176x6604665": invalid syntax`,
+ },
+ {
+ name: "ok, TimeLayoutUnixTimeMilli",
+ when: "1766604665123",
+ whenLayout: TimeLayoutUnixTimeMilli,
+ expect: parse(t, time.RFC3339Nano, "2025-12-24T19:31:05.123Z"),
+ },
+ {
+ name: "nok, TimeLayoutUnixTimeMilli, invalid value",
+ when: "1x766604665123",
+ whenLayout: TimeLayoutUnixTimeMilli,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "1x766604665123": invalid syntax`,
+ },
+ {
+ name: "ok, TimeLayoutUnixTimeMilli",
+ when: "1766604665999999999",
+ whenLayout: TimeLayoutUnixTimeNano,
+ expect: parse(t, time.RFC3339Nano, "2025-12-24T19:31:05.999999999Z"),
+ },
+ {
+ name: "nok, TimeLayoutUnixTimeMilli, invalid value",
+ when: "1x766604665999999999",
+ whenLayout: TimeLayoutUnixTimeNano,
+ expectErr: `failed to parse value, err: strconv.ParseInt: parsing "1x766604665999999999": invalid syntax`,
+ },
+ {
+ name: "ok, invalid",
+ when: "xx",
+ expect: time.Time{},
+ expectErr: `failed to parse value, err: parsing time "xx" as "2006-01-02T15:04:05.999999999Z07:00": cannot parse "xx" as "2006"`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ var opts []any
+ if tc.whenLayout != "" {
+ opts = append(opts, tc.whenLayout)
+ }
+ if tc.whenTimeOpts != nil {
+ opts = append(opts, *tc.whenTimeOpts)
+ }
+ v, err := ParseValue[time.Time](tc.when, opts...)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_OptionsOnlyForTime(t *testing.T) {
+ _, err := ParseValue[int]("test", TimeLayoutUnixTime)
+ assert.EqualError(t, err, `failed to parse value, err: options are only supported for time.Time, got *int`)
+}
+
+func TestParseValue_BindUnmarshaler(t *testing.T) {
+ exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00")
+
+ var testCases = []struct {
+ name string
+ when string
+ expect Timestamp
+ expectErr string
+ }{
+ {
+ name: "ok",
+ when: "2020-12-23T09:45:31+02:00",
+ expect: Timestamp(exampleTime),
+ },
+ {
+ name: "nok, invalid value",
+ when: "2020-12-23T09:45:3102:00",
+ expect: Timestamp{},
+ expectErr: `failed to parse value, err: parsing time "2020-12-23T09:45:3102:00" as "2006-01-02T15:04:05Z07:00": cannot parse "02:00" as "Z07:00"`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[Timestamp](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_TextUnmarshaler(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect TextUnmarshalerType
+ expectErr string
+ }{
+ {
+ name: "ok, converts to uppercase",
+ when: "hello",
+ expect: TextUnmarshalerType{Value: "HELLO"},
+ },
+ {
+ name: "ok, empty string",
+ when: "",
+ expect: TextUnmarshalerType{Value: ""},
+ },
+ {
+ name: "nok, invalid value",
+ when: "invalid",
+ expect: TextUnmarshalerType{},
+ expectErr: "failed to parse value, err: invalid value: invalid",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[TextUnmarshalerType](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValue_JSONUnmarshaler(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when string
+ expect JSONUnmarshalerType
+ expectErr string
+ }{
+ {
+ name: "ok, valid JSON string",
+ when: `"hello"`,
+ expect: JSONUnmarshalerType{Value: "hello"},
+ },
+ {
+ name: "ok, empty JSON string",
+ when: `""`,
+ expect: JSONUnmarshalerType{Value: ""},
+ },
+ {
+ name: "nok, invalid JSON",
+ when: "not-json",
+ expect: JSONUnmarshalerType{},
+ expectErr: "failed to parse value, err: invalid character 'o' in literal null (expecting 'u')",
+ },
+ {
+ name: "nok, unquoted string",
+ when: "hello",
+ expect: JSONUnmarshalerType{},
+ expectErr: "failed to parse value, err: invalid character 'h' looking for beginning of value",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValue[JSONUnmarshalerType](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestParseValues_bools(t *testing.T) {
+ var testCases = []struct {
+ name string
+ when []string
+ expect []bool
+ expectErr string
+ }{
+ {
+ name: "ok",
+ when: []string{"true", "0", "false", "1"},
+ expect: []bool{true, false, false, true},
+ },
+ {
+ name: "nok",
+ when: []string{"true", "10"},
+ expect: nil,
+ expectErr: `failed to parse value, err: strconv.ParseBool: parsing "10": invalid syntax`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ v, err := ParseValues[bool](tc.when)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestPathParamOr(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenKey string
+ givenValue string
+ defaultValue int
+ expect int
+ expectErr string
+ }{
+ {
+ name: "ok, param exists",
+ givenKey: "id",
+ givenValue: "123",
+ defaultValue: 999,
+ expect: 123,
+ },
+ {
+ name: "ok, param missing - returns default",
+ givenKey: "other",
+ givenValue: "123",
+ defaultValue: 999,
+ expect: 999,
+ },
+ {
+ name: "ok, param exists but empty - returns default",
+ givenKey: "id",
+ givenValue: "",
+ defaultValue: 999,
+ expect: 999,
+ },
+ {
+ name: "nok, invalid value",
+ givenKey: "id",
+ givenValue: "invalid",
+ defaultValue: 999,
+ expectErr: "code=400, message=path value, err=failed to parse value",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := NewContext(nil, nil)
+ c.SetPathValues(PathValues{{Name: tc.givenKey, Value: tc.givenValue}})
+
+ v, err := PathParamOr[int](c, "id", tc.defaultValue)
+ if tc.expectErr != "" {
+ assert.ErrorContains(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestQueryParamOr(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ defaultValue int
+ expect int
+ expectErr string
+ }{
+ {
+ name: "ok, param exists",
+ givenURL: "/?key=42",
+ defaultValue: 999,
+ expect: 42,
+ },
+ {
+ name: "ok, param missing - returns default",
+ givenURL: "/?other=42",
+ defaultValue: 999,
+ expect: 999,
+ },
+ {
+ name: "ok, param exists but empty - returns default",
+ givenURL: "/?key=",
+ defaultValue: 999,
+ expect: 999,
+ },
+ {
+ name: "nok, invalid value",
+ givenURL: "/?key=invalid",
+ defaultValue: 999,
+ expectErr: "code=400, message=query param",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
+ c := NewContext(req, nil)
+
+ v, err := QueryParamOr[int](c, "key", tc.defaultValue)
+ if tc.expectErr != "" {
+ assert.ErrorContains(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestQueryParamsOr(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ defaultValue []int
+ expect []int
+ expectErr string
+ }{
+ {
+ name: "ok, params exist",
+ givenURL: "/?key=1&key=2&key=3",
+ defaultValue: []int{999},
+ expect: []int{1, 2, 3},
+ },
+ {
+ name: "ok, params missing - returns default",
+ givenURL: "/?other=1",
+ defaultValue: []int{7, 8, 9},
+ expect: []int{7, 8, 9},
+ },
+ {
+ name: "nok, invalid value",
+ givenURL: "/?key=1&key=invalid",
+ defaultValue: []int{999},
+ expectErr: "code=400, message=query params",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
+ c := NewContext(req, nil)
+
+ v, err := QueryParamsOr[int](c, "key", tc.defaultValue)
+ if tc.expectErr != "" {
+ assert.ErrorContains(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestFormValueOr(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ defaultValue string
+ expect string
+ expectErr string
+ }{
+ {
+ name: "ok, value exists",
+ givenURL: "/?name=john",
+ defaultValue: "default",
+ expect: "john",
+ },
+ {
+ name: "ok, value missing - returns default",
+ givenURL: "/?other=john",
+ defaultValue: "default",
+ expect: "default",
+ },
+ {
+ name: "ok, value exists but empty - returns default",
+ givenURL: "/?name=",
+ defaultValue: "default",
+ expect: "default",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
+ c := NewContext(req, nil)
+
+ v, err := FormValueOr[string](c, "name", tc.defaultValue)
+ if tc.expectErr != "" {
+ assert.ErrorContains(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func TestFormValuesOr(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ defaultValue []string
+ expect []string
+ expectErr string
+ }{
+ {
+ name: "ok, values exist",
+ givenURL: "/?tags=go&tags=rust&tags=python",
+ defaultValue: []string{"default"},
+ expect: []string{"go", "rust", "python"},
+ },
+ {
+ name: "ok, values missing - returns default",
+ givenURL: "/?other=value",
+ defaultValue: []string{"a", "b"},
+ expect: []string{"a", "b"},
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil)
+ c := NewContext(req, nil)
+
+ v, err := FormValuesOr[string](c, "tags", tc.defaultValue)
+ if tc.expectErr != "" {
+ assert.ErrorContains(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
diff --git a/binder_test.go b/binder_test.go
new file mode 100644
index 000000000..8eced8208
--- /dev/null
+++ b/binder_test.go
@@ -0,0 +1,3252 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "github.com/stretchr/testify/assert"
+ "io"
+ "math/big"
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+)
+
+func createTestContext(URL string, body io.Reader, pathValues map[string]string) *Context {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, URL, body)
+ if body != nil {
+ req.Header.Set(HeaderContentType, MIMEApplicationJSON)
+ }
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ if len(pathValues) > 0 {
+ params := make(PathValues, 0)
+ for name, value := range pathValues {
+ params = append(params, PathValue{
+ Name: name,
+ Value: value,
+ })
+ }
+ c.SetPathValues(params)
+ }
+
+ return c
+}
+
+func TestBindingError_Error(t *testing.T) {
+ err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error"))
+ assert.EqualError(t, err, `code=400, message=bind failed, err=internal error, field=id`)
+
+ bErr := err.(*BindingError)
+ assert.Equal(t, 400, bErr.Code)
+ assert.Equal(t, "bind failed", bErr.Message)
+ assert.Equal(t, errors.New("internal error"), bErr.err)
+
+ assert.Equal(t, "id", bErr.Field)
+ assert.Equal(t, []string{"1", "nope"}, bErr.Values)
+}
+
+func TestBindingError_ErrorJSON(t *testing.T) {
+ err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error"))
+
+ resp, _ := json.Marshal(err)
+
+ assert.Equal(t, `{"field":"id","message":"bind failed"}`, string(resp))
+}
+
+func TestPathValuesBinder(t *testing.T) {
+ c := createTestContext("/api/user/999", nil, map[string]string{
+ "id": "1",
+ "nr": "2",
+ "slice": "3",
+ })
+ b := PathValuesBinder(c)
+
+ id := int64(99)
+ nr := int64(88)
+ var slice = make([]int64, 0)
+ var notExisting = make([]int64, 0)
+ err := b.Int64("id", &id).
+ Int64("nr", &nr).
+ Int64s("slice", &slice).
+ Int64s("not_existing", ¬Existing).
+ BindError()
+
+ assert.NoError(t, err)
+ assert.Equal(t, int64(1), id)
+ assert.Equal(t, int64(2), nr)
+ assert.Equal(t, []int64{3}, slice) // binding params to slice does not make sense but it should not panic either
+ assert.Equal(t, []int64{}, notExisting) // binding params to slice does not make sense but it should not panic either
+}
+
+func TestQueryParamsBinder_FailFast(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError []string
+ givenFailFast bool
+ }{
+ {
+ name: "ok, FailFast=true stops at first error",
+ whenURL: "/api/user/999?nr=en&id=nope",
+ givenFailFast: true,
+ expectError: []string{
+ `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "nope": invalid syntax, field=id`,
+ },
+ },
+ {
+ name: "ok, FailFast=false encounters all errors",
+ whenURL: "/api/user/999?nr=en&id=nope",
+ givenFailFast: false,
+ expectError: []string{
+ `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "nope": invalid syntax, field=id`,
+ `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "en": invalid syntax, field=nr`,
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, map[string]string{"id": "999"})
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ id := int64(99)
+ nr := int64(88)
+ errs := b.Int64("id", &id).
+ Int64("nr", &nr).
+ BindErrors()
+
+ assert.Len(t, errs, len(tc.expectError))
+ for _, err := range errs {
+ assert.Contains(t, tc.expectError, err.Error())
+ }
+ })
+ }
+}
+
+func TestFormFieldBinder(t *testing.T) {
+ e := New()
+ body := `texta=foo&slice=5`
+ req := httptest.NewRequest(http.MethodPost, "/api/search?id=1&nr=2&slice=3&slice=4", strings.NewReader(body))
+ req.Header.Set(HeaderContentLength, strconv.Itoa(len(body)))
+ req.Header.Set(HeaderContentType, MIMEApplicationForm)
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ b := FormFieldBinder(c)
+
+ var texta string
+ id := int64(99)
+ nr := int64(88)
+ var slice = make([]int64, 0)
+ var notExisting = make([]int64, 0)
+ err := b.
+ Int64s("slice", &slice).
+ Int64("id", &id).
+ Int64("nr", &nr).
+ String("texta", &texta).
+ Int64s("notExisting", ¬Existing).
+ BindError()
+
+ assert.NoError(t, err)
+ assert.Equal(t, "foo", texta)
+ assert.Equal(t, int64(1), id)
+ assert.Equal(t, int64(2), nr)
+ assert.Equal(t, []int64{5, 3, 4}, slice)
+ assert.Equal(t, []int64{}, notExisting)
+}
+
+func TestValueBinder_errorStopsBinding(t *testing.T) {
+ // this test documents "feature" that binding multiple params can change destination if it was bound before
+ // failing parameter binding
+
+ c := createTestContext("/api/user/999?id=1&nr=nope", nil, nil)
+ b := QueryParamsBinder(c)
+
+ id := int64(99) // will be changed before nr binding fails
+ nr := int64(88) // will not be changed
+ err := b.Int64("id", &id).
+ Int64("nr", &nr).
+ BindError()
+
+ assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=nr")
+ assert.Equal(t, int64(1), id)
+ assert.Equal(t, int64(88), nr)
+}
+
+func TestValueBinder_BindError(t *testing.T) {
+ c := createTestContext("/api/user/999?nr=en&id=nope", nil, nil)
+ b := QueryParamsBinder(c)
+
+ id := int64(99)
+ nr := int64(88)
+ err := b.Int64("id", &id).
+ Int64("nr", &nr).
+ BindError()
+
+ assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=id")
+ assert.Nil(t, b.errors)
+ assert.Nil(t, b.BindError())
+}
+
+func TestValueBinder_GetValues(t *testing.T) {
+ var testCases = []struct {
+ whenValuesFunc func(sourceParam string) []string
+ name string
+ expectError string
+ expect []int64
+ }{
+ {
+ name: "ok, default implementation",
+ expect: []int64{1, 101},
+ },
+ {
+ name: "ok, values returns nil",
+ whenValuesFunc: func(sourceParam string) []string {
+ return nil
+ },
+ expect: []int64(nil),
+ },
+ {
+ name: "ok, values returns empty slice",
+ whenValuesFunc: func(sourceParam string) []string {
+ return []string{}
+ },
+ expect: []int64(nil),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext("/search?nr=en&id=1&id=101", nil, nil)
+ b := QueryParamsBinder(c)
+ if tc.whenValuesFunc != nil {
+ b.ValuesFunc = tc.whenValuesFunc
+ }
+
+ var IDs []int64
+ err := b.Int64s("id", &IDs).BindError()
+
+ assert.Equal(t, tc.expect, IDs)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_CustomFuncWithError(t *testing.T) {
+ c := createTestContext("/search?nr=en&id=1&id=101", nil, nil)
+ b := QueryParamsBinder(c)
+
+ id := int64(99)
+ givenCustomFunc := func(values []string) []error {
+ assert.Equal(t, []string{"1", "101"}, values)
+
+ return []error{
+ errors.New("first error"),
+ errors.New("second error"),
+ }
+ }
+ err := b.CustomFunc("id", givenCustomFunc).BindError()
+
+ assert.Equal(t, int64(99), id)
+ assert.EqualError(t, err, "first error")
+}
+
+func TestValueBinder_CustomFunc(t *testing.T) {
+ var testCases = []struct {
+ expectValue any
+ name string
+ whenURL string
+ givenFuncErrors []error
+ expectParamValues []string
+ expectErrors []string
+ givenFailFast bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectParamValues: []string{"1", "100"},
+ expectValue: int64(1000),
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nr=en",
+ expectParamValues: []string{},
+ expectValue: int64(99),
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectParamValues: []string{"1", "100"},
+ expectValue: int64(99),
+ expectErrors: []string{"previous error"},
+ },
+ {
+ name: "nok, func returns errors",
+ givenFuncErrors: []error{
+ errors.New("first error"),
+ errors.New("second error"),
+ },
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectParamValues: []string{"1", "100"},
+ expectValue: int64(99),
+ expectErrors: []string{"first error", "second error"},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ id := int64(99)
+ givenCustomFunc := func(values []string) []error {
+ assert.Equal(t, tc.expectParamValues, values)
+ if tc.givenFuncErrors == nil {
+ id = 1000 // emulated conversion and setting value
+ return nil
+ }
+ return tc.givenFuncErrors
+ }
+ errs := b.CustomFunc("id", givenCustomFunc).BindErrors()
+
+ assert.Equal(t, tc.expectValue, id)
+ if tc.expectErrors != nil {
+ assert.Len(t, errs, len(tc.expectErrors))
+ for _, err := range errs {
+ assert.Contains(t, tc.expectErrors, err.Error())
+ }
+ } else {
+ assert.Nil(t, errs)
+ }
+ })
+ }
+}
+
+func TestValueBinder_MustCustomFunc(t *testing.T) {
+ var testCases = []struct {
+ expectValue any
+ name string
+ whenURL string
+ givenFuncErrors []error
+ expectParamValues []string
+ expectErrors []string
+ givenFailFast bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectParamValues: []string{"1", "100"},
+ expectValue: int64(1000),
+ },
+ {
+ name: "nok, params values empty, returns error, value is not changed",
+ whenURL: "/search?nr=en",
+ expectParamValues: []string{},
+ expectValue: int64(99),
+ expectErrors: []string{"code=400, message=required field value is empty, field=id"},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectParamValues: []string{"1", "100"},
+ expectValue: int64(99),
+ expectErrors: []string{"previous error"},
+ },
+ {
+ name: "nok, func returns errors",
+ givenFuncErrors: []error{
+ errors.New("first error"),
+ errors.New("second error"),
+ },
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectParamValues: []string{"1", "100"},
+ expectValue: int64(99),
+ expectErrors: []string{"first error", "second error"},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ id := int64(99)
+ givenCustomFunc := func(values []string) []error {
+ assert.Equal(t, tc.expectParamValues, values)
+ if tc.givenFuncErrors == nil {
+ id = 1000 // emulated conversion and setting value
+ return nil
+ }
+ return tc.givenFuncErrors
+ }
+ errs := b.MustCustomFunc("id", givenCustomFunc).BindErrors()
+
+ assert.Equal(t, tc.expectValue, id)
+ if tc.expectErrors != nil {
+ assert.Len(t, errs, len(tc.expectErrors))
+ for _, err := range errs {
+ assert.Contains(t, tc.expectErrors, err.Error())
+ }
+ } else {
+ assert.Nil(t, errs)
+ }
+ })
+ }
+}
+
+func TestValueBinder_String(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectValue string
+ expectError string
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=en¶m=de",
+ expectValue: "en",
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nr=en",
+ expectValue: "default",
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectValue: "default",
+ expectError: "previous error",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=en¶m=de",
+ expectValue: "en",
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nr=en",
+ expectValue: "default",
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectValue: "default",
+ expectError: "previous error",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := "default"
+ var err error
+ if tc.whenMust {
+ err = b.MustString("param", &dest).BindError()
+ } else {
+ err = b.String("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Strings(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []string
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=en¶m=de",
+ expectValue: []string{"en", "de"},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nr=en",
+ expectValue: []string{"default"},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectValue: []string{"default"},
+ expectError: "previous error",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=en¶m=de",
+ expectValue: []string{"en", "de"},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nr=en",
+ expectValue: []string{"default"},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?nr=en&id=1&id=100",
+ expectValue: []string{"default"},
+ expectError: "previous error",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := []string{"default"}
+ var err error
+ if tc.whenMust {
+ err = b.MustStrings("param", &dest).BindError()
+ } else {
+ err = b.Strings("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Int64_intValue(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue int64
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 1,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: 99,
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 99,
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 99,
+ expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 1,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: 99,
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 99,
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 99,
+ expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := int64(99)
+ var err error
+ if tc.whenMust {
+ err = b.MustInt64("param", &dest).BindError()
+ } else {
+ err = b.Int64("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Int_errorMessage(t *testing.T) {
+ // int/uint (without byte size) has a little bit different error message so test these separately
+ c := createTestContext("/search?param=nope", nil, nil)
+ b := QueryParamsBinder(c).FailFast(false)
+
+ destInt := 99
+ destUint := uint(98)
+ errs := b.Int("param", &destInt).Uint("param", &destUint).BindErrors()
+
+ assert.Equal(t, 99, destInt)
+ assert.Equal(t, uint(98), destUint)
+ assert.EqualError(t, errs[0], `code=400, message=failed to bind field value to int, err=strconv.ParseInt: parsing "nope": invalid syntax, field=param`)
+ assert.EqualError(t, errs[1], `code=400, message=failed to bind field value to uint, err=strconv.ParseUint: parsing "nope": invalid syntax, field=param`)
+}
+
+func TestValueBinder_Uint64_uintValue(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue uint64
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 1,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: 99,
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 99,
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 99,
+ expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 1,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: 99,
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 99,
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 99,
+ expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := uint64(99)
+ var err error
+ if tc.whenMust {
+ err = b.MustUint64("param", &dest).BindError()
+ } else {
+ err = b.Uint64("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Int_Types(t *testing.T) {
+ type target struct {
+ int64 int64
+ mustInt64 int64
+ uint64 uint64
+ mustUint64 uint64
+
+ int32 int32
+ mustInt32 int32
+ uint32 uint32
+ mustUint32 uint32
+
+ int16 int16
+ mustInt16 int16
+ uint16 uint16
+ mustUint16 uint16
+
+ int8 int8
+ mustInt8 int8
+ uint8 uint8
+ mustUint8 uint8
+
+ byte byte
+ mustByte byte
+
+ int int
+ mustInt int
+ uint uint
+ mustUint uint
+ }
+ types := []string{
+ "int64=1",
+ "mustInt64=2",
+ "uint64=3",
+ "mustUint64=4",
+
+ "int32=5",
+ "mustInt32=6",
+ "uint32=7",
+ "mustUint32=8",
+
+ "int16=9",
+ "mustInt16=10",
+ "uint16=11",
+ "mustUint16=12",
+
+ "int8=13",
+ "mustInt8=14",
+ "uint8=15",
+ "mustUint8=16",
+
+ "byte=17",
+ "mustByte=18",
+
+ "int=19",
+ "mustInt=20",
+ "uint=21",
+ "mustUint=22",
+ }
+ c := createTestContext("/search?"+strings.Join(types, "&"), nil, nil)
+ b := QueryParamsBinder(c)
+
+ dest := target{}
+ err := b.
+ Int64("int64", &dest.int64).
+ MustInt64("mustInt64", &dest.mustInt64).
+ Uint64("uint64", &dest.uint64).
+ MustUint64("mustUint64", &dest.mustUint64).
+ Int32("int32", &dest.int32).
+ MustInt32("mustInt32", &dest.mustInt32).
+ Uint32("uint32", &dest.uint32).
+ MustUint32("mustUint32", &dest.mustUint32).
+ Int16("int16", &dest.int16).
+ MustInt16("mustInt16", &dest.mustInt16).
+ Uint16("uint16", &dest.uint16).
+ MustUint16("mustUint16", &dest.mustUint16).
+ Int8("int8", &dest.int8).
+ MustInt8("mustInt8", &dest.mustInt8).
+ Uint8("uint8", &dest.uint8).
+ MustUint8("mustUint8", &dest.mustUint8).
+ Byte("byte", &dest.byte).
+ MustByte("mustByte", &dest.mustByte).
+ Int("int", &dest.int).
+ MustInt("mustInt", &dest.mustInt).
+ Uint("uint", &dest.uint).
+ MustUint("mustUint", &dest.mustUint).
+ BindError()
+
+ assert.NoError(t, err)
+ assert.Equal(t, int64(1), dest.int64)
+ assert.Equal(t, int64(2), dest.mustInt64)
+ assert.Equal(t, uint64(3), dest.uint64)
+ assert.Equal(t, uint64(4), dest.mustUint64)
+
+ assert.Equal(t, int32(5), dest.int32)
+ assert.Equal(t, int32(6), dest.mustInt32)
+ assert.Equal(t, uint32(7), dest.uint32)
+ assert.Equal(t, uint32(8), dest.mustUint32)
+
+ assert.Equal(t, int16(9), dest.int16)
+ assert.Equal(t, int16(10), dest.mustInt16)
+ assert.Equal(t, uint16(11), dest.uint16)
+ assert.Equal(t, uint16(12), dest.mustUint16)
+
+ assert.Equal(t, int8(13), dest.int8)
+ assert.Equal(t, int8(14), dest.mustInt8)
+ assert.Equal(t, uint8(15), dest.uint8)
+ assert.Equal(t, uint8(16), dest.mustUint8)
+
+ assert.Equal(t, uint8(17), dest.byte)
+ assert.Equal(t, uint8(18), dest.mustByte)
+
+ assert.Equal(t, 19, dest.int)
+ assert.Equal(t, 20, dest.mustInt)
+ assert.Equal(t, uint(21), dest.uint)
+ assert.Equal(t, uint(22), dest.mustUint)
+}
+
+func TestValueBinder_Int64s_intsValue(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []int64
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=1¶m=2¶m=1",
+ expectValue: []int64{1, 2, 1},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: []int64{99},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []int64{99},
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []int64{99},
+ expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1¶m=2¶m=1",
+ expectValue: []int64{1, 2, 1},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: []int64{99},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []int64{99},
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []int64{99},
+ expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := []int64{99} // when values are set with bind - contents before bind is gone
+ var err error
+ if tc.whenMust {
+ err = b.MustInt64s("param", &dest).BindError()
+ } else {
+ err = b.Int64s("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Uint64s_uintsValue(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []uint64
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=1¶m=2¶m=1",
+ expectValue: []uint64{1, 2, 1},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: []uint64{99},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []uint64{99},
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []uint64{99},
+ expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1¶m=2¶m=1",
+ expectValue: []uint64{1, 2, 1},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: []uint64{99},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []uint64{99},
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []uint64{99},
+ expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := []uint64{99} // when values are set with bind - contents before bind is gone
+ var err error
+ if tc.whenMust {
+ err = b.MustUint64s("param", &dest).BindError()
+ } else {
+ err = b.Uint64s("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Ints_Types(t *testing.T) {
+ type target struct {
+ int64 []int64
+ mustInt64 []int64
+ uint64 []uint64
+ mustUint64 []uint64
+
+ int32 []int32
+ mustInt32 []int32
+ uint32 []uint32
+ mustUint32 []uint32
+
+ int16 []int16
+ mustInt16 []int16
+ uint16 []uint16
+ mustUint16 []uint16
+
+ int8 []int8
+ mustInt8 []int8
+ uint8 []uint8
+ mustUint8 []uint8
+
+ int []int
+ mustInt []int
+ uint []uint
+ mustUint []uint
+ }
+ types := []string{
+ "int64=1",
+ "mustInt64=2",
+ "uint64=3",
+ "mustUint64=4",
+
+ "int32=5",
+ "mustInt32=6",
+ "uint32=7",
+ "mustUint32=8",
+
+ "int16=9",
+ "mustInt16=10",
+ "uint16=11",
+ "mustUint16=12",
+
+ "int8=13",
+ "mustInt8=14",
+ "uint8=15",
+ "mustUint8=16",
+
+ "int=19",
+ "mustInt=20",
+ "uint=21",
+ "mustUint=22",
+ }
+ url := "/search?"
+ for _, v := range types {
+ url = url + "&" + v + "&" + v
+ }
+ c := createTestContext(url, nil, nil)
+ b := QueryParamsBinder(c)
+
+ dest := target{}
+ err := b.
+ Int64s("int64", &dest.int64).
+ MustInt64s("mustInt64", &dest.mustInt64).
+ Uint64s("uint64", &dest.uint64).
+ MustUint64s("mustUint64", &dest.mustUint64).
+ Int32s("int32", &dest.int32).
+ MustInt32s("mustInt32", &dest.mustInt32).
+ Uint32s("uint32", &dest.uint32).
+ MustUint32s("mustUint32", &dest.mustUint32).
+ Int16s("int16", &dest.int16).
+ MustInt16s("mustInt16", &dest.mustInt16).
+ Uint16s("uint16", &dest.uint16).
+ MustUint16s("mustUint16", &dest.mustUint16).
+ Int8s("int8", &dest.int8).
+ MustInt8s("mustInt8", &dest.mustInt8).
+ Uint8s("uint8", &dest.uint8).
+ MustUint8s("mustUint8", &dest.mustUint8).
+ Ints("int", &dest.int).
+ MustInts("mustInt", &dest.mustInt).
+ Uints("uint", &dest.uint).
+ MustUints("mustUint", &dest.mustUint).
+ BindError()
+
+ assert.NoError(t, err)
+ assert.Equal(t, []int64{1, 1}, dest.int64)
+ assert.Equal(t, []int64{2, 2}, dest.mustInt64)
+ assert.Equal(t, []uint64{3, 3}, dest.uint64)
+ assert.Equal(t, []uint64{4, 4}, dest.mustUint64)
+
+ assert.Equal(t, []int32{5, 5}, dest.int32)
+ assert.Equal(t, []int32{6, 6}, dest.mustInt32)
+ assert.Equal(t, []uint32{7, 7}, dest.uint32)
+ assert.Equal(t, []uint32{8, 8}, dest.mustUint32)
+
+ assert.Equal(t, []int16{9, 9}, dest.int16)
+ assert.Equal(t, []int16{10, 10}, dest.mustInt16)
+ assert.Equal(t, []uint16{11, 11}, dest.uint16)
+ assert.Equal(t, []uint16{12, 12}, dest.mustUint16)
+
+ assert.Equal(t, []int8{13, 13}, dest.int8)
+ assert.Equal(t, []int8{14, 14}, dest.mustInt8)
+ assert.Equal(t, []uint8{15, 15}, dest.uint8)
+ assert.Equal(t, []uint8{16, 16}, dest.mustUint8)
+
+ assert.Equal(t, []int{19, 19}, dest.int)
+ assert.Equal(t, []int{20, 20}, dest.mustInt)
+ assert.Equal(t, []uint{21, 21}, dest.uint)
+ assert.Equal(t, []uint{22, 22}, dest.mustUint)
+}
+
+func TestValueBinder_Ints_Types_FailFast(t *testing.T) {
+ // FailFast() should stop parsing and return early
+ errTmpl := "code=400, message=failed to bind field value to %v, err=strconv.Parse%v: parsing \"nope\": invalid syntax, field=param"
+ c := createTestContext("/search?param=1¶m=nope¶m=2", nil, nil)
+
+ var dest64 []int64
+ err := QueryParamsBinder(c).FailFast(true).Int64s("param", &dest64).BindError()
+ assert.Equal(t, []int64(nil), dest64)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int64", "Int"))
+
+ var dest32 []int32
+ err = QueryParamsBinder(c).FailFast(true).Int32s("param", &dest32).BindError()
+ assert.Equal(t, []int32(nil), dest32)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int32", "Int"))
+
+ var dest16 []int16
+ err = QueryParamsBinder(c).FailFast(true).Int16s("param", &dest16).BindError()
+ assert.Equal(t, []int16(nil), dest16)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int16", "Int"))
+
+ var dest8 []int8
+ err = QueryParamsBinder(c).FailFast(true).Int8s("param", &dest8).BindError()
+ assert.Equal(t, []int8(nil), dest8)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int8", "Int"))
+
+ var dest []int
+ err = QueryParamsBinder(c).FailFast(true).Ints("param", &dest).BindError()
+ assert.Equal(t, []int(nil), dest)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int", "Int"))
+
+ var destu64 []uint64
+ err = QueryParamsBinder(c).FailFast(true).Uint64s("param", &destu64).BindError()
+ assert.Equal(t, []uint64(nil), destu64)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint64", "Uint"))
+
+ var destu32 []uint32
+ err = QueryParamsBinder(c).FailFast(true).Uint32s("param", &destu32).BindError()
+ assert.Equal(t, []uint32(nil), destu32)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint32", "Uint"))
+
+ var destu16 []uint16
+ err = QueryParamsBinder(c).FailFast(true).Uint16s("param", &destu16).BindError()
+ assert.Equal(t, []uint16(nil), destu16)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint16", "Uint"))
+
+ var destu8 []uint8
+ err = QueryParamsBinder(c).FailFast(true).Uint8s("param", &destu8).BindError()
+ assert.Equal(t, []uint8(nil), destu8)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint8", "Uint"))
+
+ var destu []uint
+ err = QueryParamsBinder(c).FailFast(true).Uints("param", &destu).BindError()
+ assert.Equal(t, []uint(nil), destu)
+ assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint", "Uint"))
+}
+
+func TestValueBinder_Bool(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ expectValue bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=true¶m=1",
+ expectValue: true,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: false,
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: false,
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: false,
+ expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: true,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: false,
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: false,
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: false,
+ expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := false
+ var err error
+ if tc.whenMust {
+ err = b.MustBool("param", &dest).BindError()
+ } else {
+ err = b.Bool("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Bools(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []bool
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=true¶m=false¶m=1¶m=0",
+ expectValue: []bool{true, false, true, false},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: []bool(nil),
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []bool(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=true¶m=nope¶m=100",
+ expectValue: []bool(nil),
+ expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "nok, conversion fails fast, value is not changed",
+ givenFailFast: true,
+ whenURL: "/search?param=true¶m=nope¶m=100",
+ expectValue: []bool(nil),
+ expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=true¶m=false¶m=1¶m=0",
+ expectValue: []bool{true, false, true, false},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: []bool(nil),
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []bool(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []bool(nil),
+ expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ b.errors = tc.givenBindErrors
+
+ var dest []bool
+ var err error
+ if tc.whenMust {
+ err = b.MustBools("param", &dest).BindError()
+ } else {
+ err = b.Bools("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Float64(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue float64
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=4.3¶m=1",
+ expectValue: 4.3,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: 1.123,
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 1.123,
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 1.123,
+ expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=4.3¶m=100",
+ expectValue: 4.3,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: 1.123,
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 1.123,
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 1.123,
+ expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := 1.123
+ var err error
+ if tc.whenMust {
+ err = b.MustFloat64("param", &dest).BindError()
+ } else {
+ err = b.Float64("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Float64s(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []float64
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=4.3¶m=0",
+ expectValue: []float64{4.3, 0},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: []float64(nil),
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []float64(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []float64(nil),
+ expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "nok, conversion fails fast, value is not changed",
+ givenFailFast: true,
+ whenURL: "/search?param=0¶m=nope¶m=100",
+ expectValue: []float64(nil),
+ expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=4.3¶m=0",
+ expectValue: []float64{4.3, 0},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: []float64(nil),
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []float64(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []float64(nil),
+ expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ b.errors = tc.givenBindErrors
+
+ var dest []float64
+ var err error
+ if tc.whenMust {
+ err = b.MustFloat64s("param", &dest).BindError()
+ } else {
+ err = b.Float64s("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Float32(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue float32
+ givenNoFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=4.3¶m=1",
+ expectValue: 4.3,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: 1.123,
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenNoFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 1.123,
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 1.123,
+ expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=4.3¶m=100",
+ expectValue: 4.3,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: 1.123,
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenNoFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 1.123,
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 1.123,
+ expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenNoFailFast)
+ if tc.givenNoFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := float32(1.123)
+ var err error
+ if tc.whenMust {
+ err = b.MustFloat32("param", &dest).BindError()
+ } else {
+ err = b.Float32("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Float32s(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []float32
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=4.3¶m=0",
+ expectValue: []float32{4.3, 0},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: []float32(nil),
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []float32(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []float32(nil),
+ expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "nok, conversion fails fast, value is not changed",
+ givenFailFast: true,
+ whenURL: "/search?param=0¶m=nope¶m=100",
+ expectValue: []float32(nil),
+ expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=4.3¶m=0",
+ expectValue: []float32{4.3, 0},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: []float32(nil),
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []float32(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []float32(nil),
+ expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ b.errors = tc.givenBindErrors
+
+ var dest []float32
+ var err error
+ if tc.whenMust {
+ err = b.MustFloat32s("param", &dest).BindError()
+ } else {
+ err = b.Float32s("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Time(t *testing.T) {
+ exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00")
+ var testCases = []struct {
+ expectValue time.Time
+ name string
+ whenURL string
+ whenLayout string
+ expectError string
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
+ whenLayout: time.RFC3339,
+ expectValue: exampleTime,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: time.Time{},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: time.Time{},
+ expectError: "previous error",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
+ whenLayout: time.RFC3339,
+ expectValue: exampleTime,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: time.Time{},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: time.Time{},
+ expectError: "previous error",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := time.Time{}
+ var err error
+ if tc.whenMust {
+ err = b.MustTime("param", &dest, tc.whenLayout).BindError()
+ } else {
+ err = b.Time("param", &dest, tc.whenLayout).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Times(t *testing.T) {
+ exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00")
+ exampleTime2, _ := time.Parse(time.RFC3339, "2000-01-02T09:45:31+00:00")
+ var testCases = []struct {
+ name string
+ whenURL string
+ whenLayout string
+ expectError string
+ givenBindErrors []error
+ expectValue []time.Time
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
+ whenLayout: time.RFC3339,
+ expectValue: []time.Time{exampleTime, exampleTime2},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: []time.Time(nil),
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []time.Time(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
+ whenLayout: time.RFC3339,
+ expectValue: []time.Time{exampleTime, exampleTime2},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: []time.Time(nil),
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []time.Time(nil),
+ expectError: "previous error",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ b.errors = tc.givenBindErrors
+
+ layout := time.RFC3339
+ if tc.whenLayout != "" {
+ layout = tc.whenLayout
+ }
+
+ var dest []time.Time
+ var err error
+ if tc.whenMust {
+ err = b.MustTimes("param", &dest, layout).BindError()
+ } else {
+ err = b.Times("param", &dest, layout).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Duration(t *testing.T) {
+ example := 42 * time.Second
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue time.Duration
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=42s¶m=1ms",
+ expectValue: example,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: 0,
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 0,
+ expectError: "previous error",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=42s¶m=1ms",
+ expectValue: example,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: 0,
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: 0,
+ expectError: "previous error",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ var dest time.Duration
+ var err error
+ if tc.whenMust {
+ err = b.MustDuration("param", &dest).BindError()
+ } else {
+ err = b.Duration("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_Durations(t *testing.T) {
+ exampleDuration := 42 * time.Second
+ exampleDuration2 := 1 * time.Millisecond
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []time.Duration
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=42s¶m=1ms",
+ expectValue: []time.Duration{exampleDuration, exampleDuration2},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: []time.Duration(nil),
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []time.Duration(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=42s¶m=1ms",
+ expectValue: []time.Duration{exampleDuration, exampleDuration2},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: []time.Duration(nil),
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ givenBindErrors: []error{errors.New("previous error")},
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []time.Duration(nil),
+ expectError: "previous error",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ b.errors = tc.givenBindErrors
+
+ var dest []time.Duration
+ var err error
+ if tc.whenMust {
+ err = b.MustDurations("param", &dest).BindError()
+ } else {
+ err = b.Durations("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_BindUnmarshaler(t *testing.T) {
+ exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00")
+
+ var testCases = []struct {
+ expectValue Timestamp
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
+ expectValue: Timestamp(exampleTime),
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: Timestamp{},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: Timestamp{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: Timestamp{},
+ expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00",
+ expectValue: Timestamp(exampleTime),
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: Timestamp{},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: Timestamp{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: Timestamp{},
+ expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ var dest Timestamp
+ var err error
+ if tc.whenMust {
+ err = b.MustBindUnmarshaler("param", &dest).BindError()
+ } else {
+ err = b.BindUnmarshaler("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_JSONUnmarshaler(t *testing.T) {
+ example := big.NewInt(999)
+
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ expectValue big.Int
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=999¶m=998",
+ expectValue: *example,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: big.Int{},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: big.Int{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=xxx",
+ expectValue: big.Int{},
+ expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=999¶m=998",
+ expectValue: *example,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: big.Int{},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=xxx",
+ expectValue: big.Int{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=xxx",
+ expectValue: big.Int{},
+ expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ var dest big.Int
+ var err error
+ if tc.whenMust {
+ err = b.MustJSONUnmarshaler("param", &dest).BindError()
+ } else {
+ err = b.JSONUnmarshaler("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_TextUnmarshaler(t *testing.T) {
+ example := big.NewInt(999)
+
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ expectValue big.Int
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=999¶m=998",
+ expectValue: *example,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: big.Int{},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: big.Int{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=xxx",
+ expectValue: big.Int{},
+ expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=999¶m=998",
+ expectValue: *example,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: big.Int{},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=xxx",
+ expectValue: big.Int{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=xxx",
+ expectValue: big.Int{},
+ expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ var dest big.Int
+ var err error
+ if tc.whenMust {
+ err = b.MustTextUnmarshaler("param", &dest).BindError()
+ } else {
+ err = b.TextUnmarshaler("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_BindWithDelimiter_types(t *testing.T) {
+ var testCases = []struct {
+ expect any
+ name string
+ whenURL string
+ }{
+ {
+ name: "ok, strings",
+ expect: []string{"1", "2", "1"},
+ },
+ {
+ name: "ok, int64",
+ expect: []int64{1, 2, 1},
+ },
+ {
+ name: "ok, int32",
+ expect: []int32{1, 2, 1},
+ },
+ {
+ name: "ok, int16",
+ expect: []int16{1, 2, 1},
+ },
+ {
+ name: "ok, int8",
+ expect: []int8{1, 2, 1},
+ },
+ {
+ name: "ok, int",
+ expect: []int{1, 2, 1},
+ },
+ {
+ name: "ok, uint64",
+ expect: []uint64{1, 2, 1},
+ },
+ {
+ name: "ok, uint32",
+ expect: []uint32{1, 2, 1},
+ },
+ {
+ name: "ok, uint16",
+ expect: []uint16{1, 2, 1},
+ },
+ {
+ name: "ok, uint8",
+ expect: []uint8{1, 2, 1},
+ },
+ {
+ name: "ok, uint",
+ expect: []uint{1, 2, 1},
+ },
+ {
+ name: "ok, float64",
+ expect: []float64{1, 2, 1},
+ },
+ {
+ name: "ok, float32",
+ expect: []float32{1, 2, 1},
+ },
+ {
+ name: "ok, bool",
+ whenURL: "/search?param=1,false¶m=true",
+ expect: []bool{true, false, true},
+ },
+ {
+ name: "ok, Duration",
+ whenURL: "/search?param=1s,42s¶m=1ms",
+ expect: []time.Duration{1 * time.Second, 42 * time.Second, 1 * time.Millisecond},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ URL := "/search?param=1,2¶m=1"
+ if tc.whenURL != "" {
+ URL = tc.whenURL
+ }
+ c := createTestContext(URL, nil, nil)
+ b := QueryParamsBinder(c)
+
+ switch tc.expect.(type) {
+ case []string:
+ var dest []string
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []int64:
+ var dest []int64
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []int32:
+ var dest []int32
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []int16:
+ var dest []int16
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []int8:
+ var dest []int8
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []int:
+ var dest []int
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []uint64:
+ var dest []uint64
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []uint32:
+ var dest []uint32
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []uint16:
+ var dest []uint16
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []uint8:
+ var dest []uint8
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []uint:
+ var dest []uint
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []float64:
+ var dest []float64
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []float32:
+ var dest []float32
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []bool:
+ var dest []bool
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ case []time.Duration:
+ var dest []time.Duration
+ assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError())
+ assert.Equal(t, tc.expect, dest)
+ default:
+ assert.Fail(t, "invalid type")
+ }
+ })
+ }
+}
+
+func TestValueBinder_BindWithDelimiter(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []int64
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value",
+ whenURL: "/search?param=1,2¶m=1",
+ expectValue: []int64{1, 2, 1},
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: []int64(nil),
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []int64(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []int64(nil),
+ expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1,2¶m=1",
+ expectValue: []int64{1, 2, 1},
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: []int64(nil),
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []int64(nil),
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []int64(nil),
+ expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ var dest []int64
+ var err error
+ if tc.whenMust {
+ err = b.MustBindWithDelimiter("param", &dest, ",").BindError()
+ } else {
+ err = b.BindWithDelimiter("param", &dest, ",").BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestBindWithDelimiter_invalidType(t *testing.T) {
+ c := createTestContext("/search?param=1¶m=100", nil, nil)
+ b := QueryParamsBinder(c)
+
+ var dest []BindUnmarshaler
+ err := b.BindWithDelimiter("param", &dest, ",").BindError()
+ assert.Equal(t, []BindUnmarshaler(nil), dest)
+ assert.EqualError(t, err, "code=400, message=unsupported bind type, field=param")
+}
+
+func TestValueBinder_UnixTime(t *testing.T) {
+ exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43+00:00") // => 1609180603
+ var testCases = []struct {
+ expectValue time.Time
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value, unix time in seconds",
+ whenURL: "/search?param=1609180603¶m=1609180604",
+ expectValue: exampleTime,
+ },
+ {
+ name: "ok, binds value, unix time over int32 value",
+ whenURL: "/search?param=2147483648¶m=1609180604",
+ expectValue: time.Unix(2147483648, 0),
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: time.Time{},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: time.Time{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: time.Time{},
+ expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1609180603¶m=1609180604",
+ expectValue: exampleTime,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: time.Time{},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: time.Time{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: time.Time{},
+ expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := time.Time{}
+ var err error
+ if tc.whenMust {
+ err = b.MustUnixTime("param", &dest).BindError()
+ } else {
+ err = b.UnixTime("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano())
+ assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC))
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_UnixTimeMilli(t *testing.T) {
+ exampleTime, _ := time.Parse(time.RFC3339Nano, "2022-03-13T15:13:30.140000000+00:00") // => 1647184410140
+ var testCases = []struct {
+ expectValue time.Time
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value, unix time in milliseconds",
+ whenURL: "/search?param=1647184410140¶m=1647184410199",
+ expectValue: exampleTime,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: time.Time{},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: time.Time{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: time.Time{},
+ expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1647184410140¶m=1647184410199",
+ expectValue: exampleTime,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: time.Time{},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: time.Time{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: time.Time{},
+ expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := time.Time{}
+ var err error
+ if tc.whenMust {
+ err = b.MustUnixTimeMilli("param", &dest).BindError()
+ } else {
+ err = b.UnixTimeMilli("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano())
+ assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC))
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_UnixTimeNano(t *testing.T) {
+ exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43.000000000+00:00") // => 1609180603
+ exampleTimeNano, _ := time.Parse(time.RFC3339Nano, "2020-12-28T18:36:43.123456789+00:00") // => 1609180603123456789
+ exampleTimeNanoBelowSec, _ := time.Parse(time.RFC3339Nano, "1970-01-01T00:00:00.999999999+00:00")
+ var testCases = []struct {
+ expectValue time.Time
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "ok, binds value, unix time in nano seconds (sec precision)",
+ whenURL: "/search?param=1609180603000000000¶m=1609180604",
+ expectValue: exampleTime,
+ },
+ {
+ name: "ok, binds value, unix time in nano seconds",
+ whenURL: "/search?param=1609180603123456789¶m=1609180604",
+ expectValue: exampleTimeNano,
+ },
+ {
+ name: "ok, binds value, unix time in nano seconds (below 1 sec)",
+ whenURL: "/search?param=999999999¶m=1609180604",
+ expectValue: exampleTimeNanoBelowSec,
+ },
+ {
+ name: "ok, params values empty, value is not changed",
+ whenURL: "/search?nope=1",
+ expectValue: time.Time{},
+ },
+ {
+ name: "nok, previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: time.Time{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: time.Time{},
+ expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ {
+ name: "ok (must), binds value",
+ whenMust: true,
+ whenURL: "/search?param=1609180603000000000¶m=1609180604",
+ expectValue: exampleTime,
+ },
+ {
+ name: "ok (must), params values empty, returns error, value is not changed",
+ whenMust: true,
+ whenURL: "/search?nope=1",
+ expectValue: time.Time{},
+ expectError: "code=400, message=required field value is empty, field=param",
+ },
+ {
+ name: "nok (must), previous errors fail fast without binding value",
+ givenFailFast: true,
+ whenMust: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: time.Time{},
+ expectError: "previous error",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: time.Time{},
+ expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := time.Time{}
+ var err error
+ if tc.whenMust {
+ err = b.MustUnixTimeNano("param", &dest).BindError()
+ } else {
+ err = b.UnixTimeNano("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano())
+ assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC))
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) {
+ type Opts struct {
+ Param int64 `query:"param"`
+ }
+ c := createTestContext("/search?param=1¶m=100", nil, nil)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ binder := new(DefaultBinder)
+ for i := 0; i < b.N; i++ {
+ var dest Opts
+ _ = binder.Bind(c, &dest)
+ }
+}
+
+func BenchmarkValueBinder_BindInt64_single(b *testing.B) {
+ c := createTestContext("/search?param=1¶m=100", nil, nil)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ type Opts struct {
+ Param int64
+ }
+ binder := QueryParamsBinder(c)
+ for i := 0; i < b.N; i++ {
+ var dest Opts
+ _ = binder.Int64("param", &dest.Param).BindError()
+ }
+}
+
+func BenchmarkRawFunc_Int64_single(b *testing.B) {
+ c := createTestContext("/search?param=1¶m=100", nil, nil)
+
+ rawFunc := func(input string, defaultValue int64) (int64, bool) {
+ if input == "" {
+ return defaultValue, true
+ }
+ n, err := strconv.Atoi(input)
+ if err != nil {
+ return 0, false
+ }
+ return int64(n), true
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ type Opts struct {
+ Param int64
+ }
+ for i := 0; i < b.N; i++ {
+ var dest Opts
+ if n, ok := rawFunc(c.QueryParam("param"), 1); ok {
+ dest.Param = n
+ }
+ }
+}
+
+func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) {
+ type Opts struct {
+ String string `query:"string"`
+ Strings []string `query:"strings"`
+ Int64 int64 `query:"int64"`
+ Uint64 uint64 `query:"uint64"`
+ Int32 int32 `query:"int32"`
+ Uint32 uint32 `query:"uint32"`
+ Int16 int16 `query:"int16"`
+ Uint16 uint16 `query:"uint16"`
+ Int8 int8 `query:"int8"`
+ Uint8 uint8 `query:"uint8"`
+ }
+ c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ binder := new(DefaultBinder)
+ for i := 0; i < b.N; i++ {
+ var dest Opts
+ _ = binder.Bind(c, &dest)
+ if dest.Int64 != 1 {
+ b.Fatalf("int64!=1")
+ }
+ }
+}
+
+func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) {
+ type Opts struct {
+ String string `query:"string"`
+ Strings []string `query:"strings"`
+ Int64 int64 `query:"int64"`
+ Uint64 uint64 `query:"uint64"`
+ Int32 int32 `query:"int32"`
+ Uint32 uint32 `query:"uint32"`
+ Int16 int16 `query:"int16"`
+ Uint16 uint16 `query:"uint16"`
+ Int8 int8 `query:"int8"`
+ Uint8 uint8 `query:"uint8"`
+ }
+ c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ binder := QueryParamsBinder(c)
+ for i := 0; i < b.N; i++ {
+ var dest Opts
+ _ = binder.
+ Int64("int64", &dest.Int64).
+ Int32("int32", &dest.Int32).
+ Int16("int16", &dest.Int16).
+ Int8("int8", &dest.Int8).
+ String("string", &dest.String).
+ Uint64("int64", &dest.Uint64).
+ Uint32("int32", &dest.Uint32).
+ Uint16("int16", &dest.Uint16).
+ Uint8("int8", &dest.Uint8).
+ Strings("strings", &dest.Strings).
+ BindError()
+ if dest.Int64 != 1 {
+ b.Fatalf("int64!=1")
+ }
+ }
+}
+
+func TestValueBinder_TimeError(t *testing.T) {
+ var testCases = []struct {
+ expectValue time.Time
+ name string
+ whenURL string
+ whenLayout string
+ expectError string
+ givenBindErrors []error
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: time.Time{},
+ expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\": extra text: \"nope\", field=param",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: time.Time{},
+ expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\": extra text: \"nope\", field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ dest := time.Time{}
+ var err error
+ if tc.whenMust {
+ err = b.MustTime("param", &dest, tc.whenLayout).BindError()
+ } else {
+ err = b.Time("param", &dest, tc.whenLayout).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_TimesError(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ whenLayout string
+ expectError string
+ givenBindErrors []error
+ expectValue []time.Time
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "nok, fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []time.Time(nil),
+ expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []time.Time(nil),
+ expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []time.Time(nil),
+ expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ b.errors = tc.givenBindErrors
+
+ layout := time.RFC3339
+ if tc.whenLayout != "" {
+ layout = tc.whenLayout
+ }
+
+ var dest []time.Time
+ var err error
+ if tc.whenMust {
+ err = b.MustTimes("param", &dest, layout).BindError()
+ } else {
+ err = b.Times("param", &dest, layout).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_DurationError(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue time.Duration
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 0,
+ expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: 0,
+ expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ if tc.givenFailFast {
+ b.errors = []error{errors.New("previous error")}
+ }
+
+ var dest time.Duration
+ var err error
+ if tc.whenMust {
+ err = b.MustDuration("param", &dest).BindError()
+ } else {
+ err = b.Duration("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValueBinder_DurationsError(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectError string
+ givenBindErrors []error
+ expectValue []time.Duration
+ givenFailFast bool
+ whenMust bool
+ }{
+ {
+ name: "nok, fail fast without binding value",
+ givenFailFast: true,
+ whenURL: "/search?param=1¶m=100",
+ expectValue: []time.Duration(nil),
+ expectError: "code=400, message=failed to bind field value to Duration, err=time: missing unit in duration \"1\", field=param",
+ },
+ {
+ name: "nok, conversion fails, value is not changed",
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []time.Duration(nil),
+ expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param",
+ },
+ {
+ name: "nok (must), conversion fails, value is not changed",
+ whenMust: true,
+ whenURL: "/search?param=nope¶m=100",
+ expectValue: []time.Duration(nil),
+ expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ c := createTestContext(tc.whenURL, nil, nil)
+ b := QueryParamsBinder(c).FailFast(tc.givenFailFast)
+ b.errors = tc.givenBindErrors
+
+ var dest []time.Duration
+ var err error
+ if tc.whenMust {
+ err = b.MustDurations("param", &dest).BindError()
+ } else {
+ err = b.Durations("param", &dest).BindError()
+ }
+
+ assert.Equal(t, tc.expectValue, dest)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
diff --git a/codecov.yml b/codecov.yml
new file mode 100644
index 000000000..0fa3a3f18
--- /dev/null
+++ b/codecov.yml
@@ -0,0 +1,11 @@
+coverage:
+ status:
+ project:
+ default:
+ threshold: 1%
+ patch:
+ default:
+ threshold: 1%
+
+comment:
+ require_changes: true
\ No newline at end of file
diff --git a/context.go b/context.go
index f081f6d80..f91ea7a60 100644
--- a/context.go
+++ b/context.go
@@ -1,189 +1,659 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package echo
import (
- "encoding/json"
+ "bytes"
"encoding/xml"
- "net/http"
-
+ "errors"
"fmt"
-
- "golang.org/x/net/websocket"
+ "io"
+ "io/fs"
+ "log/slog"
+ "mime/multipart"
+ "net"
+ "net/http"
"net/url"
+ "path"
+ "path/filepath"
+ "strings"
+ "sync"
)
-type (
- // Context represents context for the current request. It holds request and
- // response objects, path parameters, data and registered handler.
- Context struct {
- request *http.Request
- response *Response
- socket *websocket.Conn
- pnames []string
- pvalues []string
- query url.Values
- store store
- echo *Echo
- }
- store map[string]interface{}
+const (
+ // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain.
+ // Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests.
+ // It is added to context only when Router does not find matching method handler for request.
+ ContextKeyHeaderAllow = "echo_header_allow"
)
-// NewContext creates a Context object.
-func NewContext(req *http.Request, res *Response, e *Echo) *Context {
- return &Context{
- request: req,
- response: res,
- echo: e,
- pvalues: make([]string, *e.maxParam),
- store: make(store),
+const (
+ // defaultMemory is default value for memory limit that is used when
+ // parsing multipart forms (See (*http.Request).ParseMultipartForm)
+ defaultMemory int64 = 32 << 20 // 32 MB
+ indexPage = "index.html"
+)
+
+// Context represents the context of the current HTTP request. It holds request and
+// response objects, path, path parameters, data and registered handler.
+type Context struct {
+ request *http.Request
+ orgResponse *Response
+ response http.ResponseWriter
+ query url.Values
+
+ // formParseMaxMemory is used for http.Request.ParseMultipartForm
+ formParseMaxMemory int64
+
+ route *RouteInfo
+ pathValues *PathValues
+
+ store map[string]any
+ echo *Echo
+ logger *slog.Logger
+
+ path string
+ lock sync.RWMutex
+}
+
+// NewContext returns a new Context instance.
+//
+// Note: request,response and e can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway
+// these arguments are useful when creating context for tests and cases like that.
+func NewContext(r *http.Request, w http.ResponseWriter, opts ...any) *Context {
+ var e *Echo
+ for _, opt := range opts {
+ switch v := opt.(type) {
+ case *Echo:
+ e = v
+ }
+ }
+ return newContext(r, w, e)
+}
+
+func newContext(r *http.Request, w http.ResponseWriter, e *Echo) *Context {
+ c := &Context{
+ pathValues: nil,
+ store: make(map[string]any),
+ echo: e,
+ logger: nil,
+ }
+ var logger *slog.Logger
+ paramLen := int32(0)
+ formParseMaxMemory := defaultMemory
+ if e != nil {
+ paramLen = e.contextPathParamAllocSize.Load()
+ logger = e.Logger
+ formParseMaxMemory = e.formParseMaxMemory
+ }
+ if logger == nil {
+ logger = slog.Default()
+ }
+ c.logger = logger
+ p := make(PathValues, 0, paramLen)
+ c.pathValues = &p
+
+ c.SetRequest(r)
+ c.orgResponse = NewResponse(w, logger)
+ c.response = c.orgResponse
+ c.formParseMaxMemory = formParseMaxMemory
+ return c
+}
+
+// Reset resets the context after request completes. It must be called along
+// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
+// See `Echo#ServeHTTP()`
+func (c *Context) Reset(r *http.Request, w http.ResponseWriter) {
+ c.request = r
+ c.orgResponse.reset(w)
+ c.response = c.orgResponse
+ c.query = nil
+ c.store = nil
+ c.logger = c.echo.Logger
+
+ c.route = nil
+ c.path = ""
+ // NOTE: empty by setting length to 0. PathValues has to have capacity of c.echo.contextPathParamAllocSize at all times
+ *c.pathValues = (*c.pathValues)[:0]
+}
+
+func (c *Context) writeContentType(value string) {
+ header := c.response.Header()
+ if header.Get(HeaderContentType) == "" {
+ header.Set(HeaderContentType, value)
}
}
-// Request returns *http.Request.
+// Request returns `*http.Request`.
func (c *Context) Request() *http.Request {
return c.request
}
-// Response returns *Response.
-func (c *Context) Response() *Response {
+// SetRequest sets `*http.Request`.
+func (c *Context) SetRequest(r *http.Request) {
+ c.request = r
+}
+
+// Response returns `*Response`.
+func (c *Context) Response() http.ResponseWriter {
return c.response
}
-// Socket returns *websocket.Conn.
-func (c *Context) Socket() *websocket.Conn {
- return c.socket
+// SetResponse sets `*http.ResponseWriter`. Some context methods and/or middleware require that given ResponseWriter implements following
+// method `Unwrap() http.ResponseWriter` which eventually should return *echo.Response instance.
+func (c *Context) SetResponse(r http.ResponseWriter) {
+ c.response = r
}
-// P returns path parameter by index.
-func (c *Context) P(i int) (value string) {
- l := len(c.pnames)
- if i < l {
- value = c.pvalues[i]
+// IsTLS returns true if HTTP connection is TLS otherwise false.
+func (c *Context) IsTLS() bool {
+ return c.request.TLS != nil
+}
+
+// IsWebSocket returns true if HTTP connection is WebSocket otherwise false.
+func (c *Context) IsWebSocket() bool {
+ upgrade := c.request.Header.Get(HeaderUpgrade)
+ connection := c.request.Header.Get(HeaderConnection)
+ return strings.EqualFold(upgrade, "websocket") && strings.Contains(strings.ToLower(connection), "upgrade")
+}
+
+// Scheme returns the HTTP protocol scheme, `http` or `https`.
+func (c *Context) Scheme() string {
+ // Can't use `r.Request.URL.Scheme`
+ // See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0
+ if c.IsTLS() {
+ return "https"
}
- return
+ if scheme := c.request.Header.Get(HeaderXForwardedProto); scheme != "" {
+ return scheme
+ }
+ if scheme := c.request.Header.Get(HeaderXForwardedProtocol); scheme != "" {
+ return scheme
+ }
+ if ssl := c.request.Header.Get(HeaderXForwardedSsl); ssl == "on" {
+ return "https"
+ }
+ if scheme := c.request.Header.Get(HeaderXUrlScheme); scheme != "" {
+ return scheme
+ }
+ return "http"
}
-// Param returns path parameter by name.
-func (c *Context) Param(name string) (value string) {
- l := len(c.pnames)
- for i, n := range c.pnames {
- if n == name && i < l {
- value = c.pvalues[i]
- break
+// RealIP returns the client's network address based on `X-Forwarded-For`
+// or `X-Real-IP` request header.
+// The behavior can be configured using `Echo#IPExtractor`.
+func (c *Context) RealIP() string {
+ if c.echo != nil && c.echo.IPExtractor != nil {
+ return c.echo.IPExtractor(c.request)
+ }
+ // Fall back to legacy behavior
+ if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" {
+ i := strings.IndexAny(ip, ",")
+ if i > 0 {
+ xffip := strings.TrimSpace(ip[:i])
+ xffip = strings.TrimPrefix(xffip, "[")
+ xffip = strings.TrimSuffix(xffip, "]")
+ return xffip
}
+ return ip
}
- return
+ if ip := c.request.Header.Get(HeaderXRealIP); ip != "" {
+ ip = strings.TrimPrefix(ip, "[")
+ ip = strings.TrimSuffix(ip, "]")
+ return ip
+ }
+ ra, _, _ := net.SplitHostPort(c.request.RemoteAddr)
+ return ra
}
-// Query returns query parameter by name.
-func (c *Context) Query(name string) string {
+// Path returns the registered path for the handler.
+func (c *Context) Path() string {
+ return c.path
+}
+
+// SetPath sets the registered path for the handler.
+func (c *Context) SetPath(p string) {
+ c.path = p
+}
+
+// RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route.
+//
+// RouteInfo returns generic "empty" struct for these cases:
+// * Context is accessed before Routing is done. For example inside Pre middlewares (`e.Pre()`)
+// * Router did not find matching route - 404 (route not found)
+// * Router did not find matching route with same method - 405 (method not allowed)
+func (c *Context) RouteInfo() RouteInfo {
+ if c.route != nil {
+ return c.route.Clone()
+ }
+ return RouteInfo{}
+}
+
+// Param returns path parameter by name.
+func (c *Context) Param(name string) string {
+ return c.pathValues.GetOr(name, "")
+}
+
+// ParamOr returns the path parameter or default value for the provided name.
+//
+// Notes for DefaultRouter implementation:
+// Path parameter could be empty for cases like that:
+// * route `/release-:version/bin` and request URL is `/release-/bin`
+// * route `/api/:version/image.jpg` and request URL is `/api//image.jpg`
+// but not when path parameter is last part of route path
+// * route `/download/file.:ext` will not match request `/download/file.`
+func (c *Context) ParamOr(name, defaultValue string) string {
+ return c.pathValues.GetOr(name, defaultValue)
+}
+
+// PathValues returns path parameter values.
+func (c *Context) PathValues() PathValues {
+ return *c.pathValues
+}
+
+// SetPathValues sets path parameters for current request.
+func (c *Context) SetPathValues(pathValues PathValues) {
+ if pathValues == nil {
+ panic("context SetPathValues called with nil PathValues")
+ }
+ c.setPathValues(&pathValues)
+}
+
+// InitializeRoute sets the route related variables of this request to the context.
+func (c *Context) InitializeRoute(ri *RouteInfo, pathValues *PathValues) {
+ c.route = ri
+ c.path = ri.Path
+ c.setPathValues(pathValues)
+}
+
+func (c *Context) setPathValues(pv *PathValues) {
+ // Router accesses c.pathValues by index and may resize it to full capacity during routing
+ // for that to work without going out-of-bounds we must make sure that c.pathValues slice is not replaced with smaller
+ // slice than Router can set when routing Route with maximum amount of parameters.
+ pathValues := c.pathValues
+ if cap(*c.pathValues) < len(*pv) {
+ // normally we should not end up here. pathValues is normally sized to Echo.contextPathParamAllocSize which should not
+ // be smaller than anything router knows as maximum path parameter count to be.
+ tmp := make(PathValues, len(*pv))
+ c.pathValues = &tmp
+ pathValues = c.pathValues
+ } else if len(*c.pathValues) != len(*pv) {
+ *pathValues = (*pathValues)[0:len(*pv)] // resize slice to given params length for copy to work
+ }
+ copy(*pathValues, *pv)
+}
+
+// QueryParam returns the query param for the provided name.
+func (c *Context) QueryParam(name string) string {
if c.query == nil {
c.query = c.request.URL.Query()
}
return c.query.Get(name)
}
-// Form returns form parameter by name.
-func (c *Context) Form(name string) string {
+// QueryParamOr returns the query param or default value for the provided name.
+// Note: QueryParamOr does not distinguish if query had no value by that name or value was empty string
+// This means URLs `/test?search=` and `/test` would both return `1` for `c.QueryParamOr("search", "1")`
+func (c *Context) QueryParamOr(name, defaultValue string) string {
+ value := c.QueryParam(name)
+ if value == "" {
+ value = defaultValue
+ }
+ return value
+}
+
+// QueryParams returns the query parameters as `url.Values`.
+func (c *Context) QueryParams() url.Values {
+ if c.query == nil {
+ c.query = c.request.URL.Query()
+ }
+ return c.query
+}
+
+// QueryString returns the URL query string.
+func (c *Context) QueryString() string {
+ return c.request.URL.RawQuery
+}
+
+// FormValue returns the form field value for the provided name.
+func (c *Context) FormValue(name string) string {
return c.request.FormValue(name)
}
+// FormValueOr returns the form field value or default value for the provided name.
+// Note: FormValueOr does not distinguish if form had no value by that name or value was empty string
+func (c *Context) FormValueOr(name, defaultValue string) string {
+ value := c.FormValue(name)
+ if value == "" {
+ value = defaultValue
+ }
+ return value
+}
+
+// FormValues returns the form field values as `url.Values`.
+func (c *Context) FormValues() (url.Values, error) {
+ if strings.HasPrefix(c.request.Header.Get(HeaderContentType), MIMEMultipartForm) {
+ if err := c.request.ParseMultipartForm(c.formParseMaxMemory); err != nil {
+ return nil, err
+ }
+ } else {
+ if err := c.request.ParseForm(); err != nil {
+ return nil, err
+ }
+ }
+ return c.request.Form, nil
+}
+
+// FormFile returns the multipart form file for the provided name.
+func (c *Context) FormFile(name string) (*multipart.FileHeader, error) {
+ f, fh, err := c.request.FormFile(name)
+ if err != nil {
+ return nil, err
+ }
+ _ = f.Close()
+ return fh, nil
+}
+
+// MultipartForm returns the multipart form.
+func (c *Context) MultipartForm() (*multipart.Form, error) {
+ err := c.request.ParseMultipartForm(c.formParseMaxMemory)
+ return c.request.MultipartForm, err
+}
+
+// Cookie returns the named cookie provided in the request.
+func (c *Context) Cookie(name string) (*http.Cookie, error) {
+ return c.request.Cookie(name)
+}
+
+// SetCookie adds a `Set-Cookie` header in HTTP response.
+func (c *Context) SetCookie(cookie *http.Cookie) {
+ http.SetCookie(c.Response(), cookie)
+}
+
+// Cookies returns the HTTP cookies sent with the request.
+func (c *Context) Cookies() []*http.Cookie {
+ return c.request.Cookies()
+}
+
// Get retrieves data from the context.
-func (c *Context) Get(key string) interface{} {
+// Method returns any(nil) when key does not exist which is different from typed nil (eg. []byte(nil)).
+func (c *Context) Get(key string) any {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
return c.store[key]
}
// Set saves data in the context.
-func (c *Context) Set(key string, val interface{}) {
+func (c *Context) Set(key string, val any) {
+ c.lock.Lock()
+ defer c.lock.Unlock()
+
if c.store == nil {
- c.store = make(store)
+ c.store = make(map[string]any)
}
c.store[key] = val
}
-// Bind binds the request body into specified type v. Default binder does it
-// based on Content-Type header.
-func (c *Context) Bind(i interface{}) error {
- return c.echo.binder(c.request, i)
+// Bind binds path params, query params and the request body into provided type `i`. The default binder
+// binds body based on Content-Type header.
+func (c *Context) Bind(i any) error {
+ return c.echo.Binder.Bind(c, i)
+}
+
+// Validate validates provided `i`. It is usually called after `Context#Bind()`.
+// Validator must be registered using `Echo#Validator`.
+func (c *Context) Validate(i any) error {
+ if c.echo.Validator == nil {
+ return ErrValidatorNotRegistered
+ }
+ return c.echo.Validator.Validate(i)
}
-// Render invokes the registered HTML template renderer and sends a text/html
-// response with status code.
-func (c *Context) Render(code int, name string, data interface{}) (err error) {
- if c.echo.renderer == nil {
- return RendererNotRegistered
+// Render renders a template with data and sends a text/html response with status
+// code. Renderer must be registered using `Echo.Renderer`.
+func (c *Context) Render(code int, name string, data any) (err error) {
+ if c.echo.Renderer == nil {
+ return ErrRendererNotRegistered
+ }
+ // as Renderer.Render can fail, and in that case we need to delay sending status code to the client until
+ // (global) error handler decides the correct status code for the error to be sent to the client, so we need to write
+ // the rendered template to the buffer first.
+ //
+ // html.Template.ExecuteTemplate() documentations writes:
+ // > If an error occurs executing the template or writing its output,
+ // > execution stops, but partial results may already have been written to
+ // > the output writer.
+
+ buf := new(bytes.Buffer)
+ if err = c.echo.Renderer.Render(c, buf, name, data); err != nil {
+ return
}
- c.response.Header().Set(ContentType, TextHTML)
+ return c.HTMLBlob(code, buf.Bytes())
+}
+
+// HTML sends an HTTP response with status code.
+func (c *Context) HTML(code int, html string) (err error) {
+ return c.HTMLBlob(code, []byte(html))
+}
+
+// HTMLBlob sends an HTTP blob response with status code.
+func (c *Context) HTMLBlob(code int, b []byte) (err error) {
+ return c.Blob(code, MIMETextHTMLCharsetUTF8, b)
+}
+
+// String sends a string response with status code.
+func (c *Context) String(code int, s string) (err error) {
+ return c.Blob(code, MIMETextPlainCharsetUTF8, []byte(s))
+}
+
+func (c *Context) jsonPBlob(code int, callback string, i any) (err error) {
+ c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8)
c.response.WriteHeader(code)
- if err = c.echo.renderer.Render(c.response, name, data); err != nil {
- c.response.clear()
+ if _, err = c.response.Write([]byte(callback + "(")); err != nil {
+ return
+ }
+ if err = c.echo.JSONSerializer.Serialize(c, i, ""); err != nil {
+ return
+ }
+ if _, err = c.response.Write([]byte(");")); err != nil {
+ return
}
return
}
-// HTML formats according to a format specifier and sends text/html response with
-// status code.
-func (c *Context) HTML(code int, format string, a ...interface{}) (err error) {
- c.response.Header().Set(ContentType, TextHTML)
+func (c *Context) json(code int, i any, indent string) error {
+ c.writeContentType(MIMEApplicationJSON)
+
+ // as JSONSerializer.Serialize can fail, and in that case we need to delay sending status code to the client until
+ // (global) error handler decides correct status code for the error to be sent to the client.
+ // For that we need to use writer that can store the proposed status code until the first Write is called.
+ if r, err := UnwrapResponse(c.response); err == nil {
+ r.Status = code
+ } else {
+ resp := c.Response()
+ c.SetResponse(&delayedStatusWriter{ResponseWriter: resp, status: code})
+ defer c.SetResponse(resp)
+ }
+
+ return c.echo.JSONSerializer.Serialize(c, i, indent)
+}
+
+// JSON sends a JSON response with status code.
+func (c *Context) JSON(code int, i any) (err error) {
+ return c.json(code, i, "")
+}
+
+// JSONPretty sends a pretty-print JSON with status code.
+func (c *Context) JSONPretty(code int, i any, indent string) (err error) {
+ return c.json(code, i, indent)
+}
+
+// JSONBlob sends a JSON blob response with status code.
+func (c *Context) JSONBlob(code int, b []byte) (err error) {
+ return c.Blob(code, MIMEApplicationJSON, b)
+}
+
+// JSONP sends a JSONP response with status code. It uses `callback` to construct
+// the JSONP payload.
+func (c *Context) JSONP(code int, callback string, i any) (err error) {
+ return c.jsonPBlob(code, callback, i)
+}
+
+// JSONPBlob sends a JSONP blob response with status code. It uses `callback`
+// to construct the JSONP payload.
+func (c *Context) JSONPBlob(code int, callback string, b []byte) (err error) {
+ c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8)
c.response.WriteHeader(code)
- if _, err = fmt.Fprintf(c.response, format, a...); err != nil {
- c.response.clear()
+ if _, err = c.response.Write([]byte(callback + "(")); err != nil {
+ return
}
+ if _, err = c.response.Write(b); err != nil {
+ return
+ }
+ _, err = c.response.Write([]byte(");"))
return
}
-// String formats according to a format specifier and sends text/plain response
-// with status code.
-func (c *Context) String(code int, format string, a ...interface{}) (err error) {
- c.response.Header().Set(ContentType, TextPlain)
+func (c *Context) xml(code int, i any, indent string) (err error) {
+ c.writeContentType(MIMEApplicationXMLCharsetUTF8)
c.response.WriteHeader(code)
- if _, err = fmt.Fprintf(c.response, format, a...); err != nil {
- c.response.clear()
+ enc := xml.NewEncoder(c.response)
+ if indent != "" {
+ enc.Indent("", indent)
}
- return
+ if _, err = c.response.Write([]byte(xml.Header)); err != nil {
+ return
+ }
+ return enc.Encode(i)
}
-// JSON sends an application/json response with status code.
-func (c *Context) JSON(code int, i interface{}) (err error) {
- c.response.Header().Set(ContentType, ApplicationJSON)
+// XML sends an XML response with status code.
+func (c *Context) XML(code int, i any) (err error) {
+ return c.xml(code, i, "")
+}
+
+// XMLPretty sends a pretty-print XML with status code.
+func (c *Context) XMLPretty(code int, i any, indent string) (err error) {
+ return c.xml(code, i, indent)
+}
+
+// XMLBlob sends an XML blob response with status code.
+func (c *Context) XMLBlob(code int, b []byte) (err error) {
+ c.writeContentType(MIMEApplicationXMLCharsetUTF8)
c.response.WriteHeader(code)
- if err = json.NewEncoder(c.response).Encode(i); err != nil {
- c.response.clear()
+ if _, err = c.response.Write([]byte(xml.Header)); err != nil {
+ return
}
+ _, err = c.response.Write(b)
return
}
-// XML sends an application/xml response with status code.
-func (c *Context) XML(code int, i interface{}) (err error) {
- c.response.Header().Set(ContentType, ApplicationXML)
+// Blob sends a blob response with status code and content type.
+func (c *Context) Blob(code int, contentType string, b []byte) (err error) {
+ c.writeContentType(contentType)
c.response.WriteHeader(code)
- c.response.Write([]byte(xml.Header))
- if err = xml.NewEncoder(c.response).Encode(i); err != nil {
- c.response.clear()
- }
+ _, err = c.response.Write(b)
return
}
+// Stream sends a streaming response with status code and content type.
+func (c *Context) Stream(code int, contentType string, r io.Reader) (err error) {
+ c.writeContentType(contentType)
+ c.response.WriteHeader(code)
+ _, err = io.Copy(c.response, r)
+ return
+}
+
+// File sends a response with the content of the file.
+func (c *Context) File(file string) error {
+ return fsFile(c, file, c.echo.Filesystem)
+}
+
+// FileFS serves file from given file system.
+//
+// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
+// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
+// including `assets/images` as their prefix.
+func (c *Context) FileFS(file string, filesystem fs.FS) error {
+ return fsFile(c, file, filesystem)
+}
+
+func fsFile(c *Context, file string, filesystem fs.FS) error {
+ file = path.Clean(file) // `os.Open` and `os.DirFs.Open()` behave differently, later does not like ``, `.`, `..` at all, but we allowed those now need to clean
+ f, err := filesystem.Open(file)
+ if err != nil {
+ return ErrNotFound
+ }
+ defer f.Close()
+
+ fi, _ := f.Stat()
+ if fi.IsDir() {
+ file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect.
+ f, err = filesystem.Open(file)
+ if err != nil {
+ return ErrNotFound
+ }
+ defer f.Close()
+ if fi, err = f.Stat(); err != nil {
+ return err
+ }
+ }
+ ff, ok := f.(io.ReadSeeker)
+ if !ok {
+ return errors.New("file does not implement io.ReadSeeker")
+ }
+ http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff)
+ return nil
+}
+
+// Attachment sends a response as attachment, prompting client to save the file.
+func (c *Context) Attachment(file, name string) error {
+ return c.contentDisposition(file, name, "attachment")
+}
+
+// Inline sends a response as inline, opening the file in the browser.
+func (c *Context) Inline(file, name string) error {
+ return c.contentDisposition(file, name, "inline")
+}
+
+var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
+
+func (c *Context) contentDisposition(file, name, dispositionType string) error {
+ c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`, dispositionType, quoteEscaper.Replace(name)))
+ return c.File(file)
+}
+
// NoContent sends a response with no body and a status code.
func (c *Context) NoContent(code int) error {
c.response.WriteHeader(code)
return nil
}
-// Redirect redirects the request using http.Redirect with status code.
+// Redirect redirects the request to a provided URL with status code.
func (c *Context) Redirect(code int, url string) error {
- http.Redirect(c.response, c.request, url, code)
- return nil
+ if code < 300 || code > 308 {
+ return ErrInvalidRedirectCode
+ }
+ c.response.Header().Set(HeaderLocation, url)
+ c.response.WriteHeader(code)
+ return nil
+}
+
+// Logger returns logger in Context
+func (c *Context) Logger() *slog.Logger {
+ if c.logger != nil {
+ return c.logger
+ }
+ return c.echo.Logger
}
-// Error invokes the registered HTTP error handler. Generally used by middleware.
-func (c *Context) Error(err error) {
- c.echo.httpErrorHandler(err, c)
+// SetLogger sets logger in Context
+func (c *Context) SetLogger(logger *slog.Logger) {
+ c.logger = logger
}
-func (c *Context) reset(r *http.Request, w http.ResponseWriter, e *Echo) {
- c.request = r
- c.response.reset(w)
- c.query = nil
- c.store = nil
- c.echo = e
+// Echo returns the `Echo` instance.
+func (c *Context) Echo() *Echo {
+ return c.echo
}
diff --git a/context_generic.go b/context_generic.go
new file mode 100644
index 000000000..7cf8b296c
--- /dev/null
+++ b/context_generic.go
@@ -0,0 +1,43 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import "errors"
+
+// ErrNonExistentKey is error that is returned when key does not exist
+var ErrNonExistentKey = errors.New("non existent key")
+
+// ErrInvalidKeyType is error that is returned when the value is not castable to expected type.
+var ErrInvalidKeyType = errors.New("invalid key type")
+
+// ContextGet retrieves a value from the context store or ErrNonExistentKey error the key is missing.
+// Returns ErrInvalidKeyType error if the value is not castable to type T.
+func ContextGet[T any](c *Context, key string) (T, error) {
+ c.lock.RLock()
+ defer c.lock.RUnlock()
+
+ val, ok := c.store[key]
+ if !ok {
+ var zero T
+ return zero, ErrNonExistentKey
+ }
+
+ typed, ok := val.(T)
+ if !ok {
+ var zero T
+ return zero, ErrInvalidKeyType
+ }
+
+ return typed, nil
+}
+
+// ContextGetOr retrieves a value from the context store or returns a default value when the key
+// is missing. Returns ErrInvalidKeyType error if the value is not castable to type T.
+func ContextGetOr[T any](c *Context, key string, defaultValue T) (T, error) {
+ typed, err := ContextGet[T](c, key)
+ if err == ErrNonExistentKey {
+ return defaultValue, nil
+ }
+ return typed, err
+}
diff --git a/context_generic_test.go b/context_generic_test.go
new file mode 100644
index 000000000..ce468ac3e
--- /dev/null
+++ b/context_generic_test.go
@@ -0,0 +1,70 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestContextGetOK(t *testing.T) {
+ c := NewContext(nil, nil)
+
+ c.Set("key", int64(123))
+
+ v, err := ContextGet[int64](c, "key")
+ assert.NoError(t, err)
+ assert.Equal(t, int64(123), v)
+}
+
+func TestContextGetNonExistentKey(t *testing.T) {
+ c := NewContext(nil, nil)
+
+ c.Set("key", int64(123))
+
+ v, err := ContextGet[int64](c, "nope")
+ assert.ErrorIs(t, err, ErrNonExistentKey)
+ assert.Equal(t, int64(0), v)
+}
+
+func TestContextGetInvalidCast(t *testing.T) {
+ c := NewContext(nil, nil)
+
+ c.Set("key", int64(123))
+
+ v, err := ContextGet[bool](c, "key")
+ assert.ErrorIs(t, err, ErrInvalidKeyType)
+ assert.Equal(t, false, v)
+}
+
+func TestContextGetOrOK(t *testing.T) {
+ c := NewContext(nil, nil)
+
+ c.Set("key", int64(123))
+
+ v, err := ContextGetOr[int64](c, "key", 999)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(123), v)
+}
+
+func TestContextGetOrNonExistentKey(t *testing.T) {
+ c := NewContext(nil, nil)
+
+ c.Set("key", int64(123))
+
+ v, err := ContextGetOr[int64](c, "nope", 999)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(999), v)
+}
+
+func TestContextGetOrInvalidCast(t *testing.T) {
+ c := NewContext(nil, nil)
+
+ c.Set("key", int64(123))
+
+ v, err := ContextGetOr[float32](c, "key", float32(999))
+ assert.ErrorIs(t, err, ErrInvalidKeyType)
+ assert.Equal(t, float32(0), v)
+}
diff --git a/context_test.go b/context_test.go
index 2e4f0a6dd..5945c9ecc 100644
--- a/context_test.go
+++ b/context_test.go
@@ -1,193 +1,1476 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package echo
import (
- "errors"
+ "bytes"
+ "crypto/tls"
+ "encoding/json"
+ "encoding/xml"
+ "fmt"
"io"
+ "io/fs"
+ "log/slog"
+ "math"
+ "mime/multipart"
"net/http"
"net/http/httptest"
+ "net/url"
+ "os"
+ "strings"
"testing"
"text/template"
+ "time"
- "strings"
-
- "encoding/xml"
"github.com/stretchr/testify/assert"
- "net/url"
)
-type (
- Template struct {
- templates *template.Template
+type Template struct {
+ templates *template.Template
+}
+
+var testUser = user{ID: 1, Name: "Jon Snow"}
+
+func BenchmarkAllocJSONP(b *testing.B) {
+ e := New()
+ e.Logger = slog.New(slog.DiscardHandler)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ c.JSONP(http.StatusOK, "callback", testUser)
}
-)
+}
-func (t *Template) Render(w io.Writer, name string, data interface{}) error {
+func BenchmarkAllocJSON(b *testing.B) {
+ e := New()
+ e.Logger = slog.New(slog.DiscardHandler)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ c.JSON(http.StatusOK, testUser)
+ }
+}
+
+func BenchmarkAllocXML(b *testing.B) {
+ e := New()
+ e.Logger = slog.New(slog.DiscardHandler)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ c.XML(http.StatusOK, testUser)
+ }
+}
+
+func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) {
+ c := Context{request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}},
+ }}
+ for i := 0; i < b.N; i++ {
+ c.RealIP()
+ }
+}
+
+func (t *Template) Render(c *Context, w io.Writer, name string, data any) error {
return t.templates.ExecuteTemplate(w, name, data)
}
-func TestContext(t *testing.T) {
- userJSON := `{"id":"1","name":"Joe"}`
- userXML := `1Joe`
+func TestContextEcho(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec)
- req, _ := http.NewRequest(POST, "/", strings.NewReader(userJSON))
+ assert.Equal(t, e, c.Echo())
+}
+
+func TestContextRequest(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
- c := NewContext(req, NewResponse(rec), New())
- // Request
+ c := e.NewContext(req, rec)
+
assert.NotNil(t, c.Request())
+ assert.Equal(t, req, c.Request())
+}
+
+func TestContextResponse(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec)
- // Response
assert.NotNil(t, c.Response())
+}
- // Socket
- assert.Nil(t, c.Socket())
+func TestContextRenderTemplate(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
- // Param by id
- c.pnames = []string{"id"}
- c.pvalues = []string{"1"}
- assert.Equal(t, "1", c.P(0))
+ c := e.NewContext(req, rec)
- // Param by name
- assert.Equal(t, "1", c.Param("id"))
+ tmpl := &Template{
+ templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")),
+ }
+ c.Echo().Renderer = tmpl
+ err := c.Render(http.StatusOK, "hello", "Jon Snow")
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "Hello, Jon Snow!", rec.Body.String())
+ }
+}
- // Store
- c.Set("user", "Joe")
- assert.Equal(t, "Joe", c.Get("user"))
+func TestContextRenderTemplateError(t *testing.T) {
+ // we test that when template rendering fails, no response is sent to the client yet, so the global error handler can decide what to do
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
- //------
- // Bind
- //------
+ tmpl := &Template{
+ templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")),
+ }
+ c.Echo().Renderer = tmpl
+ err := c.Render(http.StatusOK, "not_existing", "Jon Snow")
- // JSON
- testBind(t, c, "application/json")
+ assert.EqualError(t, err, `template: no template "not_existing" associated with template "hello"`)
+ assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
+ assert.Empty(t, rec.Body.String()) // body must not be sent to the client
+}
- // XML
- c.request, _ = http.NewRequest(POST, "/", strings.NewReader(userXML))
- testBind(t, c, ApplicationXML)
+func TestContextRenderErrorsOnNoRenderer(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
- // Unsupported
- testBind(t, c, "")
+ c := e.NewContext(req, rec)
- //--------
- // Render
- //--------
+ c.Echo().Renderer = nil
+ assert.Error(t, c.Render(http.StatusOK, "hello", "Jon Snow"))
+}
- tpl := &Template{
- templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")),
+func TestContextStream(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
+ r, w := io.Pipe()
+ go func() {
+ defer w.Close()
+ for i := 0; i < 3; i++ {
+ fmt.Fprintf(w, "data: index %v\n\n", i)
+ time.Sleep(5 * time.Millisecond)
+ }
+ }()
+
+ err := c.Stream(http.StatusOK, "text/event-stream", r)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "text/event-stream", rec.Header().Get(HeaderContentType))
+ assert.Equal(t, "data: index 0\n\ndata: index 1\n\ndata: index 2\n\n", rec.Body.String())
}
- c.echo.SetRenderer(tpl)
- err := c.Render(http.StatusOK, "hello", "Joe")
+}
+
+func TestContextHTML(t *testing.T) {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := NewContext(req, rec)
+
+ err := c.HTML(http.StatusOK, "Hi, Jon Snow")
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, "Hello, Joe!", rec.Body.String())
+ assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, "Hi, Jon Snow", rec.Body.String())
}
+}
- c.echo.renderer = nil
- err = c.Render(http.StatusOK, "hello", "Joe")
- assert.Error(t, err)
+func TestContextHTMLBlob(t *testing.T) {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := NewContext(req, rec)
- // JSON
- req.Header.Set(Accept, ApplicationJSON)
- rec = httptest.NewRecorder()
- c = NewContext(req, NewResponse(rec), New())
- err = c.JSON(http.StatusOK, user{"1", "Joe"})
+ err := c.HTMLBlob(http.StatusOK, []byte("Hi, Jon Snow"))
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, ApplicationJSON, rec.Header().Get(ContentType))
- assert.Equal(t, userJSON, strings.TrimSpace(rec.Body.String()))
+ assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, "Hi, Jon Snow", rec.Body.String())
}
+}
+
+func TestContextJSON(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ c := e.NewContext(req, rec)
- // XML
- req.Header.Set(Accept, ApplicationXML)
- rec = httptest.NewRecorder()
- c = NewContext(req, NewResponse(rec), New())
- err = c.XML(http.StatusOK, user{"1", "Joe"})
+ err := c.JSON(http.StatusOK, user{ID: 1, Name: "Jon Snow"})
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, ApplicationXML, rec.Header().Get(ContentType))
- assert.Equal(t, xml.Header, xml.Header, rec.Body.String())
+ assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSON+"\n", rec.Body.String())
}
+}
+
+func TestContextJSONErrorsOut(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ c := e.NewContext(req, rec)
+
+ err := c.JSON(http.StatusOK, make(chan bool))
+ assert.EqualError(t, err, "json: unsupported type: chan bool")
+
+ assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
+ assert.Empty(t, rec.Body.String()) // body must not be sent to the client
+}
+
+func TestContextJSONWithNotEchoResponse(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ c := e.NewContext(req, rec)
+
+ c.SetResponse(rec)
+
+ err := c.JSON(http.StatusCreated, map[string]float64{"foo": math.NaN()})
+ assert.EqualError(t, err, "json: unsupported value: NaN")
+
+ assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
+ assert.Empty(t, rec.Body.String()) // body must not be sent to the client
+}
- // String
- req.Header.Set(Accept, TextPlain)
- rec = httptest.NewRecorder()
- c = NewContext(req, NewResponse(rec), New())
- err = c.String(http.StatusOK, "Hello, World!")
+func TestContextJSONPretty(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
+ err := c.JSONPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, " ")
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, TextPlain, rec.Header().Get(ContentType))
- assert.Equal(t, "Hello, World!", rec.Body.String())
+ assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
}
+}
+
+func TestContextJSONWithEmptyIntent(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
+ u := user{ID: 1, Name: "Jon Snow"}
+ emptyIndent := ""
+ buf := new(bytes.Buffer)
- // HTML
- req.Header.Set(Accept, TextHTML)
- rec = httptest.NewRecorder()
- c = NewContext(req, NewResponse(rec), New())
- err = c.HTML(http.StatusOK, "Hello, World!")
+ enc := json.NewEncoder(buf)
+ enc.SetIndent(emptyIndent, emptyIndent)
+ _ = enc.Encode(u)
+ err := c.JSONPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, emptyIndent)
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, TextHTML, rec.Header().Get(ContentType))
- assert.Equal(t, "Hello, World!", rec.Body.String())
+ assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, buf.String(), rec.Body.String())
}
+}
- // NoContent
- rec = httptest.NewRecorder()
- c = NewContext(req, NewResponse(rec), New())
- c.NoContent(http.StatusOK)
- assert.Equal(t, http.StatusOK, c.response.status)
+func TestContextJSONP(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
- // Redirect
- rec = httptest.NewRecorder()
- c = NewContext(req, NewResponse(rec), New())
- assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
+ callback := "callback"
+ err := c.JSONP(http.StatusOK, callback, user{ID: 1, Name: "Jon Snow"})
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String())
+ }
+}
- // Error
- rec = httptest.NewRecorder()
- c = NewContext(req, NewResponse(rec), New())
- c.Error(errors.New("error"))
- assert.Equal(t, http.StatusInternalServerError, c.response.status)
+func TestContextJSONBlob(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
- // reset
- c.reset(req, NewResponse(httptest.NewRecorder()), New())
+ data, err := json.Marshal(user{ID: 1, Name: "Jon Snow"})
+ assert.NoError(t, err)
+ err = c.JSONBlob(http.StatusOK, data)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSON, rec.Body.String())
+ }
}
-func TestContextQuery(t *testing.T) {
- q := make(url.Values)
- q.Set("name", "joe")
- q.Set("email", "joe@labstack.com")
+func TestContextJSONPBlob(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
- req, err := http.NewRequest(GET, "/", nil)
+ callback := "callback"
+ data, err := json.Marshal(user{ID: 1, Name: "Jon Snow"})
assert.NoError(t, err)
- req.URL.RawQuery = q.Encode()
+ err = c.JSONPBlob(http.StatusOK, callback, data)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, callback+"("+userJSON+");", rec.Body.String())
+ }
+}
- c := NewContext(req, nil, New())
- assert.Equal(t, "joe", c.Query("name"))
- assert.Equal(t, "joe@labstack.com", c.Query("email"))
+func TestContextXML(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+ err := c.XML(http.StatusOK, user{ID: 1, Name: "Jon Snow"})
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+userXML, rec.Body.String())
+ }
}
-func TestContextForm(t *testing.T) {
- f := make(url.Values)
- f.Set("name", "joe")
- f.Set("email", "joe@labstack.com")
+func TestContextXMLPretty(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
+ err := c.XMLPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, " ")
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String())
+ }
+}
+
+func TestContextXMLBlob(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
- req, err := http.NewRequest(POST, "/", strings.NewReader(f.Encode()))
+ data, err := xml.Marshal(user{ID: 1, Name: "Jon Snow"})
assert.NoError(t, err)
- req.Header.Add(ContentType, ApplicationForm)
+ err = c.XMLBlob(http.StatusOK, data)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+userXML, rec.Body.String())
+ }
+}
+
+func TestContextXMLWithEmptyIntent(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
+ u := user{ID: 1, Name: "Jon Snow"}
+ emptyIndent := ""
+ buf := new(bytes.Buffer)
+
+ enc := xml.NewEncoder(buf)
+ enc.Indent(emptyIndent, emptyIndent)
+ _ = enc.Encode(u)
+ err := c.XMLPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, emptyIndent)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, xml.Header+buf.String(), rec.Body.String())
+ }
+}
+
+func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ err := c.JSON(http.StatusCreated, user{ID: 1, Name: "Jon Snow"})
+
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusCreated, rec.Code)
+ assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType))
+ assert.Equal(t, userJSON+"\n", rec.Body.String())
+ }
+}
+
+func TestContextAttachment(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenName string
+ expectHeader string
+ }{
+ {
+ name: "ok",
+ whenName: "walle.png",
+ expectHeader: `attachment; filename="walle.png"`,
+ },
+ {
+ name: "ok, escape quotes in malicious filename",
+ whenName: `malicious.sh"; \"; dummy=.txt`,
+ expectHeader: `attachment; filename="malicious.sh\"; \\\"; dummy=.txt"`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
+ err := c.Attachment("_fixture/images/walle.png", tc.whenName)
+ if assert.NoError(t, err) {
+ assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition))
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, 219885, rec.Body.Len())
+ }
+ })
+ }
+}
+
+func TestContextInline(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenName string
+ expectHeader string
+ }{
+ {
+ name: "ok",
+ whenName: "walle.png",
+ expectHeader: `inline; filename="walle.png"`,
+ },
+ {
+ name: "ok, escape quotes in malicious filename",
+ whenName: `malicious.sh"; \"; dummy=.txt`,
+ expectHeader: `inline; filename="malicious.sh\"; \\\"; dummy=.txt"`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, rec)
+
+ err := c.Inline("_fixture/images/walle.png", tc.whenName)
+ if assert.NoError(t, err) {
+ assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition))
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, 219885, rec.Body.Len())
+ }
+ })
+ }
+}
+
+func TestContextNoContent(t *testing.T) {
+ e := New()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/?pretty", nil)
+ c := e.NewContext(req, rec)
+
+ c.NoContent(http.StatusOK)
+ assert.Equal(t, http.StatusOK, rec.Code)
+}
+
+func TestContextCookie(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ theme := "theme=light"
+ user := "user=Jon Snow"
+ req.Header.Add(HeaderCookie, theme)
+ req.Header.Add(HeaderCookie, user)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ // Read single
+ cookie, err := c.Cookie("theme")
+ if assert.NoError(t, err) {
+ assert.Equal(t, "theme", cookie.Name)
+ assert.Equal(t, "light", cookie.Value)
+ }
+
+ // Read multiple
+ for _, cookie := range c.Cookies() {
+ switch cookie.Name {
+ case "theme":
+ assert.Equal(t, "light", cookie.Value)
+ case "user":
+ assert.Equal(t, "Jon Snow", cookie.Value)
+ }
+ }
+
+ // Write
+ cookie = &http.Cookie{
+ Name: "SSID",
+ Value: "Ap4PGTEq",
+ Domain: "labstack.com",
+ Path: "/",
+ Expires: time.Now(),
+ Secure: true,
+ HttpOnly: true,
+ }
+ c.SetCookie(cookie)
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure")
+ assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly")
+}
- c := NewContext(req, nil, New())
- assert.Equal(t, "joe", c.Form("name"))
- assert.Equal(t, "joe@labstack.com", c.Form("email"))
+func TestContext_PathValues(t *testing.T) {
+ var testCases = []struct {
+ name string
+ given PathValues
+ expect PathValues
+ }{
+ {
+ name: "param exists",
+ given: PathValues{
+ {Name: "uid", Value: "101"},
+ {Name: "fid", Value: "501"},
+ },
+ expect: PathValues{
+ {Name: "uid", Value: "101"},
+ {Name: "fid", Value: "501"},
+ },
+ },
+ {
+ name: "params is empty",
+ given: PathValues{},
+ expect: PathValues{},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, nil)
+
+ c.SetPathValues(tc.given)
+
+ assert.EqualValues(t, tc.expect, c.PathValues())
+ })
+ }
+}
+
+func TestContext_PathParam(t *testing.T) {
+ var testCases = []struct {
+ name string
+ given PathValues
+ whenParamName string
+ expect string
+ }{
+ {
+ name: "param exists",
+ given: PathValues{
+ {Name: "uid", Value: "101"},
+ {Name: "fid", Value: "501"},
+ },
+ whenParamName: "uid",
+ expect: "101",
+ },
+ {
+ name: "multiple same param values exists - return first",
+ given: PathValues{
+ {Name: "uid", Value: "101"},
+ {Name: "uid", Value: "202"},
+ {Name: "fid", Value: "501"},
+ },
+ whenParamName: "uid",
+ expect: "101",
+ },
+ {
+ name: "param does not exists",
+ given: PathValues{
+ {Name: "uid", Value: "101"},
+ },
+ whenParamName: "nope",
+ expect: "",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, nil)
+
+ c.SetPathValues(tc.given)
+
+ assert.EqualValues(t, tc.expect, c.Param(tc.whenParamName))
+ })
+ }
+}
+
+func TestContext_PathParamDefault(t *testing.T) {
+ var testCases = []struct {
+ name string
+ given PathValues
+ whenParamName string
+ whenDefaultValue string
+ expect string
+ }{
+ {
+ name: "param exists",
+ given: PathValues{
+ {Name: "uid", Value: "101"},
+ {Name: "fid", Value: "501"},
+ },
+ whenParamName: "uid",
+ whenDefaultValue: "999",
+ expect: "101",
+ },
+ {
+ name: "param exists and is empty",
+ given: PathValues{
+ {Name: "uid", Value: ""},
+ {Name: "fid", Value: "501"},
+ },
+ whenParamName: "uid",
+ whenDefaultValue: "999",
+ expect: "", // <-- this is different from QueryParamOr behaviour
+ },
+ {
+ name: "param does not exists",
+ given: PathValues{
+ {Name: "uid", Value: "101"},
+ },
+ whenParamName: "nope",
+ whenDefaultValue: "999",
+ expect: "999",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ c := e.NewContext(req, nil)
+
+ c.SetPathValues(tc.given)
+
+ assert.EqualValues(t, tc.expect, c.ParamOr(tc.whenParamName, tc.whenDefaultValue))
+ })
+ }
+}
+
+func TestContextGetAndSetPathValuesMutability(t *testing.T) {
+ t.Run("c.PathValues() does not return copy and modifying raw slice mutates value in context", func(t *testing.T) {
+ e := New()
+ e.contextPathParamAllocSize.Store(1)
+
+ req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
+ c := e.NewContext(req, nil)
+
+ params := PathValues{{Name: "foo", Value: "101"}}
+ c.SetPathValues(params)
+
+ // round-trip param values with modification
+ paramVals := c.PathValues()
+ assert.Equal(t, params, c.PathValues())
+
+ // PathValues() does not return copy and modifying raw slice mutates value in context
+ paramVals[0] = PathValue{Name: "xxx", Value: "yyy"}
+ assert.Equal(t, PathValues{PathValue{Name: "xxx", Value: "yyy"}}, c.PathValues())
+ })
+
+ t.Run("calling SetPathValues with bigger size changes capacity in context", func(t *testing.T) {
+ e := New()
+ e.contextPathParamAllocSize.Store(1)
+
+ req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
+ c := e.NewContext(req, nil)
+ // increase path param capacity in context
+ pathValues := PathValues{
+ {Name: "aaa", Value: "bbb"},
+ {Name: "ccc", Value: "ddd"},
+ }
+ c.SetPathValues(pathValues)
+ assert.Equal(t, pathValues, c.PathValues())
+
+ // shouldn't explode during Reset() afterwards!
+ assert.NotPanics(t, func() {
+ c.Reset(nil, nil)
+ })
+ assert.Equal(t, PathValues{}, c.PathValues())
+ assert.Len(t, *c.pathValues, 0)
+ assert.Equal(t, 2, cap(*c.pathValues))
+ })
+
+ t.Run("calling SetPathValues with smaller size slice does not change capacity in context", func(t *testing.T) {
+ e := New()
+
+ req := httptest.NewRequest(http.MethodGet, "/:foo", nil)
+ c := e.NewContext(req, nil)
+ c.pathValues = &PathValues{
+ {Name: "aaa", Value: "bbb"},
+ {Name: "ccc", Value: "ddd"},
+ }
+
+ pathValues := PathValues{
+ {Name: "aaa", Value: "bbb"},
+ }
+ // given pathValues slice is smaller. this should not decrease c.pathValues capacity
+ c.SetPathValues(pathValues)
+ assert.Equal(t, pathValues, c.PathValues())
+
+ // shouldn't explode during Reset() afterwards!
+ assert.NotPanics(t, func() {
+ c.Reset(nil, nil)
+ })
+ assert.Equal(t, PathValues{}, c.PathValues())
+ assert.Len(t, *c.pathValues, 0)
+ assert.Equal(t, 2, cap(*c.pathValues))
+ })
+
+}
+
+// Issue #1655
+func TestContext_SetParamNamesShouldNotModifyPathValuesCapacity(t *testing.T) {
+ e := New()
+ c := e.NewContext(nil, nil)
+
+ assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load())
+ expectedTwoParams := PathValues{
+ {Name: "1", Value: "one"},
+ {Name: "2", Value: "two"},
+ }
+ c.SetPathValues(expectedTwoParams)
+ assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load())
+ assert.Equal(t, expectedTwoParams, c.PathValues())
+
+ expectedThreeParams := PathValues{
+ {Name: "1", Value: "one"},
+ {Name: "2", Value: "two"},
+ {Name: "3", Value: "three"},
+ }
+ c.SetPathValues(expectedThreeParams)
+ assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load())
+ assert.Equal(t, expectedThreeParams, c.PathValues())
+}
+
+func TestContextFormValue(t *testing.T) {
+ f := make(url.Values)
+ f.Set("name", "Jon Snow")
+ f.Set("email", "jon@labstack.com")
+
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
+ req.Header.Add(HeaderContentType, MIMEApplicationForm)
+ c := e.NewContext(req, nil)
+
+ // FormValue
+ assert.Equal(t, "Jon Snow", c.FormValue("name"))
+ assert.Equal(t, "jon@labstack.com", c.FormValue("email"))
+
+ // FormValueOr
+ assert.Equal(t, "Jon Snow", c.FormValueOr("name", "nope"))
+ assert.Equal(t, "default", c.FormValueOr("missing", "default"))
+
+ // FormValues
+ values, err := c.FormValues()
+ if assert.NoError(t, err) {
+ assert.Equal(t, url.Values{
+ "name": []string{"Jon Snow"},
+ "email": []string{"jon@labstack.com"},
+ }, values)
+ }
+
+ // Multipart FormParams error
+ req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
+ req.Header.Add(HeaderContentType, MIMEMultipartForm)
+ c = e.NewContext(req, nil)
+ values, err = c.FormValues()
+ assert.Nil(t, values)
+ assert.Error(t, err)
+}
+
+func TestContext_QueryParams(t *testing.T) {
+ var testCases = []struct {
+ expect url.Values
+ name string
+ givenURL string
+ }{
+ {
+ name: "multiple values in url",
+ givenURL: "/?test=1&test=2&email=jon%40labstack.com",
+ expect: url.Values{
+ "test": []string{"1", "2"},
+ "email": []string{"jon@labstack.com"},
+ },
+ },
+ {
+ name: "single value in url",
+ givenURL: "/?nope=1",
+ expect: url.Values{
+ "nope": []string{"1"},
+ },
+ },
+ {
+ name: "no query params in url",
+ givenURL: "/?",
+ expect: url.Values{},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
+ e := New()
+ c := e.NewContext(req, nil)
+
+ assert.Equal(t, tc.expect, c.QueryParams())
+ })
+ }
+}
+
+func TestContext_QueryParam(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ whenParamName string
+ expect string
+ }{
+ {
+ name: "value exists in url",
+ givenURL: "/?test=1",
+ whenParamName: "test",
+ expect: "1",
+ },
+ {
+ name: "multiple values exists in url",
+ givenURL: "/?test=9&test=8",
+ whenParamName: "test",
+ expect: "9", // <-- first value in returned
+ },
+ {
+ name: "value does not exists in url",
+ givenURL: "/?nope=1",
+ whenParamName: "test",
+ expect: "",
+ },
+ {
+ name: "value is empty in url",
+ givenURL: "/?test=",
+ whenParamName: "test",
+ expect: "",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
+ e := New()
+ c := e.NewContext(req, nil)
+
+ assert.Equal(t, tc.expect, c.QueryParam(tc.whenParamName))
+ })
+ }
+}
+
+func TestContext_QueryParamDefault(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenURL string
+ whenParamName string
+ whenDefaultValue string
+ expect string
+ }{
+ {
+ name: "value exists in url",
+ givenURL: "/?test=1",
+ whenParamName: "test",
+ whenDefaultValue: "999",
+ expect: "1",
+ },
+ {
+ name: "value does not exists in url",
+ givenURL: "/?nope=1",
+ whenParamName: "test",
+ whenDefaultValue: "999",
+ expect: "999",
+ },
+ {
+ name: "value is empty in url",
+ givenURL: "/?test=",
+ whenParamName: "test",
+ whenDefaultValue: "999",
+ expect: "999",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
+ e := New()
+ c := e.NewContext(req, nil)
+
+ assert.Equal(t, tc.expect, c.QueryParamOr(tc.whenParamName, tc.whenDefaultValue))
+ })
+ }
+}
+
+func TestContextFormFile(t *testing.T) {
+ e := New()
+ buf := new(bytes.Buffer)
+ mr := multipart.NewWriter(buf)
+ w, err := mr.CreateFormFile("file", "test")
+ if assert.NoError(t, err) {
+ w.Write([]byte("test"))
+ }
+ mr.Close()
+ req := httptest.NewRequest(http.MethodPost, "/", buf)
+ req.Header.Set(HeaderContentType, mr.FormDataContentType())
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ f, err := c.FormFile("file")
+ if assert.NoError(t, err) {
+ assert.Equal(t, "test", f.Filename)
+ }
+}
+
+func TestContextMultipartForm(t *testing.T) {
+ e := New()
+ buf := new(bytes.Buffer)
+ mw := multipart.NewWriter(buf)
+ mw.WriteField("name", "Jon Snow")
+ fileContent := "This is a test file"
+ w, err := mw.CreateFormFile("file", "test.txt")
+ if assert.NoError(t, err) {
+ w.Write([]byte(fileContent))
+ }
+ mw.Close()
+ req := httptest.NewRequest(http.MethodPost, "/", buf)
+ req.Header.Set(HeaderContentType, mw.FormDataContentType())
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ f, err := c.MultipartForm()
+ if assert.NoError(t, err) {
+ assert.NotNil(t, f)
+
+ files := f.File["file"]
+ if assert.Len(t, files, 1) {
+ file := files[0]
+ assert.Equal(t, "test.txt", file.Filename)
+ assert.Equal(t, int64(len(fileContent)), file.Size)
+ }
+ }
+}
+
+func TestContextRedirect(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
+ assert.Equal(t, http.StatusMovedPermanently, rec.Code)
+ assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
+ assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo"))
+}
+
+func TestContextGet(t *testing.T) {
+ var testCases = []struct {
+ name string
+ given any
+ whenKey string
+ expect any
+ }{
+ {
+ name: "ok, value exist",
+ given: "Jon Snow",
+ whenKey: "key",
+ expect: "Jon Snow",
+ },
+ {
+ name: "ok, value does not exist",
+ given: "Jon Snow",
+ whenKey: "nope",
+ expect: nil,
+ },
+ {
+ name: "ok, value is nil value",
+ given: []byte(nil),
+ whenKey: "key",
+ expect: []byte(nil),
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ var c = new(Context)
+ c.Set("key", tc.given)
+
+ v := c.Get(tc.whenKey)
+ assert.Equal(t, tc.expect, v)
+ })
+ }
+}
+
+func BenchmarkContext_Store(b *testing.B) {
+ e := &Echo{}
+
+ c := &Context{
+ echo: e,
+ }
+
+ for n := 0; n < b.N; n++ {
+ c.Set("name", "Jon Snow")
+ if c.Get("name") != "Jon Snow" {
+ b.Fail()
+ }
+ }
+}
+
+type validator struct{}
+
+func (*validator) Validate(i any) error {
+ return nil
+}
+
+func TestContext_Validate(t *testing.T) {
+ e := New()
+ c := e.NewContext(nil, nil)
+
+ assert.Error(t, c.Validate(struct{}{}))
+
+ e.Validator = &validator{}
+ assert.NoError(t, c.Validate(struct{}{}))
+}
+
+func TestContext_QueryString(t *testing.T) {
+ e := New()
+
+ queryString := "query=string&var=val"
+
+ req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil)
+ c := e.NewContext(req, nil)
+
+ assert.Equal(t, queryString, c.QueryString())
+}
+
+func TestContext_Request(t *testing.T) {
+ var c = new(Context)
+
+ assert.Nil(t, c.Request())
+
+ req := httptest.NewRequest(http.MethodGet, "/path", nil)
+ c.SetRequest(req)
+
+ assert.Equal(t, req, c.Request())
+}
+
+func TestContext_Scheme(t *testing.T) {
+ tests := []struct {
+ c *Context
+ s string
+ }{
+ {
+ &Context{
+ request: &http.Request{
+ TLS: &tls.ConnectionState{},
+ },
+ },
+ "https",
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedProto: []string{"https"}},
+ },
+ },
+ "https",
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedProtocol: []string{"http"}},
+ },
+ },
+ "http",
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedSsl: []string{"on"}},
+ },
+ },
+ "https",
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{HeaderXUrlScheme: []string{"https"}},
+ },
+ },
+ "https",
+ },
+ {
+ &Context{
+ request: &http.Request{},
+ },
+ "http",
+ },
+ }
+
+ for _, tt := range tests {
+ assert.Equal(t, tt.s, tt.c.Scheme())
+ }
+}
+
+func TestContext_IsWebSocket(t *testing.T) {
+ tests := []struct {
+ c *Context
+ ws assert.BoolAssertionFunc
+ }{
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{
+ HeaderUpgrade: []string{"websocket"},
+ HeaderConnection: []string{"upgrade"},
+ },
+ },
+ },
+ assert.True,
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{
+ HeaderUpgrade: []string{"Websocket"},
+ HeaderConnection: []string{"Upgrade"},
+ },
+ },
+ },
+ assert.True,
+ },
+ {
+ &Context{
+ request: &http.Request{},
+ },
+ assert.False,
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{HeaderUpgrade: []string{"other"}},
+ },
+ },
+ assert.False,
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{
+ HeaderUpgrade: []string{"websocket"},
+ HeaderConnection: []string{"close"},
+ },
+ },
+ },
+ assert.False,
+ },
+ }
+
+ for i, tt := range tests {
+ t.Run(fmt.Sprintf("test %d", i+1), func(t *testing.T) {
+ tt.ws(t, tt.c.IsWebSocket())
+ })
+ }
}
-func testBind(t *testing.T, c *Context, ct string) {
- c.request.Header.Set(ContentType, ct)
+func TestContext_Bind(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ c := e.NewContext(req, nil)
u := new(user)
+
+ req.Header.Add(HeaderContentType, MIMEApplicationJSON)
err := c.Bind(u)
- if ct == "" {
- assert.Error(t, UnsupportedMediaType)
- } else if assert.NoError(t, err) {
- assert.Equal(t, "1", u.ID)
- assert.Equal(t, "Joe", u.Name)
+ assert.NoError(t, err)
+ assert.Equal(t, &user{ID: 1, Name: "Jon Snow"}, u)
+}
+
+func TestContext_RealIP(t *testing.T) {
+ tests := []struct {
+ c *Context
+ s string
+ }{
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}},
+ },
+ },
+ "127.0.0.1",
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1,127.0.1.1"}},
+ },
+ },
+ "127.0.0.1",
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1"}},
+ },
+ },
+ "127.0.0.1",
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348], 2001:db8::1, "}},
+ },
+ },
+ "2001:db8:85a3:8d3:1319:8a2e:370:7348",
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348],[2001:db8::1]"}},
+ },
+ },
+ "2001:db8:85a3:8d3:1319:8a2e:370:7348",
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{HeaderXForwardedFor: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}},
+ },
+ },
+ "2001:db8:85a3:8d3:1319:8a2e:370:7348",
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{
+ "X-Real-Ip": []string{"192.168.0.1"},
+ },
+ },
+ },
+ "192.168.0.1",
+ },
+ {
+ &Context{
+ request: &http.Request{
+ Header: http.Header{
+ "X-Real-Ip": []string{"[2001:db8::1]"},
+ },
+ },
+ },
+ "2001:db8::1",
+ },
+
+ {
+ &Context{
+ request: &http.Request{
+ RemoteAddr: "89.89.89.89:1654",
+ },
+ },
+ "89.89.89.89",
+ },
+ }
+
+ for _, tt := range tests {
+ assert.Equal(t, tt.s, tt.c.RealIP())
}
}
+
+func TestContext_File(t *testing.T) {
+ var testCases = []struct {
+ whenFS fs.FS
+ name string
+ whenFile string
+ expectError string
+ expectStartsWith []byte
+ expectStatus int
+ }{
+ {
+ name: "ok, from default file system",
+ whenFile: "_fixture/images/walle.png",
+ whenFS: nil,
+ expectStatus: http.StatusOK,
+ expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ },
+ {
+ name: "ok, from custom file system",
+ whenFile: "walle.png",
+ whenFS: os.DirFS("_fixture/images"),
+ expectStatus: http.StatusOK,
+ expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ },
+ {
+ name: "nok, not existent file",
+ whenFile: "not.png",
+ whenFS: os.DirFS("_fixture/images"),
+ expectStatus: http.StatusOK,
+ expectStartsWith: nil,
+ expectError: "Not Found",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ if tc.whenFS != nil {
+ e.Filesystem = tc.whenFS
+ }
+
+ handler := func(ec *Context) error {
+ return ec.File(tc.whenFile)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := handler(c)
+
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ body := rec.Body.Bytes()
+ if len(body) > len(tc.expectStartsWith) {
+ body = body[:len(tc.expectStartsWith)]
+ }
+ assert.Equal(t, tc.expectStartsWith, body)
+ })
+ }
+}
+
+func TestContext_FileFS(t *testing.T) {
+ var testCases = []struct {
+ whenFS fs.FS
+ name string
+ whenFile string
+ expectError string
+ expectStartsWith []byte
+ expectStatus int
+ }{
+ {
+ name: "ok",
+ whenFile: "walle.png",
+ whenFS: os.DirFS("_fixture/images"),
+ expectStatus: http.StatusOK,
+ expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ },
+ {
+ name: "nok, not existent file",
+ whenFile: "not.png",
+ whenFS: os.DirFS("_fixture/images"),
+ expectStatus: http.StatusOK,
+ expectStartsWith: nil,
+ expectError: "Not Found",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+
+ handler := func(ec *Context) error {
+ return ec.FileFS(tc.whenFile, tc.whenFS)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/match.png", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := handler(c)
+
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ body := rec.Body.Bytes()
+ if len(body) > len(tc.expectStartsWith) {
+ body = body[:len(tc.expectStartsWith)]
+ }
+ assert.Equal(t, tc.expectStartsWith, body)
+ })
+ }
+}
+
+func TestLogger(t *testing.T) {
+ e := New()
+ c := e.NewContext(nil, nil)
+
+ log1 := c.Logger()
+ assert.NotNil(t, log1)
+ assert.Equal(t, e.Logger, log1)
+
+ customLogger := slog.New(slog.NewTextHandler(os.Stdout, nil))
+ c.SetLogger(customLogger)
+ assert.Equal(t, customLogger, c.Logger())
+
+ // Resetting the context returns the initial Echo logger
+ c.Reset(nil, nil)
+ assert.Equal(t, e.Logger, c.Logger())
+}
+
+func TestRouteInfo(t *testing.T) {
+ e := New()
+ c := e.NewContext(nil, nil)
+
+ orgRI := RouteInfo{
+ Name: "root",
+ Method: http.MethodGet,
+ Path: "/*",
+ Parameters: []string{"*"},
+ }
+ c.route = &orgRI
+ ri := c.RouteInfo()
+ assert.Equal(t, orgRI, ri)
+
+ // Test mutability when middlewares start to change things
+
+ // RouteInfo inside context will not be affected when returned instance is changed
+ expect := orgRI.Clone()
+ ri.Path = "changed"
+ ri.Parameters[0] = "changed"
+ assert.Equal(t, expect, c.RouteInfo())
+
+ // RouteInfo inside context will not be affected when returned instance is changed
+ expect = c.RouteInfo()
+ orgRI.Name = "changed"
+ assert.NotEqual(t, expect, c.RouteInfo())
+}
diff --git a/echo.go b/echo.go
index f76553e93..4855e8429 100644
--- a/echo.go
+++ b/echo.go
@@ -1,580 +1,834 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+/*
+Package echo implements high performance, minimalist Go web framework.
+
+Example:
+
+ package main
+
+ import (
+ "log/slog"
+ "net/http"
+
+ "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v5/middleware"
+ )
+
+ // Handler
+ func hello(c *echo.Context) error {
+ return c.String(http.StatusOK, "Hello, World!")
+ }
+
+ func main() {
+ // Echo instance
+ e := echo.New()
+
+ // Middleware
+ e.Use(middleware.RequestLogger())
+ e.Use(middleware.Recover())
+
+ // Routes
+ e.GET("/", hello)
+
+ // Start server
+ if err := e.Start(":8080"); err != nil {
+ slog.Error("failed to start server", "error", err)
+ }
+ }
+
+Learn more at https://echo.labstack.com
+*/
package echo
import (
- "bytes"
+ stdContext "context"
"encoding/json"
"errors"
"fmt"
- "io"
- "log"
+ "io/fs"
+ "log/slog"
"net/http"
- spath "path"
- "reflect"
- "runtime"
+ "net/url"
+ "os"
+ "os/signal"
+ "path/filepath"
"strings"
"sync"
-
- "encoding/xml"
-
- "github.com/bradfitz/http2"
- "github.com/labstack/gommon/color"
- "golang.org/x/net/websocket"
+ "sync/atomic"
+ "syscall"
)
-type (
- Echo struct {
- prefix string
- middleware []MiddlewareFunc
- http2 bool
- maxParam *int
- notFoundHandler HandlerFunc
- defaultHTTPErrorHandler HTTPErrorHandler
- httpErrorHandler HTTPErrorHandler
- binder BindFunc
- renderer Renderer
- pool sync.Pool
- debug bool
- router *Router
- }
+// Echo is the top-level framework instance.
+//
+// Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these
+// fields from handlers/middlewares and changing field values at the same time leads to data-races.
+// Same rule applies to adding new routes after server has been started - Adding a route is not Goroutine safe action.
+type Echo struct {
+ serveHTTPFunc func(http.ResponseWriter, *http.Request)
- Route struct {
- Method string
- Path string
- Handler Handler
- }
+ Binder Binder
+ Filesystem fs.FS
+ Renderer Renderer
+ Validator Validator
+ JSONSerializer JSONSerializer
+ IPExtractor IPExtractor
+ OnAddRoute func(route Route) error
+ HTTPErrorHandler HTTPErrorHandler
+ Logger *slog.Logger
- HTTPError struct {
- code int
- message string
- }
+ contextPool sync.Pool
- Middleware interface{}
- MiddlewareFunc func(HandlerFunc) HandlerFunc
- Handler interface{}
- HandlerFunc func(*Context) error
+ router Router
- // HTTPErrorHandler is a centralized HTTP error handler.
- HTTPErrorHandler func(error, *Context)
+ // premiddleware are middlewares that are called before routing is done
+ premiddleware []MiddlewareFunc
- BindFunc func(*http.Request, interface{}) error
+ // middleware are middlewares that are called after routing is done and before handler is called
+ middleware []MiddlewareFunc
- // Renderer is the interface that wraps the Render method.
- //
- // Render renders the HTML template with given name and specified data.
- // It writes the output to w.
- Renderer interface {
- Render(w io.Writer, name string, data interface{}) error
- }
-)
+ contextPathParamAllocSize atomic.Int32
-const (
- // CONNECT HTTP method
- CONNECT = "CONNECT"
- // DELETE HTTP method
- DELETE = "DELETE"
- // GET HTTP method
- GET = "GET"
- // HEAD HTTP method
- HEAD = "HEAD"
- // OPTIONS HTTP method
- OPTIONS = "OPTIONS"
- // PATCH HTTP method
- PATCH = "PATCH"
- // POST HTTP method
- POST = "POST"
- // PUT HTTP method
- PUT = "PUT"
- // TRACE HTTP method
- TRACE = "TRACE"
-
- //-------------
- // Media types
- //-------------
-
- ApplicationJSON = "application/json; charset=utf-8"
- ApplicationXML = "application/xml; charset=utf-8"
- ApplicationForm = "application/x-www-form-urlencoded"
- ApplicationProtobuf = "application/protobuf"
- ApplicationMsgpack = "application/msgpack"
- TextHTML = "text/html; charset=utf-8"
- TextPlain = "text/plain; charset=utf-8"
- MultipartForm = "multipart/form-data"
-
- //---------
- // Headers
- //---------
-
- Accept = "Accept"
- AcceptEncoding = "Accept-Encoding"
- Authorization = "Authorization"
- ContentDisposition = "Content-Disposition"
- ContentEncoding = "Content-Encoding"
- ContentLength = "Content-Length"
- ContentType = "Content-Type"
- Location = "Location"
- Upgrade = "Upgrade"
- Vary = "Vary"
-
- //-----------
- // Protocols
- //-----------
-
- WebSocket = "websocket"
-
- indexFile = "index.html"
-)
+ // formParseMaxMemory is passed to Context for multipart form parsing (See http.Request.ParseMultipartForm)
+ formParseMaxMemory int64
+}
-var (
- methods = [...]string{
- CONNECT,
- DELETE,
- GET,
- HEAD,
- OPTIONS,
- PATCH,
- POST,
- PUT,
- TRACE,
- }
-
- //--------
- // Errors
- //--------
-
- UnsupportedMediaType = errors.New("echo ⇒ unsupported media type")
- RendererNotRegistered = errors.New("echo ⇒ renderer not registered")
-)
+// JSONSerializer is the interface that encodes and decodes JSON to and from interfaces.
+type JSONSerializer interface {
+ Serialize(c *Context, target any, indent string) error
+ Deserialize(c *Context, target any) error
+}
-// New creates an instance of Echo.
-func New() (e *Echo) {
- e = &Echo{maxParam: new(int)}
- e.pool.New = func() interface{} {
- return NewContext(nil, new(Response), e)
- }
- e.router = NewRouter(e)
+// HTTPErrorHandler is a centralized HTTP error handler.
+type HTTPErrorHandler func(c *Context, err error)
- //----------
- // Defaults
- //----------
+// HandlerFunc defines a function to serve HTTP requests.
+type HandlerFunc func(c *Context) error
- if runtime.GOOS == "windows" {
- e.DisableColoredLog()
- }
- e.HTTP2(false)
- e.notFoundHandler = func(c *Context) error {
- return NewHTTPError(http.StatusNotFound)
- }
- e.defaultHTTPErrorHandler = func(err error, c *Context) {
- code := http.StatusInternalServerError
- msg := http.StatusText(code)
- if he, ok := err.(*HTTPError); ok {
- code = he.code
- msg = he.message
- }
- if e.debug {
- msg = err.Error()
- }
- http.Error(c.response, msg, code)
- }
- e.SetHTTPErrorHandler(e.defaultHTTPErrorHandler)
- e.SetBinder(func(r *http.Request, v interface{}) error {
- ct := r.Header.Get(ContentType)
- err := UnsupportedMediaType
- if strings.HasPrefix(ApplicationJSON, ct) {
- err = json.NewDecoder(r.Body).Decode(v)
- } else if strings.HasPrefix(ApplicationXML, ct) {
- err = xml.NewDecoder(r.Body).Decode(v)
- }
- return err
- })
- return
-}
+// MiddlewareFunc defines a function to process middleware.
+type MiddlewareFunc func(next HandlerFunc) HandlerFunc
-// Router returns router.
-func (e *Echo) Router() *Router {
- return e.router
+// MiddlewareConfigurator defines interface for creating middleware handlers with possibility to return configuration errors instead of panicking.
+type MiddlewareConfigurator interface {
+ ToMiddleware() (MiddlewareFunc, error)
}
-// DisableColoredLog disables colored log.
-func (e *Echo) DisableColoredLog() {
- color.Disable()
+// Validator is the interface that wraps the Validate function.
+type Validator interface {
+ Validate(i any) error
}
-// HTTP2 enables HTTP2 support.
-func (e *Echo) HTTP2(on bool) {
- e.http2 = on
-}
+// MIME types
+const (
+ // MIMEApplicationJSON JavaScript Object Notation (JSON) https://www.rfc-editor.org/rfc/rfc8259
+ MIMEApplicationJSON = "application/json"
+ // Deprecated: Please use MIMEApplicationJSON instead. JSON should be encoded using UTF-8 by default.
+ // No "charset" parameter is defined for this registration.
+ // Adding one really has no effect on compliant recipients.
+ // See RFC 8259, section 8.1. https://datatracker.ietf.org/doc/html/rfc8259#section-8.1n"
+ MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8
+ MIMEApplicationJavaScript = "application/javascript"
+ MIMEApplicationJavaScriptCharsetUTF8 = MIMEApplicationJavaScript + "; " + charsetUTF8
+ MIMEApplicationXML = "application/xml"
+ MIMEApplicationXMLCharsetUTF8 = MIMEApplicationXML + "; " + charsetUTF8
+ MIMETextXML = "text/xml"
+ MIMETextXMLCharsetUTF8 = MIMETextXML + "; " + charsetUTF8
+ MIMEApplicationForm = "application/x-www-form-urlencoded"
+ MIMEApplicationProtobuf = "application/protobuf"
+ MIMEApplicationMsgpack = "application/msgpack"
+ MIMETextHTML = "text/html"
+ MIMETextHTMLCharsetUTF8 = MIMETextHTML + "; " + charsetUTF8
+ MIMETextPlain = "text/plain"
+ MIMETextPlainCharsetUTF8 = MIMETextPlain + "; " + charsetUTF8
+ MIMEMultipartForm = "multipart/form-data"
+ MIMEOctetStream = "application/octet-stream"
+)
-// DefaultHTTPErrorHandler invokes the default HTTP error handler.
-func (e *Echo) DefaultHTTPErrorHandler(err error, c *Context) {
- e.defaultHTTPErrorHandler(err, c)
-}
+const (
+ charsetUTF8 = "charset=UTF-8"
+ // PROPFIND Method can be used on collection and property resources.
+ PROPFIND = "PROPFIND"
+ // REPORT Method can be used to get information about a resource, see rfc 3253
+ REPORT = "REPORT"
+ // RouteNotFound is special method type for routes handling "route not found" (404) cases
+ RouteNotFound = "echo_route_not_found"
+ // RouteAny is special method type that matches any HTTP method in request. Any has lower
+ // priority that other methods that have been registered with Router to that path.
+ RouteAny = "echo_route_any"
+)
-// SetHTTPErrorHandler registers a custom Echo.HTTPErrorHandler.
-func (e *Echo) SetHTTPErrorHandler(h HTTPErrorHandler) {
- e.httpErrorHandler = h
-}
+// Headers
+const (
+ HeaderAccept = "Accept"
+ HeaderAcceptEncoding = "Accept-Encoding"
+ // HeaderAllow is the name of the "Allow" header field used to list the set of methods
+ // advertised as supported by the target resource. Returning an Allow header is mandatory
+ // for status 405 (method not found) and useful for the OPTIONS method in responses.
+ // See RFC 7231: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1
+ HeaderAllow = "Allow"
+ HeaderAuthorization = "Authorization"
+ HeaderContentDisposition = "Content-Disposition"
+ HeaderContentEncoding = "Content-Encoding"
+ HeaderContentLength = "Content-Length"
+ HeaderContentType = "Content-Type"
+ HeaderCookie = "Cookie"
+ HeaderSetCookie = "Set-Cookie"
+ HeaderIfModifiedSince = "If-Modified-Since"
+ HeaderLastModified = "Last-Modified"
+ HeaderLocation = "Location"
+ HeaderRetryAfter = "Retry-After"
+ HeaderUpgrade = "Upgrade"
+ HeaderVary = "Vary"
+ HeaderWWWAuthenticate = "WWW-Authenticate"
+ HeaderXForwardedFor = "X-Forwarded-For"
+ HeaderXForwardedProto = "X-Forwarded-Proto"
+ HeaderXForwardedProtocol = "X-Forwarded-Protocol"
+ HeaderXForwardedSsl = "X-Forwarded-Ssl"
+ HeaderXUrlScheme = "X-Url-Scheme"
+ HeaderXHTTPMethodOverride = "X-HTTP-Method-Override"
+ HeaderXRealIP = "X-Real-Ip"
+ HeaderXRequestID = "X-Request-Id"
+ HeaderXCorrelationID = "X-Correlation-Id"
+ HeaderXRequestedWith = "X-Requested-With"
+ HeaderServer = "Server"
+
+ // HeaderOrigin request header indicates the origin (scheme, hostname, and port) that caused the request.
+ // See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin
+ HeaderOrigin = "Origin"
+ HeaderCacheControl = "Cache-Control"
+ HeaderConnection = "Connection"
+
+ // Access control
+ HeaderAccessControlRequestMethod = "Access-Control-Request-Method"
+ HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers"
+ HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin"
+ HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods"
+ HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers"
+ HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials"
+ HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers"
+ HeaderAccessControlMaxAge = "Access-Control-Max-Age"
+
+ // Security
+ HeaderStrictTransportSecurity = "Strict-Transport-Security"
+ HeaderXContentTypeOptions = "X-Content-Type-Options"
+ HeaderXXSSProtection = "X-XSS-Protection"
+ HeaderXFrameOptions = "X-Frame-Options"
+ HeaderContentSecurityPolicy = "Content-Security-Policy"
+ HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only"
+ HeaderXCSRFToken = "X-CSRF-Token" // #nosec G101
+ HeaderReferrerPolicy = "Referrer-Policy"
+
+ // HeaderSecFetchSite fetch metadata request header indicates the relationship between a request initiator's
+ // origin and the origin of the requested resource.
+ // See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Sec-Fetch-Site
+ HeaderSecFetchSite = "Sec-Fetch-Site"
+)
-// SetBinder registers a custom binder. It's invoked by Context.Bind().
-func (e *Echo) SetBinder(b BindFunc) {
- e.binder = b
+// Config is configuration for NewWithConfig function
+type Config struct {
+ // Logger is the slog logger instance used for application-wide structured logging.
+ // If not set, a default TextHandler writing to stdout is created.
+ Logger *slog.Logger
+
+ // HTTPErrorHandler is the centralized error handler that processes errors returned
+ // by handlers and middleware, converting them to appropriate HTTP responses.
+ // If not set, DefaultHTTPErrorHandler(false) is used.
+ HTTPErrorHandler HTTPErrorHandler
+
+ // Router is the HTTP request router responsible for matching URLs to handlers
+ // using a radix tree-based algorithm.
+ // If not set, NewRouter(RouterConfig{}) is used.
+ Router Router
+
+ // OnAddRoute is an optional callback hook executed when routes are registered.
+ // Useful for route validation, logging, or custom route processing.
+ // If not set, no callback is executed.
+ OnAddRoute func(route Route) error
+
+ // Filesystem is the fs.FS implementation used for serving static files.
+ // Supports os.DirFS, embed.FS, and custom implementations.
+ // If not set, defaults to current working directory.
+ Filesystem fs.FS
+
+ // Binder handles automatic data binding from HTTP requests to Go structs.
+ // Supports JSON, XML, form data, query parameters, and path parameters.
+ // If not set, DefaultBinder is used.
+ Binder Binder
+
+ // Validator provides optional struct validation after data binding.
+ // Commonly used with third-party validation libraries.
+ // If not set, Context.Validate() returns ErrValidatorNotRegistered.
+ Validator Validator
+
+ // Renderer provides template rendering for generating HTML responses.
+ // Requires integration with a template engine like html/template.
+ // If not set, Context.Render() returns ErrRendererNotRegistered.
+ Renderer Renderer
+
+ // JSONSerializer handles JSON encoding and decoding for HTTP requests/responses.
+ // Can be replaced with faster alternatives like jsoniter or sonic.
+ // If not set, DefaultJSONSerializer using encoding/json is used.
+ JSONSerializer JSONSerializer
+
+ // IPExtractor defines the strategy for extracting the real client IP address
+ // from requests, particularly important when behind proxies or load balancers.
+ // Used for rate limiting, access control, and logging.
+ // If not set, falls back to checking X-Forwarded-For and X-Real-IP headers.
+ IPExtractor IPExtractor
+
+ // FormParseMaxMemory is default value for memory limit that is used
+ // when parsing multipart forms (See (*http.Request).ParseMultipartForm)
+ FormParseMaxMemory int64
+}
+
+// NewWithConfig creates an instance of Echo with given configuration.
+func NewWithConfig(config Config) *Echo {
+ e := New()
+ if config.Logger != nil {
+ e.Logger = config.Logger
+ }
+ if config.HTTPErrorHandler != nil {
+ e.HTTPErrorHandler = config.HTTPErrorHandler
+ }
+ if config.Router != nil {
+ e.router = config.Router
+ }
+ if config.OnAddRoute != nil {
+ e.OnAddRoute = config.OnAddRoute
+ }
+ if config.Filesystem != nil {
+ e.Filesystem = config.Filesystem
+ }
+ if config.Binder != nil {
+ e.Binder = config.Binder
+ }
+ if config.Validator != nil {
+ e.Validator = config.Validator
+ }
+ if config.Renderer != nil {
+ e.Renderer = config.Renderer
+ }
+ if config.JSONSerializer != nil {
+ e.JSONSerializer = config.JSONSerializer
+ }
+ if config.IPExtractor != nil {
+ e.IPExtractor = config.IPExtractor
+ }
+ if config.FormParseMaxMemory > 0 {
+ e.formParseMaxMemory = config.FormParseMaxMemory
+ }
+ return e
}
-// SetRenderer registers an HTML template renderer. It's invoked by Context.Render().
-func (e *Echo) SetRenderer(r Renderer) {
- e.renderer = r
-}
+// New creates an instance of Echo.
+func New() *Echo {
+ logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
+ e := &Echo{
+ Logger: logger,
+ Filesystem: newDefaultFS(),
+ Binder: &DefaultBinder{},
+ JSONSerializer: &DefaultJSONSerializer{},
+ formParseMaxMemory: defaultMemory,
+ }
-// SetDebug sets debug mode.
-func (e *Echo) SetDebug(on bool) {
- e.debug = on
+ e.serveHTTPFunc = e.serveHTTP
+ e.router = NewRouter(RouterConfig{})
+ e.HTTPErrorHandler = DefaultHTTPErrorHandler(false)
+ e.contextPool.New = func() any {
+ return newContext(nil, nil, e)
+ }
+ return e
}
-// Debug returns debug mode.
-func (e *Echo) Debug() bool {
- return e.debug
+// NewContext returns a new Context instance.
+//
+// Note: both request and response can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway
+// these arguments are useful when creating context for tests and cases like that.
+func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) *Context {
+ return newContext(r, w, e)
}
-// Use adds handler to the middleware chain.
-func (e *Echo) Use(m ...Middleware) {
- for _, h := range m {
- e.middleware = append(e.middleware, wrapMiddleware(h))
- }
+// Router returns the default router.
+func (e *Echo) Router() Router {
+ return e.router
}
-// Connect adds a CONNECT route > handler to the router.
-func (e *Echo) Connect(path string, h Handler) {
- e.add(CONNECT, path, h)
-}
+// DefaultHTTPErrorHandler creates new default HTTP error handler implementation. It sends a JSON response
+// with status code. `exposeError` parameter decides if returned message will contain also error message or not
+//
+// Note: DefaultHTTPErrorHandler does not log errors. Use middleware for it if errors need to be logged (separately)
+// Note: In case errors happens in middleware call-chain that is returning from handler (which did not return an error).
+// When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from
+// handler. Then the error that global error handler received will be ignored because we have already "committed" the
+// response and status code header has been sent to the client.
+func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler {
+ return func(c *Context, err error) {
+ if r, _ := UnwrapResponse(c.response); r != nil && r.Committed {
+ return
+ }
-// Delete adds a DELETE route > handler to the router.
-func (e *Echo) Delete(path string, h Handler) {
- e.add(DELETE, path, h)
-}
+ code := http.StatusInternalServerError
+ var sc HTTPStatusCoder
+ if errors.As(err, &sc) {
+ if tmp := sc.StatusCode(); tmp != 0 {
+ code = tmp
+ }
+ }
+
+ var result any
+ switch m := sc.(type) {
+ case json.Marshaler: // this type knows how to format itself to JSON
+ result = m
+ case *HTTPError:
+ sText := m.Message
+ if sText == "" {
+ sText = http.StatusText(code)
+ }
+ msg := map[string]any{"message": sText}
+ if exposeError {
+ if wrappedErr := m.Unwrap(); wrappedErr != nil {
+ msg["error"] = wrappedErr.Error()
+ }
+ }
+ result = msg
+ default:
+ msg := map[string]any{"message": http.StatusText(code)}
+ if exposeError {
+ msg["error"] = err.Error()
+ }
+ result = msg
+ }
-// Get adds a GET route > handler to the router.
-func (e *Echo) Get(path string, h Handler) {
- e.add(GET, path, h)
+ var cErr error
+ if c.Request().Method == http.MethodHead { // Issue #608
+ cErr = c.NoContent(code)
+ } else {
+ cErr = c.JSON(code, result)
+ }
+ if cErr != nil {
+ c.Logger().Error("echo default error handler failed to send error to client", "error", cErr) // truly rare case. ala client already disconnected
+ }
+ }
}
-// Head adds a HEAD route > handler to the router.
-func (e *Echo) Head(path string, h Handler) {
- e.add(HEAD, path, h)
+// Pre adds middleware to the chain which is run before router tries to find matching route.
+// Meaning middleware is executed even for 404 (not found) cases.
+func (e *Echo) Pre(middleware ...MiddlewareFunc) {
+ e.premiddleware = append(e.premiddleware, middleware...)
}
-// Options adds an OPTIONS route > handler to the router.
-func (e *Echo) Options(path string, h Handler) {
- e.add(OPTIONS, path, h)
+// Use adds middleware to the chain which is run after router has found matching route and before route/request handler method is executed.
+func (e *Echo) Use(middleware ...MiddlewareFunc) {
+ e.middleware = append(e.middleware, middleware...)
}
-// Patch adds a PATCH route > handler to the router.
-func (e *Echo) Patch(path string, h Handler) {
- e.add(PATCH, path, h)
+// CONNECT registers a new CONNECT route for a path with matching handler in the
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return e.Add(http.MethodConnect, path, h, m...)
}
-// Post adds a POST route > handler to the router.
-func (e *Echo) Post(path string, h Handler) {
- e.add(POST, path, h)
+// DELETE registers a new DELETE route for a path with matching handler in the router
+// with optional route-level middleware. Panics on error.
+func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return e.Add(http.MethodDelete, path, h, m...)
}
-// Put adds a PUT route > handler to the router.
-func (e *Echo) Put(path string, h Handler) {
- e.add(PUT, path, h)
+// GET registers a new GET route for a path with matching handler in the router
+// with optional route-level middleware. Panics on error.
+func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return e.Add(http.MethodGet, path, h, m...)
}
-// Trace adds a TRACE route > handler to the router.
-func (e *Echo) Trace(path string, h Handler) {
- e.add(TRACE, path, h)
+// HEAD registers a new HEAD route for a path with matching handler in the
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return e.Add(http.MethodHead, path, h, m...)
}
-// WebSocket adds a WebSocket route > handler to the router.
-func (e *Echo) WebSocket(path string, h HandlerFunc) {
- e.Get(path, func(c *Context) (err error) {
- wss := websocket.Server{
- Handler: func(ws *websocket.Conn) {
- c.socket = ws
- c.response.status = http.StatusSwitchingProtocols
- err = h(c)
- },
- }
- wss.ServeHTTP(c.response, c.request)
- return err
- })
+// OPTIONS registers a new OPTIONS route for a path with matching handler in the
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return e.Add(http.MethodOptions, path, h, m...)
}
-func (e *Echo) add(method, path string, h Handler) {
- path = e.prefix + path
- e.router.Add(method, path, wrapHandler(h), e)
- r := Route{
- Method: method,
- Path: path,
- Handler: runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name(),
- }
- e.router.routes = append(e.router.routes, r)
+// PATCH registers a new PATCH route for a path with matching handler in the
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return e.Add(http.MethodPatch, path, h, m...)
}
-// Index serves index file.
-func (e *Echo) Index(file string) {
- e.ServeFile("/", file)
+// POST registers a new POST route for a path with matching handler in the
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return e.Add(http.MethodPost, path, h, m...)
}
-// Favicon serves the default favicon - GET /favicon.ico.
-func (e *Echo) Favicon(file string) {
- e.ServeFile("/favicon.ico", file)
+// PUT registers a new PUT route for a path with matching handler in the
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return e.Add(http.MethodPut, path, h, m...)
}
-// Static serves static files from a directory. It's an alias for `Echo.ServeDir`
-func (e *Echo) Static(path, dir string) {
- e.ServeDir(path, dir)
+// TRACE registers a new TRACE route for a path with matching handler in the
+// router with optional route-level middleware. Panics on error.
+func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return e.Add(http.MethodTrace, path, h, m...)
}
-// ServeDir serves files from a directory.
-func (e *Echo) ServeDir(path, dir string) {
- e.Get(path+"*", func(c *Context) error {
- return serveFile(dir, c.P(0), c) // Param `_name`
- })
+// RouteNotFound registers a special-case route which is executed when no other route is found (i.e. HTTP 404 cases)
+// for current request URL.
+// Path supports static and named/any parameters just like other http method is defined. Generally path is ended with
+// wildcard/match-any character (`/*`, `/download/*` etc).
+//
+// Example: `e.RouteNotFound("/*", func(c *echo.Context) error { return c.NoContent(http.StatusNotFound) })`
+func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return e.Add(RouteNotFound, path, h, m...)
}
-// ServeFile serves a file.
-func (e *Echo) ServeFile(path, file string) {
- e.Get(path, func(c *Context) error {
- dir, file := spath.Split(file)
- return serveFile(dir, file, c)
- })
+// Any registers a new route for all HTTP methods (supported by Echo) and path with matching handler
+// in the router with optional route-level middleware.
+//
+// Note: this method only adds specific set of supported HTTP methods as handler and is not true
+// "catch-any-arbitrary-method" way of matching requests.
+func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
+ return e.Add(RouteAny, path, handler, middleware...)
}
-func serveFile(dir, file string, c *Context) error {
- fs := http.Dir(dir)
- f, err := fs.Open(file)
- if err != nil {
- return NewHTTPError(http.StatusNotFound)
+// Match registers a new route for multiple HTTP methods and path with matching
+// handler in the router with optional route-level middleware. Panics on error.
+func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
+ errs := make([]error, 0)
+ ris := make(Routes, 0)
+ for _, m := range methods {
+ ri, err := e.AddRoute(Route{
+ Method: m,
+ Path: path,
+ Handler: handler,
+ Middlewares: middleware,
+ })
+ if err != nil {
+ errs = append(errs, err)
+ continue
+ }
+ ris = append(ris, ri)
}
+ if len(errs) > 0 {
+ panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
+ }
+ return ris
+}
+
+// Static registers a new route with path prefix to serve static files from the provided root directory.
+func (e *Echo) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo {
+ subFs := MustSubFS(e.Filesystem, fsRoot)
+ return e.Add(
+ http.MethodGet,
+ pathPrefix+"*",
+ StaticDirectoryHandler(subFs, false),
+ middleware...,
+ )
+}
+
+// StaticFS registers a new route with path prefix to serve static files from the provided file system.
+//
+// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
+// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
+// including `assets/images` as their prefix.
+func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo {
+ return e.Add(
+ http.MethodGet,
+ pathPrefix+"*",
+ StaticDirectoryHandler(filesystem, false),
+ middleware...,
+ )
+}
+
+// StaticDirectoryHandler creates handler function to serve files from provided file system
+// When disablePathUnescaping is set then file name from path is not unescaped and is served as is.
+func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc {
+ return func(c *Context) error {
+ p := c.Param("*")
+ if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice
+ tmpPath, err := url.PathUnescape(p)
+ if err != nil {
+ return fmt.Errorf("failed to unescape path variable: %w", err)
+ }
+ p = tmpPath
+ }
- fi, _ := f.Stat()
- if fi.IsDir() {
- file = spath.Join(file, indexFile)
- f, err = fs.Open(file)
+ // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid
+ name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/")))
+ fi, err := fs.Stat(fileSystem, name)
if err != nil {
- return NewHTTPError(http.StatusForbidden)
+ return ErrNotFound
}
- fi, _ = f.Stat()
- }
-
- http.ServeContent(c.response, c.request, fi.Name(), fi.ModTime(), f)
- return nil
-}
-
-// Group creates a new sub router with prefix. It inherits all properties from
-// the parent. Passing middleware overrides parent middleware.
-func (e *Echo) Group(prefix string, m ...Middleware) *Group {
- g := &Group{*e}
- g.echo.prefix += prefix
- if len(m) > 0 {
- g.echo.middleware = nil
- g.Use(m...)
- }
- return g
-}
-
-// URI generates a URI from handler.
-func (e *Echo) URI(h Handler, params ...interface{}) string {
- uri := new(bytes.Buffer)
- pl := len(params)
- n := 0
- hn := runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
- for _, r := range e.router.routes {
- if r.Handler == hn {
- for i, l := 0, len(r.Path); i < l; i++ {
- if r.Path[i] == ':' && n < pl {
- for ; i < l && r.Path[i] != '/'; i++ {
- }
- uri.WriteString(fmt.Sprintf("%v", params[n]))
- n++
- }
- if i < l {
- uri.WriteByte(r.Path[i])
- }
- }
- break
+
+ // If the request is for a directory and does not end with "/"
+ p = c.Request().URL.Path // path must not be empty.
+ if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' {
+ // Redirect to ends with "/"
+ return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/"))
}
+ return fsFile(c, name, fileSystem)
}
- return uri.String()
}
-// URL is an alias for `URI` function.
-func (e *Echo) URL(h Handler, params ...interface{}) string {
- return e.URI(h, params...)
+// FileFS registers a new route with path to serve file from the provided file system.
+func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo {
+ return e.GET(path, StaticFileHandler(file, filesystem), m...)
}
-// Routes returns the registered routes.
-func (e *Echo) Routes() []Route {
- return e.router.routes
+// StaticFileHandler creates handler function to serve file from provided file system
+func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc {
+ return func(c *Context) error {
+ return fsFile(c, file, filesystem)
+ }
}
-// ServeHTTP implements `http.Handler` interface, which serves HTTP requests.
-func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- c := e.pool.Get().(*Context)
- h, echo := e.router.Find(r.Method, r.URL.Path, c)
- if echo != nil {
- e = echo
+// File registers a new route with path to serve a static file with optional route-level middleware. Panics on error.
+func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo {
+ handler := func(c *Context) error {
+ return c.File(file)
}
- c.reset(r, w, e)
- if h == nil {
- h = e.notFoundHandler
+ return e.Add(http.MethodGet, path, handler, middleware...)
+}
+
+// AddRoute registers a new Route with default host Router
+func (e *Echo) AddRoute(route Route) (RouteInfo, error) {
+ return e.add(route)
+}
+
+func (e *Echo) add(route Route) (RouteInfo, error) {
+ if e.OnAddRoute != nil {
+ if err := e.OnAddRoute(route); err != nil {
+ return RouteInfo{}, err
+ }
}
- // Chain middleware with handler in the end
- for i := len(e.middleware) - 1; i >= 0; i-- {
- h = e.middleware[i](h)
+ ri, err := e.router.Add(route)
+ if err != nil {
+ return RouteInfo{}, err
}
- // Execute chain
- if err := h(c); err != nil {
- e.httpErrorHandler(err, c)
+ paramsCount := int32(len(ri.Parameters)) // #nosec G115
+ if paramsCount > e.contextPathParamAllocSize.Load() {
+ e.contextPathParamAllocSize.Store(paramsCount)
}
+ return ri, nil
+}
+
+// Add registers a new route for an HTTP method and path with matching handler
+// in the router with optional route-level middleware.
+func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
+ ri, err := e.add(
+ Route{
+ Method: method,
+ Path: path,
+ Handler: handler,
+ Middlewares: middleware,
+ Name: "",
+ },
+ )
+ if err != nil {
+ panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
+ }
+ return ri
+}
- e.pool.Put(c)
+// Group creates a new router group with prefix and optional group-level middleware.
+func (e *Echo) Group(prefix string, m ...MiddlewareFunc) (g *Group) {
+ g = &Group{prefix: prefix, echo: e}
+ g.Use(m...)
+ return
}
-// Server returns the internal *http.Server.
-func (e *Echo) Server(addr string) *http.Server {
- s := &http.Server{Addr: addr}
- s.Handler = e
- if e.http2 {
- http2.ConfigureServer(s, nil)
- }
- return s
+// PreMiddlewares returns registered pre middlewares. These are middleware to the chain
+// which are run before router tries to find matching route.
+// Use this method to build your own ServeHTTP method.
+//
+// NOTE: returned slice is not a copy. Do not mutate.
+func (e *Echo) PreMiddlewares() []MiddlewareFunc {
+ return e.premiddleware
}
-// Run runs a server.
-func (e *Echo) Run(addr string) {
- s := e.Server(addr)
- e.run(s)
+// Middlewares returns registered route level middlewares. Does not contain any group level
+// middlewares. Use this method to build your own ServeHTTP method.
+//
+// NOTE: returned slice is not a copy. Do not mutate.
+func (e *Echo) Middlewares() []MiddlewareFunc {
+ return e.middleware
}
-// RunTLS runs a server with TLS configuration.
-func (e *Echo) RunTLS(addr, certFile, keyFile string) {
- s := e.Server(addr)
- e.run(s, certFile, keyFile)
+// AcquireContext returns an empty `Context` instance from the pool.
+// You must return the context by calling `ReleaseContext()`.
+func (e *Echo) AcquireContext() *Context {
+ return e.contextPool.Get().(*Context)
}
-// RunServer runs a custom server.
-func (e *Echo) RunServer(s *http.Server) {
- e.run(s)
+// ReleaseContext returns the `Context` instance back to the pool.
+// You must call it after `AcquireContext()`.
+func (e *Echo) ReleaseContext(c *Context) {
+ e.contextPool.Put(c)
}
-// RunTLSServer runs a custom server with TLS configuration.
-func (e *Echo) RunTLSServer(s *http.Server, certFile, keyFile string) {
- e.run(s, certFile, keyFile)
+// ServeHTTP implements `http.Handler` interface, which serves HTTP requests.
+func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ e.serveHTTPFunc(w, r)
}
-func (e *Echo) run(s *http.Server, files ...string) {
- if len(files) == 0 {
- log.Fatal(s.ListenAndServe())
- } else if len(files) == 2 {
- log.Fatal(s.ListenAndServeTLS(files[0], files[1]))
+// serveHTTP implements `http.Handler` interface, which serves HTTP requests.
+func (e *Echo) serveHTTP(w http.ResponseWriter, r *http.Request) {
+ c := e.contextPool.Get().(*Context)
+ defer e.contextPool.Put(c)
+
+ c.Reset(r, w)
+ var h HandlerFunc
+
+ if e.premiddleware == nil {
+ h = applyMiddleware(e.router.Route(c), e.middleware...)
} else {
- log.Fatal("echo => invalid TLS configuration")
+ h = func(cc *Context) error {
+ h1 := applyMiddleware(e.router.Route(cc), e.middleware...)
+ return h1(cc)
+ }
+ h = applyMiddleware(h, e.premiddleware...)
+ }
+
+ // Execute chain
+ if err := h(c); err != nil {
+ e.HTTPErrorHandler(c, err)
}
}
-func NewHTTPError(code int, msg ...string) *HTTPError {
- he := &HTTPError{code: code, message: http.StatusText(code)}
- if len(msg) > 0 {
- m := msg[0]
- he.message = m
+// Start stars HTTP server on given address with Echo as a handler serving requests. The server can be shutdown by
+// sending os.Interrupt signal with `ctrl+c`. Method returns only errors that are not http.ErrServerClosed.
+//
+// Note: this method is created for use in examples/demos and is deliberately simple without providing configuration
+// options.
+//
+// In need of customization use:
+//
+// ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
+// defer cancel()
+// sc := echo.StartConfig{Address: ":8080"}
+// if err := sc.Start(ctx, e); err != nil && !errors.Is(err, http.ErrServerClosed) {
+// slog.Error(err.Error())
+// }
+//
+// // or standard library `http.Server`
+//
+// s := http.Server{Addr: ":8080", Handler: e}
+// if err := s.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
+// slog.Error(err.Error())
+// }
+func (e *Echo) Start(address string) error {
+ sc := StartConfig{Address: address}
+ ctx, cancel := signal.NotifyContext(stdContext.Background(), os.Interrupt, syscall.SIGTERM) // start shutdown process on ctrl+c
+ defer cancel()
+ return sc.Start(ctx, e)
+}
+
+// WrapHandler wraps `http.Handler` into `echo.HandlerFunc`.
+func WrapHandler(h http.Handler) HandlerFunc {
+ return func(c *Context) error {
+ req := c.Request()
+ req.Pattern = c.Path()
+ for _, p := range c.PathValues() {
+ req.SetPathValue(p.Name, p.Value)
+ }
+
+ h.ServeHTTP(c.Response(), req)
+ return nil
}
- return he
}
-// SetCode sets code.
-func (e *HTTPError) SetCode(code int) {
- e.code = code
+// WrapMiddleware wraps `func(http.Handler) http.Handler` into `echo.MiddlewareFunc`
+func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc {
+ return func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) (err error) {
+ req := c.Request()
+ req.Pattern = c.Path()
+ for _, p := range c.PathValues() {
+ req.SetPathValue(p.Name, p.Value)
+ }
+
+ m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ c.SetRequest(r)
+ c.SetResponse(NewResponse(w, c.echo.Logger))
+ err = next(c)
+ })).ServeHTTP(c.Response(), req)
+ return
+ }
+ }
}
-// Code returns code.
-func (e *HTTPError) Code() int {
- return e.code
+func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc {
+ for i := len(middleware) - 1; i >= 0; i-- {
+ h = middleware[i](h)
+ }
+ return h
}
-// Error returns message.
-func (e *HTTPError) Error() string {
- return e.message
+// defaultFS emulates os.Open behaviour with filesystem opened by `os.DirFs`. Difference between `os.Open` and `fs.Open`
+// is that FS does not allow to open path that start with `..` or `/` etc. For example previously you could have `../images`
+// in your application but `fs := os.DirFS("./")` would not allow you to use `fs.Open("../images")` and this would break
+// all old applications that rely on being able to traverse up from current executable run path.
+// NB: private because you really should use fs.FS implementation instances
+type defaultFS struct {
+ fs fs.FS
+ prefix string
}
-// wrapMiddleware wraps middleware.
-func wrapMiddleware(m Middleware) MiddlewareFunc {
- switch m := m.(type) {
- case MiddlewareFunc:
- return m
- case func(HandlerFunc) HandlerFunc:
- return m
- case HandlerFunc:
- return wrapHandlerFuncMW(m)
- case func(*Context) error:
- return wrapHandlerFuncMW(m)
- case func(http.Handler) http.Handler:
- return func(h HandlerFunc) HandlerFunc {
- return func(c *Context) (err error) {
- m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- c.response.writer = w
- c.request = r
- err = h(c)
- })).ServeHTTP(c.response.writer, c.request)
- return
- }
- }
- case http.Handler:
- return wrapHTTPHandlerFuncMW(m.ServeHTTP)
- case func(http.ResponseWriter, *http.Request):
- return wrapHTTPHandlerFuncMW(m)
- default:
- panic("echo => unknown middleware")
+func newDefaultFS() *defaultFS {
+ dir, _ := os.Getwd()
+ return &defaultFS{
+ prefix: dir,
+ fs: os.DirFS(dir),
}
}
-// wrapHandlerFuncMW wraps HandlerFunc middleware.
-func wrapHandlerFuncMW(m HandlerFunc) MiddlewareFunc {
- return func(h HandlerFunc) HandlerFunc {
- return func(c *Context) error {
- if err := m(c); err != nil {
- return err
- }
- return h(c)
+func (fs defaultFS) Open(name string) (fs.File, error) {
+ return fs.fs.Open(name)
+}
+
+func subFS(currentFs fs.FS, root string) (fs.FS, error) {
+ root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows
+ if dFS, ok := currentFs.(*defaultFS); ok {
+ // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS.
+ // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we
+ // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs
+ if !filepath.IsAbs(root) {
+ root = filepath.Join(dFS.prefix, root)
}
+ return &defaultFS{
+ prefix: root,
+ fs: os.DirFS(root),
+ }, nil
}
+ return fs.Sub(currentFs, root)
}
-// wrapHTTPHandlerFuncMW wraps http.HandlerFunc middleware.
-func wrapHTTPHandlerFuncMW(m http.HandlerFunc) MiddlewareFunc {
- return func(h HandlerFunc) HandlerFunc {
- return func(c *Context) error {
- if !c.response.committed {
- m.ServeHTTP(c.response.writer, c.request)
- }
- return h(c)
- }
+// MustSubFS creates sub FS from current filesystem or panic on failure.
+// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules.
+//
+// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with
+// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to
+// create sub fs which uses necessary prefix for directory path.
+func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS {
+ subFs, err := subFS(currentFs, fsRoot)
+ if err != nil {
+ panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err))
}
+ return subFs
}
-// wrapHandler wraps handler.
-func wrapHandler(h Handler) HandlerFunc {
- switch h := h.(type) {
- case HandlerFunc:
- return h
- case func(*Context) error:
- return h
- case http.Handler, http.HandlerFunc:
- return func(c *Context) error {
- h.(http.Handler).ServeHTTP(c.response, c.request)
- return nil
- }
- case func(http.ResponseWriter, *http.Request):
- return func(c *Context) error {
- h(c.response, c.request)
- return nil
- }
- default:
- panic("echo => unknown handler")
+func sanitizeURI(uri string) string {
+ // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri
+ // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash
+ if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') {
+ uri = "/" + strings.TrimLeft(uri, `/\`)
}
+ return uri
}
diff --git a/echo_test.go b/echo_test.go
index 00d4d7a56..f26eed8e2 100644
--- a/echo_test.go
+++ b/echo_test.go
@@ -1,168 +1,469 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package echo
import (
"bytes"
+ stdContext "context"
+ "errors"
"fmt"
+ "io/fs"
+ "log/slog"
+ "net"
"net/http"
"net/http/httptest"
- "testing"
-
- "reflect"
+ "net/url"
+ "os"
+ "runtime"
"strings"
-
- "errors"
+ "testing"
+ "time"
"github.com/stretchr/testify/assert"
- "golang.org/x/net/websocket"
)
-type (
- user struct {
- ID string `json:"id" xml:"id"`
- Name string `json:"name" xml:"name"`
- }
+type user struct {
+ ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"`
+ Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"`
+}
+
+const (
+ userJSON = `{"id":1,"name":"Jon Snow"}`
+ usersJSON = `[{"id":1,"name":"Jon Snow"}]`
+ userXML = `1Jon Snow`
+ userForm = `id=1&name=Jon Snow`
+ invalidContent = "invalid content"
+ userJSONInvalidType = `{"id":"1","name":"Jon Snow"}`
+ userXMLConvertNumberError = `Number oneJon Snow`
+ userXMLUnsupportedTypeError = `<>Number one>Jon Snow`
)
+const userJSONPretty = `{
+ "id": 1,
+ "name": "Jon Snow"
+}`
+
+const userXMLPretty = `
+ 1
+ Jon Snow
+`
+
+var dummyQuery = url.Values{"dummy": []string{"useless"}}
+
func TestEcho(t *testing.T) {
e := New()
- req, _ := http.NewRequest(GET, "/", nil)
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
- c := NewContext(req, NewResponse(rec), e)
+ c := e.NewContext(req, rec)
// Router
assert.NotNil(t, e.Router())
- // Debug
- e.SetDebug(true)
- assert.True(t, e.Debug())
+ e.HTTPErrorHandler(c, errors.New("error"))
- // DefaultHTTPErrorHandler
- e.DefaultHTTPErrorHandler(errors.New("error"), c)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
-func TestEchoIndex(t *testing.T) {
- e := New()
- e.Index("examples/website/public/index.html")
- c, b := request(GET, "/", e)
- assert.Equal(t, http.StatusOK, c)
- assert.NotEmpty(t, b)
+func TestNewWithConfig(t *testing.T) {
+ e := NewWithConfig(Config{})
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e.GET("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "Hello, World!")
+ })
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusTeapot, rec.Code)
+ assert.Equal(t, `Hello, World!`, rec.Body.String())
}
-func TestEchoFavicon(t *testing.T) {
- e := New()
- e.Favicon("examples/website/public/favicon.ico")
- c, b := request(GET, "/favicon.ico", e)
- assert.Equal(t, http.StatusOK, c)
- assert.NotEmpty(t, b)
+func TestEcho_StaticFS(t *testing.T) {
+ var testCases = []struct {
+ givenFs fs.FS
+ name string
+ givenPrefix string
+ givenFsRoot string
+ whenURL string
+ expectHeaderLocation string
+ expectBodyStartsWith string
+ expectStatus int
+ }{
+ {
+ name: "ok",
+ givenPrefix: "/images",
+ givenFs: os.DirFS("./_fixture/images"),
+ whenURL: "/images/walle.png",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
+ },
+ {
+ name: "ok, from sub fs",
+ givenPrefix: "/images",
+ givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"),
+ whenURL: "/images/walle.png",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
+ },
+ {
+ name: "No file",
+ givenPrefix: "/images",
+ givenFs: os.DirFS("_fixture/scripts"),
+ whenURL: "/images/bolt.png",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Directory",
+ givenPrefix: "/images",
+ givenFs: os.DirFS("_fixture/images"),
+ whenURL: "/images/",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Directory Redirect",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture/"),
+ whenURL: "/folder",
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/folder/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Directory Redirect with non-root path",
+ givenPrefix: "/static",
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/static",
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/static/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory 404 (request URL without slash)",
+ givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/folder", // no trailing slash
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Prefixed directory redirect (without slash redirect to slash)",
+ givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/folder", // no trailing slash
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/folder/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Directory with index.html",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory with index.html (prefix ending with slash)",
+ givenPrefix: "/assets/",
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/assets/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory with index.html (prefix ending without slash)",
+ givenPrefix: "/assets",
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/assets/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Sub-directory with index.html",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture"),
+ whenURL: "/folder/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "do not allow directory traversal (backslash - windows separator)",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture/"),
+ whenURL: `/..\\middleware/basic_auth.go`,
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "do not allow directory traversal (slash - unix separator)",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture/"),
+ whenURL: `/../middleware/basic_auth.go`,
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "open redirect vulnerability",
+ givenPrefix: "/",
+ givenFs: os.DirFS("_fixture/"),
+ whenURL: "/open.redirect.hackercom%2f..",
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/open.redirect.hackercom/../", // location starting with `//open` would be very bad
+ expectBodyStartsWith: "",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+
+ tmpFs := tc.givenFs
+ if tc.givenFsRoot != "" {
+ tmpFs = MustSubFS(tmpFs, tc.givenFsRoot)
+ }
+ e.StaticFS(tc.givenPrefix, tmpFs)
+
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ body := rec.Body.String()
+ if tc.expectBodyStartsWith != "" {
+ assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
+ } else {
+ assert.Equal(t, "", body)
+ }
+
+ if tc.expectHeaderLocation != "" {
+ assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
+ } else {
+ _, ok := rec.Result().Header["Location"]
+ assert.False(t, ok)
+ }
+ })
+ }
}
-func TestEchoStatic(t *testing.T) {
+func TestEcho_FileFS(t *testing.T) {
+ var testCases = []struct {
+ whenFS fs.FS
+ name string
+ whenPath string
+ whenFile string
+ givenURL string
+ expectStartsWith []byte
+ expectCode int
+ }{
+ {
+ name: "ok",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "walle.png",
+ givenURL: "/walle",
+ expectCode: http.StatusOK,
+ expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ },
+ {
+ name: "nok, requesting invalid path",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "walle.png",
+ givenURL: "/walle.png",
+ expectCode: http.StatusNotFound,
+ expectStartsWith: []byte(`{"message":"Not Found"}`),
+ },
+ {
+ name: "nok, serving not existent file from filesystem",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "not-existent.png",
+ givenURL: "/walle",
+ expectCode: http.StatusNotFound,
+ expectStartsWith: []byte(`{"message":"Not Found"}`),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS)
+
+ req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectCode, rec.Code)
+
+ body := rec.Body.Bytes()
+ if len(body) > len(tc.expectStartsWith) {
+ body = body[:len(tc.expectStartsWith)]
+ }
+ assert.Equal(t, tc.expectStartsWith, body)
+ })
+ }
+}
+
+func TestEcho_StaticPanic(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenRoot string
+ }{
+ {
+ name: "panics for ../",
+ givenRoot: "../assets",
+ },
+ {
+ name: "panics for /",
+ givenRoot: "/assets",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ e.Filesystem = os.DirFS("./")
+
+ assert.Panics(t, func() {
+ e.Static("../assets", tc.givenRoot)
+ })
+ })
+ }
+}
+
+func TestEchoStaticRedirectIndex(t *testing.T) {
e := New()
- // OK
- e.Static("/scripts", "examples/website/public/scripts")
- c, b := request(GET, "/scripts/main.js", e)
- assert.Equal(t, http.StatusOK, c)
- assert.NotEmpty(t, b)
+ // HandlerFunc
+ ri := e.Static("/static", "_fixture")
+ assert.Equal(t, http.MethodGet, ri.Method)
+ assert.Equal(t, "/static*", ri.Path)
+ assert.Equal(t, "GET:/static*", ri.Name)
+ assert.Equal(t, []string{"*"}, ri.Parameters)
+
+ ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond)
+ defer cancel()
+ addr, err := startOnRandomPort(ctx, e)
+ if err != nil {
+ assert.Fail(t, err.Error())
+ }
- // No file
- e.Static("/scripts", "examples/website/public/scripts")
- c, _ = request(GET, "/scripts/index.js", e)
- assert.Equal(t, http.StatusNotFound, c)
+ code, body, err := doGet(fmt.Sprintf("http://%v/static", addr))
+ assert.NoError(t, err)
+ assert.True(t, strings.HasPrefix(body, ""))
+ assert.Equal(t, http.StatusOK, code)
+}
- // Directory
- e.Static("/scripts", "examples/website/public/scripts")
- c, _ = request(GET, "/scripts", e)
- assert.Equal(t, http.StatusForbidden, c)
+func TestEchoFile(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenPath string
+ givenFile string
+ whenPath string
+ expectStartsWith string
+ expectCode int
+ }{
+ {
+ name: "ok",
+ givenPath: "/walle",
+ givenFile: "_fixture/images/walle.png",
+ whenPath: "/walle",
+ expectCode: http.StatusOK,
+ expectStartsWith: string([]byte{0x89, 0x50, 0x4e}),
+ },
+ {
+ name: "ok with relative path",
+ givenPath: "/",
+ givenFile: "./go.mod",
+ whenPath: "/",
+ expectCode: http.StatusOK,
+ expectStartsWith: "module github.com/labstack/echo/v",
+ },
+ {
+ name: "nok file does not exist",
+ givenPath: "/",
+ givenFile: "./this-file-does-not-exist",
+ whenPath: "/",
+ expectCode: http.StatusNotFound,
+ expectStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ }
- // Directory with index.html
- e.Static("/", "examples/website/public")
- c, r := request(GET, "/", e)
- assert.Equal(t, http.StatusOK, c)
- assert.Equal(t, true, strings.HasPrefix(r, ""))
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New() // we are using echo.defaultFS instance
+ e.File(tc.givenPath, tc.givenFile)
- // Sub-directory with index.html
- c, r = request(GET, "/folder", e)
- assert.Equal(t, http.StatusOK, c)
- assert.Equal(t, "sub directory", r)
+ c, b := request(http.MethodGet, tc.whenPath, e)
+ assert.Equal(t, tc.expectCode, c)
+
+ if len(b) > len(tc.expectStartsWith) {
+ b = b[:len(tc.expectStartsWith)]
+ }
+ assert.Equal(t, tc.expectStartsWith, b)
+ })
+ }
}
func TestEchoMiddleware(t *testing.T) {
e := New()
buf := new(bytes.Buffer)
- // echo.MiddlewareFunc
- e.Use(MiddlewareFunc(func(h HandlerFunc) HandlerFunc {
+ e.Pre(func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
- buf.WriteString("a")
- return h(c)
+ // before route match is found RouteInfo does not exist
+ assert.Equal(t, RouteInfo{}, c.RouteInfo())
+ buf.WriteString("-1")
+ return next(c)
}
- }))
+ })
- // func(echo.HandlerFunc) echo.HandlerFunc
- e.Use(func(h HandlerFunc) HandlerFunc {
+ e.Use(func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
- buf.WriteString("b")
- return h(c)
+ buf.WriteString("1")
+ return next(c)
}
})
- // echo.HandlerFunc
- e.Use(HandlerFunc(func(c *Context) error {
- buf.WriteString("c")
- return nil
- }))
-
- // func(*echo.Context) error
- e.Use(func(c *Context) error {
- buf.WriteString("d")
- return nil
- })
-
- // func(http.Handler) http.Handler
- e.Use(func(h http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- buf.WriteString("e")
- h.ServeHTTP(w, r)
- })
- })
-
- // http.Handler
- e.Use(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- buf.WriteString("f")
- })))
-
- // http.HandlerFunc
- e.Use(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- buf.WriteString("g")
- }))
-
- // func(http.ResponseWriter, *http.Request)
- e.Use(func(w http.ResponseWriter, r *http.Request) {
- buf.WriteString("h")
+ e.Use(func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ buf.WriteString("2")
+ return next(c)
+ }
})
- // Unknown
- assert.Panics(t, func() {
- e.Use(nil)
+ e.Use(func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ buf.WriteString("3")
+ return next(c)
+ }
})
// Route
- e.Get("/", func(c *Context) error {
- return c.String(http.StatusOK, "Hello!")
+ e.GET("/", func(c *Context) error {
+ return c.String(http.StatusOK, "OK")
})
- c, b := request(GET, "/", e)
- assert.Equal(t, "abcdefgh", buf.String())
+ c, b := request(http.MethodGet, "/", e)
+ assert.Equal(t, "-1123", buf.String())
assert.Equal(t, http.StatusOK, c)
- assert.Equal(t, "Hello!", b)
+ assert.Equal(t, "OK", b)
+}
- // Error
- e.Use(func(*Context) error {
- return errors.New("error")
+func TestEchoMiddlewareError(t *testing.T) {
+ e := New()
+ e.Use(func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ return errors.New("error")
+ }
})
- c, b = request(GET, "/", e)
+ e.GET("/", notFoundHandler)
+ c, _ := request(http.MethodGet, "/", e)
assert.Equal(t, http.StatusInternalServerError, c)
}
@@ -170,235 +471,757 @@ func TestEchoHandler(t *testing.T) {
e := New()
// HandlerFunc
- e.Get("/1", HandlerFunc(func(c *Context) error {
- return c.String(http.StatusOK, "1")
- }))
-
- // func(*echo.Context) error
- e.Get("/2", func(c *Context) error {
- return c.String(http.StatusOK, "2")
+ e.GET("/ok", func(c *Context) error {
+ return c.String(http.StatusOK, "OK")
})
- // http.Handler/http.HandlerFunc
- e.Get("/3", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte("3"))
+ c, b := request(http.MethodGet, "/ok", e)
+ assert.Equal(t, http.StatusOK, c)
+ assert.Equal(t, "OK", b)
+}
+
+func TestEchoWrapHandler(t *testing.T) {
+ e := New()
+
+ var actualID string
+ var actualPattern string
+ e.GET("/:id", WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte("test"))
+ actualID = r.PathValue("id")
+ actualPattern = r.Pattern
+ })))
+
+ req := httptest.NewRequest(http.MethodGet, "/123", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "test", rec.Body.String())
+ assert.Equal(t, "123", actualID)
+ assert.Equal(t, "/:id", actualPattern)
+}
+
+func TestEchoWrapMiddleware(t *testing.T) {
+ e := New()
+
+ var actualID string
+ var actualPattern string
+ e.Use(WrapMiddleware(func(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ actualID = r.PathValue("id")
+ actualPattern = r.Pattern
+ h.ServeHTTP(w, r)
+ })
}))
- // func(http.ResponseWriter, *http.Request)
- e.Get("/4", func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte("4"))
+ e.GET("/:id", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
})
- for _, p := range []string{"1", "2", "3", "4"} {
- c, b := request(GET, "/"+p, e)
- assert.Equal(t, http.StatusOK, c)
- assert.Equal(t, p, b)
- }
+ req := httptest.NewRequest(http.MethodGet, "/123", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
- // Unknown
- assert.Panics(t, func() {
- e.Get("/5", nil)
- })
+ assert.Equal(t, http.StatusTeapot, rec.Code)
+ assert.Equal(t, "OK", rec.Body.String())
+ assert.Equal(t, "123", actualID)
+ assert.Equal(t, "/:id", actualPattern)
}
func TestEchoConnect(t *testing.T) {
e := New()
- testMethod(t, CONNECT, "/", e)
+
+ ri := e.CONNECT("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodConnect, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodConnect+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodConnect, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoDelete(t *testing.T) {
e := New()
- testMethod(t, DELETE, "/", e)
+
+ ri := e.DELETE("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodDelete, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodDelete+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodDelete, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoGet(t *testing.T) {
e := New()
- testMethod(t, GET, "/", e)
+
+ ri := e.GET("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodGet, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodGet+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodGet, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoHead(t *testing.T) {
e := New()
- testMethod(t, HEAD, "/", e)
+
+ ri := e.HEAD("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodHead, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodHead+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodHead, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoOptions(t *testing.T) {
e := New()
- testMethod(t, OPTIONS, "/", e)
+
+ ri := e.OPTIONS("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodOptions, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodOptions+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodOptions, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoPatch(t *testing.T) {
e := New()
- testMethod(t, PATCH, "/", e)
+
+ ri := e.PATCH("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPatch, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodPatch+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodPatch, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoPost(t *testing.T) {
e := New()
- testMethod(t, POST, "/", e)
+
+ ri := e.POST("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPost, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodPost+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodPost, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoPut(t *testing.T) {
e := New()
- testMethod(t, PUT, "/", e)
+
+ ri := e.PUT("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPut, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodPut+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodPut, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
func TestEchoTrace(t *testing.T) {
e := New()
- testMethod(t, TRACE, "/", e)
+
+ ri := e.TRACE("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodTrace, ri.Method)
+ assert.Equal(t, "/", ri.Path)
+ assert.Equal(t, http.MethodTrace+":/", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodTrace, "/", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, "OK", body)
}
-func TestEchoWebSocket(t *testing.T) {
+func TestEcho_Any(t *testing.T) {
e := New()
- e.WebSocket("/ws", func(c *Context) error {
- c.socket.Write([]byte("test"))
- return nil
+
+ ri := e.Any("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK from ANY")
})
- srv := httptest.NewServer(e)
- defer srv.Close()
- addr := srv.Listener.Addr().String()
- origin := "http://localhost"
- url := fmt.Sprintf("ws://%s/ws", addr)
- ws, err := websocket.Dial(url, "", origin)
- if assert.NoError(t, err) {
- ws.Write([]byte("test"))
- defer ws.Close()
- buf := new(bytes.Buffer)
- buf.ReadFrom(ws)
- assert.Equal(t, "test", buf.String())
- }
+
+ assert.Equal(t, RouteAny, ri.Method)
+ assert.Equal(t, "/activate", ri.Path)
+ assert.Equal(t, RouteAny+":/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodTrace, "/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK from ANY`, body)
}
-func TestEchoURL(t *testing.T) {
+func TestEcho_Any_hasLowerPriority(t *testing.T) {
e := New()
- static := func(*Context) error { return nil }
- getUser := func(*Context) error { return nil }
- getFile := func(*Context) error { return nil }
+ e.Any("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "ANY")
+ })
+ e.GET("/activate", func(c *Context) error {
+ return c.String(http.StatusLocked, "GET")
+ })
- e.Get("/static/file", static)
- e.Get("/users/:id", getUser)
- g := e.Group("/group")
- g.Get("/users/:uid/files/:fid", getFile)
+ status, body := request(http.MethodTrace, "/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `ANY`, body)
- assert.Equal(t, "/static/file", e.URL(static))
- assert.Equal(t, "/users/:id", e.URL(getUser))
- assert.Equal(t, "/users/1", e.URL(getUser, "1"))
- assert.Equal(t, "/group/users/1/files/:fid", e.URL(getFile, "1"))
- assert.Equal(t, "/group/users/1/files/1", e.URL(getFile, "1", "1"))
+ status, body = request(http.MethodGet, "/activate", e)
+ assert.Equal(t, http.StatusLocked, status)
+ assert.Equal(t, `GET`, body)
}
-func TestEchoRoutes(t *testing.T) {
+func TestEchoMatch(t *testing.T) { // JFC
e := New()
- h := func(*Context) error { return nil }
- routes := []Route{
- {GET, "/users/:user/events", h},
- {GET, "/users/:user/events/public", h},
- {POST, "/repos/:owner/:repo/git/refs", h},
- {POST, "/repos/:owner/:repo/git/tags", h},
- }
- for _, r := range routes {
- e.add(r.Method, r.Path, h)
+ ris := e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c *Context) error {
+ return c.String(http.StatusOK, "Match")
+ })
+ assert.Len(t, ris, 2)
+}
+
+func TestEchoServeHTTPPathEncoding(t *testing.T) {
+ e := New()
+ e.GET("/with/slash", func(c *Context) error {
+ return c.String(http.StatusOK, "/with/slash")
+ })
+ e.GET("/:id", func(c *Context) error {
+ return c.String(http.StatusOK, c.Param("id"))
+ })
+
+ var testCases = []struct {
+ name string
+ whenURL string
+ expectURL string
+ expectStatus int
+ }{
+ {
+ name: "url with encoding is not decoded for routing",
+ whenURL: "/with%2Fslash",
+ expectURL: "with%2Fslash", // `%2F` is not decoded to `/` for routing
+ expectStatus: http.StatusOK,
+ },
+ {
+ name: "url without encoding is used as is",
+ whenURL: "/with/slash",
+ expectURL: "/with/slash",
+ expectStatus: http.StatusOK,
+ },
}
- for i, r := range e.Routes() {
- assert.Equal(t, routes[i].Method, r.Method)
- assert.Equal(t, routes[i].Path, r.Path)
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ assert.Equal(t, tc.expectURL, rec.Body.String())
+ })
}
}
func TestEchoGroup(t *testing.T) {
e := New()
buf := new(bytes.Buffer)
- e.Use(func(*Context) error {
- buf.WriteString("0")
- return nil
- })
- h := func(*Context) error { return nil }
+ e.Use(MiddlewareFunc(func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ buf.WriteString("0")
+ return next(c)
+ }
+ }))
+ h := func(c *Context) error {
+ return c.NoContent(http.StatusOK)
+ }
//--------
// Routes
//--------
- e.Get("/users", h)
+ e.GET("/users", h)
// Group
g1 := e.Group("/group1")
- g1.Use(func(*Context) error {
- buf.WriteString("1")
- return nil
+ g1.Use(func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ buf.WriteString("1")
+ return next(c)
+ }
})
- g1.Get("/", h)
+ g1.GET("", h)
- // Group with no parent middleware
- g2 := e.Group("/group2", func(*Context) error {
- buf.WriteString("2")
- return nil
+ // Nested groups with middleware
+ g2 := e.Group("/group2")
+ g2.Use(func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ buf.WriteString("2")
+ return next(c)
+ }
})
- g2.Get("/", h)
-
- // Nested groups
- g3 := e.Group("/group3")
- g4 := g3.Group("/group4")
- g4.Get("/", func(c *Context) error {
- return c.NoContent(http.StatusOK)
+ g3 := g2.Group("/group3")
+ g3.Use(func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ buf.WriteString("3")
+ return next(c)
+ }
})
+ g3.GET("", h)
- request(GET, "/users", e)
- // println(len(e.middleware))
+ request(http.MethodGet, "/users", e)
assert.Equal(t, "0", buf.String())
buf.Reset()
- request(GET, "/group1/", e)
- // println(len(g1.echo.middleware))
+ request(http.MethodGet, "/group1", e)
assert.Equal(t, "01", buf.String())
buf.Reset()
- request(GET, "/group2/", e)
- assert.Equal(t, "2", buf.String())
+ request(http.MethodGet, "/group2/group3", e)
+ assert.Equal(t, "023", buf.String())
+}
- buf.Reset()
- c, _ := request(GET, "/group3/group4/", e)
- assert.Equal(t, http.StatusOK, c)
+func TestEcho_RouteNotFound(t *testing.T) {
+ var testCases = []struct {
+ expectRoute any
+ name string
+ whenURL string
+ expectCode int
+ }{
+ {
+ name: "404, route to static not found handler /a/c/xx",
+ whenURL: "/a/c/xx",
+ expectRoute: "GET /a/c/xx",
+ expectCode: http.StatusNotFound,
+ },
+ {
+ name: "404, route to path param not found handler /a/:file",
+ whenURL: "/a/echo.exe",
+ expectRoute: "GET /a/:file",
+ expectCode: http.StatusNotFound,
+ },
+ {
+ name: "404, route to any not found handler /*",
+ whenURL: "/b/echo.exe",
+ expectRoute: "GET /*",
+ expectCode: http.StatusNotFound,
+ },
+ {
+ name: "200, route /a/c/df to /a/c/df",
+ whenURL: "/a/c/df",
+ expectRoute: "GET /a/c/df",
+ expectCode: http.StatusOK,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+
+ okHandler := func(c *Context) error {
+ return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
+ }
+ notFoundHandler := func(c *Context) error {
+ return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
+ }
+
+ e.GET("/", okHandler)
+ e.GET("/a/c/df", okHandler)
+ e.GET("/a/b*", okHandler)
+ e.PUT("/*", okHandler)
+
+ e.RouteNotFound("/a/c/xx", notFoundHandler) // static
+ e.RouteNotFound("/a/:file", notFoundHandler) // param
+ e.RouteNotFound("/*", notFoundHandler) // any
+
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectCode, rec.Code)
+ assert.Equal(t, tc.expectRoute, rec.Body.String())
+ })
+ }
}
func TestEchoNotFound(t *testing.T) {
e := New()
- r, _ := http.NewRequest(GET, "/files", nil)
- w := httptest.NewRecorder()
- e.ServeHTTP(w, r)
- assert.Equal(t, http.StatusNotFound, w.Code)
+ req := httptest.NewRequest(http.MethodGet, "/files", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestEchoMethodNotAllowed(t *testing.T) {
+ e := New()
+
+ e.GET("/", func(c *Context) error {
+ return c.String(http.StatusOK, "Echo!")
+ })
+ req := httptest.NewRequest(http.MethodPost, "/", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusMethodNotAllowed, rec.Code)
+ assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow))
+}
+
+func TestEcho_OnAddRoute(t *testing.T) {
+ exampleRoute := Route{
+ Method: http.MethodGet,
+ Path: "/api/files/:id",
+ Handler: notFoundHandler,
+ Middlewares: nil,
+ Name: "x",
+ }
+
+ var testCases = []struct {
+ whenRoute Route
+ whenError error
+ name string
+ expectError string
+ expectAdded []string
+ expectLen int
+ }{
+ {
+ name: "ok",
+ whenRoute: exampleRoute,
+ whenError: nil,
+ expectAdded: []string{"/static", "/api/files/:id"},
+ expectError: "",
+ expectLen: 2,
+ },
+ {
+ name: "nok, error is returned",
+ whenRoute: exampleRoute,
+ whenError: errors.New("nope"),
+ expectAdded: []string{"/static"},
+ expectError: "nope",
+ expectLen: 1,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+
+ e := New()
+
+ added := make([]string, 0)
+ cnt := 0
+ e.OnAddRoute = func(route Route) error {
+ if cnt > 0 && tc.whenError != nil { // we want to GET /static to succeed for nok tests
+ return tc.whenError
+ }
+ cnt++
+ added = append(added, route.Path)
+ return nil
+ }
+
+ e.GET("/static", notFoundHandler)
+
+ var err error
+ _, err = e.AddRoute(tc.whenRoute)
+
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ assert.Len(t, e.Router().Routes(), tc.expectLen)
+ assert.Equal(t, tc.expectAdded, added)
+ })
+ }
}
-func TestEchoHTTPError(t *testing.T) {
- m := http.StatusText(http.StatusBadRequest)
- he := NewHTTPError(http.StatusBadRequest, m)
- assert.Equal(t, http.StatusBadRequest, he.Code())
- assert.Equal(t, m, he.Error())
+func TestEchoContext(t *testing.T) {
+ e := New()
+ c := e.AcquireContext()
+ assert.IsType(t, new(Context), c)
+ e.ReleaseContext(c)
}
-func TestEchoServer(t *testing.T) {
+func TestPreMiddlewares(t *testing.T) {
e := New()
- s := e.Server(":1323")
- assert.IsType(t, &http.Server{}, s)
+ assert.Equal(t, 0, len(e.PreMiddlewares()))
+
+ e.Pre(func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ return next(c)
+ }
+ })
+
+ assert.Equal(t, 1, len(e.PreMiddlewares()))
}
-func testMethod(t *testing.T, method, path string, e *Echo) {
- m := fmt.Sprintf("%c%s", method[0], strings.ToLower(method[1:]))
- p := reflect.ValueOf(path)
- h := reflect.ValueOf(func(c *Context) error {
- c.String(http.StatusOK, method)
- return nil
+func TestMiddlewares(t *testing.T) {
+ e := New()
+ assert.Equal(t, 0, len(e.Middlewares()))
+
+ e.Use(func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ return next(c)
+ }
})
- i := interface{}(e)
- reflect.ValueOf(i).MethodByName(m).Call([]reflect.Value{p, h})
- _, body := request(method, path, e)
- if body != method {
- t.Errorf("expected body `%s`, got %s.", method, body)
+
+ assert.Equal(t, 1, len(e.Middlewares()))
+}
+
+func TestEcho_Start(t *testing.T) {
+ e := New()
+ e.GET("/", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+ rndPort, err := net.Listen("tcp", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer rndPort.Close()
+ errChan := make(chan error, 1)
+ go func() {
+ errChan <- e.Start(rndPort.Addr().String())
+ }()
+
+ select {
+ case <-time.After(250 * time.Millisecond):
+ t.Fatal("start did not error out")
+ case err := <-errChan:
+ expectContains := "bind: address already in use"
+ if runtime.GOOS == "windows" {
+ expectContains = "bind: Only one usage of each socket address"
+ }
+ assert.Contains(t, err.Error(), expectContains)
}
}
func request(method, path string, e *Echo) (int, string) {
- r, _ := http.NewRequest(method, path, nil)
+ req := httptest.NewRequest(method, path, nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ return rec.Code, rec.Body.String()
+}
+
+type customError struct {
+ Code int
+ Message string
+}
+
+func (ce *customError) StatusCode() int {
+ return ce.Code
+}
+
+func (ce *customError) MarshalJSON() ([]byte, error) {
+ return []byte(fmt.Sprintf(`{"x":"%v"}`, ce.Message)), nil
+}
+
+func (ce *customError) Error() string {
+ return ce.Message
+}
+
+func TestDefaultHTTPErrorHandler(t *testing.T) {
+ var testCases = []struct {
+ whenError error
+ name string
+ whenMethod string
+ expectBody string
+ expectLogged string
+ expectStatus int
+ givenExposeError bool
+ givenLoggerFunc bool
+ }{
+ {
+ name: "ok, expose error = true, HTTPError, no wrapped err",
+ givenExposeError: true,
+ whenError: &HTTPError{Code: http.StatusTeapot, Message: "my_error"},
+ expectStatus: http.StatusTeapot,
+ expectBody: `{"message":"my_error"}` + "\n",
+ },
+ {
+ name: "ok, expose error = true, HTTPError + wrapped error",
+ givenExposeError: true,
+ whenError: HTTPError{Code: http.StatusTeapot, Message: "my_error"}.Wrap(errors.New("internal_error")),
+ expectStatus: http.StatusTeapot,
+ expectBody: `{"error":"internal_error","message":"my_error"}` + "\n",
+ },
+ {
+ name: "ok, expose error = true, HTTPError + wrapped HTTPError",
+ givenExposeError: true,
+ whenError: HTTPError{Code: http.StatusTeapot, Message: "my_error"}.Wrap(&HTTPError{Code: http.StatusTeapot, Message: "early_error"}),
+ expectStatus: http.StatusTeapot,
+ expectBody: `{"error":"code=418, message=early_error","message":"my_error"}` + "\n",
+ },
+ {
+ name: "ok, expose error = false, HTTPError",
+ whenError: &HTTPError{Code: http.StatusTeapot, Message: "my_error"},
+ expectStatus: http.StatusTeapot,
+ expectBody: `{"message":"my_error"}` + "\n",
+ },
+ {
+ name: "ok, expose error = false, HTTPError, no message",
+ whenError: &HTTPError{Code: http.StatusTeapot, Message: ""},
+ expectStatus: http.StatusTeapot,
+ expectBody: `{"message":"I'm a teapot"}` + "\n",
+ },
+ {
+ name: "ok, expose error = false, HTTPError + internal HTTPError",
+ whenError: HTTPError{Code: http.StatusTooEarly, Message: "my_error"}.Wrap(&HTTPError{Code: http.StatusTeapot, Message: "early_error"}),
+ expectStatus: http.StatusTooEarly,
+ expectBody: `{"message":"my_error"}` + "\n",
+ },
+ {
+ name: "ok, expose error = true, Error",
+ givenExposeError: true,
+ whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")),
+ expectStatus: http.StatusInternalServerError,
+ expectBody: `{"error":"my errors wraps: internal_error","message":"Internal Server Error"}` + "\n",
+ },
+ {
+ name: "ok, expose error = false, Error",
+ whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")),
+ expectStatus: http.StatusInternalServerError,
+ expectBody: `{"message":"Internal Server Error"}` + "\n",
+ },
+ {
+ name: "ok, http.HEAD, expose error = true, Error",
+ givenExposeError: true,
+ whenMethod: http.MethodHead,
+ whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")),
+ expectStatus: http.StatusInternalServerError,
+ expectBody: ``,
+ },
+ {
+ name: "ok, custom error implement MarshalJSON + HTTPStatusCoder",
+ whenMethod: http.MethodGet,
+ whenError: &customError{Code: http.StatusTeapot, Message: "custom error msg"},
+ expectStatus: http.StatusTeapot,
+ expectBody: `{"x":"custom error msg"}` + "\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ buf := new(bytes.Buffer)
+ e := New()
+ e.Logger = slog.New(slog.DiscardHandler)
+ e.Any("/path", func(c *Context) error {
+ return tc.whenError
+ })
+
+ e.HTTPErrorHandler = DefaultHTTPErrorHandler(tc.givenExposeError)
+
+ method := http.MethodGet
+ if tc.whenMethod != "" {
+ method = tc.whenMethod
+ }
+ c, b := request(method, "/path", e)
+
+ assert.Equal(t, tc.expectStatus, c)
+ assert.Equal(t, tc.expectBody, b)
+ assert.Equal(t, tc.expectLogged, buf.String())
+ })
+ }
+}
+
+func TestDefaultHTTPErrorHandler_CommitedResponse(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ resp := httptest.NewRecorder()
+ c := e.NewContext(req, resp)
+
+ c.orgResponse.Committed = true
+ errHandler := DefaultHTTPErrorHandler(false)
+
+ errHandler(c, errors.New("my_error"))
+ assert.Equal(t, http.StatusOK, resp.Code)
+}
+
+func benchmarkEchoRoutes(b *testing.B, routes []testRoute) {
+ e := New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ u := req.URL
w := httptest.NewRecorder()
- e.ServeHTTP(w, r)
- return w.Code, w.Body.String()
+
+ b.ReportAllocs()
+
+ // Add routes
+ for _, route := range routes {
+ e.Add(route.Method, route.Path, func(c *Context) error {
+ return nil
+ })
+ }
+
+ // Find routes
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ for _, route := range routes {
+ req.Method = route.Method
+ u.Path = route.Path
+ e.ServeHTTP(w, req)
+ }
+ }
+}
+
+func BenchmarkEchoStaticRoutes(b *testing.B) {
+ benchmarkEchoRoutes(b, staticRoutes)
+}
+
+func BenchmarkEchoStaticRoutesMisses(b *testing.B) {
+ benchmarkEchoRoutes(b, staticRoutes)
+}
+
+func BenchmarkEchoGitHubAPI(b *testing.B) {
+ benchmarkEchoRoutes(b, gitHubAPI)
+}
+
+func BenchmarkEchoGitHubAPIMisses(b *testing.B) {
+ benchmarkEchoRoutes(b, gitHubAPI)
+}
+
+func BenchmarkEchoParseAPI(b *testing.B) {
+ benchmarkEchoRoutes(b, parseAPI)
}
diff --git a/echotest/context.go b/echotest/context.go
new file mode 100644
index 000000000..2f665705d
--- /dev/null
+++ b/echotest/context.go
@@ -0,0 +1,183 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echotest
+
+import (
+ "bytes"
+ "io"
+ "mime/multipart"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+)
+
+// ContextConfig is configuration for creating echo.Context for testing purposes.
+type ContextConfig struct {
+ // Request will be used instead of default `httptest.NewRequest(http.MethodGet, "/", nil)`
+ Request *http.Request
+
+ // Response will be used instead of default `httptest.NewRecorder()`
+ Response *httptest.ResponseRecorder
+
+ // QueryValues wil be set as Request.URL.RawQuery value
+ QueryValues url.Values
+
+ // Headers wil be set as Request.Header value
+ Headers http.Header
+
+ // PathValues initializes context.PathValues with given value.
+ PathValues echo.PathValues
+
+ // RouteInfo initializes context.RouteInfo() with given value
+ RouteInfo *echo.RouteInfo
+
+ // FormValues creates form-urlencoded form out of given values. If there is no
+ // `content-type` header it will be set to `application/x-www-form-urlencoded`
+ // In case Request was not set the Request.Method is set to `POST`
+ //
+ // FormValues, MultipartForm and JSONBody are mutually exclusive.
+ FormValues url.Values
+
+ // MultipartForm creates multipart form out of given value. If there is no
+ // `content-type` header it will be set to `multipart/form-data`
+ // In case Request was not set the Request.Method is set to `POST`
+ //
+ // FormValues, MultipartForm and JSONBody are mutually exclusive.
+ MultipartForm *MultipartForm
+
+ // JSONBody creates JSON body out of given bytes. If there is no
+ // `content-type` header it will be set to `application/json`
+ // In case Request was not set the Request.Method is set to `POST`
+ //
+ // FormValues, MultipartForm and JSONBody are mutually exclusive.
+ JSONBody []byte
+}
+
+// MultipartForm is used to create multipart form out of given value
+type MultipartForm struct {
+ Fields map[string]string
+ Files []MultipartFormFile
+}
+
+// MultipartFormFile is used to create file in multipart form out of given value
+type MultipartFormFile struct {
+ Fieldname string
+ Filename string
+ Content []byte
+}
+
+// ToContext converts ContextConfig to echo.Context
+func (conf ContextConfig) ToContext(t *testing.T) *echo.Context {
+ c, _ := conf.ToContextRecorder(t)
+ return c
+}
+
+// ToContextRecorder converts ContextConfig to echo.Context and httptest.ResponseRecorder
+func (conf ContextConfig) ToContextRecorder(t *testing.T) (*echo.Context, *httptest.ResponseRecorder) {
+ if conf.Response == nil {
+ conf.Response = httptest.NewRecorder()
+ }
+ isDefaultRequest := false
+ if conf.Request == nil {
+ isDefaultRequest = true
+ conf.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+ }
+
+ if len(conf.QueryValues) > 0 {
+ conf.Request.URL.RawQuery = conf.QueryValues.Encode()
+ }
+ if len(conf.Headers) > 0 {
+ conf.Request.Header = conf.Headers
+ }
+ if len(conf.FormValues) > 0 {
+ body := strings.NewReader(url.Values(conf.FormValues).Encode())
+ conf.Request.Body = io.NopCloser(body)
+ conf.Request.ContentLength = int64(body.Len())
+
+ if conf.Request.Header.Get(echo.HeaderContentType) == "" {
+ conf.Request.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
+ }
+ if isDefaultRequest {
+ conf.Request.Method = http.MethodPost
+ }
+ } else if conf.MultipartForm != nil {
+ var body bytes.Buffer
+ mw := multipart.NewWriter(&body)
+ for field, value := range conf.MultipartForm.Fields {
+ if err := mw.WriteField(field, value); err != nil {
+ t.Fatal(err)
+ }
+ }
+ for _, file := range conf.MultipartForm.Files {
+ fw, err := mw.CreateFormFile(file.Fieldname, file.Filename)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err = fw.Write(file.Content); err != nil {
+ t.Fatal(err)
+ }
+ }
+ if err := mw.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ conf.Request.Body = io.NopCloser(&body)
+ conf.Request.ContentLength = int64(body.Len())
+ if conf.Request.Header.Get(echo.HeaderContentType) == "" {
+ conf.Request.Header.Set(echo.HeaderContentType, mw.FormDataContentType())
+ }
+ if isDefaultRequest {
+ conf.Request.Method = http.MethodPost
+ }
+ } else if conf.JSONBody != nil {
+ body := bytes.NewReader(conf.JSONBody)
+ conf.Request.Body = io.NopCloser(body)
+ conf.Request.ContentLength = int64(body.Len())
+
+ if conf.Request.Header.Get(echo.HeaderContentType) == "" {
+ conf.Request.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
+ }
+ if isDefaultRequest {
+ conf.Request.Method = http.MethodPost
+ }
+ }
+
+ ec := echo.NewContext(conf.Request, conf.Response, echo.New())
+ if conf.RouteInfo == nil {
+ conf.RouteInfo = &echo.RouteInfo{
+ Name: "",
+ Method: conf.Request.Method,
+ Path: "/test",
+ Parameters: []string{},
+ }
+ for _, p := range conf.PathValues {
+ conf.RouteInfo.Parameters = append(conf.RouteInfo.Parameters, p.Name)
+ }
+ }
+ ec.InitializeRoute(conf.RouteInfo, &conf.PathValues)
+ return ec, conf.Response
+}
+
+// ServeWithHandler serves ContextConfig with given handler and returns httptest.ResponseRecorder for response checking
+func (conf ContextConfig) ServeWithHandler(t *testing.T, handler echo.HandlerFunc, opts ...any) *httptest.ResponseRecorder {
+ c, rec := conf.ToContextRecorder(t)
+
+ errHandler := echo.DefaultHTTPErrorHandler(false)
+ for _, opt := range opts {
+ switch o := opt.(type) {
+ case echo.HTTPErrorHandler:
+ errHandler = o
+ }
+ }
+
+ err := handler(c)
+ if err != nil {
+ errHandler(c, err)
+ }
+ return rec
+}
diff --git a/echotest/context_external_test.go b/echotest/context_external_test.go
new file mode 100644
index 000000000..d98257148
--- /dev/null
+++ b/echotest/context_external_test.go
@@ -0,0 +1,27 @@
+package echotest_test
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/labstack/echo/v5/echotest"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestToContext_JSONBody(t *testing.T) {
+ c := echotest.ContextConfig{
+ JSONBody: echotest.LoadBytes(t, "testdata/test.json"),
+ }.ToContext(t)
+
+ payload := struct {
+ Field string `json:"field"`
+ }{}
+ if err := c.Bind(&payload); err != nil {
+ t.Fatal(err)
+ }
+
+ assert.Equal(t, "value", payload.Field)
+ assert.Equal(t, http.MethodPost, c.Request().Method)
+ assert.Equal(t, echo.MIMEApplicationJSON, c.Request().Header.Get(echo.HeaderContentType))
+}
diff --git a/echotest/context_test.go b/echotest/context_test.go
new file mode 100644
index 000000000..66815e4b0
--- /dev/null
+++ b/echotest/context_test.go
@@ -0,0 +1,157 @@
+package echotest
+
+import (
+ "net/http"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestServeWithHandler(t *testing.T) {
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, c.QueryParam("key"))
+ }
+ testConf := ContextConfig{
+ QueryValues: url.Values{"key": []string{"value"}},
+ }
+
+ resp := testConf.ServeWithHandler(t, handler)
+
+ assert.Equal(t, http.StatusOK, resp.Code)
+ assert.Equal(t, "value", resp.Body.String())
+}
+
+func TestServeWithHandler_error(t *testing.T) {
+ handler := func(c *echo.Context) error {
+ return echo.NewHTTPError(http.StatusBadRequest, "something went wrong")
+ }
+ testConf := ContextConfig{
+ QueryValues: url.Values{"key": []string{"value"}},
+ }
+
+ customErrHandler := echo.DefaultHTTPErrorHandler(true)
+
+ resp := testConf.ServeWithHandler(t, handler, customErrHandler)
+
+ assert.Equal(t, http.StatusBadRequest, resp.Code)
+ assert.Equal(t, `{"message":"something went wrong"}`+"\n", resp.Body.String())
+}
+
+func TestToContext_QueryValues(t *testing.T) {
+ testConf := ContextConfig{
+ QueryValues: url.Values{"t": []string{"2006-01-02"}},
+ }
+ c := testConf.ToContext(t)
+
+ v, err := echo.QueryParam[string](c, "t")
+
+ assert.NoError(t, err)
+ assert.Equal(t, "2006-01-02", v)
+}
+
+func TestToContext_Headers(t *testing.T) {
+ testConf := ContextConfig{
+ Headers: http.Header{echo.HeaderXRequestID: []string{"ABC"}},
+ }
+ c := testConf.ToContext(t)
+
+ id := c.Request().Header.Get(echo.HeaderXRequestID)
+
+ assert.Equal(t, "ABC", id)
+}
+
+func TestToContext_PathValues(t *testing.T) {
+ testConf := ContextConfig{
+ PathValues: echo.PathValues{{
+ Name: "key",
+ Value: "value",
+ }},
+ }
+ c := testConf.ToContext(t)
+
+ key := c.Param("key")
+
+ assert.Equal(t, "value", key)
+}
+
+func TestToContext_RouteInfo(t *testing.T) {
+ testConf := ContextConfig{
+ RouteInfo: &echo.RouteInfo{
+ Name: "my_route",
+ Method: http.MethodGet,
+ Path: "/:id",
+ Parameters: []string{"id"},
+ },
+ }
+ c := testConf.ToContext(t)
+
+ ri := c.RouteInfo()
+
+ assert.Equal(t, echo.RouteInfo{
+ Name: "my_route",
+ Method: http.MethodGet,
+ Path: "/:id",
+ Parameters: []string{"id"},
+ }, ri)
+}
+
+func TestToContext_FormValues(t *testing.T) {
+ testConf := ContextConfig{
+ FormValues: url.Values{"key": []string{"value"}},
+ }
+ c := testConf.ToContext(t)
+
+ assert.Equal(t, "value", c.FormValue("key"))
+ assert.Equal(t, http.MethodPost, c.Request().Method)
+ assert.Equal(t, echo.MIMEApplicationForm, c.Request().Header.Get(echo.HeaderContentType))
+}
+
+func TestToContext_MultipartForm(t *testing.T) {
+ testConf := ContextConfig{
+ MultipartForm: &MultipartForm{
+ Fields: map[string]string{
+ "key": "value",
+ },
+ Files: []MultipartFormFile{
+ {
+ Fieldname: "file",
+ Filename: "test.json",
+ Content: LoadBytes(t, "testdata/test.json"),
+ },
+ },
+ },
+ }
+ c := testConf.ToContext(t)
+
+ assert.Equal(t, "value", c.FormValue("key"))
+ assert.Equal(t, http.MethodPost, c.Request().Method)
+ assert.Equal(t, true, strings.HasPrefix(c.Request().Header.Get(echo.HeaderContentType), "multipart/form-data; boundary="))
+
+ fv, err := c.FormFile("file")
+ if err != nil {
+ t.Fatal(err)
+ }
+ assert.Equal(t, "test.json", fv.Filename)
+ assert.Equal(t, int64(23), fv.Size)
+}
+
+func TestToContext_JSONBody(t *testing.T) {
+ testConf := ContextConfig{
+ JSONBody: LoadBytes(t, "testdata/test.json"),
+ }
+ c := testConf.ToContext(t)
+
+ payload := struct {
+ Field string `json:"field"`
+ }{}
+ if err := c.Bind(&payload); err != nil {
+ t.Fatal(err)
+ }
+
+ assert.Equal(t, "value", payload.Field)
+ assert.Equal(t, http.MethodPost, c.Request().Method)
+ assert.Equal(t, echo.MIMEApplicationJSON, c.Request().Header.Get(echo.HeaderContentType))
+}
diff --git a/echotest/reader.go b/echotest/reader.go
new file mode 100644
index 000000000..0caceca02
--- /dev/null
+++ b/echotest/reader.go
@@ -0,0 +1,46 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echotest
+
+import (
+ "os"
+ "path/filepath"
+ "runtime"
+ "testing"
+)
+
+type loadBytesOpts func([]byte) []byte
+
+// TrimNewlineEnd instructs LoadBytes to remove `\n` from the end of loaded file.
+func TrimNewlineEnd(bytes []byte) []byte {
+ bLen := len(bytes)
+ if bLen > 1 && bytes[bLen-1] == '\n' {
+ bytes = bytes[:bLen-1]
+ }
+ return bytes
+}
+
+// LoadBytes is helper to load file contents relative to current (where test file is) package
+// directory.
+func LoadBytes(t *testing.T, name string, opts ...loadBytesOpts) []byte {
+ bytes := loadBytes(t, name, 2)
+
+ for _, f := range opts {
+ bytes = f(bytes)
+ }
+
+ return bytes
+}
+
+func loadBytes(t *testing.T, name string, callDepth int) []byte {
+ _, b, _, _ := runtime.Caller(callDepth)
+ basepath := filepath.Dir(b)
+
+ path := filepath.Join(basepath, name) // relative path
+ bytes, err := os.ReadFile(path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return bytes[:]
+}
diff --git a/echotest/reader_external_test.go b/echotest/reader_external_test.go
new file mode 100644
index 000000000..43fd57416
--- /dev/null
+++ b/echotest/reader_external_test.go
@@ -0,0 +1,25 @@
+package echotest_test
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/labstack/echo/v5/echotest"
+ "github.com/stretchr/testify/assert"
+)
+
+const testJSONContent = `{
+ "field": "value"
+}`
+
+func TestLoadBytesOK(t *testing.T) {
+ data := echotest.LoadBytes(t, "testdata/test.json")
+ assert.Equal(t, []byte(testJSONContent+"\n"), data)
+}
+
+func TestLoadBytes_custom(t *testing.T) {
+ data := echotest.LoadBytes(t, "testdata/test.json", func(bytes []byte) []byte {
+ return []byte(strings.ToUpper(string(bytes)))
+ })
+ assert.Equal(t, []byte(strings.ToUpper(testJSONContent)+"\n"), data)
+}
diff --git a/echotest/reader_test.go b/echotest/reader_test.go
new file mode 100644
index 000000000..23b3c2dd2
--- /dev/null
+++ b/echotest/reader_test.go
@@ -0,0 +1,21 @@
+package echotest
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+const testJSONContent = `{
+ "field": "value"
+}`
+
+func TestLoadBytesOK(t *testing.T) {
+ data := LoadBytes(t, "testdata/test.json")
+ assert.Equal(t, []byte(testJSONContent+"\n"), data)
+}
+
+func TestLoadBytesOK_TrimNewlineEnd(t *testing.T) {
+ data := LoadBytes(t, "testdata/test.json", TrimNewlineEnd)
+ assert.Equal(t, []byte(testJSONContent), data)
+}
diff --git a/echotest/testdata/test.json b/echotest/testdata/test.json
new file mode 100644
index 000000000..94ae65f17
--- /dev/null
+++ b/echotest/testdata/test.json
@@ -0,0 +1,3 @@
+{
+ "field": "value"
+}
diff --git a/examples/crud/server.go b/examples/crud/server.go
deleted file mode 100644
index 683acb8bb..000000000
--- a/examples/crud/server.go
+++ /dev/null
@@ -1,75 +0,0 @@
-package main
-
-import (
- "net/http"
- "strconv"
-
- "github.com/labstack/echo"
- mw "github.com/labstack/echo/middleware"
-)
-
-type (
- user struct {
- ID int
- Name string
- }
-)
-
-var (
- users = map[int]*user{}
- seq = 1
-)
-
-//----------
-// Handlers
-//----------
-
-func createUser(c *echo.Context) error {
- u := &user{
- ID: seq,
- }
- if err := c.Bind(u); err != nil {
- return err
- }
- users[u.ID] = u
- seq++
- return c.JSON(http.StatusCreated, u)
-}
-
-func getUser(c *echo.Context) error {
- id, _ := strconv.Atoi(c.Param("id"))
- return c.JSON(http.StatusOK, users[id])
-}
-
-func updateUser(c *echo.Context) error {
- u := new(user)
- if err := c.Bind(u); err != nil {
- return err
- }
- id, _ := strconv.Atoi(c.Param("id"))
- users[id].Name = u.Name
- return c.JSON(http.StatusOK, users[id])
-}
-
-func deleteUser(c *echo.Context) error {
- id, _ := strconv.Atoi(c.Param("id"))
- delete(users, id)
- return c.NoContent(http.StatusNoContent)
-}
-
-func main() {
- e := echo.New()
-
- // Middleware
- e.Use(mw.Logger())
- e.Use(mw.Recover())
-
- // Routes
- e.Post("/users", createUser)
- e.Get("/users/:id", getUser)
- e.Patch("/users/:id", updateUser)
- e.Delete("/users/:id", deleteUser)
-
- // Start server
- e.Run(":1323")
-}
diff --git a/examples/hello/server.go b/examples/hello/server.go
deleted file mode 100644
index 88883d074..000000000
--- a/examples/hello/server.go
+++ /dev/null
@@ -1,28 +0,0 @@
-package main
-
-import (
- "net/http"
-
- "github.com/labstack/echo"
- mw "github.com/labstack/echo/middleware"
-)
-
-// Handler
-func hello(c *echo.Context) error {
- return c.String(http.StatusOK, "Hello, World!\n")
-}
-
-func main() {
- // Echo instance
- e := echo.New()
-
- // Middleware
- e.Use(mw.Logger())
- e.Use(mw.Recover())
-
- // Routes
- e.Get("/", hello)
-
- // Start server
- e.Run(":1323")
-}
diff --git a/examples/middleware/server.go b/examples/middleware/server.go
deleted file mode 100644
index 503210c33..000000000
--- a/examples/middleware/server.go
+++ /dev/null
@@ -1,58 +0,0 @@
-package main
-
-import (
- "net/http"
-
- "github.com/labstack/echo"
- mw "github.com/labstack/echo/middleware"
-)
-
-// Handler
-func hello(c *echo.Context) error {
- return c.String(http.StatusOK, "Hello, World!\n")
-}
-
-func main() {
- // Echo instance
- e := echo.New()
-
- // Debug mode
- e.SetDebug(true)
-
- //------------
- // Middleware
- //------------
-
- // Logger
- e.Use(mw.Logger())
-
- // Recover
- e.Use(mw.Recover())
-
- // Basic auth
- e.Use(mw.BasicAuth(func(usr, pwd string) bool {
- if usr == "joe" && pwd == "secret" {
- return true
- }
- return false
- }))
-
- //-------
- // Slash
- //-------
-
- e.Use(mw.StripTrailingSlash())
-
- // or
-
- // e.Use(mw.RedirectToSlash())
-
- // Gzip
- e.Use(mw.Gzip())
-
- // Routes
- e.Get("/", hello)
-
- // Start server
- e.Run(":1323")
-}
diff --git a/examples/website/public/folder/index.html b/examples/website/public/folder/index.html
deleted file mode 100644
index 36b3a421a..000000000
--- a/examples/website/public/folder/index.html
+++ /dev/null
@@ -1 +0,0 @@
-sub directory
\ No newline at end of file
diff --git a/examples/website/public/index.html b/examples/website/public/index.html
deleted file mode 100644
index aed4f4668..000000000
--- a/examples/website/public/index.html
+++ /dev/null
@@ -1,15 +0,0 @@
-
-
-
-
-
- Echo
-
-
-
-
-
- Echo!
-
-
-
diff --git a/examples/website/public/scripts/main.js b/examples/website/public/scripts/main.js
deleted file mode 100644
index c3b96d214..000000000
--- a/examples/website/public/scripts/main.js
+++ /dev/null
@@ -1 +0,0 @@
-console.log("Echo!")
diff --git a/examples/website/public/views/welcome.html b/examples/website/public/views/welcome.html
deleted file mode 100644
index 5dc667c36..000000000
--- a/examples/website/public/views/welcome.html
+++ /dev/null
@@ -1 +0,0 @@
-{{define "welcome"}}Hello, {{.}}!{{end}}
diff --git a/examples/website/server.go b/examples/website/server.go
deleted file mode 100644
index 38886b92f..000000000
--- a/examples/website/server.go
+++ /dev/null
@@ -1,146 +0,0 @@
-package main
-
-import (
- "io"
- "net/http"
-
- "html/template"
-
- "github.com/labstack/echo"
- mw "github.com/labstack/echo/middleware"
- "github.com/rs/cors"
- "github.com/thoas/stats"
-)
-
-type (
- // Template provides HTML template rendering
- Template struct {
- templates *template.Template
- }
-
- user struct {
- ID string `json:"id"`
- Name string `json:"name"`
- }
-)
-
-var (
- users map[string]user
-)
-
-// Render HTML
-func (t *Template) Render(w io.Writer, name string, data interface{}) error {
- return t.templates.ExecuteTemplate(w, name, data)
-}
-
-//----------
-// Handlers
-//----------
-
-func welcome(c *echo.Context) error {
- return c.Render(http.StatusOK, "welcome", "Joe")
-}
-
-func createUser(c *echo.Context) error {
- u := new(user)
- if err := c.Bind(u); err != nil {
- return err
- }
- users[u.ID] = *u
- return c.JSON(http.StatusCreated, u)
-}
-
-func getUsers(c *echo.Context) error {
- return c.JSON(http.StatusOK, users)
-}
-
-func getUser(c *echo.Context) error {
- return c.JSON(http.StatusOK, users[c.P(0)])
-}
-
-func main() {
- e := echo.New()
-
- // Middleware
- e.Use(mw.Logger())
- e.Use(mw.Recover())
- e.Use(mw.Gzip())
-
- //------------------------
- // Third-party middleware
- //------------------------
-
- // https://github.com/rs/cors
- e.Use(cors.Default().Handler)
-
- // https://github.com/thoas/stats
- s := stats.New()
- e.Use(s.Handler)
- // Route
- e.Get("/stats", func(c *echo.Context) error {
- return c.JSON(http.StatusOK, s.Data())
- })
-
- // Serve index file
- e.Index("public/index.html")
-
- // Serve favicon
- e.Favicon("public/favicon.ico")
-
- // Serve static files
- e.Static("/scripts", "public/scripts")
-
- //--------
- // Routes
- //--------
-
- e.Post("/users", createUser)
- e.Get("/users", getUsers)
- e.Get("/users/:id", getUser)
-
- //-----------
- // Templates
- //-----------
-
- t := &Template{
- // Cached templates
- templates: template.Must(template.ParseFiles("public/views/welcome.html")),
- }
- e.SetRenderer(t)
- e.Get("/welcome", welcome)
-
- //-------
- // Group
- //-------
-
- // Group with parent middleware
- a := e.Group("/admin")
- a.Use(func(c *echo.Context) error {
- // Security middleware
- return nil
- })
- a.Get("", func(c *echo.Context) error {
- return c.String(http.StatusOK, "Welcome admin!")
- })
-
- // Group with no parent middleware
- g := e.Group("/files", func(c *echo.Context) error {
- // Security middleware
- return nil
- })
- g.Get("", func(c *echo.Context) error {
- return c.String(http.StatusOK, "Your files!")
- })
-
- // Start server
- e.Run(":1323")
-}
-
-func init() {
- users = map[string]user{
- "1": user{
- ID: "1",
- Name: "Wreck-It Ralph",
- },
- }
-}
diff --git a/go.mod b/go.mod
new file mode 100644
index 000000000..a2480a285
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,16 @@
+module github.com/labstack/echo/v5
+
+go 1.25.0
+
+require (
+ github.com/stretchr/testify v1.11.1
+ golang.org/x/net v0.49.0
+ golang.org/x/time v0.14.0
+)
+
+require (
+ github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/pmezard/go-difflib v1.0.0 // indirect
+ golang.org/x/text v0.33.0 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
+)
diff --git a/go.sum b/go.sum
new file mode 100644
index 000000000..f1e80fc13
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,16 @@
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
+github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
+golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
+golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
+golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
+golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
+golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
+golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/group.go b/group.go
index 856a33651..d81cd9163 100644
--- a/group.go
+++ b/group.go
@@ -1,69 +1,172 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package echo
-type (
- Group struct {
- echo Echo
- }
+import (
+ "io/fs"
+ "net/http"
)
-func (g *Group) Use(m ...Middleware) {
- for _, h := range m {
- g.echo.middleware = append(g.echo.middleware, wrapMiddleware(h))
- }
+// Group is a set of sub-routes for a specified route. It can be used for inner
+// routes that share a common middleware or functionality that should be separate
+// from the parent echo instance while still inheriting from it.
+type Group struct {
+ echo *Echo
+ prefix string
+ middleware []MiddlewareFunc
+}
+
+// Use implements `Echo#Use()` for sub-routes within the Group.
+// Group middlewares are not executed on request when there is no matching route found.
+func (g *Group) Use(middleware ...MiddlewareFunc) {
+ g.middleware = append(g.middleware, middleware...)
}
-func (g *Group) Connect(path string, h Handler) {
- g.echo.Connect(path, h)
+// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error.
+func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return g.Add(http.MethodConnect, path, h, m...)
}
-func (g *Group) Delete(path string, h Handler) {
- g.echo.Delete(path, h)
+// DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error.
+func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return g.Add(http.MethodDelete, path, h, m...)
}
-func (g *Group) Get(path string, h Handler) {
- g.echo.Get(path, h)
+// GET implements `Echo#GET()` for sub-routes within the Group. Panics on error.
+func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return g.Add(http.MethodGet, path, h, m...)
}
-func (g *Group) Head(path string, h Handler) {
- g.echo.Head(path, h)
+// HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error.
+func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return g.Add(http.MethodHead, path, h, m...)
}
-func (g *Group) Options(path string, h Handler) {
- g.echo.Options(path, h)
+// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error.
+func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return g.Add(http.MethodOptions, path, h, m...)
}
-func (g *Group) Patch(path string, h Handler) {
- g.echo.Patch(path, h)
+// PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error.
+func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return g.Add(http.MethodPatch, path, h, m...)
}
-func (g *Group) Post(path string, h Handler) {
- g.echo.Post(path, h)
+// POST implements `Echo#POST()` for sub-routes within the Group. Panics on error.
+func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return g.Add(http.MethodPost, path, h, m...)
}
-func (g *Group) Put(path string, h Handler) {
- g.echo.Put(path, h)
+// PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error.
+func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return g.Add(http.MethodPut, path, h, m...)
}
-func (g *Group) Trace(path string, h Handler) {
- g.echo.Trace(path, h)
+// TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error.
+func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return g.Add(http.MethodTrace, path, h, m...)
}
-func (g *Group) WebSocket(path string, h HandlerFunc) {
- g.echo.WebSocket(path, h)
+// Any implements `Echo#Any()` for sub-routes within the Group. Panics on error.
+func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
+ return g.Add(RouteAny, path, handler, middleware...)
}
-func (g *Group) Static(path, root string) {
- g.echo.Static(path, root)
+// Match implements `Echo#Match()` for sub-routes within the Group. Panics on error.
+func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes {
+ errs := make([]error, 0)
+ ris := make(Routes, 0)
+ for _, m := range methods {
+ ri, err := g.AddRoute(Route{
+ Method: m,
+ Path: path,
+ Handler: handler,
+ Middlewares: middleware,
+ })
+ if err != nil {
+ errs = append(errs, err)
+ continue
+ }
+ ris = append(ris, ri)
+ }
+ if len(errs) > 0 {
+ panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
+ }
+ return ris
}
-func (g *Group) ServeDir(path, root string) {
- g.echo.ServeDir(path, root)
+// Group creates a new sub-group with prefix and optional sub-group-level middleware.
+// Important! Group middlewares are only executed in case there was exact route match and not
+// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add
+// a catch-all route `/*` for the group which handler returns always 404
+func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) {
+ m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware))
+ m = append(m, g.middleware...)
+ m = append(m, middleware...)
+ sg = g.echo.Group(g.prefix+prefix, m...)
+ return
}
-func (g *Group) ServeFile(path, file string) {
- g.echo.ServeFile(path, file)
+// Static implements `Echo#Static()` for sub-routes within the Group.
+func (g *Group) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo {
+ subFs := MustSubFS(g.echo.Filesystem, fsRoot)
+ return g.StaticFS(pathPrefix, subFs, middleware...)
+}
+
+// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group.
+//
+// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary
+// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths
+// including `assets/images` as their prefix.
+func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo {
+ return g.Add(
+ http.MethodGet,
+ pathPrefix+"*",
+ StaticDirectoryHandler(filesystem, false),
+ middleware...,
+ )
+}
+
+// FileFS implements `Echo#FileFS()` for sub-routes within the Group.
+func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo {
+ return g.GET(path, StaticFileHandler(file, filesystem), m...)
+}
+
+// File implements `Echo#File()` for sub-routes within the Group. Panics on error.
+func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo {
+ handler := func(c *Context) error {
+ return c.File(file)
+ }
+ return g.Add(http.MethodGet, path, handler, middleware...)
+}
+
+// RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group.
+//
+// Example: `g.RouteNotFound("/*", func(c *echo.Context) error { return c.NoContent(http.StatusNotFound) })`
+func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo {
+ return g.Add(RouteNotFound, path, h, m...)
+}
+
+// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error.
+func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
+ ri, err := g.AddRoute(Route{
+ Method: method,
+ Path: path,
+ Handler: handler,
+ Middlewares: middleware,
+ })
+ if err != nil {
+ panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage
+ }
+ return ri
}
-func (g *Group) Group(prefix string, m ...Middleware) *Group {
- return g.echo.Group(prefix, m...)
+// AddRoute registers a new Routable with Router
+func (g *Group) AddRoute(route Route) (RouteInfo, error) {
+ // Combine middleware into a new slice to avoid accidentally passing the same slice for
+ // multiple routes, which would lead to later add() calls overwriting the
+ // middleware from earlier calls.
+ groupRoute := route.WithPrefix(g.prefix, append([]MiddlewareFunc{}, g.middleware...))
+ return g.echo.add(groupRoute)
}
diff --git a/group_test.go b/group_test.go
index f605993cc..7078b6497 100644
--- a/group_test.go
+++ b/group_test.go
@@ -1,21 +1,814 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package echo
-import "testing"
+import (
+ "io/fs"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) {
+ e := New()
+
+ called := false
+ mw := func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ called = true
+ return c.NoContent(http.StatusTeapot)
+ }
+ }
+ // even though group has middleware it will not be executed when there are no routes under that group
+ _ = e.Group("/group", mw)
+
+ status, body := request(http.MethodGet, "/group/nope", e)
+ assert.Equal(t, http.StatusNotFound, status)
+ assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
+
+ assert.False(t, called)
+}
+
+func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) {
+ e := New()
+
+ called := false
+ mw := func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ called = true
+ return c.NoContent(http.StatusTeapot)
+ }
+ }
+ // even though group has middleware and routes when we have no match on some route the middlewares for that
+ // group will not be executed
+ g := e.Group("/group", mw)
+ g.GET("/yes", handlerFunc)
+
+ status, body := request(http.MethodGet, "/group/nope", e)
+ assert.Equal(t, http.StatusNotFound, status)
+ assert.Equal(t, `{"message":"Not Found"}`+"\n", body)
+
+ assert.False(t, called)
+}
+
+func TestGroup_multiLevelGroup(t *testing.T) {
+ e := New()
+
+ api := e.Group("/api")
+ users := api.Group("/users")
+ users.GET("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ status, body := request(http.MethodGet, "/api/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
-func TestGroup(t *testing.T) {
- g := New().Group("/group")
+func TestGroupFile(t *testing.T) {
+ e := New()
+ g := e.Group("/group")
+ g.File("/walle", "_fixture/images/walle.png")
+ expectedData, err := os.ReadFile("_fixture/images/walle.png")
+ assert.Nil(t, err)
+ req := httptest.NewRequest(http.MethodGet, "/group/walle", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, expectedData, rec.Body.Bytes())
+}
+
+func TestGroupRouteMiddleware(t *testing.T) {
+ // Ensure middleware slices are not re-used
+ e := New()
+ g := e.Group("/group")
h := func(*Context) error { return nil }
- g.Connect("/", h)
- g.Delete("/", h)
- g.Get("/", h)
- g.Head("/", h)
- g.Options("/", h)
- g.Patch("/", h)
- g.Post("/", h)
- g.Put("/", h)
- g.Trace("/", h)
- g.WebSocket("/ws", h)
- g.Static("/scripts", "scripts")
- g.ServeDir("/scripts", "scripts")
- g.ServeFile("/scripts/main.js", "scripts/main.js")
+ m1 := func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ return next(c)
+ }
+ }
+ m2 := func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ return next(c)
+ }
+ }
+ m3 := func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ return next(c)
+ }
+ }
+ m4 := func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ return c.NoContent(404)
+ }
+ }
+ m5 := func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ return c.NoContent(405)
+ }
+ }
+ g.Use(m1, m2, m3)
+ g.GET("/404", h, m4)
+ g.GET("/405", h, m5)
+
+ c, _ := request(http.MethodGet, "/group/404", e)
+ assert.Equal(t, 404, c)
+ c, _ = request(http.MethodGet, "/group/405", e)
+ assert.Equal(t, 405, c)
+}
+
+func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) {
+ // Ensure middleware and match any routes do not conflict
+ e := New()
+ g := e.Group("/group")
+ m1 := func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ return next(c)
+ }
+ }
+ m2 := func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ return c.String(http.StatusOK, c.RouteInfo().Path)
+ }
+ }
+ h := func(c *Context) error {
+ return c.String(http.StatusOK, c.RouteInfo().Path)
+ }
+ g.Use(m1)
+ g.GET("/help", h, m2)
+ g.GET("/*", h, m2)
+ g.GET("", h, m2)
+ e.GET("unrelated", h, m2)
+ e.GET("*", h, m2)
+
+ _, m := request(http.MethodGet, "/group/help", e)
+ assert.Equal(t, "/group/help", m)
+ _, m = request(http.MethodGet, "/group/help/other", e)
+ assert.Equal(t, "/group/*", m)
+ _, m = request(http.MethodGet, "/group/404", e)
+ assert.Equal(t, "/group/*", m)
+ _, m = request(http.MethodGet, "/group", e)
+ assert.Equal(t, "/group", m)
+ _, m = request(http.MethodGet, "/other", e)
+ assert.Equal(t, "/*", m)
+ _, m = request(http.MethodGet, "/", e)
+ assert.Equal(t, "/*", m)
+
+}
+
+func TestGroup_CONNECT(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.CONNECT("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodConnect, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodConnect, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_DELETE(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.DELETE("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodDelete, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodDelete, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_HEAD(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.HEAD("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodHead, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, http.MethodHead+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodHead, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_OPTIONS(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.OPTIONS("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodOptions, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodOptions, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_PATCH(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.PATCH("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPatch, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodPatch, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_POST(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.POST("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPost, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, http.MethodPost+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodPost, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_PUT(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.PUT("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodPut, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, http.MethodPut+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodPut, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_TRACE(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.TRACE("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ assert.Equal(t, http.MethodTrace, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodTrace, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+}
+
+func TestGroup_RouteNotFound(t *testing.T) {
+ var testCases = []struct {
+ expectRoute any
+ name string
+ whenURL string
+ expectCode int
+ }{
+ {
+ name: "404, route to static not found handler /group/a/c/xx",
+ whenURL: "/group/a/c/xx",
+ expectRoute: "GET /group/a/c/xx",
+ expectCode: http.StatusNotFound,
+ },
+ {
+ name: "404, route to path param not found handler /group/a/:file",
+ whenURL: "/group/a/echo.exe",
+ expectRoute: "GET /group/a/:file",
+ expectCode: http.StatusNotFound,
+ },
+ {
+ name: "404, route to any not found handler /group/*",
+ whenURL: "/group/b/echo.exe",
+ expectRoute: "GET /group/*",
+ expectCode: http.StatusNotFound,
+ },
+ {
+ name: "200, route /group/a/c/df to /group/a/c/df",
+ whenURL: "/group/a/c/df",
+ expectRoute: "GET /group/a/c/df",
+ expectCode: http.StatusOK,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ g := e.Group("/group")
+
+ okHandler := func(c *Context) error {
+ return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
+ }
+ notFoundHandler := func(c *Context) error {
+ return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path())
+ }
+
+ g.GET("/", okHandler)
+ g.GET("/a/c/df", okHandler)
+ g.GET("/a/b*", okHandler)
+ g.PUT("/*", okHandler)
+
+ g.RouteNotFound("/a/c/xx", notFoundHandler) // static
+ g.RouteNotFound("/a/:file", notFoundHandler) // param
+ g.RouteNotFound("/*", notFoundHandler) // any
+
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectCode, rec.Code)
+ assert.Equal(t, tc.expectRoute, rec.Body.String())
+ })
+ }
+}
+
+func TestGroup_Any(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ ri := users.Any("/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK from ANY")
+ })
+
+ assert.Equal(t, RouteAny, ri.Method)
+ assert.Equal(t, "/users/activate", ri.Path)
+ assert.Equal(t, RouteAny+":/users/activate", ri.Name)
+ assert.Nil(t, ri.Parameters)
+
+ status, body := request(http.MethodTrace, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK from ANY`, body)
+}
+
+func TestGroup_Match(t *testing.T) {
+ e := New()
+
+ myMethods := []string{http.MethodGet, http.MethodPost}
+ users := e.Group("/users")
+ ris := users.Match(myMethods, "/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+ assert.Len(t, ris, 2)
+
+ for _, m := range myMethods {
+ status, body := request(m, "/users/activate", e)
+ assert.Equal(t, http.StatusTeapot, status)
+ assert.Equal(t, `OK`, body)
+ }
+}
+
+func TestGroup_MatchWithErrors(t *testing.T) {
+ e := New()
+
+ users := e.Group("/users")
+ users.GET("/activate", func(c *Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+ myMethods := []string{http.MethodGet, http.MethodPost}
+
+ errs := func() (errs []error) {
+ defer func() {
+ if r := recover(); r != nil {
+ if tmpErr, ok := r.([]error); ok {
+ errs = tmpErr
+ return
+ }
+ panic(r)
+ }
+ }()
+
+ users.Match(myMethods, "/activate", func(c *Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+ return nil
+ }()
+ assert.Len(t, errs, 1)
+ assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed")
+
+ for _, m := range myMethods {
+ status, body := request(m, "/users/activate", e)
+
+ expect := http.StatusTeapot
+ if m == http.MethodGet {
+ expect = http.StatusOK
+ }
+ assert.Equal(t, expect, status)
+ assert.Equal(t, `OK`, body)
+ }
+}
+
+func TestGroup_Static(t *testing.T) {
+ e := New()
+
+ g := e.Group("/books")
+ ri := g.Static("/download", "_fixture")
+ assert.Equal(t, http.MethodGet, ri.Method)
+ assert.Equal(t, "/books/download*", ri.Path)
+ assert.Equal(t, "GET:/books/download*", ri.Name)
+ assert.Equal(t, []string{"*"}, ri.Parameters)
+
+ req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+ body := rec.Body.String()
+ assert.True(t, strings.HasPrefix(body, ""))
+}
+
+func TestGroup_StaticMultiTest(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenPrefix string
+ givenRoot string
+ whenURL string
+ expectHeaderLocation string
+ expectBodyStartsWith string
+ expectBodyNotContains string
+ expectStatus int
+ }{
+ {
+ name: "ok",
+ givenPrefix: "/images",
+ givenRoot: "_fixture/images",
+ whenURL: "/test/images/walle.png",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
+ },
+ {
+ name: "ok, without prefix",
+ givenPrefix: "",
+ givenRoot: "_fixture/images",
+ whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png`
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}),
+ },
+ {
+ name: "nok, without prefix does not serve dir index",
+ givenPrefix: "",
+ givenRoot: "_fixture/images",
+ whenURL: "/test/", // `/test` + `*` creates route `/test*`
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "No file",
+ givenPrefix: "/images",
+ givenRoot: "_fixture/scripts",
+ whenURL: "/test/images/bolt.png",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Directory",
+ givenPrefix: "/images",
+ givenRoot: "_fixture/images",
+ whenURL: "/test/images/",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Directory Redirect",
+ givenPrefix: "/",
+ givenRoot: "_fixture",
+ whenURL: "/test/folder",
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/test/folder/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Directory Redirect with non-root path",
+ givenPrefix: "/static",
+ givenRoot: "_fixture",
+ whenURL: "/test/static",
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/test/static/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory 404 (request URL without slash)",
+ givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder"
+ givenRoot: "_fixture",
+ whenURL: "/test/folder", // no trailing slash
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "Prefixed directory redirect (without slash redirect to slash)",
+ givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/*
+ givenRoot: "_fixture",
+ whenURL: "/test/folder", // no trailing slash
+ expectStatus: http.StatusMovedPermanently,
+ expectHeaderLocation: "/test/folder/",
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Directory with index.html",
+ givenPrefix: "/",
+ givenRoot: "_fixture",
+ whenURL: "/test/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory with index.html (prefix ending with slash)",
+ givenPrefix: "/assets/",
+ givenRoot: "_fixture",
+ whenURL: "/test/assets/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Prefixed directory with index.html (prefix ending without slash)",
+ givenPrefix: "/assets",
+ givenRoot: "_fixture",
+ whenURL: "/test/assets/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "Sub-directory with index.html",
+ givenPrefix: "/",
+ givenRoot: "_fixture",
+ whenURL: "/test/folder/",
+ expectStatus: http.StatusOK,
+ expectBodyStartsWith: "",
+ },
+ {
+ name: "nok, URL encoded path traversal (single encoding, slash - unix separator)",
+ givenRoot: "_fixture/dist/public",
+ whenURL: "/%2e%2e%2fprivate.txt",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ expectBodyNotContains: `private file`,
+ },
+ {
+ name: "nok, URL encoded path traversal (single encoding, backslash - windows separator)",
+ givenRoot: "_fixture/dist/public",
+ whenURL: "/%2e%2e%5cprivate.txt",
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ expectBodyNotContains: `private file`,
+ },
+ {
+ name: "do not allow directory traversal (backslash - windows separator)",
+ givenPrefix: "/",
+ givenRoot: "_fixture/",
+ whenURL: `/test/..\\middleware/basic_auth.go`,
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ {
+ name: "do not allow directory traversal (slash - unix separator)",
+ givenPrefix: "/",
+ givenRoot: "_fixture/",
+ whenURL: `/test/../middleware/basic_auth.go`,
+ expectStatus: http.StatusNotFound,
+ expectBodyStartsWith: "{\"message\":\"Not Found\"}\n",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+
+ g := e.Group("/test")
+ g.Static(tc.givenPrefix, tc.givenRoot)
+
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ body := rec.Body.String()
+ if tc.expectBodyStartsWith != "" {
+ assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith))
+ } else {
+ assert.Equal(t, "", body)
+ }
+ if tc.expectBodyNotContains != "" {
+ assert.NotContains(t, body, tc.expectBodyNotContains)
+ }
+
+ if tc.expectHeaderLocation != "" {
+ assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0])
+ } else {
+ _, ok := rec.Result().Header["Location"]
+ assert.False(t, ok)
+ }
+ })
+ }
+}
+
+func TestGroup_FileFS(t *testing.T) {
+ var testCases = []struct {
+ whenFS fs.FS
+ name string
+ whenPath string
+ whenFile string
+ givenURL string
+ expectStartsWith []byte
+ expectCode int
+ }{
+ {
+ name: "ok",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "walle.png",
+ givenURL: "/assets/walle",
+ expectCode: http.StatusOK,
+ expectStartsWith: []byte{0x89, 0x50, 0x4e},
+ },
+ {
+ name: "nok, requesting invalid path",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "walle.png",
+ givenURL: "/assets/walle.png",
+ expectCode: http.StatusNotFound,
+ expectStartsWith: []byte(`{"message":"Not Found"}`),
+ },
+ {
+ name: "nok, serving not existent file from filesystem",
+ whenPath: "/walle",
+ whenFS: os.DirFS("_fixture/images"),
+ whenFile: "not-existent.png",
+ givenURL: "/assets/walle",
+ expectCode: http.StatusNotFound,
+ expectStartsWith: []byte(`{"message":"Not Found"}`),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ g := e.Group("/assets")
+ g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS)
+
+ req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectCode, rec.Code)
+
+ body := rec.Body.Bytes()
+ if len(body) > len(tc.expectStartsWith) {
+ body = body[:len(tc.expectStartsWith)]
+ }
+ assert.Equal(t, tc.expectStartsWith, body)
+ })
+ }
+}
+
+func TestGroup_StaticPanic(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenRoot string
+ }{
+ {
+ name: "panics for ../",
+ givenRoot: "../images",
+ },
+ {
+ name: "panics for /",
+ givenRoot: "/images",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := New()
+ e.Filesystem = os.DirFS("./")
+
+ g := e.Group("/assets")
+
+ assert.Panics(t, func() {
+ g.Static("/images", tc.givenRoot)
+ })
+ })
+ }
+}
+
+func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) {
+ var testCases = []struct {
+ expectBody any
+ name string
+ whenURL string
+ expectCode int
+ givenCustom404 bool
+ expectMiddlewareCalled bool
+ }{
+ {
+ name: "ok, custom 404 handler is called with middleware",
+ givenCustom404: true,
+ whenURL: "/group/test3",
+ expectBody: "404 GET /group/*",
+ expectCode: http.StatusNotFound,
+ expectMiddlewareCalled: true, // because RouteNotFound is added after middleware is added
+ },
+ {
+ name: "ok, default group 404 handler is not called with middleware",
+ givenCustom404: false,
+ whenURL: "/group/test3",
+ expectBody: "404 GET /*",
+ expectCode: http.StatusNotFound,
+ expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added
+ },
+ {
+ name: "ok, (no slash) default group 404 handler is called with middleware",
+ givenCustom404: false,
+ whenURL: "/group",
+ expectBody: "404 GET /*",
+ expectCode: http.StatusNotFound,
+ expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+
+ okHandler := func(c *Context) error {
+ return c.String(http.StatusOK, c.Request().Method+" "+c.Path())
+ }
+ notFoundHandler := func(c *Context) error {
+ return c.String(http.StatusNotFound, "404 "+c.Request().Method+" "+c.Path())
+ }
+
+ e := New()
+ e.GET("/test1", okHandler)
+ e.RouteNotFound("/*", notFoundHandler)
+
+ g := e.Group("/group")
+ g.GET("/test1", okHandler)
+
+ middlewareCalled := false
+ g.Use(func(next HandlerFunc) HandlerFunc {
+ return func(c *Context) error {
+ middlewareCalled = true
+ return next(c)
+ }
+ })
+ if tc.givenCustom404 {
+ g.RouteNotFound("/*", notFoundHandler)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectMiddlewareCalled, middlewareCalled)
+ assert.Equal(t, tc.expectCode, rec.Code)
+ assert.Equal(t, tc.expectBody, rec.Body.String())
+ })
+ }
}
diff --git a/httperror.go b/httperror.go
new file mode 100644
index 000000000..6e14da3d9
--- /dev/null
+++ b/httperror.go
@@ -0,0 +1,117 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+)
+
+// HTTPStatusCoder is interface that errors can implement to produce status code for HTTP response
+type HTTPStatusCoder interface {
+ StatusCode() int
+}
+
+// StatusCode returns status code from error if it implements HTTPStatusCoder interface.
+// If error does not implement the interface it returns 0.
+func StatusCode(err error) int {
+ var sc HTTPStatusCoder
+ if errors.As(err, &sc) {
+ return sc.StatusCode()
+ }
+ return 0
+}
+
+// Following errors can produce HTTP status code by implementing HTTPStatusCoder interface
+var (
+ ErrBadRequest = &httpError{http.StatusBadRequest} // 400
+ ErrUnauthorized = &httpError{http.StatusUnauthorized} // 401
+ ErrForbidden = &httpError{http.StatusForbidden} // 403
+ ErrNotFound = &httpError{http.StatusNotFound} // 404
+ ErrMethodNotAllowed = &httpError{http.StatusMethodNotAllowed} // 405
+ ErrRequestTimeout = &httpError{http.StatusRequestTimeout} // 408
+ ErrStatusRequestEntityTooLarge = &httpError{http.StatusRequestEntityTooLarge} // 413
+ ErrUnsupportedMediaType = &httpError{http.StatusUnsupportedMediaType} // 415
+ ErrTooManyRequests = &httpError{http.StatusTooManyRequests} // 429
+ ErrInternalServerError = &httpError{http.StatusInternalServerError} // 500
+ ErrBadGateway = &httpError{http.StatusBadGateway} // 502
+ ErrServiceUnavailable = &httpError{http.StatusServiceUnavailable} // 503
+)
+
+// Following errors fall into 500 (InternalServerError) category
+var (
+ ErrValidatorNotRegistered = errors.New("validator not registered")
+ ErrRendererNotRegistered = errors.New("renderer not registered")
+ ErrInvalidRedirectCode = errors.New("invalid redirect status code")
+ ErrCookieNotFound = errors.New("cookie not found")
+ ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
+ ErrInvalidListenerNetwork = errors.New("invalid listener network")
+)
+
+// NewHTTPError creates new instance of HTTPError
+func NewHTTPError(code int, message string) *HTTPError {
+ return &HTTPError{
+ Code: code,
+ Message: message,
+ }
+}
+
+// HTTPError represents an error that occurred while handling a request.
+type HTTPError struct {
+ // Code is status code for HTTP response
+ Code int `json:"-"`
+ Message string `json:"message"`
+ err error
+}
+
+// StatusCode returns status code for HTTP response
+func (he *HTTPError) StatusCode() int {
+ return he.Code
+}
+
+// Error makes it compatible with `error` interface.
+func (he *HTTPError) Error() string {
+ msg := he.Message
+ if msg == "" {
+ msg = http.StatusText(he.Code)
+ }
+ if he.err == nil {
+ return fmt.Sprintf("code=%d, message=%v", he.Code, msg)
+ }
+ return fmt.Sprintf("code=%d, message=%v, err=%v", he.Code, msg, he.err.Error())
+}
+
+// Wrap eturns new HTTPError with given errors wrapped inside
+func (he HTTPError) Wrap(err error) error {
+ return &HTTPError{
+ Code: he.Code,
+ Message: he.Message,
+ err: err,
+ }
+}
+
+func (he *HTTPError) Unwrap() error {
+ return he.err
+}
+
+type httpError struct {
+ code int
+}
+
+func (he httpError) StatusCode() int {
+ return he.code
+}
+
+func (he httpError) Error() string {
+ return http.StatusText(he.code) // does not include status code
+}
+
+func (he httpError) Wrap(err error) error {
+ return &HTTPError{
+ Code: he.code,
+ Message: http.StatusText(he.code),
+ err: err,
+ }
+}
diff --git a/httperror_external_test.go b/httperror_external_test.go
new file mode 100644
index 000000000..91acdca25
--- /dev/null
+++ b/httperror_external_test.go
@@ -0,0 +1,52 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+// run tests as external package to get real feel for API
+package echo_test
+
+import (
+ "encoding/json"
+ "fmt"
+ "github.com/labstack/echo/v5"
+ "net/http"
+ "net/http/httptest"
+)
+
+func ExampleDefaultHTTPErrorHandler() {
+ e := echo.New()
+ e.GET("/api/endpoint", func(c *echo.Context) error {
+ return &apiError{
+ Code: http.StatusBadRequest,
+ Body: map[string]any{"message": "custom error"},
+ }
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/api/endpoint?err=1", nil)
+ resp := httptest.NewRecorder()
+
+ e.ServeHTTP(resp, req)
+
+ fmt.Printf("%d %s", resp.Code, resp.Body.String())
+
+ // Output: 400 {"error":{"message":"custom error"}}
+}
+
+type apiError struct {
+ Code int
+ Body any
+}
+
+func (e *apiError) StatusCode() int {
+ return e.Code
+}
+
+func (e *apiError) MarshalJSON() ([]byte, error) {
+ type body struct {
+ Error any `json:"error"`
+ }
+ return json.Marshal(body{Error: e.Body})
+}
+
+func (e *apiError) Error() string {
+ return http.StatusText(e.Code)
+}
diff --git a/httperror_test.go b/httperror_test.go
new file mode 100644
index 000000000..0a91bbc9c
--- /dev/null
+++ b/httperror_test.go
@@ -0,0 +1,109 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestHTTPError_StatusCode(t *testing.T) {
+ var err error = &HTTPError{Code: http.StatusBadRequest, Message: "my error message"}
+
+ code := 0
+ var sc HTTPStatusCoder
+ if errors.As(err, &sc) {
+ code = sc.StatusCode()
+ }
+ assert.Equal(t, http.StatusBadRequest, code)
+}
+
+func TestHTTPError_Error(t *testing.T) {
+ var testCases = []struct {
+ name string
+ error error
+ expect string
+ }{
+ {
+ name: "ok, without message",
+ error: &HTTPError{Code: http.StatusBadRequest},
+ expect: "code=400, message=Bad Request",
+ },
+ {
+ name: "ok, with message",
+ error: &HTTPError{Code: http.StatusBadRequest, Message: "my error message"},
+ expect: "code=400, message=my error message",
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ assert.Equal(t, tc.expect, tc.error.Error())
+ })
+ }
+}
+
+func TestHTTPError_WrapUnwrap(t *testing.T) {
+ err := &HTTPError{Code: http.StatusBadRequest, Message: "bad"}
+ wrapped := err.Wrap(errors.New("my_error")).(*HTTPError)
+
+ err.Code = http.StatusOK
+ err.Message = "changed"
+
+ assert.Equal(t, http.StatusBadRequest, wrapped.Code)
+ assert.Equal(t, "bad", wrapped.Message)
+
+ assert.Equal(t, errors.New("my_error"), wrapped.Unwrap())
+ assert.Equal(t, "code=400, message=bad, err=my_error", wrapped.Error())
+}
+
+func TestNewHTTPError(t *testing.T) {
+ err := NewHTTPError(http.StatusBadRequest, "bad")
+ err2 := &HTTPError{Code: http.StatusBadRequest, Message: "bad"}
+
+ assert.Equal(t, err2, err)
+}
+
+func TestStatusCode(t *testing.T) {
+ var testCases = []struct {
+ name string
+ err error
+ expect int
+ }{
+ {
+ name: "ok, HTTPError",
+ err: &HTTPError{Code: http.StatusNotFound},
+ expect: http.StatusNotFound,
+ },
+ {
+ name: "ok, sentinel error",
+ err: ErrNotFound,
+ expect: http.StatusNotFound,
+ },
+ {
+ name: "ok, wrapped HTTPError",
+ err: fmt.Errorf("wrapped: %w", &HTTPError{Code: http.StatusTeapot}),
+ expect: http.StatusTeapot,
+ },
+ {
+ name: "nok, normal error",
+ err: errors.New("error"),
+ expect: 0,
+ },
+ {
+ name: "nok, nil",
+ err: nil,
+ expect: 0,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ assert.Equal(t, tc.expect, StatusCode(tc.err))
+ })
+ }
+}
diff --git a/ip.go b/ip.go
new file mode 100644
index 000000000..e2b287bfd
--- /dev/null
+++ b/ip.go
@@ -0,0 +1,267 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "net"
+ "net/http"
+ "strings"
+)
+
+/**
+By: https://github.com/tmshn (See: https://github.com/labstack/echo/pull/1478 , https://github.com/labstack/echox/pull/134 )
+Source: https://echo.labstack.com/guide/ip-address/
+
+IP address plays fundamental role in HTTP; it's used for access control, auditing, geo-based access analysis and more.
+Echo provides handy method [`Context#RealIP()`](https://godoc.org/github.com/labstack/echo#Context) for that.
+
+However, it is not trivial to retrieve the _real_ IP address from requests especially when you put L7 proxies before the application.
+In such situation, _real_ IP needs to be relayed on HTTP layer from proxies to your app, but you must not trust HTTP headers unconditionally.
+Otherwise, you might give someone a chance of deceiving you. **A security risk!**
+
+To retrieve IP address reliably/securely, you must let your application be aware of the entire architecture of your infrastructure.
+In Echo, this can be done by configuring `Echo#IPExtractor` appropriately.
+This guides show you why and how.
+
+> Note: if you don't set `Echo#IPExtractor` explicitly, Echo fallback to legacy behavior, which is not a good choice.
+
+Let's start from two questions to know the right direction:
+
+1. Do you put any HTTP (L7) proxy in front of the application?
+ - It includes both cloud solutions (such as AWS ALB or GCP HTTP LB) and OSS ones (such as Nginx, Envoy or Istio ingress gateway).
+2. If yes, what HTTP header do your proxies use to pass client IP to the application?
+
+## Case 1. With no proxy
+
+If you put no proxy (e.g.: directory facing to the internet), all you need to (and have to) see is IP address from network layer.
+Any HTTP header is untrustable because the clients have full control what headers to be set.
+
+In this case, use `echo.ExtractIPDirect()`.
+
+```go
+e.IPExtractor = echo.ExtractIPDirect()
+```
+
+## Case 2. With proxies using `X-Forwarded-For` header
+
+[`X-Forwared-For` (XFF)](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For) is the popular header
+to relay clients' IP addresses.
+At each hop on the proxies, they append the request IP address at the end of the header.
+
+Following example diagram illustrates this behavior.
+
+```text
+┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐
+│ "Origin" │───────────>│ Proxy 1 │───────────>│ Proxy 2 │───────────>│ Your app │
+│ (IP: a) │ │ (IP: b) │ │ (IP: c) │ │ │
+└──────────┘ └──────────┘ └──────────┘ └──────────┘
+
+Case 1.
+XFF: "" "a" "a, b"
+ ~~~~~~
+Case 2.
+XFF: "x" "x, a" "x, a, b"
+ ~~~~~~~~~
+ ↑ What your app will see
+```
+
+In this case, use **first _untrustable_ IP reading from right**. Never use first one reading from left, as it is
+configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructure".
+In above example, if `b` and `c` are trustable, the IP address of the client is `a` for both cases, never be `x`.
+
+In Echo, use `ExtractIPFromXFFHeader(...TrustOption)`.
+
+```go
+e.IPExtractor = echo.ExtractIPFromXFFHeader()
+```
+
+By default, it trusts internal IP addresses (loopback, link-local unicast, private-use and unique local address
+from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and
+[RFC4193](https://tools.ietf.org/html/rfc4193)).
+To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s.
+
+E.g.:
+
+```go
+e.IPExtractor = echo.ExtractIPFromXFFHeader(
+ TrustLinkLocal(false),
+ TrustIPRanges(lbIPRange),
+)
+```
+
+- Ref: https://godoc.org/github.com/labstack/echo#TrustOption
+
+## Case 3. With proxies using `X-Real-IP` header
+
+`X-Real-IP` is another HTTP header to relay clients' IP addresses, but it carries only one address unlike XFF.
+
+If your proxies set this header, use `ExtractIPFromRealIPHeader(...TrustOption)`.
+
+```go
+e.IPExtractor = echo.ExtractIPFromRealIPHeader()
+```
+
+Again, it trusts internal IP addresses by default (loopback, link-local unicast, private-use and unique local address
+from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and
+[RFC4193](https://tools.ietf.org/html/rfc4193)).
+To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s.
+
+- Ref: https://godoc.org/github.com/labstack/echo#TrustOption
+
+> **Never forget** to configure the outermost proxy (i.e.; at the edge of your infrastructure) **not to pass through incoming headers**.
+> Otherwise there is a chance of fraud, as it is what clients can control.
+
+## About default behavior
+
+In default behavior, Echo sees all of first XFF header, X-Real-IP header and IP from network layer.
+
+As you might already notice, after reading this article, this is not good.
+Sole reason this is default is just backward compatibility.
+
+## Private IP ranges
+
+See: https://en.wikipedia.org/wiki/Private_network
+
+Private IPv4 address ranges (RFC 1918):
+* 10.0.0.0 – 10.255.255.255 (24-bit block)
+* 172.16.0.0 – 172.31.255.255 (20-bit block)
+* 192.168.0.0 – 192.168.255.255 (16-bit block)
+
+Private IPv6 address ranges:
+* fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA)
+
+*/
+
+type ipChecker struct {
+ trustExtraRanges []*net.IPNet
+ trustLoopback bool
+ trustLinkLocal bool
+ trustPrivateNet bool
+}
+
+// TrustOption is config for which IP address to trust
+type TrustOption func(*ipChecker)
+
+// TrustLoopback configures if you trust loopback address (default: true).
+func TrustLoopback(v bool) TrustOption {
+ return func(c *ipChecker) {
+ c.trustLoopback = v
+ }
+}
+
+// TrustLinkLocal configures if you trust link-local address (default: true).
+func TrustLinkLocal(v bool) TrustOption {
+ return func(c *ipChecker) {
+ c.trustLinkLocal = v
+ }
+}
+
+// TrustPrivateNet configures if you trust private network address (default: true).
+func TrustPrivateNet(v bool) TrustOption {
+ return func(c *ipChecker) {
+ c.trustPrivateNet = v
+ }
+}
+
+// TrustIPRange add trustable IP ranges using CIDR notation.
+func TrustIPRange(ipRange *net.IPNet) TrustOption {
+ return func(c *ipChecker) {
+ c.trustExtraRanges = append(c.trustExtraRanges, ipRange)
+ }
+}
+
+func newIPChecker(configs []TrustOption) *ipChecker {
+ checker := &ipChecker{trustLoopback: true, trustLinkLocal: true, trustPrivateNet: true}
+ for _, configure := range configs {
+ configure(checker)
+ }
+ return checker
+}
+
+func (c *ipChecker) trust(ip net.IP) bool {
+ if c.trustLoopback && ip.IsLoopback() {
+ return true
+ }
+ if c.trustLinkLocal && ip.IsLinkLocalUnicast() {
+ return true
+ }
+ if c.trustPrivateNet && ip.IsPrivate() {
+ return true
+ }
+ for _, trustedRange := range c.trustExtraRanges {
+ if trustedRange.Contains(ip) {
+ return true
+ }
+ }
+ return false
+}
+
+// IPExtractor is a function to extract IP addr from http.Request.
+// Set appropriate one to Echo#IPExtractor.
+// See https://echo.labstack.com/guide/ip-address for more details.
+type IPExtractor func(*http.Request) string
+
+// ExtractIPDirect extracts IP address using actual IP address.
+// Use this if your server faces to internet directory (i.e.: uses no proxy).
+func ExtractIPDirect() IPExtractor {
+ return extractIP
+}
+
+func extractIP(req *http.Request) string {
+ host, _, err := net.SplitHostPort(req.RemoteAddr)
+ if err != nil {
+ if net.ParseIP(req.RemoteAddr) != nil {
+ return req.RemoteAddr
+ }
+ return ""
+ }
+ return host
+}
+
+// ExtractIPFromRealIPHeader extracts IP address using x-real-ip header.
+// Use this if you put proxy which uses this header.
+func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor {
+ checker := newIPChecker(options)
+ return func(req *http.Request) string {
+ realIP := req.Header.Get(HeaderXRealIP)
+ if realIP != "" {
+ realIP = strings.TrimPrefix(realIP, "[")
+ realIP = strings.TrimSuffix(realIP, "]")
+ if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) {
+ return realIP
+ }
+ }
+ return extractIP(req)
+ }
+}
+
+// ExtractIPFromXFFHeader extracts IP address using x-forwarded-for header.
+// Use this if you put proxy which uses this header.
+// This returns nearest untrustable IP. If all IPs are trustable, returns furthest one (i.e.: XFF[0]).
+func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor {
+ checker := newIPChecker(options)
+ return func(req *http.Request) string {
+ directIP := extractIP(req)
+ xffs := req.Header[HeaderXForwardedFor]
+ if len(xffs) == 0 {
+ return directIP
+ }
+ ips := append(strings.Split(strings.Join(xffs, ","), ","), directIP)
+ for i := len(ips) - 1; i >= 0; i-- {
+ ips[i] = strings.TrimSpace(ips[i])
+ ips[i] = strings.TrimPrefix(ips[i], "[")
+ ips[i] = strings.TrimSuffix(ips[i], "]")
+ ip := net.ParseIP(ips[i])
+ if ip == nil {
+ // Unable to parse IP; cannot trust entire records
+ return directIP
+ }
+ if !checker.trust(ip) {
+ return ip.String()
+ }
+ }
+ // All of the IPs are trusted; return first element because it is furthest from server (best effort strategy).
+ return strings.TrimSpace(ips[0])
+ }
+}
diff --git a/ip_test.go b/ip_test.go
new file mode 100644
index 000000000..29bf6afde
--- /dev/null
+++ b/ip_test.go
@@ -0,0 +1,716 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "net"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func mustParseCIDR(s string) *net.IPNet {
+ _, IPNet, err := net.ParseCIDR(s)
+ if err != nil {
+ panic(err)
+ }
+ return IPNet
+}
+
+func TestIPChecker_TrustOption(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenIP string
+ givenOptions []TrustOption
+ expect bool
+ }{
+ {
+ name: "ip is within trust range, trusts additional private IPV6 network",
+ givenOptions: []TrustOption{
+ TrustLoopback(false),
+ TrustLinkLocal(false),
+ TrustPrivateNet(false),
+ // this is private IPv6 ip
+ // CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48
+ // Address: 2001:0db8:0000:0000:0000:0000:0000:0103
+ // Range start: 2001:0db8:0000:0000:0000:0000:0000:0000
+ // Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff
+ TrustIPRange(mustParseCIDR("2001:db8::103/48")),
+ },
+ whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103",
+ expect: true,
+ },
+ {
+ name: "ip is within trust range, trusts additional private IPV6 network",
+ givenOptions: []TrustOption{
+ TrustIPRange(mustParseCIDR("2001:db8::103/48")),
+ },
+ whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103",
+ expect: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ checker := newIPChecker(tc.givenOptions)
+
+ result := checker.trust(net.ParseIP(tc.whenIP))
+ assert.Equal(t, tc.expect, result)
+ })
+ }
+}
+
+func TestTrustIPRange(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenRange string
+ whenIP string
+ expect bool
+ }{
+ {
+ name: "ip is within trust range, IPV6 network range",
+ // CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48
+ // Address: 2001:0db8:0000:0000:0000:0000:0000:0103
+ // Range start: 2001:0db8:0000:0000:0000:0000:0000:0000
+ // Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff
+ givenRange: "2001:db8::103/48",
+ whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103",
+ expect: true,
+ },
+ {
+ name: "ip is outside (upper bounds) of trust range, IPV6 network range",
+ givenRange: "2001:db8::103/48",
+ whenIP: "2001:0db8:0001:0000:0000:0000:0000:0000",
+ expect: false,
+ },
+ {
+ name: "ip is outside (lower bounds) of trust range, IPV6 network range",
+ givenRange: "2001:db8::103/48",
+ whenIP: "2001:0db7:ffff:ffff:ffff:ffff:ffff:ffff",
+ expect: false,
+ },
+ {
+ name: "ip is within trust range, IPV4 network range",
+ // CIDR Notation: 8.8.8.8/24
+ // Address: 8.8.8.8
+ // Range start: 8.8.8.0
+ // Range end: 8.8.8.255
+ givenRange: "8.8.8.0/24",
+ whenIP: "8.8.8.8",
+ expect: true,
+ },
+ {
+ name: "ip is within trust range, IPV4 network range",
+ // CIDR Notation: 8.8.8.8/24
+ // Address: 8.8.8.8
+ // Range start: 8.8.8.0
+ // Range end: 8.8.8.255
+ givenRange: "8.8.8.0/24",
+ whenIP: "8.8.8.8",
+ expect: true,
+ },
+ {
+ name: "ip is outside (upper bounds) of trust range, IPV4 network range",
+ givenRange: "8.8.8.0/24",
+ whenIP: "8.8.9.0",
+ expect: false,
+ },
+ {
+ name: "ip is outside (lower bounds) of trust range, IPV4 network range",
+ givenRange: "8.8.8.0/24",
+ whenIP: "8.8.7.255",
+ expect: false,
+ },
+ {
+ name: "public ip, trust everything in IPV4 network range",
+ givenRange: "0.0.0.0/0",
+ whenIP: "8.8.8.8",
+ expect: true,
+ },
+ {
+ name: "internal ip, trust everything in IPV4 network range",
+ givenRange: "0.0.0.0/0",
+ whenIP: "127.0.10.1",
+ expect: true,
+ },
+ {
+ name: "public ip, trust everything in IPV6 network range",
+ givenRange: "::/0",
+ whenIP: "2a00:1450:4026:805::200e",
+ expect: true,
+ },
+ {
+ name: "internal ip, trust everything in IPV6 network range",
+ givenRange: "::/0",
+ whenIP: "0:0:0:0:0:0:0:1",
+ expect: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ cidr := mustParseCIDR(tc.givenRange)
+ checker := newIPChecker([]TrustOption{
+ TrustLoopback(false), // disable to avoid interference
+ TrustLinkLocal(false), // disable to avoid interference
+ TrustPrivateNet(false), // disable to avoid interference
+
+ TrustIPRange(cidr),
+ })
+
+ result := checker.trust(net.ParseIP(tc.whenIP))
+ assert.Equal(t, tc.expect, result)
+ })
+ }
+}
+
+func TestTrustPrivateNet(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenIP string
+ expect bool
+ }{
+ {
+ name: "do not trust public IPv4 address",
+ whenIP: "8.8.8.8",
+ expect: false,
+ },
+ {
+ name: "do not trust public IPv6 address",
+ whenIP: "2a00:1450:4026:805::200e",
+ expect: false,
+ },
+
+ { // Class A: 10.0.0.0 — 10.255.255.255
+ name: "do not trust IPv4 just outside of class A (lower bounds)",
+ whenIP: "9.255.255.255",
+ expect: false,
+ },
+ {
+ name: "do not trust IPv4 just outside of class A (upper bounds)",
+ whenIP: "11.0.0.0",
+ expect: false,
+ },
+ {
+ name: "trust IPv4 of class A (lower bounds)",
+ whenIP: "10.0.0.0",
+ expect: true,
+ },
+ {
+ name: "trust IPv4 of class A (upper bounds)",
+ whenIP: "10.255.255.255",
+ expect: true,
+ },
+
+ { // Class B: 172.16.0.0 — 172.31.255.255
+ name: "do not trust IPv4 just outside of class B (lower bounds)",
+ whenIP: "172.15.255.255",
+ expect: false,
+ },
+ {
+ name: "do not trust IPv4 just outside of class B (upper bounds)",
+ whenIP: "172.32.0.0",
+ expect: false,
+ },
+ {
+ name: "trust IPv4 of class B (lower bounds)",
+ whenIP: "172.16.0.0",
+ expect: true,
+ },
+ {
+ name: "trust IPv4 of class B (upper bounds)",
+ whenIP: "172.31.255.255",
+ expect: true,
+ },
+
+ { // Class C: 192.168.0.0 — 192.168.255.255
+ name: "do not trust IPv4 just outside of class C (lower bounds)",
+ whenIP: "192.167.255.255",
+ expect: false,
+ },
+ {
+ name: "do not trust IPv4 just outside of class C (upper bounds)",
+ whenIP: "192.169.0.0",
+ expect: false,
+ },
+ {
+ name: "trust IPv4 of class C (lower bounds)",
+ whenIP: "192.168.0.0",
+ expect: true,
+ },
+ {
+ name: "trust IPv4 of class C (upper bounds)",
+ whenIP: "192.168.255.255",
+ expect: true,
+ },
+
+ { // fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA)
+ // splits the address block in two equally sized halves, fc00::/8 and fd00::/8.
+ // https://en.wikipedia.org/wiki/Unique_local_address
+ name: "trust IPv6 private address",
+ whenIP: "fdfc:3514:2cb3:4bd5::",
+ expect: true,
+ },
+ {
+ name: "do not trust IPv6 just out of /fd (upper bounds)",
+ whenIP: "/fe00:0000:0000:0000:0000",
+ expect: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ checker := newIPChecker([]TrustOption{
+ TrustLoopback(false), // disable to avoid interference
+ TrustLinkLocal(false), // disable to avoid interference
+
+ TrustPrivateNet(true),
+ })
+
+ result := checker.trust(net.ParseIP(tc.whenIP))
+ assert.Equal(t, tc.expect, result)
+ })
+ }
+}
+
+func TestTrustLinkLocal(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenIP string
+ expect bool
+ }{
+ {
+ name: "trust link local IPv4 address (lower bounds)",
+ whenIP: "169.254.0.0",
+ expect: true,
+ },
+ {
+ name: "trust link local IPv4 address (upper bounds)",
+ whenIP: "169.254.255.255",
+ expect: true,
+ },
+ {
+ name: "do not trust link local IPv4 address (outside of lower bounds)",
+ whenIP: "169.253.255.255",
+ expect: false,
+ },
+ {
+ name: "do not trust link local IPv4 address (outside of upper bounds)",
+ whenIP: "169.255.0.0",
+ expect: false,
+ },
+ {
+ name: "trust link local IPv6 address ",
+ whenIP: "fe80::1",
+ expect: true,
+ },
+ {
+ name: "do not trust link local IPv6 address ",
+ whenIP: "fec0::1",
+ expect: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ checker := newIPChecker([]TrustOption{
+ TrustLoopback(false), // disable to avoid interference
+ TrustPrivateNet(false), // disable to avoid interference
+
+ TrustLinkLocal(true),
+ })
+
+ result := checker.trust(net.ParseIP(tc.whenIP))
+ assert.Equal(t, tc.expect, result)
+ })
+ }
+}
+
+func TestTrustLoopback(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenIP string
+ expect bool
+ }{
+ {
+ name: "trust IPv4 as localhost",
+ whenIP: "127.0.0.1",
+ expect: true,
+ },
+ {
+ name: "trust IPv6 as localhost",
+ whenIP: "::1",
+ expect: true,
+ },
+ {
+ name: "do not trust public ip as localhost",
+ whenIP: "8.8.8.8",
+ expect: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ checker := newIPChecker([]TrustOption{
+ TrustLinkLocal(false), // disable to avoid interference
+ TrustPrivateNet(false), // disable to avoid interference
+
+ TrustLoopback(true),
+ })
+
+ result := checker.trust(net.ParseIP(tc.whenIP))
+ assert.Equal(t, tc.expect, result)
+ })
+ }
+}
+
+func TestExtractIPDirect(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenRequest http.Request
+ expectIP string
+ }{
+ {
+ name: "request has no headers, extracts IP from request remote addr",
+ whenRequest: http.Request{
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "remote addr is IP without port, extracts IP directly",
+ whenRequest: http.Request{
+ RemoteAddr: "203.0.113.1",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "remote addr is IPv6 without port, extracts IP directly",
+ whenRequest: http.Request{
+ RemoteAddr: "2001:db8::1",
+ },
+ expectIP: "2001:db8::1",
+ },
+ {
+ name: "remote addr is IPv6 with port",
+ whenRequest: http.Request{
+ RemoteAddr: "[2001:db8::1]:8080",
+ },
+ expectIP: "2001:db8::1",
+ },
+ {
+ name: "remote addr is invalid, returns empty string",
+ whenRequest: http.Request{
+ RemoteAddr: "invalid-ip-format",
+ },
+ expectIP: "",
+ },
+ {
+ name: "request is from external IP has X-Real-Ip header, extractor still extracts IP from request remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"203.0.113.10"},
+ },
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "request is from internal IP and has Real-IP header, extractor still extracts internal IP from request remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"203.0.113.10"},
+ },
+ RemoteAddr: "127.0.0.1:8080",
+ },
+ expectIP: "127.0.0.1",
+ },
+ {
+ name: "request is from external IP and has XFF + Real-IP header, extractor still extracts external IP from request remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"203.0.113.10"},
+ HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
+ },
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "request is from internal IP and has XFF + Real-IP header, extractor still extracts internal IP from request remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"127.0.0.1"},
+ HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
+ },
+ RemoteAddr: "127.0.0.1:8080",
+ },
+ expectIP: "127.0.0.1",
+ },
+ {
+ name: "request is from external IP and has XFF header, extractor still extracts external IP from request remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
+ },
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "request is from internal IP and has XFF header, extractor still extracts internal IP from request remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"},
+ },
+ RemoteAddr: "127.0.0.1:8080",
+ },
+ expectIP: "127.0.0.1",
+ },
+ {
+ name: "request is from internal IP and has INVALID XFF header, extractor still extracts internal IP from request remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"this.is.broken.lol, 169.254.0.101"},
+ },
+ RemoteAddr: "127.0.0.1:8080",
+ },
+ expectIP: "127.0.0.1",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ extractedIP := ExtractIPDirect()(&tc.whenRequest)
+ assert.Equal(t, tc.expectIP, extractedIP)
+ })
+ }
+}
+
+func TestExtractIPFromRealIPHeader(t *testing.T) {
+ _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24")
+ _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64")
+
+ var testCases = []struct {
+ whenRequest http.Request
+ name string
+ expectIP string
+ givenTrustOptions []TrustOption
+ }{
+ {
+ name: "request has no headers, extracts IP from request remote addr",
+ whenRequest: http.Request{
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "request is from external IP has INVALID external X-Real-Ip header, extract IP from remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"xxx.yyy.zzz.ccc"}, // <-- this is invalid
+ },
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"203.0.113.199"}, // <-- this is untrusted
+ },
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"[2001:db8::113:199]"}, // <-- this is untrusted
+ },
+ RemoteAddr: "[2001:db8::113:1]:8080",
+ },
+ expectIP: "2001:db8::113:1",
+ },
+ {
+ name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
+ givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
+ TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
+ },
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"203.0.113.199"},
+ },
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.199",
+ },
+ {
+ name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
+ givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
+ TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64"
+ },
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"[2001:db8::113:199]"},
+ },
+ RemoteAddr: "[2001:db8::113:1]:8080",
+ },
+ expectIP: "2001:db8::113:199",
+ },
+ {
+ name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
+ givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
+ TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
+ },
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"203.0.113.199"},
+ HeaderXForwardedFor: []string{"203.0.113.198, 203.0.113.197"}, // <-- should not affect anything
+ },
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.199",
+ },
+ {
+ name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header",
+ givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy"
+ TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64"
+ },
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXRealIP: []string{"[2001:db8::113:199]"},
+ HeaderXForwardedFor: []string{"[2001:db8::113:198], [2001:db8::113:197]"}, // <-- should not affect anything
+ },
+ RemoteAddr: "[2001:db8::113:1]:8080",
+ },
+ expectIP: "2001:db8::113:199",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ extractedIP := ExtractIPFromRealIPHeader(tc.givenTrustOptions...)(&tc.whenRequest)
+ assert.Equal(t, tc.expectIP, extractedIP)
+ })
+ }
+}
+
+func TestExtractIPFromXFFHeader(t *testing.T) {
+ _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24")
+ _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64")
+
+ var testCases = []struct {
+ whenRequest http.Request
+ name string
+ expectIP string
+ givenTrustOptions []TrustOption
+ }{
+ {
+ name: "request has no headers, extracts IP from request remote addr",
+ whenRequest: http.Request{
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "request has INVALID external XFF header, extract IP from remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"xxx.yyy.zzz.ccc, 127.0.0.2"}, // <-- this is invalid
+ },
+ RemoteAddr: "127.0.0.1:8080",
+ },
+ expectIP: "127.0.0.1",
+ },
+ {
+ name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"127.0.0.3, 127.0.0.2, 127.0.0.1"},
+ },
+ RemoteAddr: "127.0.0.1:8080",
+ },
+ expectIP: "127.0.0.3",
+ },
+ {
+ name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"[fe80::3], [fe80::2], [fe80::1]"},
+ },
+ RemoteAddr: "[fe80::1]:8080",
+ },
+ expectIP: "fe80::3",
+ },
+ {
+ name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"203.0.113.199"}, // <-- this is untrusted
+ },
+ RemoteAddr: "203.0.113.1:8080",
+ },
+ expectIP: "203.0.113.1",
+ },
+ {
+ name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr",
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"[2001:db8::1]"}, // <-- this is untrusted
+ },
+ RemoteAddr: "[2001:db8::2]:8080",
+ },
+ expectIP: "2001:db8::2",
+ },
+ {
+ name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header",
+ givenTrustOptions: []TrustOption{
+ TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24"
+ },
+ // from request its seems that request has been proxied through 6 servers.
+ // 1) 203.0.1.100 (this is external IP set by 203.0.100.100 which we do not trust - could be spoofed)
+ // 2) 203.0.100.100 (this is outside of our network but set by 203.0.113.199 which we trust to set correct IPs)
+ // 3) 203.0.113.199 (we trust, for example maybe our proxy from some other office)
+ // 4) 192.168.1.100 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products)
+ // 5) 127.0.0.1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing)
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"203.0.1.100, 203.0.100.100, 203.0.113.199, 192.168.1.100"},
+ },
+ RemoteAddr: "127.0.0.1:8080", // IP of proxy upstream of our APP
+ },
+ expectIP: "203.0.100.100", // this is first trusted IP in XFF chain
+ },
+ {
+ name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header",
+ givenTrustOptions: []TrustOption{
+ TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64"
+ },
+ // from request its seems that request has been proxied through 6 servers.
+ // 1) 2001:db8:1::1:100 (this is external IP set by 2001:db8:2::100:100 which we do not trust - could be spoofed)
+ // 2) 2001:db8:2::100:100 (this is outside of our network but set by 2001:db8::113:199 which we trust to set correct IPs)
+ // 3) 2001:db8::113:199 (we trust, for example maybe our proxy from some other office)
+ // 4) fd12:3456:789a:1::1 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products)
+ // 5) fe80::1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing)
+ whenRequest: http.Request{
+ Header: http.Header{
+ HeaderXForwardedFor: []string{"[2001:db8:1::1:100], [2001:db8:2::100:100], [2001:db8::113:199], [fd12:3456:789a:1::1]"},
+ },
+ RemoteAddr: "[fe80::1]:8080", // IP of proxy upstream of our APP
+ },
+ expectIP: "2001:db8:2::100:100", // this is first trusted IP in XFF chain
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ extractedIP := ExtractIPFromXFFHeader(tc.givenTrustOptions...)(&tc.whenRequest)
+ assert.Equal(t, tc.expectIP, extractedIP)
+ })
+ }
+}
diff --git a/json.go b/json.go
new file mode 100644
index 000000000..a969ccb8c
--- /dev/null
+++ b/json.go
@@ -0,0 +1,29 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "encoding/json"
+)
+
+// DefaultJSONSerializer implements JSON encoding using encoding/json.
+type DefaultJSONSerializer struct{}
+
+// Serialize converts an interface into a json and writes it to the response.
+// You can optionally use the indent parameter to produce pretty JSONs.
+func (d DefaultJSONSerializer) Serialize(c *Context, target any, indent string) error {
+ enc := json.NewEncoder(c.Response())
+ if indent != "" {
+ enc.SetIndent("", indent)
+ }
+ return enc.Encode(target)
+}
+
+// Deserialize reads a JSON from a request body and converts it into an interface.
+func (d DefaultJSONSerializer) Deserialize(c *Context, target any) error {
+ if err := json.NewDecoder(c.Request().Body).Decode(target); err != nil {
+ return ErrBadRequest.Wrap(err)
+ }
+ return nil
+}
diff --git a/json_test.go b/json_test.go
new file mode 100644
index 000000000..1804b3e82
--- /dev/null
+++ b/json_test.go
@@ -0,0 +1,100 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package echo
+
+import (
+ "github.com/stretchr/testify/assert"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+// Note this test is deliberately simple as there's not a lot to test.
+// Just need to ensure it writes JSONs. The heavy work is done by the context methods.
+func TestDefaultJSONCodec_Encode(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ // Echo
+ assert.Equal(t, e, c.Echo())
+
+ // Request
+ assert.NotNil(t, c.Request())
+
+ // Response
+ assert.NotNil(t, c.Response())
+
+ //--------
+ // Default JSON encoder
+ //--------
+
+ enc := new(DefaultJSONSerializer)
+
+ err := enc.Serialize(c, user{ID: 1, Name: "Jon Snow"}, "")
+ if assert.NoError(t, err) {
+ assert.Equal(t, userJSON+"\n", rec.Body.String())
+ }
+
+ req = httptest.NewRequest(http.MethodPost, "/", nil)
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+ err = enc.Serialize(c, user{ID: 1, Name: "Jon Snow"}, " ")
+ if assert.NoError(t, err) {
+ assert.Equal(t, userJSONPretty+"\n", rec.Body.String())
+ }
+}
+
+// Note this test is deliberately simple as there's not a lot to test.
+// Just need to ensure it writes JSONs. The heavy work is done by the context methods.
+func TestDefaultJSONCodec_Decode(t *testing.T) {
+ e := New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ // Echo
+ assert.Equal(t, e, c.Echo())
+
+ // Request
+ assert.NotNil(t, c.Request())
+
+ // Response
+ assert.NotNil(t, c.Response())
+
+ //--------
+ // Default JSON encoder
+ //--------
+
+ enc := new(DefaultJSONSerializer)
+
+ var u = user{}
+ err := enc.Deserialize(c, &u)
+ if assert.NoError(t, err) {
+ assert.Equal(t, u, user{ID: 1, Name: "Jon Snow"})
+ }
+
+ var userUnmarshalSyntaxError = user{}
+ req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(invalidContent))
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+ err = enc.Deserialize(c, &userUnmarshalSyntaxError)
+ assert.IsType(t, &HTTPError{}, err)
+ assert.EqualError(t, err, "code=400, message=Bad Request, err=invalid character 'i' looking for beginning of value")
+
+ var userUnmarshalTypeError = struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ }{}
+
+ req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+ err = enc.Deserialize(c, &userUnmarshalTypeError)
+ assert.IsType(t, &HTTPError{}, err)
+ assert.EqualError(t, err, "code=400, message=Bad Request, err=json: cannot unmarshal number into Go struct field .id of type string")
+
+}
diff --git a/middleware/DEVELOPMENT.md b/middleware/DEVELOPMENT.md
new file mode 100644
index 000000000..77cb226dd
--- /dev/null
+++ b/middleware/DEVELOPMENT.md
@@ -0,0 +1,11 @@
+# Development Guidelines for middlewares
+
+## Best practices:
+
+* Do not use `panic` in middleware creator functions in case of invalid configuration.
+* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead
+ because previous middlewares up in call chain could have logic for dealing with returned errors.
+* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they
+ want to create middleware with panics or with returning errors on configuration errors.
+* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same.
+
diff --git a/middleware/auth.go b/middleware/auth.go
deleted file mode 100644
index 1dc5c4ea9..000000000
--- a/middleware/auth.go
+++ /dev/null
@@ -1,51 +0,0 @@
-package middleware
-
-import (
- "encoding/base64"
- "net/http"
-
- "github.com/labstack/echo"
-)
-
-type (
- BasicValidateFunc func(string, string) bool
-)
-
-const (
- Basic = "Basic"
-)
-
-// BasicAuth returns an HTTP basic authentication middleware.
-//
-// For valid credentials it calls the next handler.
-// For invalid Authorization header it sends "404 - Bad Request" response.
-// For invalid credentials, it sends "401 - Unauthorized" response.
-func BasicAuth(fn BasicValidateFunc) echo.HandlerFunc {
- return func(c *echo.Context) error {
- // Skip WebSocket
- if (c.Request().Header.Get(echo.Upgrade)) == echo.WebSocket {
- return nil
- }
-
- auth := c.Request().Header.Get(echo.Authorization)
- l := len(Basic)
- he := echo.NewHTTPError(http.StatusBadRequest)
-
- if len(auth) > l+1 && auth[:l] == Basic {
- b, err := base64.StdEncoding.DecodeString(auth[l+1:])
- if err == nil {
- cred := string(b)
- for i := 0; i < len(cred); i++ {
- if cred[i] == ':' {
- // Verify credentials
- if fn(cred[:i], cred[i+1:]) {
- return nil
- }
- he.SetCode(http.StatusUnauthorized)
- }
- }
- }
- }
- return he
- }
-}
diff --git a/middleware/auth_test.go b/middleware/auth_test.go
deleted file mode 100644
index c953d9278..000000000
--- a/middleware/auth_test.go
+++ /dev/null
@@ -1,54 +0,0 @@
-package middleware
-
-import (
- "encoding/base64"
- "net/http"
- "net/http/httptest"
- "testing"
-
- "github.com/labstack/echo"
- "github.com/stretchr/testify/assert"
-)
-
-func TestBasicAuth(t *testing.T) {
- req, _ := http.NewRequest(echo.GET, "/", nil)
- rec := httptest.NewRecorder()
- c := echo.NewContext(req, echo.NewResponse(rec), echo.New())
- fn := func(u, p string) bool {
- if u == "joe" && p == "secret" {
- return true
- }
- return false
- }
- ba := BasicAuth(fn)
-
- // Valid credentials
- auth := Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
- req.Header.Set(echo.Authorization, auth)
- assert.NoError(t, ba(c))
-
- //---------------------
- // Invalid credentials
- //---------------------
-
- // Incorrect password
- auth = Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
- req.Header.Set(echo.Authorization, auth)
- he := ba(c).(*echo.HTTPError)
- assert.Equal(t, http.StatusUnauthorized, he.Code())
-
- // Empty Authorization header
- req.Header.Set(echo.Authorization, "")
- he = ba(c).(*echo.HTTPError)
- assert.Equal(t, http.StatusBadRequest, he.Code())
-
- // Invalid Authorization header
- auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
- req.Header.Set(echo.Authorization, auth)
- he = ba(c).(*echo.HTTPError)
- assert.Equal(t, http.StatusBadRequest, he.Code())
-
- // WebSocket
- c.Request().Header.Set(echo.Upgrade, echo.WebSocket)
- assert.NoError(t, ba(c))
-}
diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go
new file mode 100644
index 000000000..e0a284c67
--- /dev/null
+++ b/middleware/basic_auth.go
@@ -0,0 +1,156 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "bytes"
+ "cmp"
+ "encoding/base64"
+ "errors"
+ "strconv"
+ "strings"
+
+ "github.com/labstack/echo/v5"
+)
+
+// BasicAuthConfig defines the config for BasicAuthWithConfig middleware.
+//
+// SECURITY: The Validator function is responsible for securely comparing credentials.
+// See BasicAuthValidator documentation for guidance on preventing timing attacks.
+type BasicAuthConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers
+ // this function would be called once for each header until first valid result is returned
+ // Required.
+ Validator BasicAuthValidator
+
+ // Realm is a string to define realm attribute of BasicAuthWithConfig.
+ // Default value "Restricted".
+ Realm string
+
+ // AllowedCheckLimit set how many headers are allowed to be checked. This is useful
+ // environments like corporate test environments with application proxies restricting
+ // access to environment with their own auth scheme.
+ // Defaults to 1.
+ AllowedCheckLimit uint
+}
+
+// BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials.
+//
+// SECURITY WARNING: To prevent timing attacks that could allow attackers to enumerate
+// valid usernames or passwords, validator implementations MUST use constant-time
+// comparison for credential checking. Use crypto/subtle.ConstantTimeCompare instead
+// of standard string equality (==) or switch statements.
+//
+// Example of SECURE implementation:
+//
+// import "crypto/subtle"
+//
+// validator := func(c *echo.Context, username, password string) (bool, error) {
+// // Fetch expected credentials from database/config
+// expectedUser := "admin"
+// expectedPass := "secretpassword"
+//
+// // Use constant-time comparison to prevent timing attacks
+// userMatch := subtle.ConstantTimeCompare([]byte(username), []byte(expectedUser)) == 1
+// passMatch := subtle.ConstantTimeCompare([]byte(password), []byte(expectedPass)) == 1
+//
+// if userMatch && passMatch {
+// return true, nil
+// }
+// return false, nil
+// }
+//
+// Example of INSECURE implementation (DO NOT USE):
+//
+// // VULNERABLE TO TIMING ATTACKS - DO NOT USE
+// validator := func(c *echo.Context, username, password string) (bool, error) {
+// if username == "admin" && password == "secret" { // Timing leak!
+// return true, nil
+// }
+// return false, nil
+// }
+type BasicAuthValidator func(c *echo.Context, user string, password string) (bool, error)
+
+const (
+ basic = "basic"
+ defaultRealm = "Restricted"
+)
+
+// BasicAuth returns an BasicAuth middleware.
+//
+// For valid credentials it calls the next handler.
+// For missing or invalid credentials, it sends "401 - Unauthorized" response.
+func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
+ return BasicAuthWithConfig(BasicAuthConfig{Validator: fn})
+}
+
+// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config.
+func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration
+func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Validator == nil {
+ return nil, errors.New("echo basic-auth middleware requires a validator function")
+ }
+ if config.Skipper == nil {
+ config.Skipper = DefaultSkipper
+ }
+ realm := defaultRealm
+ if config.Realm != "" {
+ realm = config.Realm
+ }
+ realm = strconv.Quote(realm)
+ limit := cmp.Or(config.AllowedCheckLimit, 1)
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ var lastError error
+ l := len(basic)
+ i := uint(0)
+ for _, auth := range c.Request().Header[echo.HeaderAuthorization] {
+ if i >= limit {
+ break
+ }
+ if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) {
+ continue
+ }
+ i++
+
+ // Invalid base64 shouldn't be treated as error
+ // instead should be treated as invalid client input
+ b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
+ if errDecode != nil {
+ lastError = echo.ErrBadRequest.Wrap(errDecode)
+ continue
+ }
+ idx := bytes.IndexByte(b, ':')
+ if idx >= 0 {
+ valid, errValidate := config.Validator(c, string(b[:idx]), string(b[idx+1:]))
+ if errValidate != nil {
+ lastError = errValidate
+ } else if valid {
+ return next(c)
+ }
+ }
+ }
+
+ if lastError != nil {
+ return lastError
+ }
+
+ // Need to return `401` for browsers to pop-up login box.
+ c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm)
+ return echo.ErrUnauthorized
+ }
+ }, nil
+}
diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go
new file mode 100644
index 000000000..42386354f
--- /dev/null
+++ b/middleware/basic_auth_test.go
@@ -0,0 +1,240 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "crypto/subtle"
+ "encoding/base64"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestBasicAuth(t *testing.T) {
+ validatorFunc := func(c *echo.Context, u, p string) (bool, error) {
+ // Use constant-time comparison to prevent timing attacks
+ userMatch := subtle.ConstantTimeCompare([]byte(u), []byte("joe")) == 1
+ passMatch := subtle.ConstantTimeCompare([]byte(p), []byte("secret")) == 1
+
+ if userMatch && passMatch {
+ return true, nil
+ }
+
+ // Special case for testing error handling
+ if u == "error" {
+ return false, errors.New(p)
+ }
+
+ return false, nil
+ }
+ defaultConfig := BasicAuthConfig{Validator: validatorFunc}
+
+ var testCases = []struct {
+ name string
+ givenConfig BasicAuthConfig
+ whenAuth []string
+ expectHeader string
+ expectErr string
+ }{
+ {
+ name: "ok",
+ givenConfig: defaultConfig,
+ whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
+ },
+ {
+ name: "ok, multiple",
+ givenConfig: BasicAuthConfig{Validator: validatorFunc, AllowedCheckLimit: 2},
+ whenAuth: []string{
+ "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")),
+ basic + " NOT_BASE64",
+ basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
+ },
+ },
+ {
+ name: "nok, multiple, valid out of limit",
+ givenConfig: BasicAuthConfig{Validator: validatorFunc, AllowedCheckLimit: 1},
+ whenAuth: []string{
+ "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")),
+ basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid_password")),
+ // limit only check first and should not check auth below
+ basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
+ },
+ expectHeader: basic + ` realm="Restricted"`,
+ expectErr: "Unauthorized",
+ },
+ {
+ name: "nok, invalid Authorization header",
+ givenConfig: defaultConfig,
+ whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
+ expectHeader: basic + ` realm="Restricted"`,
+ expectErr: "Unauthorized",
+ },
+ {
+ name: "nok, not base64 Authorization header",
+ givenConfig: defaultConfig,
+ whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"},
+ expectErr: "code=400, message=Bad Request, err=illegal base64 data at input byte 3",
+ },
+ {
+ name: "nok, missing Authorization header",
+ givenConfig: defaultConfig,
+ expectHeader: basic + ` realm="Restricted"`,
+ expectErr: "Unauthorized",
+ },
+ {
+ name: "ok, realm",
+ givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
+ whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
+ },
+ {
+ name: "ok, realm, case-insensitive header scheme",
+ givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
+ whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))},
+ },
+ {
+ name: "nok, realm, invalid Authorization header",
+ givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"},
+ whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
+ expectHeader: basic + ` realm="someRealm"`,
+ expectErr: "Unauthorized",
+ },
+ {
+ name: "nok, validator func returns an error",
+ givenConfig: defaultConfig,
+ whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))},
+ expectErr: "my_error",
+ },
+ {
+ name: "ok, skipped",
+ givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c *echo.Context) bool {
+ return true
+ }},
+ whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ res := httptest.NewRecorder()
+ c := e.NewContext(req, res)
+
+ config := tc.givenConfig
+
+ mw, err := config.ToMiddleware()
+ assert.NoError(t, err)
+
+ h := mw(func(c *echo.Context) error {
+ return c.String(http.StatusTeapot, "test")
+ })
+
+ if len(tc.whenAuth) != 0 {
+ for _, a := range tc.whenAuth {
+ req.Header.Add(echo.HeaderAuthorization, a)
+ }
+ }
+ err = h(c)
+
+ if tc.expectErr != "" {
+ assert.Equal(t, http.StatusOK, res.Code)
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.Equal(t, http.StatusTeapot, res.Code)
+ assert.NoError(t, err)
+ }
+ if tc.expectHeader != "" {
+ assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate))
+ }
+ })
+ }
+}
+
+func TestBasicAuth_panic(t *testing.T) {
+ assert.Panics(t, func() {
+ mw := BasicAuth(nil)
+ assert.NotNil(t, mw)
+ })
+
+ mw := BasicAuth(func(c *echo.Context, user string, password string) (bool, error) {
+ return true, nil
+ })
+ assert.NotNil(t, mw)
+}
+
+func TestBasicAuthWithConfig_panic(t *testing.T) {
+ assert.Panics(t, func() {
+ mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil})
+ assert.NotNil(t, mw)
+ })
+
+ mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c *echo.Context, user string, password string) (bool, error) {
+ return true, nil
+ }})
+ assert.NotNil(t, mw)
+}
+
+func TestBasicAuthRealm(t *testing.T) {
+ e := echo.New()
+ mockValidator := func(c *echo.Context, u, p string) (bool, error) {
+ return false, nil // Always fail to trigger WWW-Authenticate header
+ }
+
+ tests := []struct {
+ name string
+ realm string
+ expectedAuth string
+ }{
+ {
+ name: "Default realm",
+ realm: "Restricted",
+ expectedAuth: `basic realm="Restricted"`,
+ },
+ {
+ name: "Custom realm",
+ realm: "My API",
+ expectedAuth: `basic realm="My API"`,
+ },
+ {
+ name: "Realm with special characters",
+ realm: `Realm with "quotes" and \backslashes`,
+ expectedAuth: `basic realm="Realm with \"quotes\" and \\backslashes"`,
+ },
+ {
+ name: "Empty realm (falls back to default)",
+ realm: "",
+ expectedAuth: `basic realm="Restricted"`,
+ },
+ {
+ name: "Realm with unicode",
+ realm: "测试领域",
+ expectedAuth: `basic realm="测试领域"`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ res := httptest.NewRecorder()
+ c := e.NewContext(req, res)
+
+ h := BasicAuthWithConfig(BasicAuthConfig{
+ Validator: mockValidator,
+ Realm: tt.realm,
+ })(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ err := h(c)
+
+ assert.Equal(t, echo.ErrUnauthorized, err)
+ assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate))
+ })
+ }
+}
diff --git a/middleware/body_dump.go b/middleware/body_dump.go
new file mode 100644
index 000000000..d5c823c9b
--- /dev/null
+++ b/middleware/body_dump.go
@@ -0,0 +1,201 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "io"
+ "net"
+ "net/http"
+ "sync"
+
+ "github.com/labstack/echo/v5"
+)
+
+// BodyDumpConfig defines the config for BodyDump middleware.
+type BodyDumpConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // Handler receives request, response payloads and handler error if there are any.
+ // Required.
+ Handler BodyDumpHandler
+
+ // MaxRequestBytes limits how much of the request body to dump.
+ // If the request body exceeds this limit, only the first MaxRequestBytes
+ // are dumped. The handler callback receives truncated data.
+ // Default: 5 * MB (5,242,880 bytes)
+ // Set to -1 to disable limits (not recommended in production).
+ MaxRequestBytes int64
+
+ // MaxResponseBytes limits how much of the response body to dump.
+ // If the response body exceeds this limit, only the first MaxResponseBytes
+ // are dumped. The handler callback receives truncated data.
+ // Default: 5 * MB (5,242,880 bytes)
+ // Set to -1 to disable limits (not recommended in production).
+ MaxResponseBytes int64
+}
+
+// BodyDumpHandler receives the request and response payload.
+type BodyDumpHandler func(c *echo.Context, reqBody []byte, resBody []byte, err error)
+
+type bodyDumpResponseWriter struct {
+ io.Writer
+ http.ResponseWriter
+}
+
+// BodyDump returns a BodyDump middleware.
+//
+// BodyDump middleware captures the request and response payload and calls the
+// registered handler.
+//
+// SECURITY: By default, this limits dumped bodies to 5MB to prevent memory exhaustion
+// attacks. To customize limits, use BodyDumpWithConfig. To disable limits (not recommended
+// in production), explicitly set MaxRequestBytes and MaxResponseBytes to -1.
+func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc {
+ return BodyDumpWithConfig(BodyDumpConfig{Handler: handler})
+}
+
+// BodyDumpWithConfig returns a BodyDump middleware with config.
+// See: `BodyDump()`.
+//
+// SECURITY: If MaxRequestBytes and MaxResponseBytes are not set (zero values), they default
+// to 5MB each to prevent DoS attacks via large payloads. Set them explicitly to -1 to disable
+// limits if needed for your use case.
+func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration
+func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Handler == nil {
+ return nil, errors.New("echo body-dump middleware requires a handler function")
+ }
+ if config.Skipper == nil {
+ config.Skipper = DefaultSkipper
+ }
+ if config.MaxRequestBytes == 0 {
+ config.MaxRequestBytes = 5 * MB
+ }
+ if config.MaxResponseBytes == 0 {
+ config.MaxResponseBytes = 5 * MB
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ reqBuf := bodyDumpBufferPool.Get().(*bytes.Buffer)
+ reqBuf.Reset()
+ defer bodyDumpBufferPool.Put(reqBuf)
+
+ var bodyReader io.Reader = c.Request().Body
+ if config.MaxRequestBytes > 0 {
+ bodyReader = io.LimitReader(c.Request().Body, config.MaxRequestBytes)
+ }
+ _, readErr := io.Copy(reqBuf, bodyReader)
+ if readErr != nil && readErr != io.EOF {
+ return readErr
+ }
+ if config.MaxRequestBytes > 0 {
+ // Drain any remaining body data to prevent connection issues
+ _, _ = io.Copy(io.Discard, c.Request().Body)
+ _ = c.Request().Body.Close()
+ }
+
+ reqBody := make([]byte, reqBuf.Len())
+ copy(reqBody, reqBuf.Bytes())
+ c.Request().Body = io.NopCloser(bytes.NewReader(reqBody))
+
+ // response part
+ resBuf := bodyDumpBufferPool.Get().(*bytes.Buffer)
+ resBuf.Reset()
+ defer bodyDumpBufferPool.Put(resBuf)
+
+ var respWriter io.Writer
+ if config.MaxResponseBytes > 0 {
+ respWriter = &limitedWriter{
+ response: c.Response(),
+ dumpBuf: resBuf,
+ limit: config.MaxResponseBytes,
+ }
+ } else {
+ respWriter = io.MultiWriter(c.Response(), resBuf)
+ }
+ writer := &bodyDumpResponseWriter{
+ Writer: respWriter,
+ ResponseWriter: c.Response(),
+ }
+ c.SetResponse(writer)
+
+ err := next(c)
+
+ // Callback
+ config.Handler(c, reqBody, resBuf.Bytes(), err)
+
+ return err
+ }
+ }, nil
+}
+
+func (w *bodyDumpResponseWriter) WriteHeader(code int) {
+ w.ResponseWriter.WriteHeader(code)
+}
+
+func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) {
+ return w.Writer.Write(b)
+}
+
+func (w *bodyDumpResponseWriter) Flush() {
+ err := http.NewResponseController(w.ResponseWriter).Flush()
+ if err != nil && errors.Is(err, http.ErrNotSupported) {
+ panic(errors.New("response writer flushing is not supported"))
+ }
+}
+
+func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ return http.NewResponseController(w.ResponseWriter).Hijack()
+}
+
+func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter {
+ return w.ResponseWriter
+}
+
+var bodyDumpBufferPool = sync.Pool{
+ New: func() any {
+ return new(bytes.Buffer)
+ },
+}
+
+type limitedWriter struct {
+ response http.ResponseWriter
+ dumpBuf *bytes.Buffer
+ dumped int64
+ limit int64
+}
+
+func (w *limitedWriter) Write(b []byte) (n int, err error) {
+ // Always write full data to actual response (don't truncate client response)
+ n, err = w.response.Write(b)
+ if err != nil {
+ return n, err
+ }
+
+ // Write to dump buffer only up to limit
+ if w.dumped < w.limit {
+ remaining := w.limit - w.dumped
+ toDump := int64(n)
+ if toDump > remaining {
+ toDump = remaining
+ }
+ w.dumpBuf.Write(b[:toDump])
+ w.dumped += toDump
+ }
+
+ return n, nil
+}
diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go
new file mode 100644
index 000000000..f493e75c8
--- /dev/null
+++ b/middleware/body_dump_test.go
@@ -0,0 +1,581 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestBodyDump(t *testing.T) {
+ e := echo.New()
+ hw := "Hello, World!"
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := func(c *echo.Context) error {
+ body, err := io.ReadAll(c.Request().Body)
+ if err != nil {
+ return err
+ }
+ return c.String(http.StatusOK, string(body))
+ }
+
+ requestBody := ""
+ responseBody := ""
+ mw, err := BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ requestBody = string(reqBody)
+ responseBody = string(resBody)
+ }}.ToMiddleware()
+ assert.NoError(t, err)
+
+ if assert.NoError(t, mw(h)(c)) {
+ assert.Equal(t, requestBody, hw)
+ assert.Equal(t, responseBody, hw)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, hw, rec.Body.String())
+ }
+
+}
+
+func TestBodyDump_skipper(t *testing.T) {
+ e := echo.New()
+
+ isCalled := false
+ mw, err := BodyDumpConfig{
+ Skipper: func(c *echo.Context) bool {
+ return true
+ },
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ isCalled = true
+ },
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}"))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := func(c *echo.Context) error {
+ return errors.New("some error")
+ }
+
+ err = mw(h)(c)
+ assert.EqualError(t, err, "some error")
+ assert.False(t, isCalled)
+}
+
+func TestBodyDump_fails(t *testing.T) {
+ e := echo.New()
+ hw := "Hello, World!"
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := func(c *echo.Context) error {
+ return errors.New("some error")
+ }
+
+ mw, err := BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}}.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.EqualError(t, err, "some error")
+ assert.Equal(t, http.StatusOK, rec.Code)
+
+}
+
+func TestBodyDumpWithConfig_panic(t *testing.T) {
+ assert.Panics(t, func() {
+ mw := BodyDumpWithConfig(BodyDumpConfig{
+ Skipper: nil,
+ Handler: nil,
+ })
+ assert.NotNil(t, mw)
+ })
+
+ assert.NotPanics(t, func() {
+ mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}})
+ assert.NotNil(t, mw)
+ })
+}
+
+func TestBodyDump_panic(t *testing.T) {
+ assert.Panics(t, func() {
+ mw := BodyDump(nil)
+ assert.NotNil(t, mw)
+ })
+
+ assert.NotPanics(t, func() {
+ BodyDump(func(c *echo.Context, reqBody, resBody []byte, err error) {})
+ })
+}
+
+func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) {
+ bdrw := bodyDumpResponseWriter{
+ ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush
+ }
+ assert.PanicsWithError(t, "response writer flushing is not supported", func() {
+ bdrw.Flush()
+ })
+}
+
+func TestBodyDumpResponseWriter_CanFlush(t *testing.T) {
+ trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
+ bdrw := bodyDumpResponseWriter{
+ ResponseWriter: &trwu,
+ }
+ bdrw.Flush()
+ assert.Equal(t, 1, trwu.unwrapCalled)
+}
+
+func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) {
+ trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
+ bdrw := bodyDumpResponseWriter{
+ ResponseWriter: trwu,
+ }
+ result := bdrw.Unwrap()
+ assert.Equal(t, trwu, result)
+}
+
+func TestBodyDumpResponseWriter_CanHijack(t *testing.T) {
+ trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
+ bdrw := bodyDumpResponseWriter{
+ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
+ }
+ _, _, err := bdrw.Hijack()
+ assert.EqualError(t, err, "can hijack")
+}
+
+func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) {
+ trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
+ bdrw := bodyDumpResponseWriter{
+ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
+ }
+ _, _, err := bdrw.Hijack()
+ assert.EqualError(t, err, "feature not supported")
+}
+
+func TestBodyDump_ReadError(t *testing.T) {
+ e := echo.New()
+
+ // Create a reader that fails during read
+ failingReader := &failingReadCloser{
+ data: []byte("partial data"),
+ failAt: 7, // Fail after 7 bytes
+ failWith: errors.New("connection reset"),
+ }
+
+ req := httptest.NewRequest(http.MethodPost, "/", failingReader)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ // This handler should not be reached if body read fails
+ body, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(body))
+ }
+
+ requestBodyReceived := ""
+ mw := BodyDump(func(c *echo.Context, reqBody, resBody []byte, err error) {
+ requestBodyReceived = string(reqBody)
+ })
+
+ err := mw(h)(c)
+
+ // Verify error is propagated
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "connection reset")
+
+ // Verify handler was not executed (callback wouldn't have received data)
+ assert.Empty(t, requestBodyReceived)
+}
+
+// failingReadCloser is a helper type for testing read errors
+type failingReadCloser struct {
+ data []byte
+ pos int
+ failAt int
+ failWith error
+}
+
+func (f *failingReadCloser) Read(p []byte) (n int, err error) {
+ if f.pos >= f.failAt {
+ return 0, f.failWith
+ }
+
+ n = copy(p, f.data[f.pos:])
+ f.pos += n
+
+ if f.pos >= f.failAt {
+ return n, f.failWith
+ }
+
+ return n, nil
+}
+
+func (f *failingReadCloser) Close() error {
+ return nil
+}
+
+func TestBodyDump_RequestWithinLimit(t *testing.T) {
+ e := echo.New()
+ requestData := "Hello, World!"
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ body, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(body))
+ }
+
+ requestBodyDumped := ""
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ requestBodyDumped = string(reqBody)
+ },
+ MaxRequestBytes: 1 * MB, // 1MB limit
+ MaxResponseBytes: 1 * MB,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, requestData, requestBodyDumped, "Small request should be fully dumped")
+ assert.Equal(t, requestData, rec.Body.String(), "Handler should receive full request")
+}
+
+func TestBodyDump_RequestExceedsLimit(t *testing.T) {
+ e := echo.New()
+ // Create 2KB of data but limit to 1KB
+ largeData := strings.Repeat("A", 2*1024)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeData))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ body, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(body))
+ }
+
+ requestBodyDumped := ""
+ limit := int64(1024) // 1KB limit
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ requestBodyDumped = string(reqBody)
+ },
+ MaxRequestBytes: limit,
+ MaxResponseBytes: 1 * MB,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, int(limit), len(requestBodyDumped), "Dumped request should be truncated to limit")
+ assert.Equal(t, strings.Repeat("A", 1024), requestBodyDumped, "Dumped data should match first N bytes")
+ // Handler should receive truncated data (what was dumped)
+ assert.Equal(t, strings.Repeat("A", 1024), rec.Body.String())
+}
+
+func TestBodyDump_RequestAtExactLimit(t *testing.T) {
+ e := echo.New()
+ exactData := strings.Repeat("B", 1024)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(exactData))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ body, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(body))
+ }
+
+ requestBodyDumped := ""
+ limit := int64(1024)
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ requestBodyDumped = string(reqBody)
+ },
+ MaxRequestBytes: limit,
+ MaxResponseBytes: 1 * MB,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, int(limit), len(requestBodyDumped), "Exact limit should dump full data")
+ assert.Equal(t, exactData, requestBodyDumped)
+}
+
+func TestBodyDump_ResponseWithinLimit(t *testing.T) {
+ e := echo.New()
+ responseData := "Response data"
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ return c.String(http.StatusOK, responseData)
+ }
+
+ responseBodyDumped := ""
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ responseBodyDumped = string(resBody)
+ },
+ MaxRequestBytes: 1 * MB,
+ MaxResponseBytes: 1 * MB,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, responseData, responseBodyDumped, "Small response should be fully dumped")
+ assert.Equal(t, responseData, rec.Body.String(), "Client should receive full response")
+}
+
+func TestBodyDump_ResponseExceedsLimit(t *testing.T) {
+ e := echo.New()
+ largeResponse := strings.Repeat("X", 2*1024) // 2KB
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ return c.String(http.StatusOK, largeResponse)
+ }
+
+ responseBodyDumped := ""
+ limit := int64(1024) // 1KB limit
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ responseBodyDumped = string(resBody)
+ },
+ MaxRequestBytes: 1 * MB,
+ MaxResponseBytes: limit,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ // Dump should be truncated
+ assert.Equal(t, int(limit), len(responseBodyDumped), "Dumped response should be truncated to limit")
+ assert.Equal(t, strings.Repeat("X", 1024), responseBodyDumped)
+ // Client should still receive full response!
+ assert.Equal(t, largeResponse, rec.Body.String(), "Client must receive full response despite dump truncation")
+}
+
+func TestBodyDump_ClientGetsFullResponse(t *testing.T) {
+ e := echo.New()
+ // This is critical - even when dump is limited, client gets everything
+ largeResponse := strings.Repeat("DATA", 500) // 2KB
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ // Write response in chunks to test incremental writes
+ for i := 0; i < 4; i++ {
+ c.Response().Write([]byte(strings.Repeat("DATA", 125)))
+ }
+ return nil
+ }
+
+ responseBodyDumped := ""
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ responseBodyDumped = string(resBody)
+ },
+ MaxRequestBytes: 1 * MB,
+ MaxResponseBytes: 512, // Very small limit
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, 512, len(responseBodyDumped), "Dump should be limited")
+ assert.Equal(t, largeResponse, rec.Body.String(), "Client must get full response")
+}
+
+func TestBodyDump_BothLimitsSimultaneous(t *testing.T) {
+ e := echo.New()
+ largeRequest := strings.Repeat("Q", 2*1024)
+ largeResponse := strings.Repeat("R", 2*1024)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeRequest))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ io.ReadAll(c.Request().Body) // Consume request
+ return c.String(http.StatusOK, largeResponse)
+ }
+
+ requestBodyDumped := ""
+ responseBodyDumped := ""
+ limit := int64(1024)
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ requestBodyDumped = string(reqBody)
+ responseBodyDumped = string(resBody)
+ },
+ MaxRequestBytes: limit,
+ MaxResponseBytes: limit,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, int(limit), len(requestBodyDumped), "Request dump should be limited")
+ assert.Equal(t, int(limit), len(responseBodyDumped), "Response dump should be limited")
+}
+
+func TestBodyDump_DefaultConfig(t *testing.T) {
+ e := echo.New()
+ smallData := "test"
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(smallData))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ body, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(body))
+ }
+
+ requestBodyDumped := ""
+ // Use default config which should have 1MB limits
+ config := BodyDumpConfig{}
+ config.Handler = func(c *echo.Context, reqBody, resBody []byte, err error) {
+ requestBodyDumped = string(reqBody)
+ }
+ mw, err := config.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, smallData, requestBodyDumped)
+}
+
+func TestBodyDump_LargeRequestDosPrevention(t *testing.T) {
+ e := echo.New()
+ // Simulate a very large request (10MB) that could cause OOM
+ largeSize := 10 * 1024 * 1024 // 10MB
+ largeData := strings.Repeat("M", largeSize)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeData))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ body, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(body))
+ }
+
+ requestBodyDumped := ""
+ limit := int64(1 * MB) // Only dump 1MB max
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ requestBodyDumped = string(reqBody)
+ },
+ MaxRequestBytes: limit,
+ MaxResponseBytes: 1 * MB,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ // Verify only 1MB was dumped, not 10MB
+ assert.Equal(t, int(limit), len(requestBodyDumped), "Should only dump up to limit, preventing DoS")
+ assert.Less(t, len(requestBodyDumped), largeSize, "Dump should be much smaller than full request")
+}
+
+func TestBodyDump_LargeResponseDosPrevention(t *testing.T) {
+ e := echo.New()
+ // Simulate a very large response (10MB)
+ largeSize := 10 * 1024 * 1024 // 10MB
+ largeResponse := strings.Repeat("R", largeSize)
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h := func(c *echo.Context) error {
+ return c.String(http.StatusOK, largeResponse)
+ }
+
+ responseBodyDumped := ""
+ limit := int64(1 * MB) // Only dump 1MB max
+ mw, err := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ responseBodyDumped = string(resBody)
+ },
+ MaxRequestBytes: 1 * MB,
+ MaxResponseBytes: limit,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ // Verify only 1MB was dumped, not 10MB
+ assert.Equal(t, int(limit), len(responseBodyDumped), "Should only dump up to limit, preventing DoS")
+ assert.Less(t, len(responseBodyDumped), largeSize, "Dump should be much smaller than full response")
+ // Client still gets full response
+ assert.Equal(t, largeSize, rec.Body.Len(), "Client must receive full response")
+}
+
+func BenchmarkBodyDump_WithLimit(b *testing.B) {
+ e := echo.New()
+ requestData := strings.Repeat("data", 256) // 1KB
+ responseData := strings.Repeat("resp", 256) // 1KB
+
+ h := func(c *echo.Context) error {
+ io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, responseData)
+ }
+
+ mw, _ := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {
+ // Simulate logging
+ _ = len(reqBody) + len(resBody)
+ },
+ MaxRequestBytes: 1 * MB,
+ MaxResponseBytes: 1 * MB,
+ }.ToMiddleware()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ mw(h)(c)
+ }
+}
+
+func BenchmarkBodyDump_BufferPooling(b *testing.B) {
+ e := echo.New()
+ requestData := strings.Repeat("x", 1024)
+ responseData := "response"
+
+ h := func(c *echo.Context) error {
+ io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, responseData)
+ }
+
+ mw, _ := BodyDumpConfig{
+ Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {},
+ MaxRequestBytes: 1 * MB,
+ MaxResponseBytes: 1 * MB,
+ }.ToMiddleware()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ mw(h)(c)
+ }
+}
diff --git a/middleware/body_limit.go b/middleware/body_limit.go
new file mode 100644
index 000000000..4f1963e18
--- /dev/null
+++ b/middleware/body_limit.go
@@ -0,0 +1,99 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "io"
+ "net/http"
+ "sync"
+
+ "github.com/labstack/echo/v5"
+)
+
+// BodyLimitConfig defines the config for BodyLimitWithConfig middleware.
+type BodyLimitConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // LimitBytes is maximum allowed size in bytes for a request body
+ LimitBytes int64
+}
+
+type limitedReader struct {
+ BodyLimitConfig
+ reader io.ReadCloser
+ read int64
+}
+
+// BodyLimit returns a BodyLimit middleware.
+//
+// BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it
+// sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request
+// header and actual content read, which makes it super secure.
+func BodyLimit(limitBytes int64) echo.MiddlewareFunc {
+ return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes})
+}
+
+// BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for
+// a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response.
+// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which
+// makes it super secure.
+func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration
+func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Skipper == nil {
+ config.Skipper = DefaultSkipper
+ }
+ pool := sync.Pool{
+ New: func() any {
+ return &limitedReader{BodyLimitConfig: config}
+ },
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+ req := c.Request()
+
+ // Based on content length
+ if req.ContentLength > config.LimitBytes {
+ return echo.ErrStatusRequestEntityTooLarge
+ }
+
+ // Based on content read
+ r, ok := pool.Get().(*limitedReader)
+ if !ok {
+ return echo.NewHTTPError(http.StatusInternalServerError, "invalid pool object")
+ }
+ r.Reset(req.Body)
+ defer pool.Put(r)
+ req.Body = r
+
+ return next(c)
+ }
+ }, nil
+}
+
+func (r *limitedReader) Read(b []byte) (n int, err error) {
+ n, err = r.reader.Read(b)
+ r.read += int64(n)
+ if r.read > r.LimitBytes {
+ return n, echo.ErrStatusRequestEntityTooLarge
+ }
+ return
+}
+
+func (r *limitedReader) Close() error {
+ return r.reader.Close()
+}
+
+func (r *limitedReader) Reset(reader io.ReadCloser) {
+ r.reader = reader
+ r.read = 0
+}
diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go
new file mode 100644
index 000000000..5529f5d84
--- /dev/null
+++ b/middleware/body_limit_test.go
@@ -0,0 +1,166 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "bytes"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestBodyLimitConfig_ToMiddleware(t *testing.T) {
+ e := echo.New()
+ hw := []byte("Hello, World!")
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := func(c *echo.Context) error {
+ body, err := io.ReadAll(c.Request().Body)
+ if err != nil {
+ return err
+ }
+ return c.String(http.StatusOK, string(body))
+ }
+
+ // Based on content length (within limit)
+ mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(h)(c)
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, hw, rec.Body.Bytes())
+ }
+
+ // Based on content read (overlimit)
+ mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
+ assert.NoError(t, err)
+ he := mw(h)(c).(echo.HTTPStatusCoder)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
+
+ // Based on content read (within limit)
+ req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ req.ContentLength = -1
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+
+ mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware()
+ assert.NoError(t, err)
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "Hello, World!", rec.Body.String())
+
+ // Based on content read (overlimit)
+ req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ req.ContentLength = -1
+ rec = httptest.NewRecorder()
+ c = e.NewContext(req, rec)
+ mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware()
+ assert.NoError(t, err)
+ he = mw(h)(c).(echo.HTTPStatusCoder)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
+}
+
+func TestBodyLimitReader(t *testing.T) {
+ hw := []byte("Hello, World!")
+
+ config := BodyLimitConfig{
+ Skipper: DefaultSkipper,
+ LimitBytes: 2,
+ }
+ reader := &limitedReader{
+ BodyLimitConfig: config,
+ reader: io.NopCloser(bytes.NewReader(hw)),
+ }
+
+ // read all should return ErrStatusRequestEntityTooLarge
+ _, err := io.ReadAll(reader)
+ he := err.(echo.HTTPStatusCoder)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
+
+ // reset reader and read two bytes must succeed
+ bt := make([]byte, 2)
+ reader.Reset(io.NopCloser(bytes.NewReader(hw)))
+ n, err := reader.Read(bt)
+ assert.Equal(t, 2, n)
+ assert.Equal(t, nil, err)
+}
+
+func TestBodyLimit_skipper(t *testing.T) {
+ e := echo.New()
+ h := func(c *echo.Context) error {
+ body, err := io.ReadAll(c.Request().Body)
+ if err != nil {
+ return err
+ }
+ return c.String(http.StatusOK, string(body))
+ }
+ mw, err := BodyLimitConfig{
+ Skipper: func(c *echo.Context) bool {
+ return true
+ },
+ LimitBytes: 2,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ hw := []byte("Hello, World!")
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err = mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, hw, rec.Body.Bytes())
+}
+
+func TestBodyLimitWithConfig(t *testing.T) {
+ e := echo.New()
+ hw := []byte("Hello, World!")
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := func(c *echo.Context) error {
+ body, err := io.ReadAll(c.Request().Body)
+ if err != nil {
+ return err
+ }
+ return c.String(http.StatusOK, string(body))
+ }
+
+ mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB})
+
+ err := mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, hw, rec.Body.Bytes())
+}
+
+func TestBodyLimit(t *testing.T) {
+ e := echo.New()
+ hw := []byte("Hello, World!")
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := func(c *echo.Context) error {
+ body, err := io.ReadAll(c.Request().Body)
+ if err != nil {
+ return err
+ }
+ return c.String(http.StatusOK, string(body))
+ }
+
+ mw := BodyLimit(2 * MB)
+
+ err := mw(h)(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, hw, rec.Body.Bytes())
+}
diff --git a/middleware/compress.go b/middleware/compress.go
index eba2c7c4e..7754d5db8 100644
--- a/middleware/compress.go
+++ b/middleware/compress.go
@@ -1,61 +1,235 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
"bufio"
+ "bytes"
"compress/gzip"
+ "errors"
"io"
"net"
"net/http"
"strings"
+ "sync"
- "github.com/labstack/echo"
+ "github.com/labstack/echo/v5"
)
-type (
- gzipWriter struct {
- io.Writer
- http.ResponseWriter
- }
+const (
+ gzipScheme = "gzip"
)
-func (w gzipWriter) Write(b []byte) (int, error) {
- if w.Header().Get(echo.ContentType) == "" {
- w.Header().Set(echo.ContentType, http.DetectContentType(b))
- }
- return w.Writer.Write(b)
+// GzipConfig defines the config for Gzip middleware.
+type GzipConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // Gzip compression level.
+ // Optional. Default value -1.
+ Level int
+
+ // Length threshold before gzip compression is applied.
+ // Optional. Default value 0.
+ //
+ // Most of the time you will not need to change the default. Compressing
+ // a short response might increase the transmitted data because of the
+ // gzip format overhead. Compressing the response will also consume CPU
+ // and time on the server and the client (for decompressing). Depending on
+ // your use case such a threshold might be useful.
+ //
+ // See also:
+ // https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits
+ MinLength int
}
-func (w gzipWriter) Flush() error {
- return w.Writer.(*gzip.Writer).Flush()
+type gzipResponseWriter struct {
+ io.Writer
+ http.ResponseWriter
+ wroteHeader bool
+ wroteBody bool
+ minLength int
+ minLengthExceeded bool
+ buffer *bytes.Buffer
+ code int
}
-func (w gzipWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
- return w.ResponseWriter.(http.Hijacker).Hijack()
+// Gzip returns a middleware which compresses HTTP response using gzip compression scheme.
+func Gzip() echo.MiddlewareFunc {
+ return GzipWithConfig(GzipConfig{})
}
-func (w *gzipWriter) CloseNotify() <-chan bool {
- return w.ResponseWriter.(http.CloseNotifier).CloseNotify()
+// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme.
+func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
}
-// Gzip returns a middleware which compresses HTTP response using gzip compression
-// scheme.
-func Gzip() echo.MiddlewareFunc {
- scheme := "gzip"
+// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration
+func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Skipper == nil {
+ config.Skipper = DefaultSkipper
+ }
+ if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression
+ return nil, errors.New("invalid gzip level")
+ }
+ if config.Level == 0 {
+ config.Level = -1
+ }
+ if config.MinLength < 0 {
+ config.MinLength = 0
+ }
+
+ pool := gzipCompressPool(config)
+ bpool := bufferPool()
- return func(h echo.HandlerFunc) echo.HandlerFunc {
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
- c.Response().Header().Add(echo.Vary, echo.AcceptEncoding)
- if strings.Contains(c.Request().Header.Get(echo.AcceptEncoding), scheme) {
- w := gzip.NewWriter(c.Response().Writer())
- defer w.Close()
- gw := gzipWriter{Writer: w, ResponseWriter: c.Response().Writer()}
- c.Response().Header().Set(echo.ContentEncoding, scheme)
- c.Response().SetWriter(gw)
+ if config.Skipper(c) {
+ return next(c)
}
- if err := h(c); err != nil {
- c.Error(err)
+
+ res := c.Response()
+ res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
+ if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
+ i := pool.Get()
+ w, ok := i.(*gzip.Writer)
+ if !ok {
+ return echo.NewHTTPError(http.StatusInternalServerError, "invalid pool object")
+ }
+ rw := res
+ w.Reset(rw)
+ buf := bpool.Get().(*bytes.Buffer)
+ buf.Reset()
+
+ grw := &gzipResponseWriter{
+ Writer: w,
+ ResponseWriter: rw,
+ minLength: config.MinLength,
+ buffer: buf,
+ }
+ c.SetResponse(grw)
+ defer func() {
+ // There are different reasons for cases when we have not yet written response to the client and now need to do so.
+ // a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now.
+ // b) body is shorter than our minimum length threshold and being buffered currently and needs to be written
+ if !grw.wroteBody {
+ if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
+ res.Header().Del(echo.HeaderContentEncoding)
+ }
+ if grw.wroteHeader {
+ rw.WriteHeader(grw.code)
+ }
+ // We have to reset response to it's pristine state when
+ // nothing is written to body or error is returned.
+ // See issue #424, #407.
+ c.SetResponse(rw)
+ w.Reset(io.Discard)
+ } else if !grw.minLengthExceeded {
+ // Write uncompressed response
+ c.SetResponse(rw)
+ if grw.wroteHeader {
+ grw.ResponseWriter.WriteHeader(grw.code)
+ }
+ _, _ = grw.buffer.WriteTo(rw)
+ w.Reset(io.Discard)
+ }
+ _ = w.Close()
+ bpool.Put(buf)
+ pool.Put(w)
+ }()
+ }
+ return next(c)
+ }
+ }, nil
+}
+
+func (w *gzipResponseWriter) WriteHeader(code int) {
+ w.Header().Del(echo.HeaderContentLength) // Issue #444
+
+ w.wroteHeader = true
+
+ // Delay writing of the header until we know if we'll actually compress the response
+ w.code = code
+}
+
+func (w *gzipResponseWriter) Write(b []byte) (int, error) {
+ if w.Header().Get(echo.HeaderContentType) == "" {
+ w.Header().Set(echo.HeaderContentType, http.DetectContentType(b))
+ }
+ w.wroteBody = true
+
+ if !w.minLengthExceeded {
+ n, err := w.buffer.Write(b)
+
+ if w.buffer.Len() >= w.minLength {
+ w.minLengthExceeded = true
+
+ // The minimum length is exceeded, add Content-Encoding header and write the header
+ w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
+ if w.wroteHeader {
+ w.ResponseWriter.WriteHeader(w.code)
}
- return nil
+
+ return w.Writer.Write(w.buffer.Bytes())
}
+
+ return n, err
+ }
+
+ return w.Writer.Write(b)
+}
+
+func (w *gzipResponseWriter) Flush() {
+ if !w.minLengthExceeded {
+ // Enforce compression because we will not know how much more data will come
+ w.minLengthExceeded = true
+ w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
+ if w.wroteHeader {
+ w.ResponseWriter.WriteHeader(w.code)
+ }
+
+ _, _ = w.Writer.Write(w.buffer.Bytes())
+ }
+
+ if gw, ok := w.Writer.(*gzip.Writer); ok {
+ gw.Flush()
+ }
+ _ = http.NewResponseController(w.ResponseWriter).Flush()
+}
+
+func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ return http.NewResponseController(w.ResponseWriter).Hijack()
+}
+
+func (w *gzipResponseWriter) Unwrap() http.ResponseWriter {
+ return w.ResponseWriter
+}
+
+func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
+ if p, ok := w.ResponseWriter.(http.Pusher); ok {
+ return p.Push(target, opts)
+ }
+ return http.ErrNotSupported
+}
+
+func gzipCompressPool(config GzipConfig) sync.Pool {
+ return sync.Pool{
+ New: func() any {
+ w, err := gzip.NewWriterLevel(io.Discard, config.Level)
+ if err != nil {
+ return err
+ }
+ return w
+ },
+ }
+}
+
+func bufferPool() sync.Pool {
+ return sync.Pool{
+ New: func() any {
+ b := &bytes.Buffer{}
+ return b
+ },
}
}
diff --git a/middleware/compress_test.go b/middleware/compress_test.go
index 85cc9ca21..084ffc9c7 100644
--- a/middleware/compress_test.go
+++ b/middleware/compress_test.go
@@ -1,121 +1,398 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
"bytes"
"compress/gzip"
+ "io"
"net/http"
"net/http/httptest"
+ "os"
"testing"
"time"
- "github.com/labstack/echo"
+ "github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
-type closeNotifyingRecorder struct {
- *httptest.ResponseRecorder
- closed chan bool
-}
+func TestGzip_NoAcceptEncodingHeader(t *testing.T) {
+ // Skip if no Accept-Encoding header
+ h := Gzip()(func(c *echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })
-func newCloseNotifyingRecorder() *closeNotifyingRecorder {
- return &closeNotifyingRecorder{
- httptest.NewRecorder(),
- make(chan bool, 1),
- }
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := h(c)
+ assert.NoError(t, err)
+
+ assert.Equal(t, "test", rec.Body.String())
}
-func (c *closeNotifyingRecorder) close() {
- c.closed <- true
+func TestMustGzipWithConfig_panics(t *testing.T) {
+ assert.Panics(t, func() {
+ GzipWithConfig(GzipConfig{Level: 999})
+ })
}
-func (c *closeNotifyingRecorder) CloseNotify() <-chan bool {
- return c.closed
+func TestGzip_AcceptEncodingHeader(t *testing.T) {
+ h := Gzip()(func(c *echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })
+
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := h(c)
+ assert.NoError(t, err)
+
+ assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
+ assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
+
+ r, err := gzip.NewReader(rec.Body)
+ assert.NoError(t, err)
+ buf := new(bytes.Buffer)
+ defer r.Close()
+ buf.ReadFrom(r)
+ assert.Equal(t, "test", buf.String())
}
-func TestGzip(t *testing.T) {
- req, _ := http.NewRequest(echo.GET, "/", nil)
+func TestGzip_chunked(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
- c := echo.NewContext(req, echo.NewResponse(rec), echo.New())
- h := func(c *echo.Context) error {
- c.Response().Write([]byte("test")) // For Content-Type sniffing
+ c := e.NewContext(req, rec)
+
+ chunkChan := make(chan struct{})
+ waitChan := make(chan struct{})
+ h := Gzip()(func(c *echo.Context) error {
+ rc := http.NewResponseController(c.Response())
+ c.Response().Header().Set("Content-Type", "text/event-stream")
+ c.Response().Header().Set("Transfer-Encoding", "chunked")
+
+ // Write and flush the first part of the data
+ c.Response().Write([]byte("first\n"))
+ rc.Flush()
+
+ chunkChan <- struct{}{}
+ <-waitChan
+
+ // Write and flush the second part of the data
+ c.Response().Write([]byte("second\n"))
+ rc.Flush()
+
+ chunkChan <- struct{}{}
+ <-waitChan
+
+ // Write the final part of the data and return
+ c.Response().Write([]byte("third"))
+
+ chunkChan <- struct{}{}
return nil
+ })
+
+ go func() {
+ err := h(c)
+ chunkChan <- struct{}{}
+ assert.NoError(t, err)
+ }()
+
+ <-chunkChan // wait for first write
+ waitChan <- struct{}{}
+
+ <-chunkChan // wait for second write
+ waitChan <- struct{}{}
+
+ <-chunkChan // wait for final write in handler
+ <-chunkChan // wait for return from handler
+ time.Sleep(5 * time.Millisecond) // to have time for flushing
+
+ assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
+
+ r, err := gzip.NewReader(rec.Body)
+ assert.NoError(t, err)
+ buf := new(bytes.Buffer)
+ buf.ReadFrom(r)
+ assert.Equal(t, "first\nsecond\nthird", buf.String())
+}
+
+func TestGzip_NoContent(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := Gzip()(func(c *echo.Context) error {
+ return c.NoContent(http.StatusNoContent)
+ })
+ if assert.NoError(t, h(c)) {
+ assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
+ assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
+ assert.Equal(t, 0, len(rec.Body.Bytes()))
}
+}
- // Skip if no Accept-Encoding header
- Gzip()(h)(c)
- assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, "test", rec.Body.String())
+func TestGzip_Empty(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := Gzip()(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "")
+ })
+ if assert.NoError(t, h(c)) {
+ assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
+ assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get(echo.HeaderContentType))
+ r, err := gzip.NewReader(rec.Body)
+ if assert.NoError(t, err) {
+ var buf bytes.Buffer
+ buf.ReadFrom(r)
+ assert.Equal(t, "", buf.String())
+ }
+ }
+}
- req, _ = http.NewRequest(echo.GET, "/", nil)
- req.Header.Set(echo.AcceptEncoding, "gzip")
- rec = httptest.NewRecorder()
- c = echo.NewContext(req, echo.NewResponse(rec), echo.New())
+func TestGzip_ErrorReturned(t *testing.T) {
+ e := echo.New()
+ e.Use(Gzip())
+ e.GET("/", func(c *echo.Context) error {
+ return echo.ErrNotFound
+ })
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusNotFound, rec.Code)
+ assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
+}
+
+func TestGzipWithConfig_invalidLevel(t *testing.T) {
+ mw, err := GzipConfig{Level: 12}.ToMiddleware()
+ assert.EqualError(t, err, "invalid gzip level")
+ assert.Nil(t, mw)
+}
+
+// Issue #806
+func TestGzipWithStatic(t *testing.T) {
+ e := echo.New()
+ e.Filesystem = os.DirFS("../")
+
+ e.Use(Gzip())
+ e.Static("/test", "_fixture/images")
+ req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
- // Gzip
- Gzip()(h)(c)
assert.Equal(t, http.StatusOK, rec.Code)
- assert.Equal(t, "gzip", rec.Header().Get(echo.ContentEncoding))
- assert.Contains(t, rec.Header().Get(echo.ContentType), echo.TextPlain)
+ // Data is written out in chunks when Content-Length == "", so only
+ // validate the content length if it's not set.
+ if cl := rec.Header().Get("Content-Length"); cl != "" {
+ assert.Equal(t, cl, rec.Body.Len())
+ }
+ r, err := gzip.NewReader(rec.Body)
+ if assert.NoError(t, err) {
+ defer r.Close()
+ want, err := os.ReadFile("../_fixture/images/walle.png")
+ if assert.NoError(t, err) {
+ buf := new(bytes.Buffer)
+ buf.ReadFrom(r)
+ assert.Equal(t, want, buf.Bytes())
+ }
+ }
+}
+
+func TestGzipWithMinLength(t *testing.T) {
+ e := echo.New()
+ // Minimal response length
+ e.Use(GzipWithConfig(GzipConfig{MinLength: 10}))
+ e.GET("/", func(c *echo.Context) error {
+ c.Response().Write([]byte("foobarfoobar"))
+ return nil
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
r, err := gzip.NewReader(rec.Body)
- defer r.Close()
if assert.NoError(t, err) {
buf := new(bytes.Buffer)
+ defer r.Close()
buf.ReadFrom(r)
- assert.Equal(t, "test", buf.String())
+ assert.Equal(t, "foobarfoobar", buf.String())
}
}
-func TestGzipFlush(t *testing.T) {
+func TestGzipWithMinLengthTooShort(t *testing.T) {
+ e := echo.New()
+ // Minimal response length
+ e.Use(GzipWithConfig(GzipConfig{MinLength: 10}))
+ e.GET("/", func(c *echo.Context) error {
+ c.Response().Write([]byte("test"))
+ return nil
+ })
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding))
+ assert.Contains(t, rec.Body.String(), "test")
+}
+
+func TestGzipWithResponseWithoutBody(t *testing.T) {
+ e := echo.New()
+
+ e.Use(Gzip())
+ e.GET("/", func(c *echo.Context) error {
+ return c.Redirect(http.StatusMovedPermanently, "http://localhost")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec := httptest.NewRecorder()
- buf := new(bytes.Buffer)
- w := gzip.NewWriter(buf)
- gw := gzipWriter{Writer: w, ResponseWriter: rec}
- n0 := buf.Len()
- if n0 != 0 {
- t.Fatalf("buffer size = %d before writes; want 0", n0)
- }
+ e.ServeHTTP(rec, req)
- if err := gw.Flush(); err != nil {
- t.Fatal(err)
- }
+ assert.Equal(t, http.StatusMovedPermanently, rec.Code)
+ assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding))
+}
+
+func TestGzipWithMinLengthChunked(t *testing.T) {
+ e := echo.New()
+
+ // Gzip chunked
+ chunkBuf := make([]byte, 5)
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
- n1 := buf.Len()
- if n1 == 0 {
- t.Fatal("no data after first flush")
+ var r *gzip.Reader = nil
+
+ c := e.NewContext(req, rec)
+ next := func(c *echo.Context) error {
+ rc := http.NewResponseController(c.Response())
+ c.Response().Header().Set("Content-Type", "text/event-stream")
+ c.Response().Header().Set("Transfer-Encoding", "chunked")
+
+ // Write and flush the first part of the data
+ c.Response().Write([]byte("test\n"))
+ rc.Flush()
+
+ // Read the first part of the data
+ assert.True(t, rec.Flushed)
+ assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
+
+ var err error
+ r, err = gzip.NewReader(rec.Body)
+ assert.NoError(t, err)
+
+ _, err = io.ReadFull(r, chunkBuf)
+ assert.NoError(t, err)
+ assert.Equal(t, "test\n", string(chunkBuf))
+
+ // Write and flush the second part of the data
+ c.Response().Write([]byte("test\n"))
+ rc.Flush()
+
+ _, err = io.ReadFull(r, chunkBuf)
+ assert.NoError(t, err)
+ assert.Equal(t, "test\n", string(chunkBuf))
+
+ // Write the final part of the data and return
+ c.Response().Write([]byte("test"))
+ return nil
}
+ err := GzipWithConfig(GzipConfig{MinLength: 10})(next)(c)
+
+ assert.NoError(t, err)
+ assert.NotNil(t, r)
+
+ buf := new(bytes.Buffer)
+
+ buf.ReadFrom(r)
+ assert.Equal(t, "test", buf.String())
- gw.Write([]byte("x"))
+ r.Close()
+}
- n2 := buf.Len()
- if n1 != n2 {
- t.Fatalf("after writing a single byte, size changed from %d to %d; want no change", n1, n2)
+func TestGzipWithMinLengthNoContent(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c *echo.Context) error {
+ return c.NoContent(http.StatusNoContent)
+ })
+ if assert.NoError(t, h(c)) {
+ assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
+ assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
+ assert.Equal(t, 0, len(rec.Body.Bytes()))
}
+}
- if err := gw.Flush(); err != nil {
- t.Fatal(err)
+func TestGzipResponseWriter_CanUnwrap(t *testing.T) {
+ trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
+ bdrw := gzipResponseWriter{
+ ResponseWriter: trwu,
}
+ result := bdrw.Unwrap()
+ assert.Equal(t, trwu, result)
+}
- n3 := buf.Len()
- if n2 == n3 {
- t.Fatal("Flush didn't flush any data")
+func TestGzipResponseWriter_CanHijack(t *testing.T) {
+ trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}}
+ bdrw := gzipResponseWriter{
+ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}
+ _, _, err := bdrw.Hijack()
+ assert.EqualError(t, err, "can hijack")
}
-func TestGzipCloseNotify(t *testing.T) {
- rec := newCloseNotifyingRecorder()
- buf := new(bytes.Buffer)
- w := gzip.NewWriter(buf)
- gw := gzipWriter{Writer: w, ResponseWriter: rec}
- closed := false
- notifier := gw.CloseNotify()
- rec.close()
-
- select {
- case <-notifier:
- closed = true
- case <-time.After(time.Second):
+func TestGzipResponseWriter_CanNotHijack(t *testing.T) {
+ trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()}
+ bdrw := gzipResponseWriter{
+ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping
}
+ _, _, err := bdrw.Hijack()
+ assert.EqualError(t, err, "feature not supported")
+}
+
+func BenchmarkGzip(b *testing.B) {
+ e := echo.New()
- assert.Equal(t, closed, true)
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
+
+ h := Gzip()(func(c *echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ // Gzip
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h(c)
+ }
}
diff --git a/middleware/context_timeout.go b/middleware/context_timeout.go
new file mode 100644
index 000000000..68465199a
--- /dev/null
+++ b/middleware/context_timeout.go
@@ -0,0 +1,71 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "context"
+ "errors"
+ "time"
+
+ "github.com/labstack/echo/v5"
+)
+
+// ContextTimeoutConfig defines the config for ContextTimeout middleware.
+type ContextTimeoutConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // ErrorHandler is a function when error arises in middeware execution.
+ ErrorHandler func(c *echo.Context, err error) error
+
+ // Timeout configures a timeout for the middleware
+ Timeout time.Duration
+}
+
+// ContextTimeout returns a middleware which returns error (503 Service Unavailable error) to client
+// when underlying method returns context.DeadlineExceeded error.
+func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc {
+ return ContextTimeoutWithConfig(ContextTimeoutConfig{Timeout: timeout})
+}
+
+// ContextTimeoutWithConfig returns a Timeout middleware with config.
+func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts Config to middleware.
+func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Timeout == 0 {
+ return nil, errors.New("timeout must be set")
+ }
+ if config.Skipper == nil {
+ config.Skipper = DefaultSkipper
+ }
+ if config.ErrorHandler == nil {
+ config.ErrorHandler = func(c *echo.Context, err error) error {
+ if err != nil && errors.Is(err, context.DeadlineExceeded) {
+ return echo.ErrServiceUnavailable.Wrap(err)
+ }
+ return err
+ }
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ timeoutContext, cancel := context.WithTimeout(c.Request().Context(), config.Timeout)
+ defer cancel()
+
+ c.SetRequest(c.Request().WithContext(timeoutContext))
+
+ if err := next(c); err != nil {
+ return config.ErrorHandler(c, err)
+ }
+ return nil
+ }
+ }, nil
+}
diff --git a/middleware/context_timeout_test.go b/middleware/context_timeout_test.go
new file mode 100644
index 000000000..c7ba76beb
--- /dev/null
+++ b/middleware/context_timeout_test.go
@@ -0,0 +1,229 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "context"
+ "errors"
+ "github.com/labstack/echo/v5"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestContextTimeoutSkipper(t *testing.T) {
+ t.Parallel()
+ m := ContextTimeoutWithConfig(ContextTimeoutConfig{
+ Skipper: func(context *echo.Context) bool {
+ return true
+ },
+ Timeout: 10 * time.Millisecond,
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e := echo.New()
+ c := e.NewContext(req, rec)
+
+ err := m(func(c *echo.Context) error {
+ if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil {
+ return err
+ }
+
+ return errors.New("response from handler")
+ })(c)
+
+ // if not skipped we would have not returned error due context timeout logic
+ assert.EqualError(t, err, "response from handler")
+}
+
+func TestContextTimeoutWithTimeout0(t *testing.T) {
+ t.Parallel()
+ assert.Panics(t, func() {
+ ContextTimeout(time.Duration(0))
+ })
+}
+
+func TestContextTimeoutErrorOutInHandler(t *testing.T) {
+ t.Parallel()
+ m := ContextTimeoutWithConfig(ContextTimeoutConfig{
+ // Timeout has to be defined or the whole flow for timeout middleware will be skipped
+ Timeout: 10 * time.Millisecond,
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e := echo.New()
+ c := e.NewContext(req, rec)
+
+ rec.Code = 1 // we want to be sure that even 200 will not be sent
+ err := m(func(c *echo.Context) error {
+ // this error must not be written to the client response. Middlewares upstream of timeout middleware must be able
+ // to handle returned error and this can be done only then handler has not yet committed (written status code)
+ // the response.
+ return echo.NewHTTPError(http.StatusTeapot, "err")
+ })(c)
+
+ assert.Error(t, err)
+ assert.EqualError(t, err, "code=418, message=err")
+ assert.Equal(t, 1, rec.Code)
+ assert.Equal(t, "", rec.Body.String())
+}
+
+func TestContextTimeoutSuccessfulRequest(t *testing.T) {
+ t.Parallel()
+ m := ContextTimeoutWithConfig(ContextTimeoutConfig{
+ // Timeout has to be defined or the whole flow for timeout middleware will be skipped
+ Timeout: 10 * time.Millisecond,
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e := echo.New()
+ c := e.NewContext(req, rec)
+
+ err := m(func(c *echo.Context) error {
+ return c.JSON(http.StatusCreated, map[string]string{"data": "ok"})
+ })(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusCreated, rec.Code)
+ assert.Equal(t, "{\"data\":\"ok\"}\n", rec.Body.String())
+}
+
+func TestContextTimeoutTestRequestClone(t *testing.T) {
+ t.Parallel()
+ req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode()))
+ req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"})
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ rec := httptest.NewRecorder()
+
+ m := ContextTimeoutWithConfig(ContextTimeoutConfig{
+ // Timeout has to be defined or the whole flow for timeout middleware will be skipped
+ Timeout: 1 * time.Second,
+ })
+
+ e := echo.New()
+ c := e.NewContext(req, rec)
+
+ err := m(func(c *echo.Context) error {
+ // Cookie test
+ cookie, err := c.Request().Cookie("cookie")
+ if assert.NoError(t, err) {
+ assert.EqualValues(t, "cookie", cookie.Name)
+ assert.EqualValues(t, "value", cookie.Value)
+ }
+
+ // Form values
+ if assert.NoError(t, c.Request().ParseForm()) {
+ assert.EqualValues(t, "value", c.Request().FormValue("form"))
+ }
+
+ // Query string
+ assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0])
+ return nil
+ })(c)
+
+ assert.NoError(t, err)
+}
+
+func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) {
+ t.Parallel()
+
+ timeout := 10 * time.Millisecond
+ m := ContextTimeoutWithConfig(ContextTimeoutConfig{
+ Timeout: timeout,
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e := echo.New()
+ c := e.NewContext(req, rec)
+
+ err := m(func(c *echo.Context) error {
+ if err := sleepWithContext(c.Request().Context(), time.Duration(80*time.Millisecond)); err != nil {
+ return err
+ }
+ return c.String(http.StatusOK, "Hello, World!")
+ })(c)
+
+ assert.Error(t, err)
+ if assert.IsType(t, &echo.HTTPError{}, err) {
+ assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
+ assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message)
+ }
+}
+
+func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) {
+ t.Parallel()
+
+ timeoutErrorHandler := func(c *echo.Context, err error) error {
+ if err != nil {
+ if errors.Is(err, context.DeadlineExceeded) {
+ return &echo.HTTPError{
+ Code: http.StatusServiceUnavailable,
+ Message: "Timeout! change me",
+ }
+ }
+ return err
+ }
+ return nil
+ }
+
+ timeout := 50 * time.Millisecond
+ m := ContextTimeoutWithConfig(ContextTimeoutConfig{
+ Timeout: timeout,
+ ErrorHandler: timeoutErrorHandler,
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e := echo.New()
+ c := e.NewContext(req, rec)
+
+ err := m(func(c *echo.Context) error {
+ // NOTE: Very short periods are not reliable for tests due to Go routine scheduling and the unpredictable order
+ // for 1) request and 2) time goroutine. For most OS this works as expected, but MacOS seems most flaky.
+
+ if err := sleepWithContext(c.Request().Context(), 100*time.Millisecond); err != nil {
+ return err
+ }
+
+ // The Request Context should have a Deadline set by http.ContextTimeoutHandler
+ if _, ok := c.Request().Context().Deadline(); !ok {
+ assert.Fail(t, "No timeout set on Request Context")
+ }
+ return c.String(http.StatusOK, "Hello, World!")
+ })(c)
+
+ assert.IsType(t, &echo.HTTPError{}, err)
+ assert.Error(t, err)
+ assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code)
+ assert.Equal(t, "Timeout! change me", err.(*echo.HTTPError).Message)
+}
+
+func sleepWithContext(ctx context.Context, d time.Duration) error {
+ timer := time.NewTimer(d)
+
+ defer func() {
+ _ = timer.Stop()
+ }()
+
+ select {
+ case <-ctx.Done():
+ return context.DeadlineExceeded
+ case <-timer.C:
+ return nil
+ }
+}
diff --git a/middleware/cors.go b/middleware/cors.go
new file mode 100644
index 000000000..96ed16985
--- /dev/null
+++ b/middleware/cors.go
@@ -0,0 +1,300 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "strconv"
+ "strings"
+
+ "github.com/labstack/echo/v5"
+)
+
+// CORSConfig defines the config for CORS middleware.
+type CORSConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // AllowOrigins determines the value of the Access-Control-Allow-Origin
+ // response header. This header defines a list of origins that may access the
+ // resource.
+ //
+ // Origin consist of following parts: `scheme + "://" + host + optional ":" + port`
+ // Wildcard can be used, but has to be set explicitly []string{"*"}
+ // Example: `https://example.com`, `http://example.com:8080`, `*`
+ //
+ // Security: use extreme caution when handling the origin, and carefully
+ // validate any logic. Remember that attackers may register hostile domain names.
+ // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
+ // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
+ //
+ // Mandatory.
+ AllowOrigins []string
+
+ // UnsafeAllowOriginFunc is an optional custom function to validate the origin. It takes the
+ // origin as an argument and returns
+ // - string, allowed origin
+ // - bool, true if allowed or false otherwise.
+ // - error, if an error is returned, it is returned immediately by the handler.
+ // If this option is set, AllowOrigins is ignored.
+ //
+ // Security: use extreme caution when handling the origin, and carefully
+ // validate any logic. Remember that attackers may register hostile (sub)domain names.
+ // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
+ //
+ // Sub-domain checks example:
+ // UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) {
+ // if strings.HasSuffix(origin, ".example.com") {
+ // return origin, true, nil
+ // }
+ // return "", false, nil
+ // },
+ //
+ // Optional.
+ UnsafeAllowOriginFunc func(c *echo.Context, origin string) (allowedOrigin string, allowed bool, err error)
+
+ // AllowMethods determines the value of the Access-Control-Allow-Methods
+ // response header. This header specified the list of methods allowed when
+ // accessing the resource. This is used in response to a preflight request.
+ //
+ // Optional. Default value DefaultCORSConfig.AllowMethods.
+ // If `allowMethods` is left empty, this middleware will fill for preflight
+ // request `Access-Control-Allow-Methods` header value
+ // from `Allow` header that echo.Router set into context.
+ //
+ // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
+ AllowMethods []string
+
+ // AllowHeaders determines the value of the Access-Control-Allow-Headers
+ // response header. This header is used in response to a preflight request to
+ // indicate which HTTP headers can be used when making the actual request.
+ //
+ // Optional. Defaults to empty list. No domains allowed for CORS.
+ //
+ // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
+ AllowHeaders []string
+
+ // AllowCredentials determines the value of the
+ // Access-Control-Allow-Credentials response header. This header indicates
+ // whether or not the response to the request can be exposed when the
+ // credentials mode (Request.credentials) is true. When used as part of a
+ // response to a preflight request, this indicates whether or not the actual
+ // request can be made using credentials. See also
+ // [MDN: Access-Control-Allow-Credentials].
+ //
+ // Optional. Default value false, in which case the header is not set.
+ //
+ // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
+ // See "Exploiting CORS misconfigurations for Bitcoins and bounties",
+ // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
+ //
+ // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
+ AllowCredentials bool
+
+ // ExposeHeaders determines the value of Access-Control-Expose-Headers, which
+ // defines a list of headers that clients are allowed to access.
+ //
+ // Optional. Default value []string{}, in which case the header is not set.
+ //
+ // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header
+ ExposeHeaders []string
+
+ // MaxAge determines the value of the Access-Control-Max-Age response header.
+ // This header indicates how long (in seconds) the results of a preflight
+ // request can be cached.
+ // The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response.
+ //
+ // Optional. Default value 0 - meaning header is not sent.
+ //
+ // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
+ MaxAge int
+}
+
+// CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
+// See also [MDN: Cross-Origin Resource Sharing (CORS)].
+//
+// Origin consist of following parts: `scheme + "://" + host + optional ":" + port`
+// Wildcard `*` can be used, but has to be set explicitly.
+// Example: `https://example.com`, `http://example.com:8080`, `*`
+//
+// Security: Poorly configured CORS can compromise security because it allows
+// relaxation of the browser's Same-Origin policy. See [Exploiting CORS
+// misconfigurations for Bitcoins and bounties] and [Portswigger: Cross-origin
+// resource sharing (CORS)] for more details.
+//
+// [MDN: Cross-Origin Resource Sharing (CORS)]: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
+// [Exploiting CORS misconfigurations for Bitcoins and bounties]: https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
+// [Portswigger: Cross-origin resource sharing (CORS)]: https://portswigger.net/web-security/cors
+func CORS(allowOrigins ...string) echo.MiddlewareFunc {
+ c := CORSConfig{
+ AllowOrigins: allowOrigins,
+ }
+ return CORSWithConfig(c)
+}
+
+// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration.
+// See: [CORS].
+func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration
+func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ // Defaults
+ if config.Skipper == nil {
+ config.Skipper = DefaultSkipper
+ }
+ hasCustomAllowMethods := true
+ if len(config.AllowMethods) == 0 {
+ hasCustomAllowMethods = false
+ config.AllowMethods = []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}
+ }
+
+ allowMethods := strings.Join(config.AllowMethods, ",")
+ allowHeaders := strings.Join(config.AllowHeaders, ",")
+ exposeHeaders := strings.Join(config.ExposeHeaders, ",")
+
+ maxAge := "0"
+ if config.MaxAge > 0 {
+ maxAge = strconv.Itoa(config.MaxAge)
+ }
+
+ allowOriginFunc := config.UnsafeAllowOriginFunc
+ if config.UnsafeAllowOriginFunc == nil {
+ if len(config.AllowOrigins) == 0 {
+ return nil, errors.New("at least one AllowOrigins is required or UnsafeAllowOriginFunc must be provided")
+ }
+ allowOriginFunc = config.defaultAllowOriginFunc
+ for _, origin := range config.AllowOrigins {
+ if origin == "*" {
+ if config.AllowCredentials {
+ return nil, fmt.Errorf("* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc")
+ }
+ allowOriginFunc = config.starAllowOriginFunc
+ break
+ }
+ if err := validateOrigin(origin, "allow origin"); err != nil {
+ return nil, err
+ }
+ }
+ config.AllowOrigins = append([]string(nil), config.AllowOrigins...)
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ req := c.Request()
+ res := c.Response()
+ origin := req.Header.Get(echo.HeaderOrigin)
+
+ res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
+
+ // Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
+ // Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
+ // For simplicity we just consider method type and later `Origin` header.
+ preflight := req.Method == http.MethodOptions
+
+ // Although router adds special handler in case of OPTIONS method we avoid calling next for OPTIONS in this middleware
+ // as CORS requests do not have cookies / authentication headers by default, so we could get stuck in auth
+ // middlewares by calling next(c).
+ // But we still want to send `Allow` header as response in case of Non-CORS OPTIONS request as router default
+ // handler does.
+ routerAllowMethods := ""
+ if preflight {
+ tmpAllowMethods, ok := c.Get(echo.ContextKeyHeaderAllow).(string)
+ if ok && tmpAllowMethods != "" {
+ routerAllowMethods = tmpAllowMethods
+ c.Response().Header().Set(echo.HeaderAllow, routerAllowMethods)
+ }
+ }
+
+ // No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
+ if origin == "" {
+ if preflight { // req.Method=OPTIONS
+ return c.NoContent(http.StatusNoContent)
+ }
+ return next(c) // let non-browser calls through
+ }
+
+ allowedOrigin, allowed, err := allowOriginFunc(c, origin)
+ if err != nil {
+ return err
+ }
+ if !allowed {
+ // Origin existed and was NOT allowed
+ if preflight {
+ // From: https://github.com/labstack/echo/issues/2767
+ // If the request's origin isn't allowed by the CORS configuration,
+ // the middleware should simply omit the relevant CORS headers from the response
+ // and let the browser fail the CORS check (if any).
+ return c.NoContent(http.StatusNoContent)
+ }
+ // From: https://github.com/labstack/echo/issues/2767
+ // no CORS middleware should block non-preflight requests;
+ // such requests should be let through. One reason is that not all requests that
+ // carry an Origin header participate in the CORS protocol.
+ return next(c)
+ }
+
+ // Origin existed and was allowed
+
+ res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
+ if config.AllowCredentials {
+ res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
+ }
+
+ // Simple request will be let though
+ if !preflight {
+ if exposeHeaders != "" {
+ res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
+ }
+ return next(c)
+ }
+ // Below code is for Preflight (OPTIONS) request
+ //
+ // Preflight will end with c.NoContent(http.StatusNoContent) as we do not know if
+ // at the end of handler chain is actual OPTIONS route or 404/405 route which
+ // response code will confuse browsers
+ res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
+ res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
+
+ if !hasCustomAllowMethods && routerAllowMethods != "" {
+ res.Header().Set(echo.HeaderAccessControlAllowMethods, routerAllowMethods)
+ } else {
+ res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods)
+ }
+
+ if allowHeaders != "" {
+ res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders)
+ } else {
+ h := req.Header.Get(echo.HeaderAccessControlRequestHeaders)
+ if h != "" {
+ res.Header().Set(echo.HeaderAccessControlAllowHeaders, h)
+ }
+ }
+ if config.MaxAge != 0 {
+ res.Header().Set(echo.HeaderAccessControlMaxAge, maxAge)
+ }
+ return c.NoContent(http.StatusNoContent)
+ }
+ }, nil
+}
+
+func (config CORSConfig) starAllowOriginFunc(c *echo.Context, origin string) (string, bool, error) {
+ return "*", true, nil
+}
+
+func (config CORSConfig) defaultAllowOriginFunc(c *echo.Context, origin string) (string, bool, error) {
+ for _, allowedOrigin := range config.AllowOrigins {
+ if strings.EqualFold(allowedOrigin, origin) {
+ return allowedOrigin, true, nil
+ }
+ }
+ return "", false, nil
+}
diff --git a/middleware/cors_test.go b/middleware/cors_test.go
new file mode 100644
index 000000000..5de4ca063
--- /dev/null
+++ b/middleware/cors_test.go
@@ -0,0 +1,628 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "cmp"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestCORS(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodOptions, "/", nil) // Preflight request
+ req.Header.Set(echo.HeaderOrigin, "http://example.com")
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ mw := CORS("*")
+ handler := mw(func(c *echo.Context) error {
+ return nil
+ })
+
+ err := handler(c)
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusNoContent, rec.Code)
+ assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+}
+
+func TestCORSConfig(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenConfig *CORSConfig
+ whenMethod string
+ whenHeaders map[string]string
+ expectHeaders map[string]string
+ notExpectHeaders map[string]string
+ expectErr string
+ }{
+ {
+ name: "ok, wildcard origin",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"*"},
+ },
+ whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"},
+ expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "*"},
+ },
+ {
+ name: "ok, wildcard AllowedOrigin with no Origin header in request",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"*"},
+ },
+ notExpectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: ""},
+ },
+ {
+ name: "ok, specific AllowOrigins and AllowCredentials",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"http://localhost", "http://localhost:8080"},
+ AllowCredentials: true,
+ MaxAge: 3600,
+ },
+ whenHeaders: map[string]string{echo.HeaderOrigin: "http://localhost"},
+ expectHeaders: map[string]string{
+ echo.HeaderAccessControlAllowOrigin: "http://localhost",
+ echo.HeaderAccessControlAllowCredentials: "true",
+ },
+ },
+ {
+ name: "ok, preflight request with matching origin for `AllowOrigins`",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"http://localhost"},
+ AllowCredentials: true,
+ MaxAge: 3600,
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{
+ echo.HeaderOrigin: "http://localhost",
+ echo.HeaderContentType: echo.MIMEApplicationJSON,
+ },
+ expectHeaders: map[string]string{
+ echo.HeaderAccessControlAllowOrigin: "http://localhost",
+ echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
+ echo.HeaderAccessControlAllowCredentials: "true",
+ echo.HeaderAccessControlMaxAge: "3600",
+ },
+ },
+ {
+ name: "ok, preflight request when `Access-Control-Max-Age` is set",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"http://localhost"},
+ AllowCredentials: true,
+ MaxAge: 1,
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{
+ echo.HeaderOrigin: "http://localhost",
+ echo.HeaderContentType: echo.MIMEApplicationJSON,
+ },
+ expectHeaders: map[string]string{
+ echo.HeaderAccessControlMaxAge: "1",
+ },
+ },
+ {
+ name: "ok, preflight request when `Access-Control-Max-Age` is set to 0 - not to cache response",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"http://localhost"},
+ AllowCredentials: true,
+ MaxAge: -1, // forces `Access-Control-Max-Age: 0`
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{
+ echo.HeaderOrigin: "http://localhost",
+ echo.HeaderContentType: echo.MIMEApplicationJSON,
+ },
+ expectHeaders: map[string]string{
+ echo.HeaderAccessControlMaxAge: "0",
+ },
+ },
+ {
+ name: "ok, CORS check are skipped",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"http://localhost"},
+ AllowCredentials: true,
+ Skipper: func(c *echo.Context) bool {
+ return true
+ },
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{
+ echo.HeaderOrigin: "http://localhost",
+ echo.HeaderContentType: echo.MIMEApplicationJSON,
+ },
+ notExpectHeaders: map[string]string{
+ echo.HeaderAccessControlAllowOrigin: "localhost",
+ echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
+ echo.HeaderAccessControlAllowCredentials: "true",
+ echo.HeaderAccessControlMaxAge: "3600",
+ },
+ },
+ {
+ name: "nok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"*"},
+ AllowCredentials: true,
+ MaxAge: 3600,
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{
+ echo.HeaderOrigin: "localhost",
+ echo.HeaderContentType: echo.MIMEApplicationJSON,
+ },
+ expectErr: `* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc`,
+ },
+ {
+ name: "nok, preflight request with invalid `AllowOrigins` value",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"http://server", "missing-scheme"},
+ },
+ expectErr: `allow origin is missing scheme or host: missing-scheme`,
+ },
+ {
+ name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` false",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"*"},
+ AllowCredentials: false, // important for this testcase
+ MaxAge: 3600,
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{
+ echo.HeaderOrigin: "localhost",
+ echo.HeaderContentType: echo.MIMEApplicationJSON,
+ },
+ expectHeaders: map[string]string{
+ echo.HeaderAccessControlAllowOrigin: "*",
+ echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
+ echo.HeaderAccessControlMaxAge: "3600",
+ },
+ notExpectHeaders: map[string]string{
+ echo.HeaderAccessControlAllowCredentials: "",
+ },
+ },
+ {
+ name: "ok, INSECURE preflight request with wildcard `AllowOrigins` and `AllowCredentials` true",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"*"},
+ AllowCredentials: true,
+ MaxAge: 3600,
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{
+ echo.HeaderOrigin: "localhost",
+ echo.HeaderContentType: echo.MIMEApplicationJSON,
+ },
+ expectErr: `* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc`,
+ },
+ {
+ name: "ok, preflight request with Access-Control-Request-Headers",
+ givenConfig: &CORSConfig{
+ AllowOrigins: []string{"*"},
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{
+ echo.HeaderOrigin: "localhost",
+ echo.HeaderContentType: echo.MIMEApplicationJSON,
+ echo.HeaderAccessControlRequestHeaders: "Special-Request-Header",
+ },
+ expectHeaders: map[string]string{
+ echo.HeaderAccessControlAllowOrigin: "*",
+ echo.HeaderAccessControlAllowHeaders: "Special-Request-Header",
+ echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
+ },
+ },
+ {
+ name: "ok, preflight request with `AllowOrigins` which allow all subdomains aaa with *",
+ givenConfig: &CORSConfig{
+ UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (allowedOrigin string, allowed bool, err error) {
+ if strings.HasSuffix(origin, ".example.com") {
+ allowed = true
+ }
+ return origin, allowed, nil
+ },
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{echo.HeaderOrigin: "http://aaa.example.com"},
+ expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://aaa.example.com"},
+ },
+ {
+ name: "ok, preflight request with `AllowOrigins` which allow all subdomains bbb with *",
+ givenConfig: &CORSConfig{
+ UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) {
+ if strings.HasSuffix(origin, ".example.com") {
+ return origin, true, nil
+ }
+ return "", false, nil
+ },
+ },
+ whenMethod: http.MethodOptions,
+ whenHeaders: map[string]string{echo.HeaderOrigin: "http://bbb.example.com"},
+ expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://bbb.example.com"},
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ var mw echo.MiddlewareFunc
+ var err error
+ if tc.givenConfig != nil {
+ mw, err = tc.givenConfig.ToMiddleware()
+ } else {
+ mw, err = CORSConfig{}.ToMiddleware()
+ }
+ if err != nil {
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ return
+ }
+ t.Fatal(err)
+ }
+
+ h := mw(func(c *echo.Context) error {
+ return nil
+ })
+
+ method := cmp.Or(tc.whenMethod, http.MethodGet)
+ req := httptest.NewRequest(method, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ for k, v := range tc.whenHeaders {
+ req.Header.Set(k, v)
+ }
+
+ err = h(c)
+
+ assert.NoError(t, err)
+ header := rec.Header()
+ for k, v := range tc.expectHeaders {
+ assert.Equal(t, v, header.Get(k), "header: `%v` should be `%v`", k, v)
+ }
+ for k, v := range tc.notExpectHeaders {
+ if v == "" {
+ assert.Len(t, header.Values(k), 0, "header: `%v` should not be set", k)
+ } else {
+ assert.NotEqual(t, v, header.Get(k), "header: `%v` should not be `%v`", k, v)
+ }
+ }
+ })
+ }
+}
+
+func Test_allowOriginScheme(t *testing.T) {
+ tests := []struct {
+ domain, pattern string
+ expected bool
+ }{
+ {
+ domain: "http://example.com",
+ pattern: "http://example.com",
+ expected: true,
+ },
+ {
+ domain: "https://example.com",
+ pattern: "https://example.com",
+ expected: true,
+ },
+ {
+ domain: "http://example.com",
+ pattern: "https://example.com",
+ expected: false,
+ },
+ {
+ domain: "https://example.com",
+ pattern: "http://example.com",
+ expected: false,
+ },
+ }
+
+ e := echo.New()
+ for _, tt := range tests {
+ req := httptest.NewRequest(http.MethodOptions, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ req.Header.Set(echo.HeaderOrigin, tt.domain)
+ cors := CORSWithConfig(CORSConfig{
+ AllowOrigins: []string{tt.pattern},
+ })
+ h := cors(func(c *echo.Context) error { return echo.ErrNotFound })
+ h(c)
+
+ if tt.expected {
+ assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ } else {
+ assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
+ }
+ }
+}
+
+func TestCORSWithConfig_AllowMethods(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenAllowOrigins []string
+ givenAllowMethods []string
+ whenAllowContextKey string
+ whenOrigin string
+ expectAllow string
+ expectAccessControlAllowMethods string
+ }{
+ {
+ name: "custom AllowMethods, preflight, no origin, sets only allow header from context key",
+ givenAllowOrigins: []string{"*"},
+ givenAllowMethods: []string{http.MethodGet, http.MethodHead},
+ whenAllowContextKey: "OPTIONS, GET",
+ whenOrigin: "",
+ expectAllow: "OPTIONS, GET",
+ },
+ {
+ name: "default AllowMethods, preflight, no origin, no allow header in context key and in response",
+ givenAllowOrigins: []string{"*"},
+ givenAllowMethods: nil,
+ whenAllowContextKey: "",
+ whenOrigin: "",
+ expectAllow: "",
+ },
+ {
+ name: "custom AllowMethods, preflight, existing origin, sets both headers different values",
+ givenAllowOrigins: []string{"*"},
+ givenAllowMethods: []string{http.MethodGet, http.MethodHead},
+ whenAllowContextKey: "OPTIONS, GET",
+ whenOrigin: "http://google.com",
+ expectAllow: "OPTIONS, GET",
+ expectAccessControlAllowMethods: "GET,HEAD",
+ },
+ {
+ name: "default AllowMethods, preflight, existing origin, sets both headers",
+ givenAllowOrigins: []string{"*"},
+ givenAllowMethods: nil,
+ whenAllowContextKey: "OPTIONS, GET",
+ whenOrigin: "http://google.com",
+ expectAllow: "OPTIONS, GET",
+ expectAccessControlAllowMethods: "OPTIONS, GET",
+ },
+ {
+ name: "default AllowMethods, preflight, existing origin, no allows, sets only CORS allow methods",
+ givenAllowOrigins: []string{"*"},
+ givenAllowMethods: nil,
+ whenAllowContextKey: "",
+ whenOrigin: "http://google.com",
+ expectAllow: "",
+ expectAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+ e.GET("/test", func(c *echo.Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+
+ cors := CORSWithConfig(CORSConfig{
+ AllowOrigins: tc.givenAllowOrigins,
+ AllowMethods: tc.givenAllowMethods,
+ })
+
+ req := httptest.NewRequest(http.MethodOptions, "/test", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ req.Header.Set(echo.HeaderOrigin, tc.whenOrigin)
+ if tc.whenAllowContextKey != "" {
+ c.Set(echo.ContextKeyHeaderAllow, tc.whenAllowContextKey)
+ }
+
+ h := cors(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+ h(c)
+
+ assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow))
+ assert.Equal(t, tc.expectAccessControlAllowMethods, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
+ })
+ }
+}
+
+func TestCorsHeaders(t *testing.T) {
+ tests := []struct {
+ name string
+ originDomain string
+ method string
+ allowedOrigin string
+ expected bool
+ expectStatus int
+ expectAllowHeader string
+ }{
+ {
+ name: "non-preflight request, allow any origin, missing origin header = no CORS logic done",
+ originDomain: "",
+ allowedOrigin: "*",
+ method: http.MethodGet,
+ expected: false,
+ expectStatus: http.StatusOK,
+ },
+ {
+ name: "non-preflight request, allow any origin, specific origin domain",
+ originDomain: "http://example.com",
+ allowedOrigin: "*",
+ method: http.MethodGet,
+ expected: true,
+ expectStatus: http.StatusOK,
+ },
+ {
+ name: "non-preflight request, allow specific origin, missing origin header = no CORS logic done",
+ originDomain: "", // Request does not have Origin header
+ allowedOrigin: "http://example.com",
+ method: http.MethodGet,
+ expected: false,
+ expectStatus: http.StatusOK,
+ },
+ {
+ name: "non-preflight request, allow specific origin, different origin header = CORS logic failure",
+ originDomain: "http://bar.com",
+ allowedOrigin: "http://example.com",
+ method: http.MethodGet,
+ expected: false,
+ expectStatus: http.StatusOK,
+ },
+ {
+ name: "non-preflight request, allow specific origin, matching origin header = CORS logic done",
+ originDomain: "http://example.com",
+ allowedOrigin: "http://example.com",
+ method: http.MethodGet,
+ expected: true,
+ expectStatus: http.StatusOK,
+ },
+ {
+ name: "preflight, allow any origin, missing origin header = no CORS logic done",
+ originDomain: "", // Request does not have Origin header
+ allowedOrigin: "*",
+ method: http.MethodOptions,
+ expected: false,
+ expectStatus: http.StatusNoContent,
+ expectAllowHeader: "OPTIONS, GET, POST",
+ },
+ {
+ name: "preflight, allow any origin, existing origin header = CORS logic done",
+ originDomain: "http://example.com",
+ allowedOrigin: "*",
+ method: http.MethodOptions,
+ expected: true,
+ expectStatus: http.StatusNoContent,
+ expectAllowHeader: "OPTIONS, GET, POST",
+ },
+ {
+ name: "preflight, allow any origin, missing origin header = no CORS logic done",
+ originDomain: "", // Request does not have Origin header
+ allowedOrigin: "http://example.com",
+ method: http.MethodOptions,
+ expected: false,
+ expectStatus: http.StatusNoContent,
+ expectAllowHeader: "OPTIONS, GET, POST",
+ },
+ {
+ name: "preflight, allow specific origin, different origin header = no CORS logic done",
+ originDomain: "http://bar.com",
+ allowedOrigin: "http://example.com",
+ method: http.MethodOptions,
+ expected: false,
+ expectStatus: http.StatusNoContent,
+ expectAllowHeader: "OPTIONS, GET, POST",
+ },
+ {
+ name: "preflight, allow specific origin, matching origin header = CORS logic done",
+ originDomain: "http://example.com",
+ allowedOrigin: "http://example.com",
+ method: http.MethodOptions,
+ expected: true,
+ expectStatus: http.StatusNoContent,
+ expectAllowHeader: "OPTIONS, GET, POST",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ e.Use(CORSWithConfig(CORSConfig{
+ AllowOrigins: []string{tc.allowedOrigin},
+ //AllowCredentials: true,
+ //MaxAge: 3600,
+ }))
+
+ e.GET("/", func(c *echo.Context) error {
+ return c.String(http.StatusOK, "OK")
+ })
+ e.POST("/", func(c *echo.Context) error {
+ return c.String(http.StatusCreated, "OK")
+ })
+
+ req := httptest.NewRequest(tc.method, "/", nil)
+ rec := httptest.NewRecorder()
+
+ if tc.originDomain != "" {
+ req.Header.Set(echo.HeaderOrigin, tc.originDomain)
+ }
+
+ // we run through whole Echo handler chain to see how CORS works with Router OPTIONS handler
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
+ assert.Equal(t, tc.expectAllowHeader, rec.Header().Get(echo.HeaderAllow))
+ assert.Equal(t, tc.expectStatus, rec.Code)
+
+ expectedAllowOrigin := ""
+ if tc.allowedOrigin == "*" {
+ expectedAllowOrigin = "*"
+ } else {
+ expectedAllowOrigin = tc.originDomain
+ }
+ switch {
+ case tc.expected && tc.method == http.MethodOptions:
+ assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
+ assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+
+ assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
+
+ case tc.expected && tc.method == http.MethodGet:
+ assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
+ default:
+ assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
+ assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
+ }
+ })
+
+ }
+}
+
+func Test_allowOriginFunc(t *testing.T) {
+ returnTrue := func(c *echo.Context, origin string) (string, bool, error) {
+ return origin, true, nil
+ }
+ returnFalse := func(c *echo.Context, origin string) (string, bool, error) {
+ return origin, false, nil
+ }
+ returnError := func(c *echo.Context, origin string) (string, bool, error) {
+ return origin, true, errors.New("this is a test error")
+ }
+
+ allowOriginFuncs := []func(c *echo.Context, origin string) (string, bool, error){
+ returnTrue,
+ returnFalse,
+ returnError,
+ }
+
+ const origin = "http://example.com"
+
+ e := echo.New()
+ for _, allowOriginFunc := range allowOriginFuncs {
+ req := httptest.NewRequest(http.MethodOptions, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ req.Header.Set(echo.HeaderOrigin, origin)
+ cors, err := CORSConfig{UnsafeAllowOriginFunc: allowOriginFunc}.ToMiddleware()
+ assert.NoError(t, err)
+
+ h := cors(func(c *echo.Context) error { return echo.ErrNotFound })
+ err = h(c)
+
+ allowedOrigin, allowed, expectedErr := allowOriginFunc(c, origin)
+ if expectedErr != nil {
+ assert.Equal(t, expectedErr, err)
+ assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ continue
+ }
+
+ if allowed {
+ assert.Equal(t, allowedOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ } else {
+ assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
+ }
+ }
+}
diff --git a/middleware/csrf.go b/middleware/csrf.go
new file mode 100644
index 000000000..33757b760
--- /dev/null
+++ b/middleware/csrf.go
@@ -0,0 +1,293 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "crypto/subtle"
+ "net/http"
+ "slices"
+ "strings"
+ "time"
+
+ "github.com/labstack/echo/v5"
+)
+
+// CSRFConfig defines the config for CSRF middleware.
+type CSRFConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+ // TrustedOrigin permits any request with `Sec-Fetch-Site` header whose `Origin` header
+ // exactly matches the specified value.
+ // Values should be formated as Origin header "scheme://host[:port]".
+ //
+ // See [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin
+ // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers
+ TrustedOrigins []string
+
+ // AllowSecFetchSameSite allows custom behaviour for `Sec-Fetch-Site` requests that are about to
+ // fail with CRSF error, to be allowed or replaced with custom error.
+ // This function applies to `Sec-Fetch-Site` values:
+ // - `same-site` same registrable domain (subdomain and/or different port)
+ // - `cross-site` request originates from different site
+ // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers
+ AllowSecFetchSiteFunc func(c *echo.Context) (bool, error)
+
+ // TokenLength is the length of the generated token.
+ TokenLength uint8
+ // Optional. Default value 32.
+
+ // TokenLookup is a string in the form of ":" or ":,:" that is used
+ // to extract token from the request.
+ // Optional. Default value "header:X-CSRF-Token".
+ // Possible values:
+ // - "header:" or "header::"
+ // - "query:"
+ // - "form:"
+ // Multiple sources example:
+ // - "header:X-CSRF-Token,query:csrf"
+ TokenLookup string `yaml:"token_lookup"`
+
+ // Generator defines a function to generate token.
+ // Optional. Defaults tp randomString(TokenLength).
+ Generator func() string
+
+ // Context key to store generated CSRF token into context.
+ // Optional. Default value "csrf".
+ ContextKey string
+
+ // Name of the CSRF cookie. This cookie will store CSRF token.
+ // Optional. Default value "csrf".
+ CookieName string
+
+ // Domain of the CSRF cookie.
+ // Optional. Default value none.
+ CookieDomain string
+
+ // Path of the CSRF cookie.
+ // Optional. Default value none.
+ CookiePath string
+
+ // Max age (in seconds) of the CSRF cookie.
+ // Optional. Default value 86400 (24hr).
+ CookieMaxAge int
+
+ // Indicates if CSRF cookie is secure.
+ // Optional. Default value false.
+ CookieSecure bool
+
+ // Indicates if CSRF cookie is HTTP only.
+ // Optional. Default value false.
+ CookieHTTPOnly bool
+
+ // Indicates SameSite mode of the CSRF cookie.
+ // Optional. Default value SameSiteDefaultMode.
+ CookieSameSite http.SameSite
+
+ // ErrorHandler defines a function which is executed for returning custom errors.
+ ErrorHandler func(c *echo.Context, err error) error
+}
+
+// ErrCSRFInvalid is returned when CSRF check fails
+var ErrCSRFInvalid = &echo.HTTPError{Code: http.StatusForbidden, Message: "invalid csrf token"}
+
+// DefaultCSRFConfig is the default CSRF middleware config.
+var DefaultCSRFConfig = CSRFConfig{
+ Skipper: DefaultSkipper,
+ TokenLength: 32,
+ TokenLookup: "header:" + echo.HeaderXCSRFToken,
+ ContextKey: "csrf",
+ CookieName: "_csrf",
+ CookieMaxAge: 86400,
+ CookieSameSite: http.SameSiteDefaultMode,
+}
+
+// CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
+// See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
+func CSRF() echo.MiddlewareFunc {
+ return CSRFWithConfig(DefaultCSRFConfig)
+}
+
+// CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration.
+func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration
+func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ // Defaults
+ if config.Skipper == nil {
+ config.Skipper = DefaultCSRFConfig.Skipper
+ }
+ if config.TokenLength == 0 {
+ config.TokenLength = DefaultCSRFConfig.TokenLength
+ }
+ if config.Generator == nil {
+ config.Generator = createRandomStringGenerator(config.TokenLength)
+ }
+ if config.TokenLookup == "" {
+ config.TokenLookup = DefaultCSRFConfig.TokenLookup
+ }
+ if config.ContextKey == "" {
+ config.ContextKey = DefaultCSRFConfig.ContextKey
+ }
+ if config.CookieName == "" {
+ config.CookieName = DefaultCSRFConfig.CookieName
+ }
+ if config.CookieMaxAge == 0 {
+ config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge
+ }
+ if config.CookieSameSite == http.SameSiteNoneMode {
+ config.CookieSecure = true
+ }
+ if len(config.TrustedOrigins) > 0 {
+ if err := validateOrigins(config.TrustedOrigins, "trusted origin"); err != nil {
+ return nil, err
+ }
+ config.TrustedOrigins = append([]string(nil), config.TrustedOrigins...)
+ }
+
+ extractors, cErr := createExtractors(config.TokenLookup, 1)
+ if cErr != nil {
+ return nil, cErr
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ // use the `Sec-Fetch-Site` header as part of a modern approach to CSRF protection
+ allow, err := config.checkSecFetchSiteRequest(c)
+ if err != nil {
+ return err
+ }
+ if allow {
+ return next(c)
+ }
+
+ // Fallback to legacy token based CSRF protection
+
+ token := ""
+ if k, err := c.Cookie(config.CookieName); err != nil {
+ token = config.Generator() // Generate token
+ } else {
+ token = k.Value // Reuse token
+ }
+
+ switch c.Request().Method {
+ case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
+ default:
+ // Validate token only for requests which are not defined as 'safe' by RFC7231
+ var lastExtractorErr error
+ var lastTokenErr error
+ outer:
+ for _, extractor := range extractors {
+ clientTokens, _, err := extractor(c)
+ if err != nil {
+ lastExtractorErr = err
+ continue
+ }
+
+ for _, clientToken := range clientTokens {
+ if validateCSRFToken(token, clientToken) {
+ lastTokenErr = nil
+ lastExtractorErr = nil
+ break outer
+ }
+ lastTokenErr = ErrCSRFInvalid
+ }
+ }
+ var finalErr error
+ if lastTokenErr != nil {
+ finalErr = lastTokenErr
+ } else if lastExtractorErr != nil {
+ finalErr = echo.ErrBadRequest.Wrap(lastExtractorErr)
+ }
+ if finalErr != nil {
+ if config.ErrorHandler != nil {
+ return config.ErrorHandler(c, finalErr)
+ }
+ return finalErr
+ }
+ }
+
+ // Set CSRF cookie
+ cookie := new(http.Cookie)
+ cookie.Name = config.CookieName
+ cookie.Value = token
+ if config.CookiePath != "" {
+ cookie.Path = config.CookiePath
+ }
+ if config.CookieDomain != "" {
+ cookie.Domain = config.CookieDomain
+ }
+ if config.CookieSameSite != http.SameSiteDefaultMode {
+ cookie.SameSite = config.CookieSameSite
+ }
+ cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
+ cookie.Secure = config.CookieSecure
+ cookie.HttpOnly = config.CookieHTTPOnly
+ c.SetCookie(cookie)
+
+ // Store token in the context
+ c.Set(config.ContextKey, token)
+
+ // Protect clients from caching the response
+ c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
+
+ return next(c)
+ }
+ }, nil
+}
+
+func validateCSRFToken(token, clientToken string) bool {
+ return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
+}
+
+var safeMethods = []string{http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace}
+
+func (config CSRFConfig) checkSecFetchSiteRequest(c *echo.Context) (bool, error) {
+ // https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers
+ // Sec-Fetch-Site values are:
+ // - `same-origin` exact origin match - allow always
+ // - `same-site` same registrable domain (subdomain and/or different port) - block, unless explicitly trusted
+ // - `cross-site` request originates from different site - block, unless explicitly trusted
+ // - `none` direct navigation (URL bar, bookmark) - allow always
+ secFetchSite := c.Request().Header.Get(echo.HeaderSecFetchSite)
+ if secFetchSite == "" {
+ return false, nil
+ }
+
+ if len(config.TrustedOrigins) > 0 {
+ // trusted sites ala OAuth callbacks etc. should be let through
+ origin := c.Request().Header.Get(echo.HeaderOrigin)
+ if origin != "" {
+ for _, trustedOrigin := range config.TrustedOrigins {
+ if strings.EqualFold(origin, trustedOrigin) {
+ return true, nil
+ }
+ }
+ }
+ }
+ isSafe := slices.Contains(safeMethods, c.Request().Method)
+ if !isSafe { // for state-changing request check SecFetchSite value
+ isSafe = secFetchSite == "same-origin" || secFetchSite == "none"
+ }
+
+ if isSafe {
+ return true, nil
+ }
+ // we are here when request is state-changing and `cross-site` or `same-site`
+
+ // Note: if you want to allow `same-site` use config.TrustedOrigins or `config.AllowSecFetchSiteFunc`
+ if config.AllowSecFetchSiteFunc != nil {
+ return config.AllowSecFetchSiteFunc(c)
+ }
+
+ if secFetchSite == "same-site" {
+ return false, echo.NewHTTPError(http.StatusForbidden, "same-site request blocked by CSRF")
+ }
+ return false, echo.NewHTTPError(http.StatusForbidden, "cross-site request blocked by CSRF")
+}
diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go
new file mode 100644
index 000000000..ddecc10e3
--- /dev/null
+++ b/middleware/csrf_test.go
@@ -0,0 +1,854 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "cmp"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestCSRF_tokenExtractors(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenTokenLookup string
+ whenCookieName string
+ givenCSRFCookie string
+ givenMethod string
+ givenQueryTokens map[string][]string
+ givenFormTokens map[string][]string
+ givenHeaderTokens map[string][]string
+ expectError string
+ expectToMiddlewareError string
+ }{
+ {
+ name: "ok, multiple token lookups sources, succeeds on last one",
+ whenTokenLookup: "header:X-CSRF-Token,form:csrf",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenHeaderTokens: map[string][]string{
+ echo.HeaderXCSRFToken: {"invalid_token"},
+ },
+ givenFormTokens: map[string][]string{
+ "csrf": {"token"},
+ },
+ },
+ {
+ name: "ok, token from POST form",
+ whenTokenLookup: "form:csrf",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenFormTokens: map[string][]string{
+ "csrf": {"token"},
+ },
+ },
+ {
+ name: "ok, token from POST form, second token passes",
+ whenTokenLookup: "form:csrf",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenFormTokens: map[string][]string{
+ "csrf": {"invalid", "token"},
+ },
+ expectError: "code=403, message=invalid csrf token",
+ },
+ {
+ name: "nok, invalid token from POST form",
+ whenTokenLookup: "form:csrf",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenFormTokens: map[string][]string{
+ "csrf": {"invalid_token"},
+ },
+ expectError: "code=403, message=invalid csrf token",
+ },
+ {
+ name: "nok, missing token from POST form",
+ whenTokenLookup: "form:csrf",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenFormTokens: map[string][]string{},
+ expectError: "code=400, message=Bad Request, err=missing value in the form",
+ },
+ {
+ name: "ok, token from POST header",
+ whenTokenLookup: "", // will use defaults
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenHeaderTokens: map[string][]string{
+ echo.HeaderXCSRFToken: {"token"},
+ },
+ },
+ {
+ name: "nok, token from POST header, tokens limited to 1, second token would pass",
+ whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenHeaderTokens: map[string][]string{
+ echo.HeaderXCSRFToken: {"invalid", "token"},
+ },
+ expectError: "code=403, message=invalid csrf token",
+ },
+ {
+ name: "nok, invalid token from POST header",
+ whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenHeaderTokens: map[string][]string{
+ echo.HeaderXCSRFToken: {"invalid_token"},
+ },
+ expectError: "code=403, message=invalid csrf token",
+ },
+ {
+ name: "nok, missing token from POST header",
+ whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPost,
+ givenHeaderTokens: map[string][]string{},
+ expectError: "code=400, message=Bad Request, err=missing value in request header",
+ },
+ {
+ name: "ok, token from PUT query param",
+ whenTokenLookup: "query:csrf-param",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPut,
+ givenQueryTokens: map[string][]string{
+ "csrf-param": {"token"},
+ },
+ },
+ {
+ name: "nok, token from PUT query form, second token would pass",
+ whenTokenLookup: "query:csrf",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPut,
+ givenQueryTokens: map[string][]string{
+ "csrf": {"invalid", "token"},
+ },
+ expectError: "code=403, message=invalid csrf token",
+ },
+ {
+ name: "nok, invalid token from PUT query form",
+ whenTokenLookup: "query:csrf",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPut,
+ givenQueryTokens: map[string][]string{
+ "csrf": {"invalid_token"},
+ },
+ expectError: "code=403, message=invalid csrf token",
+ },
+ {
+ name: "nok, missing token from PUT query form",
+ whenTokenLookup: "query:csrf",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPut,
+ givenQueryTokens: map[string][]string{},
+ expectError: "code=400, message=Bad Request, err=missing value in the query string",
+ },
+ {
+ name: "nok, invalid TokenLookup",
+ whenTokenLookup: "q",
+ givenCSRFCookie: "token",
+ givenMethod: http.MethodPut,
+ givenQueryTokens: map[string][]string{},
+ expectToMiddlewareError: "extractor source for lookup could not be split into needed parts: q",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ q := make(url.Values)
+ for queryParam, values := range tc.givenQueryTokens {
+ for _, v := range values {
+ q.Add(queryParam, v)
+ }
+ }
+
+ f := make(url.Values)
+ for formKey, values := range tc.givenFormTokens {
+ for _, v := range values {
+ f.Add(formKey, v)
+ }
+ }
+
+ var req *http.Request
+ switch tc.givenMethod {
+ case http.MethodGet:
+ req = httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil)
+ case http.MethodPost, http.MethodPut:
+ req = httptest.NewRequest(http.MethodPost, "/?"+q.Encode(), strings.NewReader(f.Encode()))
+ req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
+ }
+
+ for header, values := range tc.givenHeaderTokens {
+ for _, v := range values {
+ req.Header.Add(header, v)
+ }
+ }
+
+ if tc.givenCSRFCookie != "" {
+ req.Header.Set(echo.HeaderCookie, "_csrf="+tc.givenCSRFCookie)
+ }
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ config := CSRFConfig{
+ TokenLookup: tc.whenTokenLookup,
+ CookieName: tc.whenCookieName,
+ }
+ csrf, err := config.ToMiddleware()
+ if tc.expectToMiddlewareError != "" {
+ assert.EqualError(t, err, tc.expectToMiddlewareError)
+ return
+ } else if err != nil {
+ assert.NoError(t, err)
+ }
+
+ h := csrf(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ err = h(c)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestCSRFWithConfig(t *testing.T) {
+ token := randomString(16)
+
+ var testCases = []struct {
+ name string
+ givenConfig *CSRFConfig
+ whenMethod string
+ whenHeaders map[string]string
+ expectEmptyBody bool
+ expectMWError string
+ expectCookieContains string
+ expectErr string
+ }{
+ {
+ name: "ok, GET",
+ whenMethod: http.MethodGet,
+ expectCookieContains: "_csrf",
+ },
+ {
+ name: "ok, POST valid token",
+ whenHeaders: map[string]string{
+ echo.HeaderCookie: "_csrf=" + token,
+ echo.HeaderXCSRFToken: token,
+ },
+ whenMethod: http.MethodPost,
+ expectCookieContains: "_csrf",
+ },
+ {
+ name: "nok, POST without token",
+ whenMethod: http.MethodPost,
+ expectEmptyBody: true,
+ expectErr: `code=400, message=Bad Request, err=missing value in request header`,
+ },
+ {
+ name: "nok, POST empty token",
+ whenHeaders: map[string]string{echo.HeaderXCSRFToken: ""},
+ whenMethod: http.MethodPost,
+ expectEmptyBody: true,
+ expectErr: `code=403, message=invalid csrf token`,
+ },
+ {
+ name: "nok, invalid trusted origin in Config",
+ givenConfig: &CSRFConfig{
+ TrustedOrigins: []string{"http://example.com", "invalid"},
+ },
+ expectMWError: `trusted origin is missing scheme or host: invalid`,
+ },
+ {
+ name: "ok, TokenLength",
+ givenConfig: &CSRFConfig{
+ TokenLength: 16,
+ },
+ whenMethod: http.MethodGet,
+ expectCookieContains: "_csrf",
+ },
+ {
+ name: "ok, unsafe method + SecFetchSite=same-origin passes",
+ whenHeaders: map[string]string{
+ echo.HeaderSecFetchSite: "same-origin",
+ },
+ whenMethod: http.MethodPost,
+ },
+ {
+ name: "nok, unsafe method + SecFetchSite=same-cross blocked",
+ whenHeaders: map[string]string{
+ echo.HeaderSecFetchSite: "same-cross",
+ },
+ whenMethod: http.MethodPost,
+ expectEmptyBody: true,
+ expectErr: `code=403, message=cross-site request blocked by CSRF`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(cmp.Or(tc.whenMethod, http.MethodPost), "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ for key, value := range tc.whenHeaders {
+ req.Header.Set(key, value)
+ }
+
+ config := CSRFConfig{}
+ if tc.givenConfig != nil {
+ config = *tc.givenConfig
+ }
+ mw, err := config.ToMiddleware()
+ if tc.expectMWError != "" {
+ assert.EqualError(t, err, tc.expectMWError)
+ return
+ }
+ assert.NoError(t, err)
+
+ h := mw(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ err = h(c)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ expect := "test"
+ if tc.expectEmptyBody {
+ expect = ""
+ }
+ assert.Equal(t, expect, rec.Body.String())
+
+ if tc.expectCookieContains != "" {
+ assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), tc.expectCookieContains)
+ }
+ })
+ }
+}
+
+func TestCSRF(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ csrf := CSRF()
+ h := csrf(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ // Generate CSRF token
+ h(c)
+ assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf")
+
+}
+
+func TestCSRFSetSameSiteMode(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ csrf := CSRFWithConfig(CSRFConfig{
+ CookieSameSite: http.SameSiteStrictMode,
+ })
+
+ h := csrf(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ r := h(c)
+ assert.NoError(t, r)
+ assert.Regexp(t, "SameSite=Strict", rec.Header()["Set-Cookie"])
+}
+
+func TestCSRFWithoutSameSiteMode(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ csrf := CSRFWithConfig(CSRFConfig{})
+
+ h := csrf(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ r := h(c)
+ assert.NoError(t, r)
+ assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"])
+}
+
+func TestCSRFWithSameSiteDefaultMode(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ csrf := CSRFWithConfig(CSRFConfig{
+ CookieSameSite: http.SameSiteDefaultMode,
+ })
+
+ h := csrf(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ r := h(c)
+ assert.NoError(t, r)
+ assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"])
+}
+
+func TestCSRFWithSameSiteModeNone(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ csrf, err := CSRFConfig{
+ CookieSameSite: http.SameSiteNoneMode,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ h := csrf(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ r := h(c)
+ assert.NoError(t, r)
+ assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"])
+ assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"])
+}
+
+func TestCSRFConfig_skipper(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenSkip bool
+ expectCookies int
+ }{
+ {
+ name: "do skip",
+ whenSkip: true,
+ expectCookies: 0,
+ },
+ {
+ name: "do not skip",
+ whenSkip: false,
+ expectCookies: 1,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ csrf := CSRFWithConfig(CSRFConfig{
+ Skipper: func(c *echo.Context) bool {
+ return tc.whenSkip
+ },
+ })
+
+ h := csrf(func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ })
+
+ r := h(c)
+ assert.NoError(t, r)
+ cookie := rec.Header()["Set-Cookie"]
+ assert.Len(t, cookie, tc.expectCookies)
+ })
+ }
+}
+
+func TestCSRFErrorHandling(t *testing.T) {
+ cfg := CSRFConfig{
+ ErrorHandler: func(c *echo.Context, err error) error {
+ return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed")
+ },
+ }
+
+ e := echo.New()
+ e.POST("/", func(c *echo.Context) error {
+ return c.String(http.StatusNotImplemented, "should not end up here")
+ })
+
+ e.Use(CSRFWithConfig(cfg))
+
+ req := httptest.NewRequest(http.MethodPost, "/", nil)
+ res := httptest.NewRecorder()
+ e.ServeHTTP(res, req)
+
+ assert.Equal(t, http.StatusTeapot, res.Code)
+ assert.Equal(t, "{\"message\":\"error_handler_executed\"}\n", res.Body.String())
+}
+
+func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenConfig CSRFConfig
+ whenMethod string
+ whenSecFetchSite string
+ whenOrigin string
+ expectAllow bool
+ expectErr string
+ }{
+ {
+ name: "ok, unsafe POST, no SecFetchSite is not blocked",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "",
+ expectAllow: false, // should fall back to token CSRF
+ },
+ {
+ name: "ok, safe GET + same-origin passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodGet,
+ whenSecFetchSite: "same-origin",
+ expectAllow: true,
+ },
+ {
+ name: "ok, safe GET + none passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodGet,
+ whenSecFetchSite: "none",
+ expectAllow: true,
+ },
+ {
+ name: "ok, safe GET + same-site passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodGet,
+ whenSecFetchSite: "same-site",
+ expectAllow: true,
+ },
+ {
+ name: "ok, safe GET + cross-site passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodGet,
+ whenSecFetchSite: "cross-site",
+ expectAllow: true,
+ },
+ {
+ name: "nok, unsafe POST + cross-site is blocked",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ expectAllow: false,
+ expectErr: `code=403, message=cross-site request blocked by CSRF`,
+ },
+ {
+ name: "nok, unsafe POST + same-site is blocked",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "same-site",
+ expectAllow: false,
+ expectErr: `code=403, message=same-site request blocked by CSRF`,
+ },
+ {
+ name: "ok, unsafe POST + same-origin passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "same-origin",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe POST + none passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "none",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe PUT + same-origin passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPut,
+ whenSecFetchSite: "same-origin",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe PUT + none passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPut,
+ whenSecFetchSite: "none",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe DELETE + same-origin passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodDelete,
+ whenSecFetchSite: "same-origin",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe PATCH + same-origin passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPatch,
+ whenSecFetchSite: "same-origin",
+ expectAllow: true,
+ },
+ {
+ name: "nok, unsafe PUT + cross-site is blocked",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPut,
+ whenSecFetchSite: "cross-site",
+ expectAllow: false,
+ expectErr: `code=403, message=cross-site request blocked by CSRF`,
+ },
+ {
+ name: "nok, unsafe PUT + same-site is blocked",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPut,
+ whenSecFetchSite: "same-site",
+ expectAllow: false,
+ expectErr: `code=403, message=same-site request blocked by CSRF`,
+ },
+ {
+ name: "nok, unsafe DELETE + cross-site is blocked",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodDelete,
+ whenSecFetchSite: "cross-site",
+ expectAllow: false,
+ expectErr: `code=403, message=cross-site request blocked by CSRF`,
+ },
+ {
+ name: "nok, unsafe DELETE + same-site is blocked",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodDelete,
+ whenSecFetchSite: "same-site",
+ expectAllow: false,
+ expectErr: `code=403, message=same-site request blocked by CSRF`,
+ },
+ {
+ name: "nok, unsafe PATCH + cross-site is blocked",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPatch,
+ whenSecFetchSite: "cross-site",
+ expectAllow: false,
+ expectErr: `code=403, message=cross-site request blocked by CSRF`,
+ },
+ {
+ name: "ok, safe HEAD + same-origin passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodHead,
+ whenSecFetchSite: "same-origin",
+ expectAllow: true,
+ },
+ {
+ name: "ok, safe HEAD + cross-site passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodHead,
+ whenSecFetchSite: "cross-site",
+ expectAllow: true,
+ },
+ {
+ name: "ok, safe OPTIONS + cross-site passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodOptions,
+ whenSecFetchSite: "cross-site",
+ expectAllow: true,
+ },
+ {
+ name: "ok, safe TRACE + cross-site passes",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodTrace,
+ whenSecFetchSite: "cross-site",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe POST + cross-site + matching trusted origin passes",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://trusted.example.com"},
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ whenOrigin: "https://trusted.example.com",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe POST + same-site + matching trusted origin passes",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://trusted.example.com"},
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "same-site",
+ whenOrigin: "https://trusted.example.com",
+ expectAllow: true,
+ },
+ {
+ name: "nok, unsafe POST + cross-site + non-matching origin is blocked",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://trusted.example.com"},
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ whenOrigin: "https://evil.example.com",
+ expectAllow: false,
+ expectErr: `code=403, message=cross-site request blocked by CSRF`,
+ },
+ {
+ name: "ok, unsafe POST + cross-site + case-insensitive trusted origin match passes",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://trusted.example.com"},
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ whenOrigin: "https://TRUSTED.example.com",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe POST + same-origin + trusted origins configured but not matched passes",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://trusted.example.com"},
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "same-origin",
+ whenOrigin: "https://different.example.com",
+ expectAllow: true,
+ },
+ {
+ name: "nok, unsafe POST + cross-site + empty origin + trusted origins configured is blocked",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://trusted.example.com"},
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ whenOrigin: "",
+ expectAllow: false,
+ expectErr: `code=403, message=cross-site request blocked by CSRF`,
+ },
+ {
+ name: "ok, unsafe POST + cross-site + multiple trusted origins, second one matches",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://first.example.com", "https://second.example.com"},
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ whenOrigin: "https://second.example.com",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe POST + same-site + custom func allows",
+ givenConfig: CSRFConfig{
+ AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ return true, nil
+ },
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "same-site",
+ expectAllow: true,
+ },
+ {
+ name: "ok, unsafe POST + cross-site + custom func allows",
+ givenConfig: CSRFConfig{
+ AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ return true, nil
+ },
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ expectAllow: true,
+ },
+ {
+ name: "nok, unsafe POST + same-site + custom func returns custom error",
+ givenConfig: CSRFConfig{
+ AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ return false, echo.NewHTTPError(http.StatusTeapot, "custom error from func")
+ },
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "same-site",
+ expectAllow: false,
+ expectErr: `code=418, message=custom error from func`,
+ },
+ {
+ name: "nok, unsafe POST + cross-site + custom func returns false with nil error",
+ givenConfig: CSRFConfig{
+ AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ return false, nil
+ },
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ expectAllow: false,
+ expectErr: "", // custom func returns nil error, so no error expected
+ },
+ {
+ name: "nok, unsafe POST + invalid Sec-Fetch-Site value treated as cross-site",
+ givenConfig: CSRFConfig{},
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "invalid-value",
+ expectAllow: false,
+ expectErr: `code=403, message=cross-site request blocked by CSRF`,
+ },
+ {
+ name: "ok, unsafe POST + cross-site + trusted origin takes precedence over custom func",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://trusted.example.com"},
+ AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ return false, echo.NewHTTPError(http.StatusTeapot, "should not be called")
+ },
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ whenOrigin: "https://trusted.example.com",
+ expectAllow: true,
+ },
+ {
+ name: "nok, unsafe POST + cross-site + trusted origin not matched, custom func blocks",
+ givenConfig: CSRFConfig{
+ TrustedOrigins: []string{"https://trusted.example.com"},
+ AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) {
+ return false, echo.NewHTTPError(http.StatusTeapot, "custom block")
+ },
+ },
+ whenMethod: http.MethodPost,
+ whenSecFetchSite: "cross-site",
+ whenOrigin: "https://evil.example.com",
+ expectAllow: false,
+ expectErr: `code=418, message=custom block`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(tc.whenMethod, "/", nil)
+ if tc.whenSecFetchSite != "" {
+ req.Header.Set(echo.HeaderSecFetchSite, tc.whenSecFetchSite)
+ }
+ if tc.whenOrigin != "" {
+ req.Header.Set(echo.HeaderOrigin, tc.whenOrigin)
+ }
+
+ res := httptest.NewRecorder()
+ c := echo.NewContext(req, res)
+
+ allow, err := tc.givenConfig.checkSecFetchSiteRequest(c)
+
+ assert.Equal(t, tc.expectAllow, allow)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
diff --git a/middleware/decompress.go b/middleware/decompress.go
new file mode 100644
index 000000000..a384af2ea
--- /dev/null
+++ b/middleware/decompress.go
@@ -0,0 +1,155 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "compress/gzip"
+ "io"
+ "net/http"
+ "sync"
+
+ "github.com/labstack/echo/v5"
+)
+
+// DecompressConfig defines the config for Decompress middleware.
+type DecompressConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
+ GzipDecompressPool Decompressor
+
+ // MaxDecompressedSize limits the maximum size of decompressed request body in bytes.
+ // If the decompressed body exceeds this limit, the middleware returns HTTP 413 error.
+ // This prevents zip bomb attacks where small compressed payloads decompress to huge sizes.
+ // Default: 100 * MB (104,857,600 bytes)
+ // Set to -1 to disable limits (not recommended in production).
+ MaxDecompressedSize int64
+}
+
+// GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
+const GZIPEncoding string = "gzip"
+
+// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
+type Decompressor interface {
+ gzipDecompressPool() sync.Pool
+}
+
+// DefaultGzipDecompressPool is the default implementation of Decompressor interface
+type DefaultGzipDecompressPool struct {
+}
+
+func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
+ return sync.Pool{New: func() any { return new(gzip.Reader) }}
+}
+
+// Decompress decompresses request body based if content encoding type is set to "gzip" with default config
+//
+// SECURITY: By default, this limits decompressed data to 100MB to prevent zip bomb attacks.
+// To customize the limit, use DecompressWithConfig. To disable limits (not recommended in production),
+// set MaxDecompressedSize to -1.
+func Decompress() echo.MiddlewareFunc {
+ return DecompressWithConfig(DecompressConfig{})
+}
+
+// DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration.
+//
+// SECURITY: If MaxDecompressedSize is not set (zero value), it defaults to 100MB to prevent
+// DoS attacks via zip bombs. Set to -1 to explicitly disable limits if needed for your use case.
+func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration
+func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Skipper == nil {
+ config.Skipper = DefaultSkipper
+ }
+ if config.GzipDecompressPool == nil {
+ config.GzipDecompressPool = &DefaultGzipDecompressPool{}
+ }
+ // Apply secure default for decompression limit
+ if config.MaxDecompressedSize == 0 {
+ config.MaxDecompressedSize = 100 * MB
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ pool := config.GzipDecompressPool.gzipDecompressPool()
+
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ if c.Request().Header.Get(echo.HeaderContentEncoding) != GZIPEncoding {
+ return next(c)
+ }
+
+ i := pool.Get()
+ gr, ok := i.(*gzip.Reader)
+ if !ok || gr == nil {
+ if err, isErr := i.(error); isErr {
+ return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
+ }
+ return echo.NewHTTPError(http.StatusInternalServerError, "unexpected type from gzip decompression pool")
+ }
+ defer pool.Put(gr)
+
+ b := c.Request().Body
+ defer b.Close()
+
+ if err := gr.Reset(b); err != nil {
+ if err == io.EOF { //ignore if body is empty
+ return next(c)
+ }
+ return err
+ }
+
+ // only Close gzip reader if it was set to a proper gzip source otherwise it will panic on close.
+ defer gr.Close()
+
+ // Apply decompression size limit to prevent zip bombs
+ if config.MaxDecompressedSize > 0 {
+ c.Request().Body = &limitedGzipReader{
+ Reader: gr,
+ remaining: config.MaxDecompressedSize,
+ limit: config.MaxDecompressedSize,
+ }
+ } else {
+ // -1 means explicitly unlimited (not recommended)
+ c.Request().Body = gr
+ }
+
+ return next(c)
+ }
+ }, nil
+}
+
+// limitedGzipReader wraps a gzip reader with size limiting to prevent zip bombs
+type limitedGzipReader struct {
+ *gzip.Reader
+ remaining int64
+ limit int64
+}
+
+func (r *limitedGzipReader) Read(p []byte) (n int, err error) {
+ if r.remaining <= 0 {
+ // Limit exceeded - return 413 error
+ return 0, echo.ErrStatusRequestEntityTooLarge
+ }
+
+ // Limit the read to remaining bytes
+ if int64(len(p)) > r.remaining {
+ p = p[:r.remaining]
+ }
+
+ n, err = r.Reader.Read(p)
+ r.remaining -= int64(n)
+
+ return n, err
+}
+
+func (r *limitedGzipReader) Close() error {
+ return r.Reader.Close()
+}
diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go
new file mode 100644
index 000000000..1823e94bb
--- /dev/null
+++ b/middleware/decompress_test.go
@@ -0,0 +1,508 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "bytes"
+ "compress/gzip"
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestDecompress(t *testing.T) {
+ e := echo.New()
+
+ h := Decompress()(func(c *echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })
+
+ // Decompress request body
+ body := `{"name": "echo"}`
+ gz, _ := gzipString(body)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := h(c)
+ assert.NoError(t, err)
+
+ assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
+ b, err := io.ReadAll(req.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, body, string(b))
+}
+
+func TestDecompress_skippedIfNoHeader(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ // Skip if no Content-Encoding header
+ h := Decompress()(func(c *echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })
+
+ err := h(c)
+ assert.NoError(t, err)
+ assert.Equal(t, "test", rec.Body.String())
+
+}
+
+func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{}.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })(c)
+ assert.NoError(t, err)
+
+ assert.Equal(t, "test", rec.Body.String())
+
+}
+
+func TestDecompressWithConfig_DefaultConfig(t *testing.T) {
+ e := echo.New()
+
+ h := Decompress()(func(c *echo.Context) error {
+ c.Response().Write([]byte("test")) // For Content-Type sniffing
+ return nil
+ })
+
+ // Decompress
+ body := `{"name": "echo"}`
+ gz, _ := gzipString(body)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := h(c)
+ assert.NoError(t, err)
+
+ assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
+ b, err := io.ReadAll(req.Body)
+ assert.NoError(t, err)
+ assert.Equal(t, body, string(b))
+}
+
+func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
+ e := echo.New()
+ body := `{"name":"echo"}`
+ gz, _ := gzipString(body)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ e.NewContext(req, rec)
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
+ b, err := io.ReadAll(req.Body)
+ assert.NoError(t, err)
+ assert.NotEqual(t, b, body)
+ assert.Equal(t, b, gz)
+}
+
+func TestDecompressNoContent(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := Decompress()(func(c *echo.Context) error {
+ return c.NoContent(http.StatusNoContent)
+ })
+
+ err := h(c)
+
+ if assert.NoError(t, err) {
+ assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
+ assert.Empty(t, rec.Header().Get(echo.HeaderContentType))
+ assert.Equal(t, 0, len(rec.Body.Bytes()))
+ }
+}
+
+func TestDecompressErrorReturned(t *testing.T) {
+ e := echo.New()
+ e.Use(Decompress())
+ e.GET("/", func(c *echo.Context) error {
+ return echo.ErrNotFound
+ })
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusNotFound, rec.Code)
+ assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
+}
+
+func TestDecompressSkipper(t *testing.T) {
+ e := echo.New()
+ e.Use(DecompressWithConfig(DecompressConfig{
+ Skipper: func(c *echo.Context) bool {
+ return c.Request().URL.Path == "/skip"
+ },
+ }))
+ body := `{"name": "echo"}`
+ req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSON)
+ reqBody, err := io.ReadAll(c.Request().Body)
+ assert.NoError(t, err)
+ assert.Equal(t, body, string(reqBody))
+}
+
+type TestDecompressPoolWithError struct {
+}
+
+func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool {
+ return sync.Pool{
+ New: func() any {
+ return errors.New("pool error")
+ },
+ }
+}
+
+func TestDecompressPoolError(t *testing.T) {
+ e := echo.New()
+ e.Use(DecompressWithConfig(DecompressConfig{
+ Skipper: DefaultSkipper,
+ GzipDecompressPool: &TestDecompressPoolWithError{},
+ }))
+ body := `{"name": "echo"}`
+ req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
+ reqBody, err := io.ReadAll(c.Request().Body)
+ assert.NoError(t, err)
+ assert.Equal(t, body, string(reqBody))
+ assert.Equal(t, rec.Code, http.StatusInternalServerError)
+}
+
+func BenchmarkDecompress(b *testing.B) {
+ e := echo.New()
+ body := `{"name": "echo"}`
+ gz, _ := gzipString(body)
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+
+ h := Decompress()(func(c *echo.Context) error {
+ c.Response().Write([]byte(body)) // For Content-Type sniffing
+ return nil
+ })
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ // Decompress
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h(c)
+ }
+}
+
+func gzipString(body string) ([]byte, error) {
+ var buf bytes.Buffer
+ gz := gzip.NewWriter(&buf)
+
+ _, err := gz.Write([]byte(body))
+ if err != nil {
+ return nil, err
+ }
+
+ if err := gz.Close(); err != nil {
+ return nil, err
+ }
+
+ return buf.Bytes(), nil
+}
+
+func TestDecompress_WithinLimit(t *testing.T) {
+ e := echo.New()
+ body := strings.Repeat("test data ", 100) // Small payload ~1KB
+ gz, _ := gzipString(body)
+
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{MaxDecompressedSize: 100 * MB}.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ b, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(b))
+ })(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, body, rec.Body.String())
+}
+
+func TestDecompress_ExceedsLimit(t *testing.T) {
+ e := echo.New()
+ // Create 2KB of data but limit to 1KB
+ largeBody := strings.Repeat("A", 2*1024)
+ gz, _ := gzipString(largeBody)
+
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ _, readErr := io.ReadAll(c.Request().Body)
+ return readErr
+ })(c)
+
+ // Should return 413 error
+ assert.Error(t, err)
+ he, ok := err.(echo.HTTPStatusCoder)
+ assert.True(t, ok)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
+}
+
+func TestDecompress_AtExactLimit(t *testing.T) {
+ e := echo.New()
+ exactBody := strings.Repeat("B", 1024) // Exactly 1KB
+ gz, _ := gzipString(exactBody)
+
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ b, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(b))
+ })(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, exactBody, rec.Body.String())
+}
+
+func TestDecompress_ZipBomb(t *testing.T) {
+ e := echo.New()
+ // Create highly compressed data that expands to 2MB
+ // but limit is 1MB
+ largeBody := bytes.Repeat([]byte("A"), 2*1024*1024) // 2MB
+ var buf bytes.Buffer
+ gzWriter := gzip.NewWriter(&buf)
+ gzWriter.Write(largeBody)
+ gzWriter.Close()
+
+ req := httptest.NewRequest(http.MethodPost, "/", &buf)
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{MaxDecompressedSize: 1 * MB}.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ _, readErr := io.ReadAll(c.Request().Body)
+ return readErr
+ })(c)
+
+ // Should return 413 error
+ assert.Error(t, err)
+ he, ok := err.(echo.HTTPStatusCoder)
+ assert.True(t, ok)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
+}
+
+func TestDecompress_UnlimitedExplicit(t *testing.T) {
+ e := echo.New()
+ largeBody := strings.Repeat("X", 10*1024) // 10KB
+ gz, _ := gzipString(largeBody)
+
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{MaxDecompressedSize: -1}.ToMiddleware() // Unlimited
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ b, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(b))
+ })(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, largeBody, rec.Body.String())
+}
+
+func TestDecompress_DefaultLimit(t *testing.T) {
+ e := echo.New()
+ smallBody := "test"
+ gz, _ := gzipString(smallBody)
+
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ // Use zero value which should apply 100MB default
+ h, err := DecompressConfig{}.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ b, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(b))
+ })(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, smallBody, rec.Body.String())
+}
+
+func TestDecompress_SmallCustomLimit(t *testing.T) {
+ e := echo.New()
+ body := strings.Repeat("D", 512) // 512 bytes
+ gz, _ := gzipString(body)
+
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ b, _ := io.ReadAll(c.Request().Body)
+ return c.String(http.StatusOK, string(b))
+ })(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, body, rec.Body.String())
+}
+
+func TestDecompress_MultipleReads(t *testing.T) {
+ e := echo.New()
+ // Test that limit is enforced across multiple Read() calls
+ largeBody := strings.Repeat("M", 2*1024) // 2KB
+ gz, _ := gzipString(largeBody)
+
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ // Read in small chunks
+ buf := make([]byte, 256)
+ total := 0
+ for {
+ n, readErr := c.Request().Body.Read(buf)
+ total += n
+ if readErr != nil {
+ if readErr == io.EOF {
+ return nil
+ }
+ return readErr
+ }
+ }
+ })(c)
+
+ // Should return 413 error from cumulative reads
+ assert.Error(t, err)
+ he, ok := err.(echo.HTTPStatusCoder)
+ assert.True(t, ok)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
+}
+
+func TestDecompress_LargePayloadDosPrevention(t *testing.T) {
+ e := echo.New()
+ // Simulate a DoS attack with highly compressed large payload
+ largeSize := 10 * 1024 * 1024 // 10MB decompressed
+ largeBody := bytes.Repeat([]byte("Z"), largeSize)
+ var buf bytes.Buffer
+ gzWriter := gzip.NewWriter(&buf)
+ gzWriter.Write(largeBody)
+ gzWriter.Close()
+
+ req := httptest.NewRequest(http.MethodPost, "/", &buf)
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h, err := DecompressConfig{MaxDecompressedSize: 1 * MB}.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = h(func(c *echo.Context) error {
+ _, readErr := io.ReadAll(c.Request().Body)
+ return readErr
+ })(c)
+
+ // Should prevent DoS by returning 413
+ assert.Error(t, err)
+ he, ok := err.(echo.HTTPStatusCoder)
+ assert.True(t, ok)
+ assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode())
+}
+
+func BenchmarkDecompress_WithLimit(b *testing.B) {
+ e := echo.New()
+ body := strings.Repeat("benchmark data ", 1000) // ~15KB
+ gz, _ := gzipString(body)
+
+ h, _ := DecompressConfig{MaxDecompressedSize: 100 * MB}.ToMiddleware()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz))
+ req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ h(func(c *echo.Context) error {
+ io.ReadAll(c.Request().Body)
+ return nil
+ })(c)
+ }
+}
diff --git a/middleware/extractor.go b/middleware/extractor.go
new file mode 100644
index 000000000..abb603186
--- /dev/null
+++ b/middleware/extractor.go
@@ -0,0 +1,253 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "fmt"
+ "net/textproto"
+ "strings"
+
+ "github.com/labstack/echo/v5"
+)
+
+const (
+ // extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion
+ // attack vector
+ extractorLimit = 20
+)
+
+// ExtractorSource is type to indicate source for extracted value
+type ExtractorSource string
+
+const (
+ // ExtractorSourceHeader means value was extracted from request header
+ ExtractorSourceHeader ExtractorSource = "header"
+ // ExtractorSourceQuery means value was extracted from request query parameters
+ ExtractorSourceQuery ExtractorSource = "query"
+ // ExtractorSourcePathParam means value was extracted from route path parameters
+ ExtractorSourcePathParam ExtractorSource = "param"
+ // ExtractorSourceCookie means value was extracted from request cookies
+ ExtractorSourceCookie ExtractorSource = "cookie"
+ // ExtractorSourceForm means value was extracted from request form values
+ ExtractorSourceForm ExtractorSource = "form"
+)
+
+// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups
+type ValueExtractorError struct {
+ message string
+}
+
+// Error returns errors text
+func (e *ValueExtractorError) Error() string {
+ return e.message
+}
+
+var errHeaderExtractorValueMissing = &ValueExtractorError{message: "missing value in request header"}
+var errHeaderExtractorValueInvalid = &ValueExtractorError{message: "invalid value in request header"}
+var errQueryExtractorValueMissing = &ValueExtractorError{message: "missing value in the query string"}
+var errParamExtractorValueMissing = &ValueExtractorError{message: "missing value in path params"}
+var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing value in cookies"}
+var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"}
+
+// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context.
+type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error)
+
+// CreateExtractors creates ValuesExtractors from given lookups.
+// lookups is a string in the form of ":" or ":,:" that is used
+// to extract key from the request.
+// Possible values:
+// - "header:" or "header::"
+// `` is argument value to cut/trim prefix of the extracted value. This is useful if header
+// value has static prefix like `Authorization: ` where part that we
+// want to cut is ` ` note the space at the end.
+// In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `.
+// - "query:"
+// - "param:"
+// - "form:"
+// - "cookie:"
+//
+// Multiple sources example:
+// - "header:Authorization,header:X-Api-Key"
+//
+// limit sets the maximum amount how many lookups can be returned.
+func CreateExtractors(lookups string, limit uint) ([]ValuesExtractor, error) {
+ return createExtractors(lookups, limit)
+}
+
+func createExtractors(lookups string, limit uint) ([]ValuesExtractor, error) {
+ if lookups == "" {
+ return nil, nil
+ }
+ if limit == 0 {
+ limit = 1
+ } else if limit > extractorLimit {
+ limit = extractorLimit
+ }
+
+ sources := strings.Split(lookups, ",")
+ var extractors = make([]ValuesExtractor, 0)
+ for _, source := range sources {
+ parts := strings.Split(source, ":")
+ if len(parts) < 2 {
+ return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source)
+ }
+
+ switch parts[0] {
+ case "query":
+ extractors = append(extractors, valuesFromQuery(parts[1], limit))
+ case "param":
+ extractors = append(extractors, valuesFromParam(parts[1], limit))
+ case "cookie":
+ extractors = append(extractors, valuesFromCookie(parts[1], limit))
+ case "form":
+ extractors = append(extractors, valuesFromForm(parts[1], limit))
+ case "header":
+ prefix := ""
+ if len(parts) > 2 {
+ prefix = parts[2]
+ }
+ extractors = append(extractors, valuesFromHeader(parts[1], prefix, limit))
+ }
+ }
+ return extractors, nil
+}
+
+// valuesFromHeader returns a functions that extracts values from the request header.
+// valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static
+// prefix like `Authorization: ` where part that we want to remove is ` `
+// note the space at the end. In case of basic authentication `Authorization: Basic ` prefix we want to remove
+// is `Basic `. In case of JWT tokens `Authorization: Bearer ` prefix is `Bearer `.
+// If prefix is left empty the whole value is returned.
+func valuesFromHeader(header string, valuePrefix string, limit uint) ValuesExtractor {
+ prefixLen := len(valuePrefix)
+ // standard library parses http.Request header keys in canonical form but we may provide something else so fix this
+ header = textproto.CanonicalMIMEHeaderKey(header)
+ if limit == 0 {
+ limit = 1
+ }
+ return func(c *echo.Context) ([]string, ExtractorSource, error) {
+ values := c.Request().Header.Values(header)
+ if len(values) == 0 {
+ return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing
+ }
+
+ i := uint(0)
+ result := make([]string, 0)
+ for _, value := range values {
+ if prefixLen == 0 {
+ result = append(result, value)
+ i++
+ if i >= limit {
+ break
+ }
+ } else if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) {
+ result = append(result, value[prefixLen:])
+ i++
+ if i >= limit {
+ break
+ }
+ }
+ }
+
+ if len(result) == 0 {
+ if prefixLen > 0 {
+ return nil, ExtractorSourceHeader, errHeaderExtractorValueInvalid
+ }
+ return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing
+ }
+ return result, ExtractorSourceHeader, nil
+ }
+}
+
+// valuesFromQuery returns a function that extracts values from the query string.
+func valuesFromQuery(param string, limit uint) ValuesExtractor {
+ if limit == 0 {
+ limit = 1
+ }
+ return func(c *echo.Context) ([]string, ExtractorSource, error) {
+ result := c.QueryParams()[param]
+ if len(result) == 0 {
+ return nil, ExtractorSourceQuery, errQueryExtractorValueMissing
+ } else if len(result) > int(limit)-1 {
+ result = result[:limit]
+ }
+ return result, ExtractorSourceQuery, nil
+ }
+}
+
+// valuesFromParam returns a function that extracts values from the url param string.
+func valuesFromParam(param string, limit uint) ValuesExtractor {
+ if limit == 0 {
+ limit = 1
+ }
+ return func(c *echo.Context) ([]string, ExtractorSource, error) {
+ result := make([]string, 0)
+ i := uint(0)
+ for _, p := range c.PathValues() {
+ if param != p.Name {
+ continue
+ }
+ result = append(result, p.Value)
+ i++
+ if i >= limit {
+ break
+ }
+ }
+ if len(result) == 0 {
+ return nil, ExtractorSourcePathParam, errParamExtractorValueMissing
+ }
+ return result, ExtractorSourcePathParam, nil
+ }
+}
+
+// valuesFromCookie returns a function that extracts values from the named cookie.
+func valuesFromCookie(name string, limit uint) ValuesExtractor {
+ if limit == 0 {
+ limit = 1
+ }
+ return func(c *echo.Context) ([]string, ExtractorSource, error) {
+ cookies := c.Cookies()
+ if len(cookies) == 0 {
+ return nil, ExtractorSourceCookie, errCookieExtractorValueMissing
+ }
+
+ i := uint(0)
+ result := make([]string, 0)
+ for _, cookie := range cookies {
+ if name != cookie.Name {
+ continue
+ }
+ result = append(result, cookie.Value)
+ i++
+ if i >= limit {
+ break
+ }
+ }
+ if len(result) == 0 {
+ return nil, ExtractorSourceCookie, errCookieExtractorValueMissing
+ }
+ return result, ExtractorSourceCookie, nil
+ }
+}
+
+// valuesFromForm returns a function that extracts values from the form field.
+func valuesFromForm(name string, limit uint) ValuesExtractor {
+ if limit == 0 {
+ limit = 1
+ }
+ return func(c *echo.Context) ([]string, ExtractorSource, error) {
+ if c.Request().Form == nil {
+ _, _ = c.MultipartForm() // we want to trigger c.request.ParseMultipartForm(c.formParseMaxMemory)
+ }
+ values := c.Request().Form[name]
+ if len(values) == 0 {
+ return nil, ExtractorSourceForm, errFormExtractorValueMissing
+ }
+ if len(values) > int(limit)-1 {
+ values = values[:limit]
+ }
+ result := append([]string{}, values...)
+ return result, ExtractorSourceForm, nil
+ }
+}
diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go
new file mode 100644
index 000000000..04cc7b829
--- /dev/null
+++ b/middleware/extractor_test.go
@@ -0,0 +1,625 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "bytes"
+ "fmt"
+ "mime/multipart"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestCreateExtractors(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenRequest func() *http.Request
+ givenPathValues echo.PathValues
+ whenLookups string
+ whenLimit uint
+ expectValues []string
+ expectSource ExtractorSource
+ expectCreateError string
+ expectError string
+ }{
+ {
+ name: "ok, header",
+ givenRequest: func() *http.Request {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAuthorization, "Bearer token")
+ return req
+ },
+ whenLookups: "header:Authorization:Bearer ",
+ expectValues: []string{"token"},
+ expectSource: ExtractorSourceHeader,
+ },
+ {
+ name: "ok, form",
+ givenRequest: func() *http.Request {
+ f := make(url.Values)
+ f.Set("name", "Jon Snow")
+
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
+ req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
+ return req
+ },
+ whenLookups: "form:name",
+ expectValues: []string{"Jon Snow"},
+ expectSource: ExtractorSourceForm,
+ },
+ {
+ name: "ok, cookie",
+ givenRequest: func() *http.Request {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderCookie, "_csrf=token")
+ return req
+ },
+ whenLookups: "cookie:_csrf",
+ expectValues: []string{"token"},
+ expectSource: ExtractorSourceCookie,
+ },
+ {
+ name: "ok, param",
+ givenPathValues: echo.PathValues{
+ {Name: "id", Value: "123"},
+ },
+ whenLookups: "param:id",
+ expectValues: []string{"123"},
+ expectSource: ExtractorSourcePathParam,
+ },
+ {
+ name: "ok, query",
+ givenRequest: func() *http.Request {
+ req := httptest.NewRequest(http.MethodGet, "/?id=999", nil)
+ return req
+ },
+ whenLookups: "query:id",
+ expectValues: []string{"999"},
+ expectSource: ExtractorSourceQuery,
+ },
+ {
+ name: "nok, invalid lookup",
+ whenLookups: "query",
+ expectCreateError: "extractor source for lookup could not be split into needed parts: query",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ if tc.givenRequest != nil {
+ req = tc.givenRequest()
+ }
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ if tc.givenPathValues != nil {
+ c.SetPathValues(tc.givenPathValues)
+ }
+
+ extractors, err := CreateExtractors(tc.whenLookups, tc.whenLimit)
+ if tc.expectCreateError != "" {
+ assert.EqualError(t, err, tc.expectCreateError)
+ return
+ }
+ assert.NoError(t, err)
+
+ for _, e := range extractors {
+ values, source, eErr := e(c)
+ assert.Equal(t, tc.expectValues, values)
+ assert.Equal(t, tc.expectSource, source)
+ if tc.expectError != "" {
+ assert.EqualError(t, eErr, tc.expectError)
+ return
+ }
+ assert.NoError(t, eErr)
+ }
+ })
+ }
+}
+
+func TestValuesFromHeader(t *testing.T) {
+ exampleRequest := func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
+ }
+
+ var testCases = []struct {
+ name string
+ givenRequest func(req *http.Request)
+ whenName string
+ whenValuePrefix string
+ whenLimit uint
+ expectValues []string
+ expectError string
+ }{
+ {
+ name: "ok, single value",
+ givenRequest: exampleRequest,
+ whenName: echo.HeaderAuthorization,
+ whenValuePrefix: "basic ",
+ expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
+ },
+ {
+ name: "ok, single value, case insensitive",
+ givenRequest: exampleRequest,
+ whenName: echo.HeaderAuthorization,
+ whenValuePrefix: "Basic ",
+ expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
+ },
+ {
+ name: "ok, multiple value",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
+ req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
+ },
+ whenName: echo.HeaderAuthorization,
+ whenValuePrefix: "basic ",
+ whenLimit: 2,
+ expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"},
+ },
+ {
+ name: "ok, empty prefix",
+ givenRequest: exampleRequest,
+ whenName: echo.HeaderAuthorization,
+ whenValuePrefix: "",
+ expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="},
+ },
+ {
+ name: "nok, no matching due different prefix",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
+ req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
+ },
+ whenName: echo.HeaderAuthorization,
+ whenValuePrefix: "Bearer ",
+ expectError: errHeaderExtractorValueInvalid.Error(),
+ },
+ {
+ name: "nok, no matching due different prefix",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
+ req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
+ },
+ whenName: echo.HeaderWWWAuthenticate,
+ whenValuePrefix: "",
+ expectError: errHeaderExtractorValueMissing.Error(),
+ },
+ {
+ name: "nok, no headers",
+ givenRequest: nil,
+ whenName: echo.HeaderAuthorization,
+ whenValuePrefix: "basic ",
+ expectError: errHeaderExtractorValueMissing.Error(),
+ },
+ {
+ name: "ok, prefix, cut values over extractorLimit",
+ givenRequest: func(req *http.Request) {
+ for i := 1; i <= 25; i++ {
+ req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("basic %v", i))
+ }
+ },
+ whenName: echo.HeaderAuthorization,
+ whenValuePrefix: "basic ",
+ whenLimit: extractorLimit,
+ expectValues: []string{
+ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
+ "11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
+ },
+ },
+ {
+ name: "ok, cut values over extractorLimit",
+ givenRequest: func(req *http.Request) {
+ for i := 1; i <= 25; i++ {
+ req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("%v", i))
+ }
+ },
+ whenName: echo.HeaderAuthorization,
+ whenValuePrefix: "",
+ whenLimit: extractorLimit,
+ expectValues: []string{
+ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
+ "11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ if tc.givenRequest != nil {
+ tc.givenRequest(req)
+ }
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix, tc.whenLimit)
+
+ values, source, err := extractor(c)
+ assert.Equal(t, tc.expectValues, values)
+ assert.Equal(t, ExtractorSourceHeader, source)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValuesFromQuery(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenQueryPart string
+ whenName string
+ whenLimit uint
+ expectValues []string
+ expectError string
+ }{
+ {
+ name: "ok, single value",
+ givenQueryPart: "?id=123&name=test",
+ whenName: "id",
+ expectValues: []string{"123"},
+ },
+ {
+ name: "ok, multiple value",
+ givenQueryPart: "?id=123&id=456&name=test",
+ whenName: "id",
+ whenLimit: 2,
+ expectValues: []string{"123", "456"},
+ },
+ {
+ name: "nok, missing value",
+ givenQueryPart: "?id=123&name=test",
+ whenName: "nope",
+ expectError: errQueryExtractorValueMissing.Error(),
+ },
+ {
+ name: "ok, cut values over extractorLimit",
+ givenQueryPart: "?name=test" +
+ "&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" +
+ "&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" +
+ "&id=21&id=22&id=23&id=24&id=25",
+ whenName: "id",
+ whenLimit: extractorLimit,
+ expectValues: []string{
+ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
+ "11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ extractor := valuesFromQuery(tc.whenName, tc.whenLimit)
+
+ values, source, err := extractor(c)
+ assert.Equal(t, tc.expectValues, values)
+ assert.Equal(t, ExtractorSourceQuery, source)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValuesFromParam(t *testing.T) {
+ examplePathValues := echo.PathValues{
+ {Name: "id", Value: "123"},
+ {Name: "gid", Value: "456"},
+ {Name: "gid", Value: "789"},
+ }
+ examplePathValues20 := make(echo.PathValues, 0)
+ for i := 1; i < 25; i++ {
+ examplePathValues20 = append(examplePathValues20, echo.PathValue{Name: "id", Value: fmt.Sprintf("%v", i)})
+ }
+
+ var testCases = []struct {
+ name string
+ givenPathValues echo.PathValues
+ whenName string
+ whenLimit uint
+ expectValues []string
+ expectError string
+ }{
+ {
+ name: "ok, single value",
+ givenPathValues: examplePathValues,
+ whenName: "id",
+ expectValues: []string{"123"},
+ },
+ {
+ name: "ok, multiple value",
+ givenPathValues: examplePathValues,
+ whenName: "gid",
+ whenLimit: 2,
+ expectValues: []string{"456", "789"},
+ },
+ {
+ name: "nok, no values",
+ givenPathValues: nil,
+ whenName: "nope",
+ expectValues: nil,
+ expectError: errParamExtractorValueMissing.Error(),
+ },
+ {
+ name: "nok, no matching value",
+ givenPathValues: examplePathValues,
+ whenName: "nope",
+ expectValues: nil,
+ expectError: errParamExtractorValueMissing.Error(),
+ },
+ {
+ name: "ok, cut values over extractorLimit",
+ givenPathValues: examplePathValues20,
+ whenName: "id",
+ whenLimit: extractorLimit,
+ expectValues: []string{
+ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
+ "11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ if tc.givenPathValues != nil {
+ c.SetPathValues(tc.givenPathValues)
+ }
+
+ extractor := valuesFromParam(tc.whenName, tc.whenLimit)
+
+ values, source, err := extractor(c)
+ assert.Equal(t, tc.expectValues, values)
+ assert.Equal(t, ExtractorSourcePathParam, source)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValuesFromCookie(t *testing.T) {
+ exampleRequest := func(req *http.Request) {
+ req.Header.Set(echo.HeaderCookie, "_csrf=token")
+ }
+
+ var testCases = []struct {
+ name string
+ givenRequest func(req *http.Request)
+ whenName string
+ whenLimit uint
+ expectValues []string
+ expectError string
+ }{
+ {
+ name: "ok, single value",
+ givenRequest: exampleRequest,
+ whenName: "_csrf",
+ expectValues: []string{"token"},
+ },
+ {
+ name: "ok, multiple value",
+ givenRequest: func(req *http.Request) {
+ req.Header.Add(echo.HeaderCookie, "_csrf=token")
+ req.Header.Add(echo.HeaderCookie, "_csrf=token2")
+ },
+ whenName: "_csrf",
+ whenLimit: 2,
+ expectValues: []string{"token", "token2"},
+ },
+ {
+ name: "nok, no matching cookie",
+ givenRequest: exampleRequest,
+ whenName: "xxx",
+ expectValues: nil,
+ expectError: errCookieExtractorValueMissing.Error(),
+ },
+ {
+ name: "nok, no cookies at all",
+ givenRequest: nil,
+ whenName: "xxx",
+ expectValues: nil,
+ expectError: errCookieExtractorValueMissing.Error(),
+ },
+ {
+ name: "ok, cut values over extractorLimit",
+ givenRequest: func(req *http.Request) {
+ for i := 1; i < 25; i++ {
+ req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i))
+ }
+ },
+ whenName: "_csrf",
+ whenLimit: extractorLimit,
+ expectValues: []string{
+ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
+ "11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ if tc.givenRequest != nil {
+ tc.givenRequest(req)
+ }
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ extractor := valuesFromCookie(tc.whenName, tc.whenLimit)
+
+ values, source, err := extractor(c)
+ assert.Equal(t, tc.expectValues, values)
+ assert.Equal(t, ExtractorSourceCookie, source)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestValuesFromForm(t *testing.T) {
+ examplePostFormRequest := func(mod func(v *url.Values)) *http.Request {
+ f := make(url.Values)
+ f.Set("name", "Jon Snow")
+ f.Set("emails[]", "jon@labstack.com")
+ if mod != nil {
+ mod(&f)
+ }
+
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
+ req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
+
+ return req
+ }
+ exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request {
+ f := make(url.Values)
+ f.Set("name", "Jon Snow")
+ f.Set("emails[]", "jon@labstack.com")
+ if mod != nil {
+ mod(&f)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil)
+ return req
+ }
+
+ exampleMultiPartFormRequest := func(mod func(w *multipart.Writer)) *http.Request {
+ var b bytes.Buffer
+ w := multipart.NewWriter(&b)
+ w.WriteField("name", "Jon Snow")
+ w.WriteField("emails[]", "jon@labstack.com")
+ if mod != nil {
+ mod(w)
+ }
+
+ fw, _ := w.CreateFormFile("upload", "my.file")
+ fw.Write([]byte(`hi
`))
+ w.Close()
+
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(b.String()))
+ req.Header.Add(echo.HeaderContentType, w.FormDataContentType())
+
+ return req
+ }
+
+ var testCases = []struct {
+ name string
+ givenRequest *http.Request
+ whenName string
+ whenLimit uint
+ expectValues []string
+ expectError string
+ }{
+ {
+ name: "ok, POST form, single value",
+ givenRequest: examplePostFormRequest(nil),
+ whenName: "emails[]",
+ expectValues: []string{"jon@labstack.com"},
+ },
+ {
+ name: "ok, POST form, multiple value",
+ givenRequest: examplePostFormRequest(func(v *url.Values) {
+ v.Add("emails[]", "snow@labstack.com")
+ }),
+ whenName: "emails[]",
+ whenLimit: 2,
+ expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
+ },
+ {
+ name: "ok, POST multipart/form, multiple value",
+ givenRequest: exampleMultiPartFormRequest(func(w *multipart.Writer) {
+ w.WriteField("emails[]", "snow@labstack.com")
+ }),
+ whenName: "emails[]",
+ whenLimit: 2,
+ expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
+ },
+ {
+ name: "ok, GET form, single value",
+ givenRequest: exampleGetFormRequest(nil),
+ whenName: "emails[]",
+ expectValues: []string{"jon@labstack.com"},
+ },
+ {
+ name: "ok, GET form, multiple value",
+ givenRequest: examplePostFormRequest(func(v *url.Values) {
+ v.Add("emails[]", "snow@labstack.com")
+ }),
+ whenName: "emails[]",
+ whenLimit: 2,
+ expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
+ },
+ {
+ name: "nok, POST form, value missing",
+ givenRequest: examplePostFormRequest(nil),
+ whenName: "nope",
+ expectError: errFormExtractorValueMissing.Error(),
+ },
+ {
+ name: "ok, cut values over extractorLimit",
+ givenRequest: examplePostFormRequest(func(v *url.Values) {
+ for i := 1; i < 25; i++ {
+ v.Add("id[]", fmt.Sprintf("%v", i))
+ }
+ }),
+ whenName: "id[]",
+ whenLimit: extractorLimit,
+ expectValues: []string{
+ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
+ "11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := tc.givenRequest
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ extractor := valuesFromForm(tc.whenName, tc.whenLimit)
+
+ values, source, err := extractor(c)
+ assert.Equal(t, tc.expectValues, values)
+ assert.Equal(t, ExtractorSourceForm, source)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
diff --git a/middleware/key_auth.go b/middleware/key_auth.go
new file mode 100644
index 000000000..e14bd9e2e
--- /dev/null
+++ b/middleware/key_auth.go
@@ -0,0 +1,205 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "cmp"
+ "errors"
+ "fmt"
+ "net/http"
+
+ "github.com/labstack/echo/v5"
+)
+
+// KeyAuthConfig defines the config for KeyAuth middleware.
+//
+// SECURITY: The Validator function is responsible for securely comparing API keys.
+// See KeyAuthValidator documentation for guidance on preventing timing attacks.
+type KeyAuthConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // KeyLookup is a string in the form of ":" or ":,:" that is used
+ // to extract key from the request.
+ // Optional. Default value "header:Authorization".
+ // Possible values:
+ // - "header:" or "header::"
+ // `` is argument value to cut/trim prefix of the extracted value. This is useful if header
+ // value has static prefix like `Authorization: ` where part that we
+ // want to cut is ` ` note the space at the end.
+ // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `.
+ // - "query:"
+ // - "form:"
+ // - "cookie:"
+ // Multiple sources example:
+ // - "header:Authorization,header:X-Api-Key"
+ KeyLookup string
+
+ // AllowedCheckLimit set how many KeyLookup values are allowed to be checked. This is
+ // useful environments like corporate test environments with application proxies restricting
+ // access to environment with their own auth scheme.
+ AllowedCheckLimit uint
+
+ // Validator is a function to validate key.
+ // Required.
+ Validator KeyAuthValidator
+
+ // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator
+ // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key.
+ // It may be used to define a custom error.
+ //
+ // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler.
+ // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users
+ // In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain.
+ ErrorHandler KeyAuthErrorHandler
+
+ // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to
+ // ignore the error (by returning `nil`).
+ // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality.
+ // In that case you can use ErrorHandler to set a default public key auth value in the request context
+ // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then.
+ ContinueOnIgnoredError bool
+}
+
+// KeyAuthValidator defines a function to validate KeyAuth credentials.
+//
+// SECURITY WARNING: To prevent timing attacks that could allow attackers to enumerate
+// valid API keys, validator implementations MUST use constant-time comparison.
+// Use crypto/subtle.ConstantTimeCompare instead of standard string equality (==)
+// or switch statements.
+//
+// Example of SECURE implementation:
+//
+// import "crypto/subtle"
+//
+// validator := func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
+// // Fetch valid keys from database/config
+// validKeys := []string{"key1", "key2", "key3"}
+//
+// for _, validKey := range validKeys {
+// // Use constant-time comparison to prevent timing attacks
+// if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 {
+// return true, nil
+// }
+// }
+// return false, nil
+// }
+//
+// Example of INSECURE implementation (DO NOT USE):
+//
+// // VULNERABLE TO TIMING ATTACKS - DO NOT USE
+// validator := func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
+// switch key { // Timing leak!
+// case "valid-key":
+// return true, nil
+// default:
+// return false, nil
+// }
+// }
+type KeyAuthValidator func(c *echo.Context, key string, source ExtractorSource) (bool, error)
+
+// KeyAuthErrorHandler defines a function which is executed for an invalid key.
+type KeyAuthErrorHandler func(c *echo.Context, err error) error
+
+// ErrKeyMissing denotes an error raised when key value could not be extracted from request
+var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key")
+
+// ErrInvalidKey denotes an error raised when key value is invalid by validator
+var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key")
+
+// DefaultKeyAuthConfig is the default KeyAuth middleware config.
+var DefaultKeyAuthConfig = KeyAuthConfig{
+ Skipper: DefaultSkipper,
+ KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ",
+}
+
+// KeyAuth returns an KeyAuth middleware.
+//
+// For valid key it calls the next handler.
+// For invalid key, it sends "401 - Unauthorized" response.
+// For missing key, it sends "400 - Bad Request" response.
+func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc {
+ c := DefaultKeyAuthConfig
+ c.Validator = fn
+ return KeyAuthWithConfig(c)
+}
+
+// KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid.
+//
+// For first valid key it calls the next handler.
+// For invalid key, it sends "401 - Unauthorized" response.
+// For missing key, it sends "400 - Bad Request" response.
+func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration
+func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Skipper == nil {
+ config.Skipper = DefaultKeyAuthConfig.Skipper
+ }
+ if config.KeyLookup == "" {
+ config.KeyLookup = DefaultKeyAuthConfig.KeyLookup
+ }
+ if config.Validator == nil {
+ return nil, errors.New("echo key-auth middleware requires a validator function")
+ }
+
+ limit := cmp.Or(config.AllowedCheckLimit, 1)
+
+ extractors, cErr := createExtractors(config.KeyLookup, limit)
+ if cErr != nil {
+ return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", cErr)
+ }
+ if len(extractors) == 0 {
+ return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string")
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ var lastExtractorErr error
+ var lastValidatorErr error
+ for _, extractor := range extractors {
+ keys, source, extrErr := extractor(c)
+ if extrErr != nil {
+ lastExtractorErr = extrErr
+ continue
+ }
+ for _, key := range keys {
+ valid, err := config.Validator(c, key, source)
+ if err != nil {
+ lastValidatorErr = err
+ continue
+ }
+ if !valid {
+ lastValidatorErr = ErrInvalidKey
+ continue
+ }
+ return next(c)
+ }
+ }
+
+ // prioritize validator errors over extracting errors
+ err := lastValidatorErr
+ if err == nil {
+ err = lastExtractorErr
+ }
+ if config.ErrorHandler != nil {
+ tmpErr := config.ErrorHandler(c, err)
+ if config.ContinueOnIgnoredError && tmpErr == nil {
+ return next(c)
+ }
+ return tmpErr
+ }
+ if lastValidatorErr == nil {
+ return ErrKeyMissing.Wrap(err)
+ }
+ return echo.ErrUnauthorized.Wrap(err)
+ }
+ }, nil
+}
diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go
new file mode 100644
index 000000000..49a917ed3
--- /dev/null
+++ b/middleware/key_auth_test.go
@@ -0,0 +1,375 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "crypto/subtle"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func testKeyValidator(c *echo.Context, key string, source ExtractorSource) (bool, error) {
+ // Use constant-time comparison to prevent timing attacks
+ if subtle.ConstantTimeCompare([]byte(key), []byte("valid-key")) == 1 {
+ return true, nil
+ }
+
+ // Special case for testing error handling
+ if key == "error-key" { // Error path doesn't need constant-time
+ return false, errors.New("some user defined error")
+ }
+
+ return false, nil
+}
+
+func TestKeyAuth(t *testing.T) {
+ handlerCalled := false
+ handler := func(c *echo.Context) error {
+ handlerCalled = true
+ return c.String(http.StatusOK, "test")
+ }
+ middlewareChain := KeyAuth(testKeyValidator)(handler)
+
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key")
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := middlewareChain(c)
+
+ assert.NoError(t, err)
+ assert.True(t, handlerCalled)
+}
+
+func TestKeyAuthWithConfig(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenRequestFunc func() *http.Request
+ givenRequest func(req *http.Request)
+ whenConfig func(conf *KeyAuthConfig)
+ expectHandlerCalled bool
+ expectError string
+ }{
+ {
+ name: "ok, defaults, key from header",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key")
+ },
+ expectHandlerCalled: true,
+ },
+ {
+ name: "ok, custom skipper",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.Skipper = func(context *echo.Context) bool {
+ return true
+ }
+ },
+ expectHandlerCalled: true,
+ },
+ {
+ name: "nok, defaults, invalid key from header, Authorization: Bearer",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
+ },
+ expectHandlerCalled: false,
+ expectError: "code=401, message=Unauthorized, err=code=401, message=invalid key",
+ },
+ {
+ name: "nok, defaults, invalid scheme in header",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "Bear valid-key")
+ },
+ expectHandlerCalled: false,
+ expectError: "code=401, message=missing key, err=invalid value in request header",
+ },
+ {
+ name: "nok, defaults, missing header",
+ givenRequest: func(req *http.Request) {},
+ expectHandlerCalled: false,
+ expectError: "code=401, message=missing key, err=missing value in request header",
+ },
+ {
+ name: "ok, custom key lookup, header",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set("API-Key", "valid-key")
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "header:API-Key"
+ },
+ expectHandlerCalled: true,
+ },
+ {
+ name: "nok, custom key lookup, missing header",
+ givenRequest: func(req *http.Request) {
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "header:API-Key"
+ },
+ expectHandlerCalled: false,
+ expectError: "code=401, message=missing key, err=missing value in request header",
+ },
+ {
+ name: "ok, custom key lookup, query",
+ givenRequest: func(req *http.Request) {
+ q := req.URL.Query()
+ q.Add("key", "valid-key")
+ req.URL.RawQuery = q.Encode()
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "query:key"
+ },
+ expectHandlerCalled: true,
+ },
+ {
+ name: "nok, custom key lookup, missing query param",
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "query:key"
+ },
+ expectHandlerCalled: false,
+ expectError: "code=401, message=missing key, err=missing value in the query string",
+ },
+ {
+ name: "ok, custom key lookup, form",
+ givenRequestFunc: func() *http.Request {
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("key=valid-key"))
+ req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
+ return req
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "form:key"
+ },
+ expectHandlerCalled: true,
+ },
+ {
+ name: "nok, custom key lookup, missing key in form",
+ givenRequestFunc: func() *http.Request {
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("xxx=valid-key"))
+ req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
+ return req
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "form:key"
+ },
+ expectHandlerCalled: false,
+ expectError: "code=401, message=missing key, err=missing value in the form",
+ },
+ {
+ name: "ok, custom key lookup, cookie",
+ givenRequest: func(req *http.Request) {
+ req.AddCookie(&http.Cookie{
+ Name: "key",
+ Value: "valid-key",
+ })
+ q := req.URL.Query()
+ q.Add("key", "valid-key")
+ req.URL.RawQuery = q.Encode()
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "cookie:key"
+ },
+ expectHandlerCalled: true,
+ },
+ {
+ name: "nok, custom key lookup, missing cookie param",
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "cookie:key"
+ },
+ expectHandlerCalled: false,
+ expectError: "code=401, message=missing key, err=missing value in cookies",
+ },
+ {
+ name: "nok, custom errorHandler, error from extractor",
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "header:token"
+ conf.ErrorHandler = func(c *echo.Context, err error) error {
+ return echo.NewHTTPError(http.StatusTeapot, "custom").Wrap(err)
+ }
+ },
+ expectHandlerCalled: false,
+ expectError: "code=418, message=custom, err=missing value in request header",
+ },
+ {
+ name: "nok, custom errorHandler, error from validator",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.ErrorHandler = func(c *echo.Context, err error) error {
+ return echo.NewHTTPError(http.StatusTeapot, "custom").Wrap(err)
+ }
+ },
+ expectHandlerCalled: false,
+ expectError: "code=418, message=custom, err=some user defined error",
+ },
+ {
+ name: "nok, defaults, error from validator",
+ givenRequest: func(req *http.Request) {
+ req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
+ },
+ whenConfig: func(conf *KeyAuthConfig) {},
+ expectHandlerCalled: false,
+ expectError: "code=401, message=Unauthorized, err=some user defined error",
+ },
+ {
+ name: "ok, custom validator checks source",
+ givenRequest: func(req *http.Request) {
+ q := req.URL.Query()
+ q.Add("key", "valid-key")
+ req.URL.RawQuery = q.Encode()
+ },
+ whenConfig: func(conf *KeyAuthConfig) {
+ conf.KeyLookup = "query:key"
+ conf.Validator = func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
+ if source == ExtractorSourceQuery {
+ return true, nil
+ }
+ return false, errors.New("invalid source")
+ }
+
+ },
+ expectHandlerCalled: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ handlerCalled := false
+ handler := func(c *echo.Context) error {
+ handlerCalled = true
+ return c.String(http.StatusOK, "test")
+ }
+ config := KeyAuthConfig{
+ Validator: testKeyValidator,
+ }
+ if tc.whenConfig != nil {
+ tc.whenConfig(&config)
+ }
+ middlewareChain := KeyAuthWithConfig(config)(handler)
+
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ if tc.givenRequestFunc != nil {
+ req = tc.givenRequestFunc()
+ }
+ if tc.givenRequest != nil {
+ tc.givenRequest(req)
+ }
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := middlewareChain(c)
+
+ assert.Equal(t, tc.expectHandlerCalled, handlerCalled)
+ if tc.expectError != "" {
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestKeyAuthWithConfig_errors(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenConfig KeyAuthConfig
+ expectError string
+ }{
+ {
+ name: "ok, no error",
+ whenConfig: KeyAuthConfig{
+ Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
+ return false, nil
+ },
+ },
+ },
+ {
+ name: "ok, missing validator func",
+ whenConfig: KeyAuthConfig{
+ Validator: nil,
+ },
+ expectError: "echo key-auth middleware requires a validator function",
+ },
+ {
+ name: "ok, extractor source can not be split",
+ whenConfig: KeyAuthConfig{
+ KeyLookup: "nope",
+ Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
+ return false, nil
+ },
+ },
+ expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope",
+ },
+ {
+ name: "ok, no extractors",
+ whenConfig: KeyAuthConfig{
+ KeyLookup: "nope:nope",
+ Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) {
+ return false, nil
+ },
+ },
+ expectError: "echo key-auth middleware could not create extractors from KeyLookup string",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ mw, err := tc.whenConfig.ToMiddleware()
+ if tc.expectError != "" {
+ assert.Nil(t, mw)
+ assert.EqualError(t, err, tc.expectError)
+ } else {
+ assert.NotNil(t, mw)
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestMustKeyAuthWithConfig_panic(t *testing.T) {
+ assert.Panics(t, func() {
+ KeyAuthWithConfig(KeyAuthConfig{})
+ })
+}
+
+func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) {
+ handlerCalled := false
+ var authValue string
+ handler := func(c *echo.Context) error {
+ handlerCalled = true
+ authValue = c.Get("auth").(string)
+ return c.String(http.StatusOK, "test")
+ }
+ middlewareChain := KeyAuthWithConfig(KeyAuthConfig{
+ Validator: testKeyValidator,
+ ErrorHandler: func(c *echo.Context, err error) error {
+ // could check error to decide if we can swallow the error
+ c.Set("auth", "public")
+ return nil
+ },
+ ContinueOnIgnoredError: true,
+ })(handler)
+
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ // no auth header this time
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := middlewareChain(c)
+
+ assert.NoError(t, err)
+ assert.True(t, handlerCalled)
+ assert.Equal(t, "public", authValue)
+}
diff --git a/middleware/logger.go b/middleware/logger.go
deleted file mode 100644
index 19d1301d5..000000000
--- a/middleware/logger.go
+++ /dev/null
@@ -1,41 +0,0 @@
-package middleware
-
-import (
- "log"
- "time"
-
- "github.com/labstack/echo"
- "github.com/labstack/gommon/color"
-)
-
-func Logger() echo.MiddlewareFunc {
- return func(h echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) error {
- start := time.Now()
- if err := h(c); err != nil {
- c.Error(err)
- }
- stop := time.Now()
- method := c.Request().Method
- path := c.Request().URL.Path
- if path == "" {
- path = "/"
- }
- size := c.Response().Size()
-
- n := c.Response().Status()
- code := color.Green(n)
- switch {
- case n >= 500:
- code = color.Red(n)
- case n >= 400:
- code = color.Yellow(n)
- case n >= 300:
- code = color.Cyan(n)
- }
-
- log.Printf("%s %s %s %s %d", method, path, code, stop.Sub(start), size)
- return nil
- }
- }
-}
diff --git a/middleware/logger_test.go b/middleware/logger_test.go
deleted file mode 100644
index e46019f3d..000000000
--- a/middleware/logger_test.go
+++ /dev/null
@@ -1,47 +0,0 @@
-package middleware
-
-import (
- "errors"
- "github.com/labstack/echo"
- "net/http"
- "net/http/httptest"
- "testing"
-)
-
-func TestLogger(t *testing.T) {
- e := echo.New()
- req, _ := http.NewRequest(echo.GET, "/", nil)
- rec := httptest.NewRecorder()
- c := echo.NewContext(req, echo.NewResponse(rec), e)
-
- // Status 2xx
- h := func(c *echo.Context) error {
- return c.String(http.StatusOK, "test")
- }
- Logger()(h)(c)
-
- // Status 3xx
- rec = httptest.NewRecorder()
- c = echo.NewContext(req, echo.NewResponse(rec), e)
- h = func(c *echo.Context) error {
- return c.String(http.StatusTemporaryRedirect, "test")
- }
- Logger()(h)(c)
-
- // Status 4xx
- rec = httptest.NewRecorder()
- c = echo.NewContext(req, echo.NewResponse(rec), e)
- h = func(c *echo.Context) error {
- return c.String(http.StatusNotFound, "test")
- }
- Logger()(h)(c)
-
- // Status 5xx with empty path
- req, _ = http.NewRequest(echo.GET, "", nil)
- rec = httptest.NewRecorder()
- c = echo.NewContext(req, echo.NewResponse(rec), e)
- h = func(c *echo.Context) error {
- return errors.New("error")
- }
- Logger()(h)(c)
-}
diff --git a/middleware/method_override.go b/middleware/method_override.go
new file mode 100644
index 000000000..25ec1f935
--- /dev/null
+++ b/middleware/method_override.go
@@ -0,0 +1,95 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "net/http"
+
+ "github.com/labstack/echo/v5"
+)
+
+// MethodOverrideConfig defines the config for MethodOverride middleware.
+type MethodOverrideConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // Getter is a function that gets overridden method from the request.
+ // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride).
+ Getter MethodOverrideGetter
+}
+
+// MethodOverrideGetter is a function that gets overridden method from the request
+type MethodOverrideGetter func(c *echo.Context) string
+
+// DefaultMethodOverrideConfig is the default MethodOverride middleware config.
+var DefaultMethodOverrideConfig = MethodOverrideConfig{
+ Skipper: DefaultSkipper,
+ Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
+}
+
+// MethodOverride returns a MethodOverride middleware.
+// MethodOverride middleware checks for the overridden method from the request and
+// uses it instead of the original method.
+//
+// For security reasons, only `POST` method can be overridden.
+func MethodOverride() echo.MiddlewareFunc {
+ return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
+}
+
+// MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration.
+func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration
+func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ // Defaults
+ if config.Skipper == nil {
+ config.Skipper = DefaultMethodOverrideConfig.Skipper
+ }
+ if config.Getter == nil {
+ config.Getter = DefaultMethodOverrideConfig.Getter
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ req := c.Request()
+ if req.Method == http.MethodPost {
+ m := config.Getter(c)
+ if m != "" {
+ req.Method = m
+ }
+ }
+ return next(c)
+ }
+ }, nil
+}
+
+// MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from
+// the request header.
+func MethodFromHeader(header string) MethodOverrideGetter {
+ return func(c *echo.Context) string {
+ return c.Request().Header.Get(header)
+ }
+}
+
+// MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the
+// form parameter.
+func MethodFromForm(param string) MethodOverrideGetter {
+ return func(c *echo.Context) string {
+ return c.FormValue(param)
+ }
+}
+
+// MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from
+// the query parameter.
+func MethodFromQuery(param string) MethodOverrideGetter {
+ return func(c *echo.Context) string {
+ return c.QueryParam(param)
+ }
+}
diff --git a/middleware/method_override_test.go b/middleware/method_override_test.go
new file mode 100644
index 000000000..525ad10ba
--- /dev/null
+++ b/middleware/method_override_test.go
@@ -0,0 +1,92 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "bytes"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestMethodOverride(t *testing.T) {
+ e := echo.New()
+ m := MethodOverride()
+ h := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ // Override with http header
+ req := httptest.NewRequest(http.MethodPost, "/", nil)
+ rec := httptest.NewRecorder()
+ req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
+ c := e.NewContext(req, rec)
+
+ err := m(h)(c)
+ assert.NoError(t, err)
+
+ assert.Equal(t, http.MethodDelete, req.Method)
+
+}
+
+func TestMethodOverride_formParam(t *testing.T) {
+ e := echo.New()
+ h := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ // Override with form parameter
+ m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware()
+ assert.NoError(t, err)
+ req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete)))
+ rec := httptest.NewRecorder()
+ req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
+ c := e.NewContext(req, rec)
+
+ err = m(h)(c)
+ assert.NoError(t, err)
+
+ assert.Equal(t, http.MethodDelete, req.Method)
+}
+
+func TestMethodOverride_queryParam(t *testing.T) {
+ e := echo.New()
+ h := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ // Override with query parameter
+ m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware()
+ assert.NoError(t, err)
+ req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err = m(h)(c)
+ assert.NoError(t, err)
+
+ assert.Equal(t, http.MethodDelete, req.Method)
+}
+
+func TestMethodOverride_ignoreGet(t *testing.T) {
+ e := echo.New()
+ m := MethodOverride()
+ h := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ // Ignore `GET`
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := m(h)(c)
+ assert.NoError(t, err)
+
+ assert.Equal(t, http.MethodGet, req.Method)
+}
diff --git a/middleware/middleware.go b/middleware/middleware.go
new file mode 100644
index 000000000..4562d03b5
--- /dev/null
+++ b/middleware/middleware.go
@@ -0,0 +1,97 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "net/http"
+ "regexp"
+ "strconv"
+ "strings"
+
+ "github.com/labstack/echo/v5"
+)
+
+// Skipper defines a function to skip middleware. Returning true skips processing the middleware.
+type Skipper func(c *echo.Context) bool
+
+// BeforeFunc defines a function which is executed just before the middleware.
+type BeforeFunc func(c *echo.Context)
+
+func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
+ groups := pattern.FindAllStringSubmatch(input, -1)
+ if groups == nil {
+ return nil
+ }
+ values := groups[0][1:]
+ replace := make([]string, 2*len(values))
+ for i, v := range values {
+ j := 2 * i
+ replace[j] = "$" + strconv.Itoa(i+1)
+ replace[j+1] = v
+ }
+ return strings.NewReplacer(replace...)
+}
+
+func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
+ // Initialize
+ rulesRegex := map[*regexp.Regexp]string{}
+ for k, v := range rewrite {
+ k = regexp.QuoteMeta(k)
+ k = strings.ReplaceAll(k, `\*`, "(.*?)")
+ if strings.HasPrefix(k, `\^`) {
+ k = strings.ReplaceAll(k, `\^`, "^")
+ }
+ k = k + "$"
+ rulesRegex[regexp.MustCompile(k)] = v
+ }
+ return rulesRegex
+}
+
+func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error {
+ if len(rewriteRegex) == 0 {
+ return nil
+ }
+
+ // Depending how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path.
+ // We only want to use path part for rewriting and therefore trim prefix if it exists
+ rawURI := req.RequestURI
+ if rawURI != "" && rawURI[0] != '/' {
+ prefix := ""
+ if req.URL.Scheme != "" {
+ prefix = req.URL.Scheme + "://"
+ }
+ if req.URL.Host != "" {
+ prefix += req.URL.Host // host or host:port
+ }
+ if prefix != "" {
+ rawURI = strings.TrimPrefix(rawURI, prefix)
+ }
+ }
+
+ for k, v := range rewriteRegex {
+ if replacer := captureTokens(k, rawURI); replacer != nil {
+ url, err := req.URL.Parse(replacer.Replace(v))
+ if err != nil {
+ return err
+ }
+ req.URL = url
+
+ return nil // rewrite only once
+ }
+ }
+ return nil
+}
+
+// DefaultSkipper returns false which processes the middleware.
+func DefaultSkipper(c *echo.Context) bool {
+ return false
+}
+
+func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc {
+ mw, err := config.ToMiddleware()
+ if err != nil {
+ panic(err)
+ }
+ return mw
+}
diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go
new file mode 100644
index 000000000..28407ed5c
--- /dev/null
+++ b/middleware/middleware_test.go
@@ -0,0 +1,136 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "bufio"
+ "errors"
+ "github.com/stretchr/testify/assert"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "regexp"
+ "testing"
+)
+
+func TestRewriteURL(t *testing.T) {
+ var testCases = []struct {
+ whenURL string
+ expectPath string
+ expectRawPath string
+ expectQuery string
+ expectErr string
+ }{
+ {
+ whenURL: "http://localhost:8080/old",
+ expectPath: "/new",
+ expectRawPath: "",
+ },
+ { // encoded `ol%64` (decoded `old`) should not be rewritten to `/new`
+ whenURL: "/ol%64", // `%64` is decoded `d`
+ expectPath: "/old",
+ expectRawPath: "/ol%64",
+ },
+ {
+ whenURL: "http://localhost:8080/users/+_+/orders/___++++?test=1",
+ expectPath: "/user/+_+/order/___++++",
+ expectRawPath: "",
+ expectQuery: "test=1",
+ },
+ {
+ whenURL: "http://localhost:8080/users/%20a/orders/%20aa",
+ expectPath: "/user/ a/order/ aa",
+ expectRawPath: "",
+ },
+ {
+ whenURL: "http://localhost:8080/%47%6f%2f?test=1",
+ expectPath: "/Go/",
+ expectRawPath: "/%47%6f%2f",
+ expectQuery: "test=1",
+ },
+ {
+ whenURL: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F",
+ expectPath: "/user/jill/order/T/cO4lW/t/Vp/",
+ expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ },
+ { // do nothing, replace nothing
+ whenURL: "http://localhost:8080/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ expectPath: "/user/jill/order/T/cO4lW/t/Vp/",
+ expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ },
+ {
+ whenURL: "http://localhost:8080/static",
+ expectPath: "/static/path",
+ expectRawPath: "",
+ expectQuery: "role=AUTHOR&limit=1000",
+ },
+ {
+ whenURL: "/static",
+ expectPath: "/static/path",
+ expectRawPath: "",
+ expectQuery: "role=AUTHOR&limit=1000",
+ },
+ }
+
+ rules := map[*regexp.Regexp]string{
+ regexp.MustCompile("^/old$"): "/new",
+ regexp.MustCompile("^/users/(.*?)/orders/(.*?)$"): "/user/$1/order/$2",
+ regexp.MustCompile("^/static$"): "/static/path?role=AUTHOR&limit=1000",
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.whenURL, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+
+ err := rewriteURL(rules, req)
+
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tc.expectPath, req.URL.Path) // Path field is stored in decoded form: /%47%6f%2f becomes /Go/.
+ assert.Equal(t, tc.expectRawPath, req.URL.RawPath) // RawPath, an optional field which only gets set if the default encoding is different from Path.
+ assert.Equal(t, tc.expectQuery, req.URL.RawQuery)
+ })
+ }
+}
+
+type testResponseWriterNoFlushHijack struct {
+}
+
+func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) {
+}
+func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) {
+ return 0, nil
+}
+func (w *testResponseWriterNoFlushHijack) Header() http.Header {
+ return nil
+}
+
+type testResponseWriterUnwrapper struct {
+ unwrapCalled int
+ rw http.ResponseWriter
+}
+
+func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) {
+}
+func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) {
+ return 0, nil
+}
+func (w *testResponseWriterUnwrapper) Header() http.Header {
+ return nil
+}
+func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter {
+ w.unwrapCalled++
+ return w.rw
+}
+
+type testResponseWriterUnwrapperHijack struct {
+ testResponseWriterUnwrapper
+}
+
+func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ return nil, nil, errors.New("can hijack")
+}
diff --git a/middleware/proxy.go b/middleware/proxy.go
new file mode 100644
index 000000000..1996032f7
--- /dev/null
+++ b/middleware/proxy.go
@@ -0,0 +1,441 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io"
+ "math/rand"
+ "net"
+ "net/http"
+ "net/http/httputil"
+ "net/url"
+ "regexp"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/labstack/echo/v5"
+)
+
+// TODO: Handle TLS proxy
+
+// ProxyConfig defines the config for Proxy middleware.
+type ProxyConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // Balancer defines a load balancing technique.
+ // Required.
+ Balancer ProxyBalancer
+
+ // RetryCount defines the number of times a failed proxied request should be retried
+ // using the next available ProxyTarget. Defaults to 0, meaning requests are never retried.
+ RetryCount int
+
+ // RetryFilter defines a function used to determine if a failed request to a
+ // ProxyTarget should be retried. The RetryFilter will only be called when the number
+ // of previous retries is less than RetryCount. If the function returns true, the
+ // request will be retried. The provided error indicates the reason for the request
+ // failure. When the ProxyTarget is unavailable, the error will be an instance of
+ // echo.HTTPError with a code of http.StatusBadGateway. In all other cases, the error
+ // will indicate an internal error in the Proxy middleware. When a RetryFilter is not
+ // specified, all requests that fail with http.StatusBadGateway will be retried. A custom
+ // RetryFilter can be provided to only retry specific requests. Note that RetryFilter is
+ // only called when the request to the target fails, or an internal error in the Proxy
+ // middleware has occurred. Successful requests that return a non-200 response code cannot
+ // be retried.
+ RetryFilter func(c *echo.Context, e error) bool
+
+ // ErrorHandler defines a function which can be used to return custom errors from
+ // the Proxy middleware. ErrorHandler is only invoked when there has been
+ // either an internal error in the Proxy middleware or the ProxyTarget is
+ // unavailable. Due to the way requests are proxied, ErrorHandler is not invoked
+ // when a ProxyTarget returns a non-200 response. In these cases, the response
+ // is already written so errors cannot be modified. ErrorHandler is only
+ // invoked after all retry attempts have been exhausted.
+ ErrorHandler func(c *echo.Context, err error) error
+
+ // Rewrite defines URL path rewrite rules. The values captured in asterisk can be
+ // retrieved by index e.g. $1, $2 and so on.
+ // Examples:
+ // "/old": "/new",
+ // "/api/*": "/$1",
+ // "/js/*": "/public/javascripts/$1",
+ // "/users/*/orders/*": "/user/$1/order/$2",
+ Rewrite map[string]string
+
+ // RegexRewrite defines rewrite rules using regexp.Rexexp with captures
+ // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
+ // Example:
+ // "^/old/[0.9]+/": "/new",
+ // "^/api/.+?/(.*)": "/v2/$1",
+ RegexRewrite map[*regexp.Regexp]string
+
+ // Context key to store selected ProxyTarget into context.
+ // Optional. Default value "target".
+ ContextKey string
+
+ // To customize the transport to remote.
+ // Examples: If custom TLS certificates are required.
+ Transport http.RoundTripper
+
+ // ModifyResponse defines function to modify response from ProxyTarget.
+ ModifyResponse func(*http.Response) error
+}
+
+// ProxyTarget defines the upstream target.
+type ProxyTarget struct {
+ Name string
+ URL *url.URL
+ Meta map[string]any
+}
+
+// ProxyBalancer defines an interface to implement a load balancing technique.
+type ProxyBalancer interface {
+ AddTarget(target *ProxyTarget) bool
+ RemoveTarget(targetName string) bool
+ Next(c *echo.Context) (*ProxyTarget, error)
+}
+
+type commonBalancer struct {
+ targets []*ProxyTarget
+ mutex sync.Mutex
+}
+
+// RandomBalancer implements a random load balancing technique.
+type randomBalancer struct {
+ commonBalancer
+ random *rand.Rand
+}
+
+// RoundRobinBalancer implements a round-robin load balancing technique.
+type roundRobinBalancer struct {
+ commonBalancer
+ // tracking the index on `targets` slice for the next `*ProxyTarget` to be used
+ i int
+}
+
+// DefaultProxyConfig is the default Proxy middleware config.
+var DefaultProxyConfig = ProxyConfig{
+ Skipper: DefaultSkipper,
+ ContextKey: "target",
+}
+
+func proxyRaw(c *echo.Context, t *ProxyTarget, config ProxyConfig) http.Handler {
+ var dialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
+ if transport, ok := config.Transport.(*http.Transport); ok {
+ if transport.TLSClientConfig != nil {
+ d := tls.Dialer{
+ Config: transport.TLSClientConfig,
+ }
+ dialFunc = d.DialContext
+ }
+ }
+ if dialFunc == nil {
+ var d net.Dialer
+ dialFunc = d.DialContext
+ }
+
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ in, _, err := http.NewResponseController(w).Hijack()
+ if err != nil {
+ c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL))
+ return
+ }
+ defer in.Close()
+
+ out, err := dialFunc(c.Request().Context(), "tcp", t.URL.Host)
+ if err != nil {
+ c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL)))
+ return
+ }
+ defer out.Close()
+
+ // Write header
+ err = r.Write(out)
+ if err != nil {
+ c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", err, t.URL)))
+ return
+ }
+
+ errCh := make(chan error, 2)
+ cp := func(dst io.Writer, src io.Reader) {
+ _, copyErr := io.Copy(dst, src)
+ errCh <- copyErr
+ }
+
+ go cp(out, in)
+ go cp(in, out)
+
+ // Wait for BOTH goroutines to complete
+ err1 := <-errCh
+ err2 := <-errCh
+
+ if err1 != nil && err1 != io.EOF {
+ c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err1, t.URL))
+ } else if err2 != nil && err2 != io.EOF {
+ c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err2, t.URL))
+ }
+ })
+}
+
+// NewRandomBalancer returns a random proxy balancer.
+func NewRandomBalancer(targets []*ProxyTarget) ProxyBalancer {
+ b := randomBalancer{}
+ b.targets = targets
+ // G404 (CWE-338): Use of weak random number generator (math/rand or math/rand/v2 instead of crypto/rand)
+ // this random is used to select next target. I can not think of reason this must be cryptographically safe. If you can - please open PR.
+ b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) // #nosec G404
+ return &b
+}
+
+// NewRoundRobinBalancer returns a round-robin proxy balancer.
+func NewRoundRobinBalancer(targets []*ProxyTarget) ProxyBalancer {
+ b := roundRobinBalancer{}
+ b.targets = targets
+ return &b
+}
+
+// AddTarget adds an upstream target to the list and returns `true`.
+//
+// However, if a target with the same name already exists then the operation is aborted returning `false`.
+func (b *commonBalancer) AddTarget(target *ProxyTarget) bool {
+ b.mutex.Lock()
+ defer b.mutex.Unlock()
+ for _, t := range b.targets {
+ if t.Name == target.Name {
+ return false
+ }
+ }
+ b.targets = append(b.targets, target)
+ return true
+}
+
+// RemoveTarget removes an upstream target from the list by name.
+//
+// Returns `true` on success, `false` if no target with the name is found.
+func (b *commonBalancer) RemoveTarget(name string) bool {
+ b.mutex.Lock()
+ defer b.mutex.Unlock()
+ for i, t := range b.targets {
+ if t.Name == name {
+ b.targets = append(b.targets[:i], b.targets[i+1:]...)
+ return true
+ }
+ }
+ return false
+}
+
+// Next randomly returns an upstream target.
+//
+// Note: `nil` is returned in case upstream target list is empty.
+func (b *randomBalancer) Next(c *echo.Context) (*ProxyTarget, error) {
+ b.mutex.Lock()
+ defer b.mutex.Unlock()
+ if len(b.targets) == 0 {
+ return nil, nil
+ } else if len(b.targets) == 1 {
+ return b.targets[0], nil
+ }
+ return b.targets[b.random.Intn(len(b.targets))], nil
+}
+
+// Next returns an upstream target using round-robin technique. In the case
+// where a previously failed request is being retried, the round-robin
+// balancer will attempt to use the next target relative to the original
+// request. If the list of targets held by the balancer is modified while a
+// failed request is being retried, it is possible that the balancer will
+// return the original failed target.
+//
+// Note: `nil` is returned in case upstream target list is empty.
+func (b *roundRobinBalancer) Next(c *echo.Context) (*ProxyTarget, error) {
+ b.mutex.Lock()
+ defer b.mutex.Unlock()
+ if len(b.targets) == 0 {
+ return nil, nil
+ } else if len(b.targets) == 1 {
+ return b.targets[0], nil
+ }
+
+ var i int
+ const lastIdxKey = "_round_robin_last_index"
+ // This request is a retry, start from the index of the previous
+ // target to ensure we don't attempt to retry the request with
+ // the same failed target
+ if c.Get(lastIdxKey) != nil {
+ i = c.Get(lastIdxKey).(int)
+ i++
+ if i >= len(b.targets) {
+ i = 0
+ }
+ } else {
+ // This is a first time request, use the global index
+ if b.i >= len(b.targets) {
+ b.i = 0
+ }
+ i = b.i
+ b.i++
+ }
+ c.Set(lastIdxKey, i)
+ return b.targets[i], nil
+}
+
+// Proxy returns a Proxy middleware.
+//
+// Proxy middleware forwards the request to upstream server using a configured load balancing technique.
+func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc {
+ c := DefaultProxyConfig
+ c.Balancer = balancer
+ return ProxyWithConfig(c)
+}
+
+// ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid.
+//
+// Proxy middleware forwards the request to upstream server using a configured load balancing technique.
+func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration
+func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Skipper == nil {
+ config.Skipper = DefaultProxyConfig.Skipper
+ }
+ if config.ContextKey == "" {
+ config.ContextKey = DefaultProxyConfig.ContextKey
+ }
+ if config.Balancer == nil {
+ return nil, errors.New("echo proxy middleware requires balancer")
+ }
+ if config.RetryFilter == nil {
+ config.RetryFilter = func(c *echo.Context, e error) bool {
+ if httpErr, ok := e.(*echo.HTTPError); ok {
+ return httpErr.Code == http.StatusBadGateway
+ }
+ return false
+ }
+ }
+ if config.ErrorHandler == nil {
+ config.ErrorHandler = func(c *echo.Context, err error) error {
+ return err
+ }
+ }
+
+ if config.Rewrite != nil {
+ if config.RegexRewrite == nil {
+ config.RegexRewrite = make(map[*regexp.Regexp]string)
+ }
+ for k, v := range rewriteRulesRegex(config.Rewrite) {
+ config.RegexRewrite[k] = v
+ }
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) (err error) {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ req := c.Request()
+ res := c.Response()
+ if err := rewriteURL(config.RegexRewrite, req); err != nil {
+ return config.ErrorHandler(c, err)
+ }
+
+ // Fix header
+ // Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.
+ // However, for backward compatibility, legacy behavior is preserved unless you configure Echo#IPExtractor.
+ if req.Header.Get(echo.HeaderXRealIP) == "" || c.Echo().IPExtractor != nil {
+ req.Header.Set(echo.HeaderXRealIP, c.RealIP())
+ }
+ if req.Header.Get(echo.HeaderXForwardedProto) == "" {
+ req.Header.Set(echo.HeaderXForwardedProto, c.Scheme())
+ }
+ if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy.
+ req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
+ }
+
+ retries := config.RetryCount
+ for {
+ tgt, err := config.Balancer.Next(c)
+ if err != nil {
+ return config.ErrorHandler(c, err)
+ }
+
+ c.Set(config.ContextKey, tgt)
+
+ //If retrying a failed request, clear any previous errors from
+ //context here so that balancers have the option to check for
+ //errors that occurred using previous target
+ if retries < config.RetryCount {
+ c.Set("_error", nil)
+ }
+
+ // This is needed for ProxyConfig.ModifyResponse and/or ProxyConfig.Transport to be able to process the Request
+ // that Balancer may have replaced with c.SetRequest.
+ req = c.Request()
+
+ // Proxy
+ switch {
+ case c.IsWebSocket():
+ proxyRaw(c, tgt, config).ServeHTTP(res, req)
+ default: // even SSE requests
+ proxyHTTP(c, tgt, config).ServeHTTP(res, req)
+ }
+
+ err, hasError := c.Get("_error").(error)
+ if !hasError {
+ return nil
+ }
+
+ retry := retries > 0 && config.RetryFilter(c, err)
+ if !retry {
+ return config.ErrorHandler(c, err)
+ }
+
+ retries--
+ }
+ }
+ }, nil
+}
+
+// StatusCodeContextCanceled is a custom HTTP status code for situations
+// where a client unexpectedly closed the connection to the server.
+// As there is no standard error code for "client closed connection", but
+// various well-known HTTP clients and server implement this HTTP code we use
+// 499 too instead of the more problematic 5xx, which does not allow to detect this situation
+const StatusCodeContextCanceled = 499
+
+func proxyHTTP(c *echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler {
+ proxy := httputil.NewSingleHostReverseProxy(tgt.URL)
+ proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) {
+ desc := tgt.URL.String()
+ if tgt.Name != "" {
+ desc = fmt.Sprintf("%s(%s)", tgt.Name, tgt.URL.String())
+ }
+ // If the client canceled the request (usually by closing the connection), we can report a
+ // client error (4xx) instead of a server error (5xx) to correctly identify the situation.
+ // The Go standard library (at of late 2020) wraps the exported, standard
+ // context. Canceled error with unexported garbage value requiring a substring check, see
+ // https://github.com/golang/go/blob/6965b01ea248cabb70c3749fd218b36089a21efb/src/net/net.go#L416-L430
+ // From Caddy https://github.com/caddyserver/caddy/blob/afa778ae05503f563af0d1015cdf7e5e78b1eeec/modules/caddyhttp/reverseproxy/reverseproxy.go#L1352
+ if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "operation was canceled") {
+ httpError := echo.NewHTTPError(StatusCodeContextCanceled, "client closed connection").Wrap(err)
+ c.Set("_error", httpError)
+ } else {
+ httpError := echo.NewHTTPError(
+ http.StatusBadGateway,
+ "remote server unreachable, could not proxy request",
+ ).Wrap(fmt.Errorf("server: %s, err: %w", desc, err))
+ c.Set("_error", httpError)
+ }
+ }
+ proxy.Transport = config.Transport
+ proxy.ModifyResponse = config.ModifyResponse
+ return proxy
+}
diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go
new file mode 100644
index 000000000..420be3240
--- /dev/null
+++ b/middleware/proxy_test.go
@@ -0,0 +1,1055 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "regexp"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+ "golang.org/x/net/websocket"
+)
+
+// Assert expected with url.EscapedPath method to obtain the path.
+func TestProxy(t *testing.T) {
+ // Setup
+ t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, "target 1")
+ }))
+ defer t1.Close()
+ url1, _ := url.Parse(t1.URL)
+ t2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, "target 2")
+ }))
+ defer t2.Close()
+ url2, _ := url.Parse(t2.URL)
+
+ targets := []*ProxyTarget{
+ {
+ Name: "target 1",
+ URL: url1,
+ },
+ {
+ Name: "target 2",
+ URL: url2,
+ },
+ }
+ rb := NewRandomBalancer(nil)
+ // must add targets:
+ for _, target := range targets {
+ assert.True(t, rb.AddTarget(target))
+ }
+
+ // must ignore duplicates:
+ for _, target := range targets {
+ assert.False(t, rb.AddTarget(target))
+ }
+
+ // Random
+ e := echo.New()
+ e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ body := rec.Body.String()
+ expected := map[string]bool{
+ "target 1": true,
+ "target 2": true,
+ }
+ assert.Condition(t, func() bool {
+ return expected[body]
+ })
+
+ for _, target := range targets {
+ assert.True(t, rb.RemoveTarget(target.Name))
+ }
+
+ assert.False(t, rb.RemoveTarget("unknown target"))
+
+ // Round-robin
+ rrb := NewRoundRobinBalancer(targets)
+ e = echo.New()
+ e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
+
+ rec = httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ body = rec.Body.String()
+ assert.Equal(t, "target 1", body)
+
+ rec = httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ body = rec.Body.String()
+ assert.Equal(t, "target 2", body)
+
+ // ModifyResponse
+ e = echo.New()
+ e.Use(ProxyWithConfig(ProxyConfig{
+ Balancer: rrb,
+ ModifyResponse: func(res *http.Response) error {
+ res.Body = io.NopCloser(bytes.NewBuffer([]byte("modified")))
+ res.Header.Set("X-Modified", "1")
+ return nil
+ },
+ }))
+
+ rec = httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, "modified", rec.Body.String())
+ assert.Equal(t, "1", rec.Header().Get("X-Modified"))
+
+ // ProxyTarget is set in context
+ contextObserver := func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) (err error) {
+ next(c)
+ assert.Contains(t, targets, c.Get("target"), "target is not set in context")
+ return nil
+ }
+ }
+
+ e = echo.New()
+ e.Use(contextObserver)
+ e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)}))
+ rec = httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+}
+
+func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) {
+ assert.Panics(t, func() {
+ ProxyWithConfig(ProxyConfig{Balancer: nil})
+ })
+}
+
+func TestProxyRealIPHeader(t *testing.T) {
+ // Setup
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
+ defer upstream.Close()
+ url, _ := url.Parse(upstream.URL)
+ rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}})
+ e := echo.New()
+ e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb}))
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ remoteAddrIP, _, _ := net.SplitHostPort(req.RemoteAddr)
+ realIPHeaderIP := "203.0.113.1"
+ extractedRealIP := "203.0.113.10"
+ tests := []*struct {
+ hasRealIPheader bool
+ hasIPExtractor bool
+ expectedXRealIP string
+ }{
+ {false, false, remoteAddrIP},
+ {false, true, extractedRealIP},
+ {true, false, realIPHeaderIP},
+ {true, true, extractedRealIP},
+ }
+
+ for _, tt := range tests {
+ if tt.hasRealIPheader {
+ req.Header.Set(echo.HeaderXRealIP, realIPHeaderIP)
+ } else {
+ req.Header.Del(echo.HeaderXRealIP)
+ }
+ if tt.hasIPExtractor {
+ e.IPExtractor = func(*http.Request) string {
+ return extractedRealIP
+ }
+ } else {
+ e.IPExtractor = nil
+ }
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, tt.expectedXRealIP, req.Header.Get(echo.HeaderXRealIP), "hasRealIPheader: %t / hasIPExtractor: %t", tt.hasRealIPheader, tt.hasIPExtractor)
+ }
+}
+
+func TestProxyRewrite(t *testing.T) {
+ var testCases = []struct {
+ whenPath string
+ expectProxiedURI string
+ expectStatus int
+ }{
+ {
+ whenPath: "/api/users",
+ expectProxiedURI: "/users",
+ expectStatus: http.StatusOK,
+ },
+ {
+ whenPath: "/js/main.js",
+ expectProxiedURI: "/public/javascripts/main.js",
+ expectStatus: http.StatusOK,
+ },
+ {
+ whenPath: "/old",
+ expectProxiedURI: "/new",
+ expectStatus: http.StatusOK,
+ },
+ {
+ whenPath: "/users/jack/orders/1",
+ expectProxiedURI: "/user/jack/order/1",
+ expectStatus: http.StatusOK,
+ },
+ {
+ whenPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ expectProxiedURI: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ expectStatus: http.StatusOK,
+ },
+ { // ` ` (space) is encoded by httpClient to `%20` when doing request to Echo. `%20` should not be double escaped when proxying request
+ whenPath: "/api/new users",
+ expectProxiedURI: "/new%20users",
+ expectStatus: http.StatusOK,
+ },
+ { // query params should be proxied and not be modified
+ whenPath: "/api/users?limit=10",
+ expectProxiedURI: "/users?limit=10",
+ expectStatus: http.StatusOK,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.whenPath, func(t *testing.T) {
+ receivedRequestURI := make(chan string, 1)
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server
+ // we need unmodified target to see if we are encoding/decoding the url in addition to rewrite/replace logic
+ // if original request had `%2F` we should not magically decode it to `/` as it would change what was requested
+ receivedRequestURI <- r.RequestURI
+ }))
+ defer upstream.Close()
+ serverURL, _ := url.Parse(upstream.URL)
+ rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: serverURL}})
+
+ // Rewrite
+ e := echo.New()
+ e.Use(ProxyWithConfig(ProxyConfig{
+ Balancer: rrb,
+ Rewrite: map[string]string{
+ "/old": "/new",
+ "/api/*": "/$1",
+ "/js/*": "/public/javascripts/$1",
+ "/users/*/orders/*": "/user/$1/order/$2",
+ },
+ }))
+
+ targetURL, _ := serverURL.Parse(tc.whenPath)
+ req := httptest.NewRequest(http.MethodGet, targetURL.String(), nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ actualRequestURI := <-receivedRequestURI
+ assert.Equal(t, tc.expectProxiedURI, actualRequestURI)
+ })
+ }
+}
+
+func TestProxyRewriteRegex(t *testing.T) {
+ // Setup
+ receivedRequestURI := make(chan string, 1)
+ upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server
+ // we need unmodified target to see if we are encoding/decoding the url in addition to rewrite/replace logic
+ // if original request had `%2F` we should not magically decode it to `/` as it would change what was requested
+ receivedRequestURI <- r.RequestURI
+ }))
+ defer upstream.Close()
+ tmpUrL, _ := url.Parse(upstream.URL)
+ rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: tmpUrL}})
+
+ // Rewrite
+ e := echo.New()
+ e.Use(ProxyWithConfig(ProxyConfig{
+ Balancer: rrb,
+ Rewrite: map[string]string{
+ "^/a/*": "/v1/$1",
+ "^/b/*/c/*": "/v2/$2/$1",
+ "^/c/*/*": "/v3/$2",
+ },
+ RegexRewrite: map[*regexp.Regexp]string{
+ regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1",
+ regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1",
+ },
+ }))
+
+ testCases := []struct {
+ requestPath string
+ statusCode int
+ expectPath string
+ }{
+ {"/unmatched", http.StatusOK, "/unmatched"},
+ {"/a/test", http.StatusOK, "/v1/test"},
+ {"/b/foo/c/bar/baz", http.StatusOK, "/v2/bar/baz/foo"},
+ {"/c/ignore/test", http.StatusOK, "/v3/test"},
+ {"/c/ignore1/test/this", http.StatusOK, "/v3/test/this"},
+ {"/x/ignore/test", http.StatusOK, "/v4/test"},
+ {"/y/foo/bar", http.StatusOK, "/v5/bar/foo"},
+ // NB: fragment is not added by golang httputil.NewSingleHostReverseProxy implementation
+ // $2 = `bar?q=1#frag`, $1 = `foo`. replaced uri = `/v5/bar?q=1#frag/foo` but httputil.NewSingleHostReverseProxy does not send `#frag/foo` (currently)
+ {"/y/foo/bar?q=1#frag", http.StatusOK, "/v5/bar?q=1"},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.requestPath, func(t *testing.T) {
+ targetURL, _ := url.Parse(tc.requestPath)
+ req := httptest.NewRequest(http.MethodGet, targetURL.String(), nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ actualRequestURI := <-receivedRequestURI
+ assert.Equal(t, tc.expectPath, actualRequestURI)
+ assert.Equal(t, tc.statusCode, rec.Code)
+ })
+ }
+}
+
+func TestProxyError(t *testing.T) {
+ // Setup
+ url1, _ := url.Parse("http://127.0.0.1:27121")
+ url2, _ := url.Parse("http://127.0.0.1:27122")
+
+ targets := []*ProxyTarget{
+ {
+ Name: "target 1",
+ URL: url1,
+ },
+ {
+ Name: "target 2",
+ URL: url2,
+ },
+ }
+ rb := NewRandomBalancer(nil)
+ // must add targets:
+ for _, target := range targets {
+ assert.True(t, rb.AddTarget(target))
+ }
+
+ // must ignore duplicates:
+ for _, target := range targets {
+ assert.False(t, rb.AddTarget(target))
+ }
+
+ // Random
+ e := echo.New()
+ e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+
+ // Remote unreachable
+ rec := httptest.NewRecorder()
+ req.URL.Path = "/api/users"
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, "/api/users", req.URL.Path)
+ assert.Equal(t, http.StatusBadGateway, rec.Code)
+}
+
+func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) {
+ var timeoutStop sync.WaitGroup
+ timeoutStop.Add(1)
+ HTTPTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ timeoutStop.Wait() // wait until we have canceled the request
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer HTTPTarget.Close()
+ targetURL, _ := url.Parse(HTTPTarget.URL)
+ target := &ProxyTarget{
+ Name: "target",
+ URL: targetURL,
+ }
+ rb := NewRandomBalancer(nil)
+ assert.True(t, rb.AddTarget(target))
+ e := echo.New()
+ e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb}))
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ ctx, cancel := context.WithCancel(req.Context())
+ req = req.WithContext(ctx)
+ go func() {
+ time.Sleep(10 * time.Millisecond)
+ cancel()
+ }()
+ e.ServeHTTP(rec, req)
+ timeoutStop.Done()
+ assert.Equal(t, 499, rec.Code)
+}
+
+type testProvider struct {
+ commonBalancer
+ target *ProxyTarget
+ err error
+}
+
+func (p *testProvider) Next(c *echo.Context) (*ProxyTarget, error) {
+ return p.target, p.err
+}
+
+func TestTargetProvider(t *testing.T) {
+ t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, "target 1")
+ }))
+ defer t1.Close()
+ url1, _ := url.Parse(t1.URL)
+
+ e := echo.New()
+ tp := &testProvider{}
+ tp.target = &ProxyTarget{Name: "target 1", URL: url1}
+ e.Use(Proxy(tp))
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ e.ServeHTTP(rec, req)
+ body := rec.Body.String()
+ assert.Equal(t, "target 1", body)
+}
+
+func TestFailNextTarget(t *testing.T) {
+ url1, err := url.Parse("http://dummy:8080")
+ assert.Nil(t, err)
+
+ e := echo.New()
+ tp := &testProvider{}
+ tp.target = &ProxyTarget{Name: "target 1", URL: url1}
+ tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target")
+
+ e.Use(Proxy(tp))
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ e.ServeHTTP(rec, req)
+ body := rec.Body.String()
+ assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body)
+}
+
+func TestRandomBalancerWithNoTargets(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ // Assert balancer with empty targets does return `nil` on `Next()`
+ rb := NewRandomBalancer(nil)
+ target, err := rb.Next(c)
+ assert.Nil(t, target)
+ assert.NoError(t, err)
+}
+
+func TestRoundRobinBalancerWithNoTargets(t *testing.T) {
+ // Assert balancer with empty targets does return `nil` on `Next()`
+ rrb := NewRoundRobinBalancer([]*ProxyTarget{})
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ target, err := rrb.Next(c)
+ assert.Nil(t, target)
+ assert.NoError(t, err)
+}
+
+func TestProxyRetries(t *testing.T) {
+ newServer := func(res int) (*url.URL, *httptest.Server) {
+ server := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(res)
+ }),
+ )
+ targetURL, _ := url.Parse(server.URL)
+ return targetURL, server
+ }
+
+ targetURL, server := newServer(http.StatusOK)
+ defer server.Close()
+ goodTarget := &ProxyTarget{
+ Name: "Good",
+ URL: targetURL,
+ }
+
+ targetURL, server = newServer(http.StatusBadRequest)
+ defer server.Close()
+ goodTargetWith40X := &ProxyTarget{
+ Name: "Good with 40X",
+ URL: targetURL,
+ }
+
+ targetURL, _ = url.Parse("http://127.0.0.1:27121")
+ badTarget := &ProxyTarget{
+ Name: "Bad",
+ URL: targetURL,
+ }
+
+ alwaysRetryFilter := func(c *echo.Context, e error) bool { return true }
+ neverRetryFilter := func(c *echo.Context, e error) bool { return false }
+
+ testCases := []struct {
+ name string
+ retryCount int
+ retryFilters []func(c *echo.Context, e error) bool
+ targets []*ProxyTarget
+ expectedResponse int
+ }{
+ {
+ name: "retry count 0 does not attempt retry on fail",
+ targets: []*ProxyTarget{
+ badTarget,
+ goodTarget,
+ },
+ expectedResponse: http.StatusBadGateway,
+ },
+ {
+ name: "retry count 1 does not attempt retry on success",
+ retryCount: 1,
+ targets: []*ProxyTarget{
+ goodTarget,
+ },
+ expectedResponse: http.StatusOK,
+ },
+ {
+ name: "retry count 1 does retry on handler return true",
+ retryCount: 1,
+ retryFilters: []func(c *echo.Context, e error) bool{
+ alwaysRetryFilter,
+ },
+ targets: []*ProxyTarget{
+ badTarget,
+ goodTarget,
+ },
+ expectedResponse: http.StatusOK,
+ },
+ {
+ name: "retry count 1 does not retry on handler return false",
+ retryCount: 1,
+ retryFilters: []func(c *echo.Context, e error) bool{
+ neverRetryFilter,
+ },
+ targets: []*ProxyTarget{
+ badTarget,
+ goodTarget,
+ },
+ expectedResponse: http.StatusBadGateway,
+ },
+ {
+ name: "retry count 2 returns error when no more retries left",
+ retryCount: 2,
+ retryFilters: []func(c *echo.Context, e error) bool{
+ alwaysRetryFilter,
+ alwaysRetryFilter,
+ },
+ targets: []*ProxyTarget{
+ badTarget,
+ badTarget,
+ badTarget,
+ goodTarget, //Should never be reached as only 2 retries
+ },
+ expectedResponse: http.StatusBadGateway,
+ },
+ {
+ name: "retry count 2 returns error when retries left but handler returns false",
+ retryCount: 3,
+ retryFilters: []func(c *echo.Context, e error) bool{
+ alwaysRetryFilter,
+ alwaysRetryFilter,
+ neverRetryFilter,
+ },
+ targets: []*ProxyTarget{
+ badTarget,
+ badTarget,
+ badTarget,
+ goodTarget, //Should never be reached as retry handler returns false on 2nd check
+ },
+ expectedResponse: http.StatusBadGateway,
+ },
+ {
+ name: "retry count 3 succeeds",
+ retryCount: 3,
+ retryFilters: []func(c *echo.Context, e error) bool{
+ alwaysRetryFilter,
+ alwaysRetryFilter,
+ alwaysRetryFilter,
+ },
+ targets: []*ProxyTarget{
+ badTarget,
+ badTarget,
+ badTarget,
+ goodTarget,
+ },
+ expectedResponse: http.StatusOK,
+ },
+ {
+ name: "40x responses are not retried",
+ retryCount: 1,
+ targets: []*ProxyTarget{
+ goodTargetWith40X,
+ goodTarget,
+ },
+ expectedResponse: http.StatusBadRequest,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+
+ retryFilterCall := 0
+ retryFilter := func(c *echo.Context, e error) bool {
+ if len(tc.retryFilters) == 0 {
+ assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall))
+ }
+
+ retryFilterCall++
+
+ nextRetryFilter := tc.retryFilters[0]
+ tc.retryFilters = tc.retryFilters[1:]
+
+ return nextRetryFilter(c, e)
+ }
+
+ e := echo.New()
+ e.Use(ProxyWithConfig(
+ ProxyConfig{
+ Balancer: NewRoundRobinBalancer(tc.targets),
+ RetryCount: tc.retryCount,
+ RetryFilter: retryFilter,
+ },
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectedResponse, rec.Code)
+ if len(tc.retryFilters) > 0 {
+ assert.FailNow(t, fmt.Sprintf("expected %d more retry handler calls", len(tc.retryFilters)))
+ }
+ })
+ }
+}
+
+func TestProxyRetryWithBackendTimeout(t *testing.T) {
+
+ transport := http.DefaultTransport.(*http.Transport).Clone()
+ transport.ResponseHeaderTimeout = time.Millisecond * 500
+
+ timeoutBackend := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ time.Sleep(1 * time.Second)
+ w.WriteHeader(404)
+ }),
+ )
+ defer timeoutBackend.Close()
+
+ timeoutTargetURL, _ := url.Parse(timeoutBackend.URL)
+ goodBackend := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(200)
+ }),
+ )
+ defer goodBackend.Close()
+
+ goodTargetURL, _ := url.Parse(goodBackend.URL)
+ e := echo.New()
+ e.Use(ProxyWithConfig(
+ ProxyConfig{
+ Transport: transport,
+ Balancer: NewRoundRobinBalancer([]*ProxyTarget{
+ {
+ Name: "Timeout",
+ URL: timeoutTargetURL,
+ },
+ {
+ Name: "Good",
+ URL: goodTargetURL,
+ },
+ }),
+ RetryCount: 1,
+ },
+ ))
+
+ var wg sync.WaitGroup
+ for i := 0; i < 20; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, 200, rec.Code)
+ }()
+ }
+
+ wg.Wait()
+
+}
+
+func TestProxyErrorHandler(t *testing.T) {
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ goodURL, _ := url.Parse(server.URL)
+ defer server.Close()
+ goodTarget := &ProxyTarget{
+ Name: "Good",
+ URL: goodURL,
+ }
+
+ badURL, _ := url.Parse("http://127.0.0.1:27121")
+ badTarget := &ProxyTarget{
+ Name: "Bad",
+ URL: badURL,
+ }
+
+ transformedError := errors.New("a new error")
+
+ testCases := []struct {
+ name string
+ target *ProxyTarget
+ errorHandler func(c *echo.Context, e error) error
+ expectFinalError func(t *testing.T, err error)
+ }{
+ {
+ name: "Error handler not invoked when request success",
+ target: goodTarget,
+ errorHandler: func(c *echo.Context, e error) error {
+ assert.FailNow(t, "error handler should not be invoked")
+ return e
+ },
+ },
+ {
+ name: "Error handler invoked when request fails",
+ target: badTarget,
+ errorHandler: func(c *echo.Context, e error) error {
+ httpErr, ok := e.(*echo.HTTPError)
+ assert.True(t, ok, "expected http error to be passed to handler")
+ assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler")
+ return transformedError
+ },
+ expectFinalError: func(t *testing.T, err error) {
+ assert.Equal(t, transformedError, err, "transformed error not returned from proxy")
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+ e.Use(ProxyWithConfig(
+ ProxyConfig{
+ Balancer: NewRoundRobinBalancer([]*ProxyTarget{tc.target}),
+ ErrorHandler: tc.errorHandler,
+ },
+ ))
+
+ errorHandlerCalled := false
+ dheh := echo.DefaultHTTPErrorHandler(false)
+ e.HTTPErrorHandler = func(c *echo.Context, err error) {
+ errorHandlerCalled = true
+ tc.expectFinalError(t, err)
+ dheh(c, err)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ if !errorHandlerCalled && tc.expectFinalError != nil {
+ t.Fatalf("error handler was not called")
+ }
+
+ })
+ }
+}
+
+type testContextKey string
+type customBalancer struct {
+ target *ProxyTarget
+}
+
+func (b *customBalancer) AddTarget(target *ProxyTarget) bool {
+ return false
+}
+func (b *customBalancer) RemoveTarget(name string) bool {
+ return false
+}
+
+func (b *customBalancer) Next(c *echo.Context) (*ProxyTarget, error) {
+ ctx := context.WithValue(c.Request().Context(), testContextKey("FROM_BALANCER"), "CUSTOM_BALANCER")
+ c.SetRequest(c.Request().WithContext(ctx))
+ return b.target, nil
+}
+
+func TestModifyResponseUseContext(t *testing.T) {
+ server := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte("OK"))
+ }),
+ )
+ defer server.Close()
+ targetURL, _ := url.Parse(server.URL)
+ e := echo.New()
+ e.Use(ProxyWithConfig(
+ ProxyConfig{
+ Balancer: &customBalancer{
+ target: &ProxyTarget{
+ Name: "tst",
+ URL: targetURL,
+ },
+ },
+ RetryCount: 1,
+ ModifyResponse: func(res *http.Response) error {
+ val := res.Request.Context().Value(testContextKey("FROM_BALANCER"))
+ if valStr, ok := val.(string); ok {
+ res.Header.Set("FROM_BALANCER", valStr)
+ }
+ return nil
+ },
+ },
+ ))
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "OK", rec.Body.String())
+ assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER"))
+}
+
+func createSimpleWebSocketServer(serveTLS bool) *httptest.Server {
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ wsHandler := func(conn *websocket.Conn) {
+ defer conn.Close()
+ for {
+ var msg string
+ err := websocket.Message.Receive(conn, &msg)
+ if err != nil {
+ return
+ }
+ // message back to the client
+ websocket.Message.Send(conn, msg)
+ }
+ }
+ websocket.Server{Handler: wsHandler}.ServeHTTP(w, r)
+ })
+ if serveTLS {
+ return httptest.NewTLSServer(handler)
+ }
+ return httptest.NewServer(handler)
+}
+
+func createSimpleProxyServer(t *testing.T, srv *httptest.Server, serveTLS bool, toTLS bool) *httptest.Server {
+ e := echo.New()
+
+ if toTLS {
+ // proxy to tls target
+ tgtURL, _ := url.Parse(srv.URL)
+ tgtURL.Scheme = "wss"
+ balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
+
+ defaultTransport, ok := http.DefaultTransport.(*http.Transport)
+ if !ok {
+ t.Fatal("Default transport is not of type *http.Transport")
+ }
+ transport := defaultTransport.Clone()
+ transport.TLSClientConfig = &tls.Config{
+ InsecureSkipVerify: true,
+ }
+ e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer, Transport: transport}))
+ } else {
+ // proxy to non-TLS target
+ tgtURL, _ := url.Parse(srv.URL)
+ balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
+ e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer}))
+ }
+
+ if serveTLS {
+ // serve proxy server with TLS
+ ts := httptest.NewTLSServer(e)
+ return ts
+ }
+ // serve proxy server without TLS
+ ts := httptest.NewServer(e)
+ return ts
+}
+
+// TestProxyWithConfigWebSocketNonTLS2NonTLS tests the proxy with non-TLS to non-TLS WebSocket connection.
+func TestProxyWithConfigWebSocketNonTLS2NonTLS(t *testing.T) {
+ /*
+ Arrange
+ */
+ // Create a WebSocket test server (non-TLS)
+ srv := createSimpleWebSocketServer(false)
+ defer srv.Close()
+
+ // create proxy server (non-TLS to non-TLS)
+ ts := createSimpleProxyServer(t, srv, false, false)
+ defer ts.Close()
+
+ tsURL, _ := url.Parse(ts.URL)
+ tsURL.Scheme = "ws"
+ tsURL.Path = "/"
+
+ /*
+ Act
+ */
+
+ // Connect to the proxy WebSocket
+ wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
+ assert.NoError(t, err)
+ defer wsConn.Close()
+
+ // Send message
+ sendMsg := "Hello, Non TLS WebSocket!"
+ err = websocket.Message.Send(wsConn, sendMsg)
+ assert.NoError(t, err)
+
+ /*
+ Assert
+ */
+ // Read response
+ var recvMsg string
+ err = websocket.Message.Receive(wsConn, &recvMsg)
+ assert.NoError(t, err)
+ assert.Equal(t, sendMsg, recvMsg)
+}
+
+// TestProxyWithConfigWebSocketTLS2TLS tests the proxy with TLS to TLS WebSocket connection.
+func TestProxyWithConfigWebSocketTLS2TLS(t *testing.T) {
+ /*
+ Arrange
+ */
+ // Create a WebSocket test server (TLS)
+ srv := createSimpleWebSocketServer(true)
+ defer srv.Close()
+
+ // create proxy server (TLS to TLS)
+ ts := createSimpleProxyServer(t, srv, true, true)
+ defer ts.Close()
+
+ tsURL, _ := url.Parse(ts.URL)
+ tsURL.Scheme = "wss"
+ tsURL.Path = "/"
+
+ /*
+ Act
+ */
+ origin, err := url.Parse(ts.URL)
+ assert.NoError(t, err)
+ config := &websocket.Config{
+ Location: tsURL,
+ Origin: origin,
+ TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
+ Version: websocket.ProtocolVersionHybi13,
+ }
+ wsConn, err := websocket.DialConfig(config)
+ assert.NoError(t, err)
+ defer wsConn.Close()
+
+ // Send message
+ sendMsg := "Hello, TLS to TLS WebSocket!"
+ err = websocket.Message.Send(wsConn, sendMsg)
+ assert.NoError(t, err)
+
+ // Read response
+ var recvMsg string
+ err = websocket.Message.Receive(wsConn, &recvMsg)
+ assert.NoError(t, err)
+ assert.Equal(t, sendMsg, recvMsg)
+}
+
+// TestProxyWithConfigWebSocketNonTLS2TLS tests the proxy with non-TLS to TLS WebSocket connection.
+func TestProxyWithConfigWebSocketNonTLS2TLS(t *testing.T) {
+ /*
+ Arrange
+ */
+
+ // Create a WebSocket test server (TLS)
+ srv := createSimpleWebSocketServer(true)
+ defer srv.Close()
+
+ // create proxy server (Non-TLS to TLS)
+ ts := createSimpleProxyServer(t, srv, false, true)
+ defer ts.Close()
+
+ tsURL, _ := url.Parse(ts.URL)
+ tsURL.Scheme = "ws"
+ tsURL.Path = "/"
+
+ /*
+ Act
+ */
+ // Connect to the proxy WebSocket
+ wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/")
+ assert.NoError(t, err)
+ defer wsConn.Close()
+
+ // Send message
+ sendMsg := "Hello, Non TLS to TLS WebSocket!"
+ err = websocket.Message.Send(wsConn, sendMsg)
+ assert.NoError(t, err)
+
+ /*
+ Assert
+ */
+ // Read response
+ var recvMsg string
+ err = websocket.Message.Receive(wsConn, &recvMsg)
+ assert.NoError(t, err)
+ assert.Equal(t, sendMsg, recvMsg)
+}
+
+// TestProxyWithConfigWebSocketTLSToNoneTLS tests the proxy with TLS to non-TLS WebSocket connection. (TLS termination)
+func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) {
+ /*
+ Arrange
+ */
+
+ // Create a WebSocket test server (non-TLS)
+ srv := createSimpleWebSocketServer(false)
+ defer srv.Close()
+
+ // create proxy server (TLS to non-TLS)
+ ts := createSimpleProxyServer(t, srv, true, false)
+ defer ts.Close()
+
+ tsURL, _ := url.Parse(ts.URL)
+ tsURL.Scheme = "wss"
+ tsURL.Path = "/"
+
+ /*
+ Act
+ */
+ origin, err := url.Parse(ts.URL)
+ assert.NoError(t, err)
+ config := &websocket.Config{
+ Location: tsURL,
+ Origin: origin,
+ TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing
+ Version: websocket.ProtocolVersionHybi13,
+ }
+ wsConn, err := websocket.DialConfig(config)
+ assert.NoError(t, err)
+ defer wsConn.Close()
+
+ // Send message
+ sendMsg := "Hello, TLS to NoneTLS WebSocket!"
+ err = websocket.Message.Send(wsConn, sendMsg)
+ assert.NoError(t, err)
+
+ // Read response
+ var recvMsg string
+ err = websocket.Message.Receive(wsConn, &recvMsg)
+ assert.NoError(t, err)
+ assert.Equal(t, sendMsg, recvMsg)
+}
diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go
new file mode 100644
index 000000000..c04ae157d
--- /dev/null
+++ b/middleware/rate_limiter.go
@@ -0,0 +1,263 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "errors"
+ "math"
+ "net/http"
+ "sync"
+ "time"
+
+ "github.com/labstack/echo/v5"
+ "golang.org/x/time/rate"
+)
+
+// RateLimiterStore is the interface to be implemented by custom stores.
+type RateLimiterStore interface {
+ Allow(identifier string) (bool, error)
+}
+
+// RateLimiterConfig defines the configuration for the rate limiter
+type RateLimiterConfig struct {
+ Skipper Skipper
+ BeforeFunc BeforeFunc
+ // IdentifierExtractor uses *echo.Context to extract the identifier for a visitor
+ IdentifierExtractor Extractor
+ // Store defines a store for the rate limiter
+ Store RateLimiterStore
+ // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error
+ ErrorHandler func(c *echo.Context, err error) error
+ // DenyHandler provides a handler to be called when RateLimiter denies access
+ DenyHandler func(c *echo.Context, identifier string, err error) error
+}
+
+// Extractor is used to extract data from *echo.Context
+type Extractor func(c *echo.Context) (string, error)
+
+// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
+var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
+
+// ErrExtractorError denotes an error raised when extractor function is unsuccessful
+var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
+
+// DefaultRateLimiterConfig defines default values for RateLimiterConfig
+var DefaultRateLimiterConfig = RateLimiterConfig{
+ Skipper: DefaultSkipper,
+ IdentifierExtractor: func(ctx *echo.Context) (string, error) {
+ id := ctx.RealIP()
+ return id, nil
+ },
+ ErrorHandler: func(c *echo.Context, err error) error {
+ return ErrExtractorError.Wrap(err)
+ },
+ DenyHandler: func(c *echo.Context, identifier string, err error) error {
+ return ErrRateLimitExceeded.Wrap(err)
+ },
+}
+
+/*
+RateLimiter returns a rate limiting middleware
+
+ e := echo.New()
+
+ limiterStore := middleware.NewRateLimiterMemoryStore(20)
+
+ e.GET("/rate-limited", func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }, RateLimiter(limiterStore))
+*/
+func RateLimiter(store RateLimiterStore) echo.MiddlewareFunc {
+ config := DefaultRateLimiterConfig
+ config.Store = store
+
+ return RateLimiterWithConfig(config)
+}
+
+/*
+RateLimiterWithConfig returns a rate limiting middleware
+
+ e := echo.New()
+
+ config := middleware.RateLimiterConfig{
+ Skipper: DefaultSkipper,
+ Store: middleware.NewRateLimiterMemoryStore(
+ middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute}
+ )
+ IdentifierExtractor: func(ctx *echo.Context) (string, error) {
+ id := ctx.RealIP()
+ return id, nil
+ },
+ ErrorHandler: func(ctx *echo.Context, err error) error {
+ return context.JSON(http.StatusTooManyRequests, nil)
+ },
+ DenyHandler: func(ctx *echo.Context, identifier string, err error) error {
+ return context.JSON(http.StatusForbidden, nil)
+ },
+ }
+
+ e.GET("/rate-limited", func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }, middleware.RateLimiterWithConfig(config))
+*/
+func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration
+func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Skipper == nil {
+ config.Skipper = DefaultRateLimiterConfig.Skipper
+ }
+ if config.IdentifierExtractor == nil {
+ config.IdentifierExtractor = DefaultRateLimiterConfig.IdentifierExtractor
+ }
+ if config.ErrorHandler == nil {
+ config.ErrorHandler = DefaultRateLimiterConfig.ErrorHandler
+ }
+ if config.DenyHandler == nil {
+ config.DenyHandler = DefaultRateLimiterConfig.DenyHandler
+ }
+ if config.Store == nil {
+ return nil, errors.New("echo rate limiter store configuration must be provided")
+ }
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+ if config.BeforeFunc != nil {
+ config.BeforeFunc(c)
+ }
+
+ identifier, err := config.IdentifierExtractor(c)
+ if err != nil {
+ return config.ErrorHandler(c, err)
+ }
+
+ if allow, allowErr := config.Store.Allow(identifier); !allow {
+ return config.DenyHandler(c, identifier, allowErr)
+ }
+ return next(c)
+ }
+ }, nil
+}
+
+// RateLimiterMemoryStore is the built-in store implementation for RateLimiter
+type RateLimiterMemoryStore struct {
+ visitors map[string]*Visitor
+ mutex sync.Mutex
+ rate float64 // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit
+ burst int
+ expiresIn time.Duration
+ lastCleanup time.Time
+
+ timeNow func() time.Time
+}
+
+// Visitor signifies a unique user's limiter details
+type Visitor struct {
+ *rate.Limiter
+ lastSeen time.Time
+}
+
+/*
+NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with
+the provided rate (as req/s).
+for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
+
+Burst and ExpiresIn will be set to default values.
+
+Note that if the provided rate is a float number and Burst is zero, Burst will be treated as the rounded down value of the rate.
+
+Example (with 20 requests/sec):
+
+ limiterStore := middleware.NewRateLimiterMemoryStore(20)
+*/
+func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore) {
+ return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
+ Rate: rateLimit,
+ })
+}
+
+/*
+NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore
+with the provided configuration. Rate must be provided. Burst will be set to the rounded down value of
+the configured rate if not provided or set to 0.
+
+The built-in memory store is usually capable for modest loads. For higher loads other
+store implementations should be considered.
+
+Characteristics:
+* Concurrency above 100 parallel requests may causes measurable lock contention
+* A high number of different IP addresses (above 16000) may be impacted by the internally used Go map
+* A high number of requests from a single IP address may cause lock contention
+
+Example:
+
+ limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig(
+ middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minute},
+ )
+*/
+func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) {
+ store = &RateLimiterMemoryStore{}
+
+ store.rate = config.Rate
+ store.burst = config.Burst
+ store.expiresIn = config.ExpiresIn
+ if config.ExpiresIn == 0 {
+ store.expiresIn = DefaultRateLimiterMemoryStoreConfig.ExpiresIn
+ }
+ if config.Burst == 0 {
+ store.burst = int(math.Max(1, math.Ceil(float64(config.Rate))))
+ }
+ store.visitors = make(map[string]*Visitor)
+ store.timeNow = time.Now
+ store.lastCleanup = store.timeNow()
+ return
+}
+
+// RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore
+type RateLimiterMemoryStoreConfig struct {
+ Rate float64 // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
+ Burst int // Burst is maximum number of requests to pass at the same moment. It additionally allows a number of requests to pass when rate limit is reached.
+ ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up
+}
+
+// DefaultRateLimiterMemoryStoreConfig provides default configuration values for RateLimiterMemoryStore
+var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{
+ ExpiresIn: 3 * time.Minute,
+}
+
+// Allow implements RateLimiterStore.Allow
+func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
+ store.mutex.Lock()
+ limiter, exists := store.visitors[identifier]
+ if !exists {
+ limiter = new(Visitor)
+ limiter.Limiter = rate.NewLimiter(rate.Limit(store.rate), store.burst)
+ store.visitors[identifier] = limiter
+ }
+ now := store.timeNow()
+ limiter.lastSeen = now
+ if now.Sub(store.lastCleanup) > store.expiresIn {
+ store.cleanupStaleVisitors(now)
+ }
+ allowed := limiter.AllowN(now, 1)
+ store.mutex.Unlock()
+ return allowed, nil
+}
+
+/*
+cleanupStaleVisitors helps manage the size of the visitors map by removing stale records
+of users who haven't visited again after the configured expiry time has elapsed
+*/
+func (store *RateLimiterMemoryStore) cleanupStaleVisitors(now time.Time) {
+ for id, visitor := range store.visitors {
+ if now.Sub(visitor.lastSeen) > store.expiresIn {
+ delete(store.visitors, id)
+ }
+ }
+ store.lastCleanup = now
+}
diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go
new file mode 100644
index 000000000..c591d2b19
--- /dev/null
+++ b/middleware/rate_limiter_test.go
@@ -0,0 +1,648 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "errors"
+ "math/rand"
+ "net/http"
+ "net/http/httptest"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+ "golang.org/x/time/rate"
+)
+
+func TestRateLimiter(t *testing.T) {
+ e := echo.New()
+
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
+
+ mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
+
+ testCases := []struct {
+ id string
+ expectErr string
+ }{
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ }
+
+ for _, tc := range testCases {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRealIP, tc.id)
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := mw(handler)(c)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, http.StatusOK, rec.Code)
+ }
+}
+
+func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) {
+ var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
+
+ assert.Panics(t, func() {
+ RateLimiterWithConfig(RateLimiterConfig{})
+ })
+
+ assert.NotPanics(t, func() {
+ RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
+ })
+}
+
+func TestRateLimiterWithConfig(t *testing.T) {
+ var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
+
+ e := echo.New()
+
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ mw, err := RateLimiterConfig{
+ IdentifierExtractor: func(c *echo.Context) (string, error) {
+ id := c.Request().Header.Get(echo.HeaderXRealIP)
+ if id == "" {
+ return "", errors.New("invalid identifier")
+ }
+ return id, nil
+ },
+ DenyHandler: func(ctx *echo.Context, identifier string, err error) error {
+ return ctx.JSON(http.StatusForbidden, nil)
+ },
+ ErrorHandler: func(ctx *echo.Context, err error) error {
+ return ctx.JSON(http.StatusBadRequest, nil)
+ },
+ Store: inMemoryStore,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ testCases := []struct {
+ id string
+ code int
+ }{
+ {"127.0.0.1", http.StatusOK},
+ {"127.0.0.1", http.StatusOK},
+ {"127.0.0.1", http.StatusOK},
+ {"127.0.0.1", http.StatusForbidden},
+ {"", http.StatusBadRequest},
+ {"127.0.0.1", http.StatusForbidden},
+ {"127.0.0.1", http.StatusForbidden},
+ }
+
+ for _, tc := range testCases {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRealIP, tc.id)
+
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec)
+
+ err := mw(handler)(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, tc.code, rec.Code)
+ }
+}
+
+func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
+ var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
+
+ e := echo.New()
+
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ mw, err := RateLimiterConfig{
+ IdentifierExtractor: func(c *echo.Context) (string, error) {
+ id := c.Request().Header.Get(echo.HeaderXRealIP)
+ if id == "" {
+ return "", errors.New("invalid identifier")
+ }
+ return id, nil
+ },
+ Store: inMemoryStore,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ testCases := []struct {
+ id string
+ expectErr string
+ }{
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {expectErr: "code=403, message=error while extracting identifier, err=invalid identifier"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ }
+
+ for _, tc := range testCases {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRealIP, tc.id)
+
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec)
+
+ err := mw(handler)(c)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, http.StatusOK, rec.Code)
+ }
+}
+
+func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
+ {
+ var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
+
+ e := echo.New()
+
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ mw, err := RateLimiterConfig{
+ Store: inMemoryStore,
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ testCases := []struct {
+ id string
+ expectErr string
+ }{
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
+ }
+
+ for _, tc := range testCases {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRealIP, tc.id)
+
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec)
+
+ err := mw(handler)(c)
+ if tc.expectErr != "" {
+ assert.EqualError(t, err, tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, http.StatusOK, rec.Code)
+ }
+ }
+}
+
+func TestRateLimiterWithConfig_skipper(t *testing.T) {
+ e := echo.New()
+
+ var beforeFuncRan bool
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+ var inMemoryStore = NewRateLimiterMemoryStore(5)
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
+
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec)
+
+ mw, err := RateLimiterConfig{
+ Skipper: func(c *echo.Context) bool {
+ return true
+ },
+ BeforeFunc: func(c *echo.Context) {
+ beforeFuncRan = true
+ },
+ Store: inMemoryStore,
+ IdentifierExtractor: func(ctx *echo.Context) (string, error) {
+ return "127.0.0.1", nil
+ },
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(handler)(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, false, beforeFuncRan)
+}
+
+func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
+ e := echo.New()
+
+ var beforeFuncRan bool
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+ var inMemoryStore = NewRateLimiterMemoryStore(5)
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
+
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec)
+
+ mw, err := RateLimiterConfig{
+ Skipper: func(c *echo.Context) bool {
+ return false
+ },
+ BeforeFunc: func(c *echo.Context) {
+ beforeFuncRan = true
+ },
+ Store: inMemoryStore,
+ IdentifierExtractor: func(ctx *echo.Context) (string, error) {
+ return "127.0.0.1", nil
+ },
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ _ = mw(handler)(c)
+
+ assert.Equal(t, true, beforeFuncRan)
+}
+
+func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
+ e := echo.New()
+
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ var beforeRan bool
+ var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
+
+ rec := httptest.NewRecorder()
+
+ c := e.NewContext(req, rec)
+
+ mw, err := RateLimiterConfig{
+ BeforeFunc: func(c *echo.Context) {
+ beforeRan = true
+ },
+ Store: inMemoryStore,
+ IdentifierExtractor: func(ctx *echo.Context) (string, error) {
+ return "127.0.0.1", nil
+ },
+ }.ToMiddleware()
+ assert.NoError(t, err)
+
+ err = mw(handler)(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, true, beforeRan)
+}
+
+func TestRateLimiterMemoryStore_Allow(t *testing.T) {
+ var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3, ExpiresIn: 2 * time.Second})
+ testCases := []struct {
+ id string
+ allowed bool
+ }{
+ {"127.0.0.1", true}, // 0 ms
+ {"127.0.0.1", true}, // 220 ms burst #2
+ {"127.0.0.1", true}, // 440 ms burst #3
+ {"127.0.0.1", false}, // 660 ms block
+ {"127.0.0.1", false}, // 880 ms block
+ {"127.0.0.1", true}, // 1100 ms next second #1
+ {"127.0.0.2", true}, // 1320 ms allow other ip
+ {"127.0.0.1", false}, // 1540 ms no burst
+ {"127.0.0.1", false}, // 1760 ms no burst
+ {"127.0.0.1", false}, // 1980 ms no burst
+ {"127.0.0.1", true}, // 2200 ms no burst
+ {"127.0.0.1", false}, // 2420 ms no burst
+ {"127.0.0.1", false}, // 2640 ms no burst
+ {"127.0.0.1", false}, // 2860 ms no burst
+ {"127.0.0.1", true}, // 3080 ms no burst
+ {"127.0.0.1", false}, // 3300 ms no burst
+ {"127.0.0.1", false}, // 3520 ms no burst
+ {"127.0.0.1", false}, // 3740 ms no burst
+ {"127.0.0.1", false}, // 3960 ms no burst
+ {"127.0.0.1", true}, // 4180 ms no burst
+ {"127.0.0.1", false}, // 4400 ms no burst
+ {"127.0.0.1", false}, // 4620 ms no burst
+ {"127.0.0.1", false}, // 4840 ms no burst
+ {"127.0.0.1", true}, // 5060 ms no burst
+ }
+
+ for i, tc := range testCases {
+ t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond)
+ inMemoryStore.timeNow = func() time.Time {
+ return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond)
+ }
+ allowed, _ := inMemoryStore.Allow(tc.id)
+ assert.Equal(t, tc.allowed, allowed)
+ }
+}
+
+func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) {
+ var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
+ inMemoryStore.visitors = map[string]*Visitor{
+ "A": {
+ Limiter: rate.NewLimiter(1, 3),
+ lastSeen: time.Now(),
+ },
+ "B": {
+ Limiter: rate.NewLimiter(1, 3),
+ lastSeen: time.Now().Add(-1 * time.Minute),
+ },
+ "C": {
+ Limiter: rate.NewLimiter(1, 3),
+ lastSeen: time.Now().Add(-5 * time.Minute),
+ },
+ "D": {
+ Limiter: rate.NewLimiter(1, 3),
+ lastSeen: time.Now().Add(-10 * time.Minute),
+ },
+ }
+
+ inMemoryStore.Allow("D")
+ inMemoryStore.cleanupStaleVisitors(time.Now())
+
+ var exists bool
+
+ _, exists = inMemoryStore.visitors["A"]
+ assert.Equal(t, true, exists)
+
+ _, exists = inMemoryStore.visitors["B"]
+ assert.Equal(t, true, exists)
+
+ _, exists = inMemoryStore.visitors["C"]
+ assert.Equal(t, false, exists)
+
+ _, exists = inMemoryStore.visitors["D"]
+ assert.Equal(t, true, exists)
+}
+
+func TestNewRateLimiterMemoryStore(t *testing.T) {
+ testCases := []struct {
+ rate float64
+ burst int
+ expiresIn time.Duration
+ expectedExpiresIn time.Duration
+ }{
+ {1, 3, 5 * time.Second, 5 * time.Second},
+ {2, 4, 0, 3 * time.Minute},
+ {1, 5, 10 * time.Minute, 10 * time.Minute},
+ {3, 7, 0, 3 * time.Minute},
+ }
+
+ for _, tc := range testCases {
+ store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: tc.rate, Burst: tc.burst, ExpiresIn: tc.expiresIn})
+ assert.Equal(t, tc.rate, store.rate)
+ assert.Equal(t, tc.burst, store.burst)
+ assert.Equal(t, tc.expectedExpiresIn, store.expiresIn)
+ }
+}
+
+func TestRateLimiterMemoryStore_FractionalRateDefaultBurst(t *testing.T) {
+ store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
+ Rate: 0.5, // fractional rate should get a burst of at least 1
+ })
+
+ base := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
+ store.timeNow = func() time.Time {
+ return base
+ }
+
+ allowed, err := store.Allow("user")
+ assert.NoError(t, err)
+ assert.True(t, allowed, "first request should not be blocked")
+
+ allowed, err = store.Allow("user")
+ assert.NoError(t, err)
+ assert.False(t, allowed, "burst token should be consumed immediately")
+
+ store.timeNow = func() time.Time {
+ return base.Add(2 * time.Second)
+ }
+
+ allowed, err = store.Allow("user")
+ assert.NoError(t, err)
+ assert.True(t, allowed, "token should refill for fractional rate after time passes")
+}
+
+func generateAddressList(count int) []string {
+ addrs := make([]string, count)
+ for i := 0; i < count; i++ {
+ addrs[i] = randomString(15)
+ }
+ return addrs
+}
+
+func run(wg *sync.WaitGroup, store RateLimiterStore, addrs []string, max int, b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ store.Allow(addrs[rand.Intn(max)])
+ }
+ wg.Done()
+}
+
+func benchmarkStore(store RateLimiterStore, parallel int, max int, b *testing.B) {
+ addrs := generateAddressList(max)
+ wg := &sync.WaitGroup{}
+ for i := 0; i < parallel; i++ {
+ wg.Add(1)
+ go run(wg, store, addrs, max, b)
+ }
+ wg.Wait()
+}
+
+const (
+ testExpiresIn = 1000 * time.Millisecond
+)
+
+func BenchmarkRateLimiterMemoryStore_1000(b *testing.B) {
+ var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
+ benchmarkStore(store, 10, 1000, b)
+}
+
+func BenchmarkRateLimiterMemoryStore_10000(b *testing.B) {
+ var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
+ benchmarkStore(store, 10, 10000, b)
+}
+
+func BenchmarkRateLimiterMemoryStore_100000(b *testing.B) {
+ var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
+ benchmarkStore(store, 10, 100000, b)
+}
+
+func BenchmarkRateLimiterMemoryStore_conc100_10000(b *testing.B) {
+ var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
+ benchmarkStore(store, 100, 10000, b)
+}
+
+// TestRateLimiterMemoryStore_TOCTOUFix verifies that the TOCTOU race condition is fixed
+// by ensuring timeNow() is only called once per Allow() call
+func TestRateLimiterMemoryStore_TOCTOUFix(t *testing.T) {
+ t.Parallel()
+
+ store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
+ Rate: 1,
+ Burst: 1,
+ ExpiresIn: 2 * time.Second,
+ })
+
+ // Track time calls to verify we use the same time value
+ timeCallCount := 0
+ baseTime := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
+
+ store.timeNow = func() time.Time {
+ timeCallCount++
+ return baseTime
+ }
+
+ // First request - should succeed
+ allowed, err := store.Allow("127.0.0.1")
+ assert.NoError(t, err)
+ assert.True(t, allowed, "First request should be allowed")
+
+ // Verify timeNow() was only called once
+ assert.Equal(t, 1, timeCallCount, "timeNow() should only be called once per Allow()")
+}
+
+// TestRateLimiterMemoryStore_ConcurrentAccess verifies rate limiting correctness under concurrent load
+func TestRateLimiterMemoryStore_ConcurrentAccess(t *testing.T) {
+ t.Parallel()
+
+ store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
+ Rate: 10,
+ Burst: 5,
+ ExpiresIn: 5 * time.Second,
+ })
+
+ const goroutines = 50
+ const requestsPerGoroutine = 20
+
+ var wg sync.WaitGroup
+ var allowedCount, deniedCount int32
+
+ for i := 0; i < goroutines; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for j := 0; j < requestsPerGoroutine; j++ {
+ allowed, err := store.Allow("test-user")
+ assert.NoError(t, err)
+ if allowed {
+ atomic.AddInt32(&allowedCount, 1)
+ } else {
+ atomic.AddInt32(&deniedCount, 1)
+ }
+ time.Sleep(time.Millisecond)
+ }
+ }()
+ }
+
+ wg.Wait()
+
+ totalRequests := goroutines * requestsPerGoroutine
+ allowed := int(allowedCount)
+ denied := int(deniedCount)
+
+ assert.Equal(t, totalRequests, allowed+denied, "All requests should be processed")
+ assert.Greater(t, denied, 0, "Some requests should be denied due to rate limiting")
+ assert.Greater(t, allowed, 0, "Some requests should be allowed")
+}
+
+// TestRateLimiterMemoryStore_RaceDetection verifies no data races with high concurrency
+// Run with: go test -race ./middleware -run TestRateLimiterMemoryStore_RaceDetection
+func TestRateLimiterMemoryStore_RaceDetection(t *testing.T) {
+ t.Parallel()
+
+ store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
+ Rate: 100,
+ Burst: 200,
+ ExpiresIn: 1 * time.Second,
+ })
+
+ const goroutines = 100
+ const requestsPerGoroutine = 100
+
+ var wg sync.WaitGroup
+ identifiers := []string{"user1", "user2", "user3", "user4", "user5"}
+
+ for i := 0; i < goroutines; i++ {
+ wg.Add(1)
+ go func(routineID int) {
+ defer wg.Done()
+ for j := 0; j < requestsPerGoroutine; j++ {
+ identifier := identifiers[routineID%len(identifiers)]
+ _, err := store.Allow(identifier)
+ assert.NoError(t, err)
+ }
+ }(i)
+ }
+
+ wg.Wait()
+}
+
+// TestRateLimiterMemoryStore_TimeOrdering verifies time ordering consistency in rate limiting decisions
+func TestRateLimiterMemoryStore_TimeOrdering(t *testing.T) {
+ t.Parallel()
+
+ store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
+ Rate: 1,
+ Burst: 2,
+ ExpiresIn: 5 * time.Second,
+ })
+
+ currentTime := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
+ store.timeNow = func() time.Time {
+ return currentTime
+ }
+
+ // First two requests should succeed (burst=2)
+ allowed1, _ := store.Allow("user1")
+ assert.True(t, allowed1, "Request 1 should be allowed (burst)")
+
+ allowed2, _ := store.Allow("user1")
+ assert.True(t, allowed2, "Request 2 should be allowed (burst)")
+
+ // Third request should be denied
+ allowed3, _ := store.Allow("user1")
+ assert.False(t, allowed3, "Request 3 should be denied (burst exhausted)")
+
+ // Advance time by 1 second
+ currentTime = currentTime.Add(1 * time.Second)
+
+ // Fourth request should succeed
+ allowed4, _ := store.Allow("user1")
+ assert.True(t, allowed4, "Request 4 should be allowed (1 token available)")
+}
diff --git a/middleware/recover.go b/middleware/recover.go
index b62edb526..01fde5152 100644
--- a/middleware/recover.go
+++ b/middleware/recover.go
@@ -1,28 +1,103 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
"fmt"
-
+ "net/http"
"runtime"
- "github.com/labstack/echo"
+ "github.com/labstack/echo/v5"
)
+// RecoverConfig defines the config for Recover middleware.
+type RecoverConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // Size of the stack to be printed.
+ // Optional. Default value 4KB.
+ StackSize int
+
+ // DisableStackAll disables formatting stack traces of all other goroutines
+ // into buffer after the trace for the current goroutine.
+ // Optional. Default value false.
+ DisableStackAll bool
+
+ // DisablePrintStack disables printing stack trace.
+ // Optional. Default value as false.
+ DisablePrintStack bool
+}
+
+// DefaultRecoverConfig is the default Recover middleware config.
+var DefaultRecoverConfig = RecoverConfig{
+ Skipper: DefaultSkipper,
+ StackSize: 4 << 10, // 4 KB
+ DisableStackAll: false,
+ DisablePrintStack: false,
+}
+
// Recover returns a middleware which recovers from panics anywhere in the chain
// and handles the control to the centralized HTTPErrorHandler.
func Recover() echo.MiddlewareFunc {
- // TODO: Provide better stack trace `https://github.com/go-errors/errors` `https://github.com/docker/libcontainer/tree/master/stacktrace`
- return func(h echo.HandlerFunc) echo.HandlerFunc {
- return func(c *echo.Context) error {
+ return RecoverWithConfig(DefaultRecoverConfig)
+}
+
+// RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration.
+func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration
+func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ // Defaults
+ if config.Skipper == nil {
+ config.Skipper = DefaultRecoverConfig.Skipper
+ }
+ if config.StackSize == 0 {
+ config.StackSize = DefaultRecoverConfig.StackSize
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) (err error) {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
defer func() {
- if err := recover(); err != nil {
- trace := make([]byte, 1<<16)
- n := runtime.Stack(trace, true)
- c.Error(fmt.Errorf("echo => panic recover\n %v\n stack trace %d bytes\n %s",
- err, n, trace[:n]))
+ if r := recover(); r != nil {
+ if r == http.ErrAbortHandler {
+ panic(r)
+ }
+ tmpErr, ok := r.(error)
+ if !ok {
+ tmpErr = fmt.Errorf("%v", r)
+ }
+ if !config.DisablePrintStack {
+ stack := make([]byte, config.StackSize)
+ length := runtime.Stack(stack, !config.DisableStackAll)
+ tmpErr = &PanicStackError{Stack: stack[:length], Err: tmpErr}
+ }
+ err = tmpErr
}
}()
- return h(c)
+ return next(c)
}
- }
+ }, nil
+}
+
+// PanicStackError is an error type that wraps an error along with its stack trace.
+// It is returned when config.DisablePrintStack is set to false.
+type PanicStackError struct {
+ Stack []byte
+ Err error
+}
+
+func (e *PanicStackError) Error() string {
+ return fmt.Sprintf("[PANIC RECOVER] %s %s", e.Err.Error(), e.Stack)
+}
+
+func (e *PanicStackError) Unwrap() error {
+ return e.Err
}
diff --git a/middleware/recover_test.go b/middleware/recover_test.go
index 003c75191..719e0cc3d 100644
--- a/middleware/recover_test.go
+++ b/middleware/recover_test.go
@@ -1,24 +1,150 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
package middleware
import (
+ "bytes"
+ "errors"
+ "log/slog"
"net/http"
"net/http/httptest"
"testing"
- "github.com/labstack/echo"
+ "github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
)
func TestRecover(t *testing.T) {
e := echo.New()
- e.SetDebug(true)
- req, _ := http.NewRequest(echo.GET, "/", nil)
+ buf := new(bytes.Buffer)
+ e.Logger = slog.New(&discardHandler{})
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
- c := echo.NewContext(req, echo.NewResponse(rec), e)
- h := func(c *echo.Context) error {
+ c := e.NewContext(req, rec)
+ h := Recover()(func(c *echo.Context) error {
panic("test")
+ })
+ err := h(c)
+ assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine")
+
+ var pse *PanicStackError
+ if errors.As(err, &pse) {
+ assert.Contains(t, string(pse.Stack), "middleware/recover.go")
+ } else {
+ assert.Fail(t, "not of type PanicStackError")
+ }
+
+ assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
+ assert.Contains(t, buf.String(), "") // nothing is logged
+}
+
+func TestRecover_skipper(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ config := RecoverConfig{
+ Skipper: func(c *echo.Context) bool {
+ return true
+ },
}
- Recover()(h)(c)
+ h := RecoverWithConfig(config)(func(c *echo.Context) error {
+ panic("testPANIC")
+ })
+
+ var err error
+ assert.Panics(t, func() {
+ err = h(c)
+ })
+
+ assert.NoError(t, err)
+ assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
+}
+
+func TestRecoverErrAbortHandler(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ h := Recover()(func(c *echo.Context) error {
+ panic(http.ErrAbortHandler)
+ })
+ defer func() {
+ r := recover()
+ if r == nil {
+ assert.Fail(t, "expecting `http.ErrAbortHandler`, got `nil`")
+ } else {
+ if err, ok := r.(error); ok {
+ assert.ErrorIs(t, err, http.ErrAbortHandler)
+ } else {
+ assert.Fail(t, "not of error type")
+ }
+ }
+ }()
+
+ hErr := h(c)
+
assert.Equal(t, http.StatusInternalServerError, rec.Code)
- assert.Contains(t, rec.Body.String(), "panic recover")
+ assert.NotContains(t, hErr.Error(), "PANIC RECOVER")
+}
+
+func TestRecoverWithConfig(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenNoPanic bool
+ whenConfig RecoverConfig
+ expectErrContain string
+ expectErr string
+ }{
+ {
+ name: "ok, default config",
+ whenConfig: DefaultRecoverConfig,
+ expectErrContain: "[PANIC RECOVER] testPANIC goroutine",
+ },
+ {
+ name: "ok, no panic",
+ givenNoPanic: true,
+ whenConfig: DefaultRecoverConfig,
+ expectErrContain: "",
+ },
+ {
+ name: "ok, DisablePrintStack",
+ whenConfig: RecoverConfig{
+ DisablePrintStack: true,
+ },
+ expectErr: "testPANIC",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ config := tc.whenConfig
+ h := RecoverWithConfig(config)(func(c *echo.Context) error {
+ if tc.givenNoPanic {
+ return nil
+ }
+ panic("testPANIC")
+ })
+
+ err := h(c)
+
+ if tc.expectErrContain != "" {
+ assert.Contains(t, err.Error(), tc.expectErrContain)
+ } else if tc.expectErr != "" {
+ assert.Contains(t, err.Error(), tc.expectErr)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain
+ })
+ }
}
diff --git a/middleware/redirect.go b/middleware/redirect.go
new file mode 100644
index 000000000..bb7045cfe
--- /dev/null
+++ b/middleware/redirect.go
@@ -0,0 +1,184 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "errors"
+ "net/http"
+ "strings"
+
+ "github.com/labstack/echo/v5"
+)
+
+// RedirectConfig defines the config for Redirect middleware.
+type RedirectConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper
+
+ // Status code to be used when redirecting the request.
+ // Optional. Default value http.StatusMovedPermanently.
+ Code int
+
+ redirect redirectLogic
+}
+
+// redirectLogic represents a function that given a scheme, host and uri
+// can both: 1) determine if redirect is needed (will set ok accordingly) and
+// 2) return the appropriate redirect url.
+type redirectLogic func(scheme, host, uri string) (ok bool, url string)
+
+const www = "www."
+
+// RedirectHTTPSConfig is the HTTPS Redirect middleware config.
+var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS}
+
+// RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config.
+var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW}
+
+// RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config.
+var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW}
+
+// RedirectWWWConfig is the WWW Redirect middleware config.
+var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW}
+
+// RedirectNonWWWConfig is the non WWW Redirect middleware config.
+var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW}
+
+// HTTPSRedirect redirects http requests to https.
+// For example, http://labstack.com will be redirect to https://labstack.com.
+//
+// Usage `Echo#Pre(HTTPSRedirect())`
+func HTTPSRedirect() echo.MiddlewareFunc {
+ return HTTPSRedirectWithConfig(RedirectHTTPSConfig)
+}
+
+// HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration.
+func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
+ config.redirect = redirectHTTPS
+ return toMiddlewareOrPanic(config)
+}
+
+// HTTPSWWWRedirect redirects http requests to https www.
+// For example, http://labstack.com will be redirect to https://www.labstack.com.
+//
+// Usage `Echo#Pre(HTTPSWWWRedirect())`
+func HTTPSWWWRedirect() echo.MiddlewareFunc {
+ return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig)
+}
+
+// HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration.
+func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
+ config.redirect = redirectHTTPSWWW
+ return toMiddlewareOrPanic(config)
+}
+
+// HTTPSNonWWWRedirect redirects http requests to https non www.
+// For example, http://www.labstack.com will be redirect to https://labstack.com.
+//
+// Usage `Echo#Pre(HTTPSNonWWWRedirect())`
+func HTTPSNonWWWRedirect() echo.MiddlewareFunc {
+ return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig)
+}
+
+// HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration.
+func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
+ config.redirect = redirectNonHTTPSWWW
+ return toMiddlewareOrPanic(config)
+}
+
+// WWWRedirect redirects non www requests to www.
+// For example, http://labstack.com will be redirect to http://www.labstack.com.
+//
+// Usage `Echo#Pre(WWWRedirect())`
+func WWWRedirect() echo.MiddlewareFunc {
+ return WWWRedirectWithConfig(RedirectWWWConfig)
+}
+
+// WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration.
+func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
+ config.redirect = redirectWWW
+ return toMiddlewareOrPanic(config)
+}
+
+// NonWWWRedirect redirects www requests to non www.
+// For example, http://www.labstack.com will be redirect to http://labstack.com.
+//
+// Usage `Echo#Pre(NonWWWRedirect())`
+func NonWWWRedirect() echo.MiddlewareFunc {
+ return NonWWWRedirectWithConfig(RedirectNonWWWConfig)
+}
+
+// NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration.
+func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
+ config.redirect = redirectNonWWW
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration
+func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Skipper == nil {
+ config.Skipper = DefaultSkipper
+ }
+ if config.Code == 0 {
+ config.Code = http.StatusMovedPermanently
+ }
+ if config.redirect == nil {
+ return nil, errors.New("redirectConfig is missing redirect function")
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ req, scheme := c.Request(), c.Scheme()
+ host := req.Host
+ if ok, url := config.redirect(scheme, host, req.RequestURI); ok {
+ return c.Redirect(config.Code, url)
+ }
+
+ return next(c)
+ }
+ }, nil
+}
+
+var redirectHTTPS = func(scheme, host, uri string) (bool, string) {
+ if scheme != "https" {
+ return true, "https://" + host + uri
+ }
+ return false, ""
+}
+
+var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) {
+ // Redirect if not HTTPS OR missing www prefix (needs either fix)
+ if scheme != "https" || !strings.HasPrefix(host, www) {
+ host = strings.TrimPrefix(host, www) // Remove www if present to avoid duplication
+ return true, "https://www." + host + uri
+ }
+ return false, ""
+}
+
+var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) {
+ // Redirect if not HTTPS OR has www prefix (needs either fix)
+ if scheme != "https" || strings.HasPrefix(host, www) {
+ host = strings.TrimPrefix(host, www)
+ return true, "https://" + host + uri
+ }
+ return false, ""
+}
+
+var redirectWWW = func(scheme, host, uri string) (bool, string) {
+ if !strings.HasPrefix(host, www) {
+ return true, scheme + "://www." + host + uri
+ }
+ return false, ""
+}
+
+var redirectNonWWW = func(scheme, host, uri string) (bool, string) {
+ if strings.HasPrefix(host, www) {
+ return true, scheme + "://" + host[4:] + uri
+ }
+ return false, ""
+}
diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go
new file mode 100644
index 000000000..a127ca40c
--- /dev/null
+++ b/middleware/redirect_test.go
@@ -0,0 +1,295 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+type middlewareGenerator func() echo.MiddlewareFunc
+
+func TestRedirectHTTPSRedirect(t *testing.T) {
+ var testCases = []struct {
+ whenHost string
+ whenHeader http.Header
+ expectLocation string
+ expectStatusCode int
+ }{
+ {
+ whenHost: "labstack.com",
+ expectLocation: "https://labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "labstack.com",
+ whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
+ expectLocation: "",
+ expectStatusCode: http.StatusOK,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.whenHost, func(t *testing.T) {
+ res := redirectTest(HTTPSRedirect, tc.whenHost, tc.whenHeader)
+
+ assert.Equal(t, tc.expectStatusCode, res.Code)
+ assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
+ })
+ }
+}
+
+func TestRedirectHTTPSWWWRedirect(t *testing.T) {
+ var testCases = []struct {
+ whenHost string
+ whenHeader http.Header
+ expectLocation string
+ expectStatusCode int
+ }{
+ {
+ whenHost: "labstack.com",
+ expectLocation: "https://www.labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "www.labstack.com",
+ expectLocation: "https://www.labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "a.com",
+ expectLocation: "https://www.a.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "ip",
+ expectLocation: "https://www.ip/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "labstack.com",
+ whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
+ expectLocation: "https://www.labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "www.labstack.com",
+ whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
+ expectLocation: "",
+ expectStatusCode: http.StatusOK,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.whenHost, func(t *testing.T) {
+ res := redirectTest(HTTPSWWWRedirect, tc.whenHost, tc.whenHeader)
+
+ assert.Equal(t, tc.expectStatusCode, res.Code)
+ assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
+ })
+ }
+}
+
+func TestRedirectHTTPSNonWWWRedirect(t *testing.T) {
+ var testCases = []struct {
+ whenHost string
+ whenHeader http.Header
+ expectLocation string
+ expectStatusCode int
+ }{
+ {
+ whenHost: "www.labstack.com",
+ expectLocation: "https://labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "a.com",
+ expectLocation: "https://a.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "ip",
+ expectLocation: "https://ip/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "www.labstack.com",
+ whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
+ expectLocation: "https://labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "labstack.com",
+ whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
+ expectLocation: "",
+ expectStatusCode: http.StatusOK,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.whenHost, func(t *testing.T) {
+ res := redirectTest(HTTPSNonWWWRedirect, tc.whenHost, tc.whenHeader)
+
+ assert.Equal(t, tc.expectStatusCode, res.Code)
+ assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
+ })
+ }
+}
+
+func TestRedirectWWWRedirect(t *testing.T) {
+ var testCases = []struct {
+ whenHost string
+ whenHeader http.Header
+ expectLocation string
+ expectStatusCode int
+ }{
+ {
+ whenHost: "labstack.com",
+ expectLocation: "http://www.labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "a.com",
+ expectLocation: "http://www.a.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "ip",
+ expectLocation: "http://www.ip/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "a.com",
+ whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
+ expectLocation: "https://www.a.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "www.ip",
+ expectLocation: "",
+ expectStatusCode: http.StatusOK,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.whenHost, func(t *testing.T) {
+ res := redirectTest(WWWRedirect, tc.whenHost, tc.whenHeader)
+
+ assert.Equal(t, tc.expectStatusCode, res.Code)
+ assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
+ })
+ }
+}
+
+func TestRedirectNonWWWRedirect(t *testing.T) {
+ var testCases = []struct {
+ whenHost string
+ whenHeader http.Header
+ expectLocation string
+ expectStatusCode int
+ }{
+ {
+ whenHost: "www.labstack.com",
+ expectLocation: "http://labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "www.a.com",
+ expectLocation: "http://a.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "www.a.com",
+ whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}},
+ expectLocation: "https://a.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ whenHost: "ip",
+ expectLocation: "",
+ expectStatusCode: http.StatusOK,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.whenHost, func(t *testing.T) {
+ res := redirectTest(NonWWWRedirect, tc.whenHost, tc.whenHeader)
+
+ assert.Equal(t, tc.expectStatusCode, res.Code)
+ assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
+ })
+ }
+}
+
+func TestNonWWWRedirectWithConfig(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenCode int
+ givenSkipFunc func(c *echo.Context) bool
+ whenHost string
+ whenHeader http.Header
+ expectLocation string
+ expectStatusCode int
+ }{
+ {
+ name: "usual redirect",
+ whenHost: "www.labstack.com",
+ expectLocation: "http://labstack.com/",
+ expectStatusCode: http.StatusMovedPermanently,
+ },
+ {
+ name: "redirect is skipped",
+ givenSkipFunc: func(c *echo.Context) bool {
+ return true // skip always
+ },
+ whenHost: "www.labstack.com",
+ expectLocation: "",
+ expectStatusCode: http.StatusOK,
+ },
+ {
+ name: "redirect with custom status code",
+ givenCode: http.StatusSeeOther,
+ whenHost: "www.labstack.com",
+ expectLocation: "http://labstack.com/",
+ expectStatusCode: http.StatusSeeOther,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.whenHost, func(t *testing.T) {
+ middleware := func() echo.MiddlewareFunc {
+ return NonWWWRedirectWithConfig(RedirectConfig{
+ Skipper: tc.givenSkipFunc,
+ Code: tc.givenCode,
+ })
+ }
+ res := redirectTest(middleware, tc.whenHost, tc.whenHeader)
+
+ assert.Equal(t, tc.expectStatusCode, res.Code)
+ assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation))
+ })
+ }
+}
+
+func redirectTest(fn middlewareGenerator, host string, header http.Header) *httptest.ResponseRecorder {
+ e := echo.New()
+ next := func(c *echo.Context) (err error) {
+ return c.NoContent(http.StatusOK)
+ }
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Host = host
+ if header != nil {
+ req.Header = header
+ }
+ res := httptest.NewRecorder()
+ c := e.NewContext(req, res)
+
+ fn()(next)(c)
+
+ return res
+}
diff --git a/middleware/request_id.go b/middleware/request_id.go
new file mode 100644
index 000000000..b3de40d19
--- /dev/null
+++ b/middleware/request_id.go
@@ -0,0 +1,73 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "github.com/labstack/echo/v5"
+)
+
+// RequestIDConfig defines the config for RequestID middleware.
+type RequestIDConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // Generator defines a function to generate an ID.
+ // Optional. Default value random.String(32).
+ Generator func() string
+
+ // RequestIDHandler defines a function which is executed for a request id.
+ RequestIDHandler func(c *echo.Context, requestID string)
+
+ // TargetHeader defines what header to look for to populate the id.
+ // Optional. Default value is `X-Request-Id`
+ TargetHeader string
+}
+
+// RequestID returns a middleware that reads RequestIDConfig.TargetHeader (`X-Request-ID`) header value or when
+// the header value is empty, generates that value and sets request ID to response
+// as RequestIDConfig.TargetHeader (`X-Request-Id`) value.
+func RequestID() echo.MiddlewareFunc {
+ return RequestIDWithConfig(RequestIDConfig{})
+}
+
+// RequestIDWithConfig returns a middleware with given valid config or panics on invalid configuration.
+// The middleware reads RequestIDConfig.TargetHeader (`X-Request-ID`) header value or when the header value is empty,
+// generates that value and sets request ID to response as RequestIDConfig.TargetHeader (`X-Request-Id`) value.
+func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration
+func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Skipper == nil {
+ config.Skipper = DefaultSkipper
+ }
+ if config.Generator == nil {
+ config.Generator = createRandomStringGenerator(32)
+ }
+ if config.TargetHeader == "" {
+ config.TargetHeader = echo.HeaderXRequestID
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ req := c.Request()
+ res := c.Response()
+ rid := req.Header.Get(config.TargetHeader)
+ if rid == "" {
+ rid = config.Generator()
+ }
+ res.Header().Set(config.TargetHeader, rid)
+ if config.RequestIDHandler != nil {
+ config.RequestIDHandler(c, rid)
+ }
+
+ return next(c)
+ }
+ }, nil
+}
diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go
new file mode 100644
index 000000000..465e6fc42
--- /dev/null
+++ b/middleware/request_id_test.go
@@ -0,0 +1,170 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestRequestID(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ rid := RequestID()
+ h := rid(handler)
+ err := h(c)
+ assert.NoError(t, err)
+ assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
+}
+
+func TestMustRequestIDWithConfig_skipper(t *testing.T) {
+ e := echo.New()
+ e.GET("/", func(c *echo.Context) error {
+ return c.String(http.StatusTeapot, "test")
+ })
+
+ generatorCalled := false
+ e.Use(RequestIDWithConfig(RequestIDConfig{
+ Skipper: func(c *echo.Context) bool {
+ return true
+ },
+ Generator: func() string {
+ generatorCalled = true
+ return "customGenerator"
+ },
+ }))
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ res := httptest.NewRecorder()
+ e.ServeHTTP(res, req)
+
+ assert.Equal(t, http.StatusTeapot, res.Code)
+ assert.Equal(t, "test", res.Body.String())
+
+ assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "")
+ assert.False(t, generatorCalled)
+}
+
+func TestMustRequestIDWithConfig_customGenerator(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ rid := RequestIDWithConfig(RequestIDConfig{
+ Generator: func() string { return "customGenerator" },
+ })
+ h := rid(handler)
+ err := h(c)
+ assert.NoError(t, err)
+ assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
+}
+
+func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ called := false
+ rid := RequestIDWithConfig(RequestIDConfig{
+ Generator: func() string { return "customGenerator" },
+ RequestIDHandler: func(c *echo.Context, s string) {
+ called = true
+ },
+ })
+ h := rid(handler)
+ err := h(c)
+ assert.NoError(t, err)
+ assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
+ assert.True(t, called)
+}
+
+func TestRequestIDWithConfig(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ rid, err := RequestIDConfig{}.ToMiddleware()
+ assert.NoError(t, err)
+ h := rid(handler)
+ h(c)
+ assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32)
+
+ // Custom generator
+ rid = RequestIDWithConfig(RequestIDConfig{
+ Generator: func() string { return "customGenerator" },
+ })
+ h = rid(handler)
+ h(c)
+ assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator")
+}
+
+func TestRequestID_IDNotAltered(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Add(echo.HeaderXRequestID, "")
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ rid := RequestIDWithConfig(RequestIDConfig{})
+ h := rid(handler)
+ _ = h(c)
+ assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "")
+}
+
+func TestRequestIDConfigDifferentHeader(t *testing.T) {
+ e := echo.New()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ handler := func(c *echo.Context) error {
+ return c.String(http.StatusOK, "test")
+ }
+
+ rid := RequestIDWithConfig(RequestIDConfig{TargetHeader: echo.HeaderXCorrelationID})
+ h := rid(handler)
+ h(c)
+ assert.Len(t, rec.Header().Get(echo.HeaderXCorrelationID), 32)
+
+ // Custom generator and handler
+ customID := "customGenerator"
+ calledHandler := false
+ rid = RequestIDWithConfig(RequestIDConfig{
+ Generator: func() string { return customID },
+ TargetHeader: echo.HeaderXCorrelationID,
+ RequestIDHandler: func(_ *echo.Context, id string) {
+ calledHandler = true
+ assert.Equal(t, customID, id)
+ },
+ })
+ h = rid(handler)
+ h(c)
+ assert.Equal(t, rec.Header().Get(echo.HeaderXCorrelationID), "customGenerator")
+ assert.True(t, calledHandler)
+}
diff --git a/middleware/request_logger.go b/middleware/request_logger.go
new file mode 100644
index 000000000..76903c62a
--- /dev/null
+++ b/middleware/request_logger.go
@@ -0,0 +1,462 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "context"
+ "errors"
+ "log/slog"
+ "net/http"
+ "time"
+
+ "github.com/labstack/echo/v5"
+)
+
+// Example for `slog` https://pkg.go.dev/log/slog
+// logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
+// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
+// LogStatus: true,
+// LogURI: true,
+// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
+// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
+// if v.Error == nil {
+// logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST",
+// slog.String("uri", v.URI),
+// slog.Int("status", v.Status),
+// )
+// } else {
+// logger.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR",
+// slog.String("uri", v.URI),
+// slog.Int("status", v.Status),
+// slog.String("err", v.Error.Error()),
+// )
+// }
+// return nil
+// },
+// }))
+//
+// Example for `fmt.Printf`
+// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
+// LogStatus: true,
+// LogURI: true,
+// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
+// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
+// if v.Error == nil {
+// fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status)
+// } else {
+// fmt.Printf("REQUEST_ERROR: uri: %v, status: %v, err: %v\n", v.URI, v.Status, v.Error)
+// }
+// return nil
+// },
+// }))
+//
+// Example for Zerolog (https://github.com/rs/zerolog)
+// logger := zerolog.New(os.Stdout)
+// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
+// LogURI: true,
+// LogStatus: true,
+// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
+// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
+// if v.Error == nil {
+// logger.Info().
+// Str("URI", v.URI).
+// Int("status", v.Status).
+// Msg("request")
+// } else {
+// logger.Error().
+// Err(v.Error).
+// Str("URI", v.URI).
+// Int("status", v.Status).
+// Msg("request error")
+// }
+// return nil
+// },
+// }))
+//
+// Example for Zap (https://github.com/uber-go/zap)
+// logger, _ := zap.NewProduction()
+// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
+// LogURI: true,
+// LogStatus: true,
+// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
+// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
+// if v.Error == nil {
+// logger.Info("request",
+// zap.String("URI", v.URI),
+// zap.Int("status", v.Status),
+// )
+// } else {
+// logger.Error("request error",
+// zap.String("URI", v.URI),
+// zap.Int("status", v.Status),
+// zap.Error(v.Error),
+// )
+// }
+// return nil
+// },
+// }))
+//
+// Example for Logrus (https://github.com/sirupsen/logrus)
+// log := logrus.New()
+// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{
+// LogURI: true,
+// LogStatus: true,
+// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code
+// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error {
+// if v.Error == nil {
+// log.WithFields(logrus.Fields{
+// "URI": v.URI,
+// "status": v.Status,
+// }).Info("request")
+// } else {
+// log.WithFields(logrus.Fields{
+// "URI": v.URI,
+// "status": v.Status,
+// "error": v.Error,
+// }).Error("request error")
+// }
+// return nil
+// },
+// }))
+
+// RequestLoggerConfig is configuration for Request Logger middleware.
+type RequestLoggerConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // BeforeNextFunc defines a function that is called before next middleware or handler is called in chain.
+ BeforeNextFunc func(c *echo.Context)
+ // LogValuesFunc defines a function that is called with values extracted by logger from request/response.
+ // Mandatory.
+ LogValuesFunc func(c *echo.Context, v RequestLoggerValues) error
+
+ // HandleError instructs logger to call global error handler when next middleware/handler returns an error.
+ // This is useful when you have custom error handler that can decide to use different status codes.
+ //
+ // A side-effect of calling global error handler is that now Response has been committed and sent to the client
+ // and middlewares up in chain can not change Response status code or response body.
+ HandleError bool
+
+ // LogLatency instructs logger to record duration it took to execute rest of the handler chain (next(c) call).
+ LogLatency bool
+ // LogProtocol instructs logger to extract request protocol (i.e. `HTTP/1.1` or `HTTP/2`)
+ LogProtocol bool
+ // LogRemoteIP instructs logger to extract request remote IP. See `echo.Context.RealIP()` for implementation details.
+ LogRemoteIP bool
+ // LogHost instructs logger to extract request host value (i.e. `example.com`)
+ LogHost bool
+ // LogMethod instructs logger to extract request method value (i.e. `GET` etc)
+ LogMethod bool
+ // LogURI instructs logger to extract request URI (i.e. `/list?lang=en&page=1`)
+ LogURI bool
+ // LogURIPath instructs logger to extract request URI path part (i.e. `/list`)
+ LogURIPath bool
+ // LogRoutePath instructs logger to extract route path part to which request was matched to (i.e. `/user/:id`)
+ LogRoutePath bool
+ // LogRequestID instructs logger to extract request ID from request `X-Request-ID` header or response if request did not have value.
+ LogRequestID bool
+ // LogReferer instructs logger to extract request referer values.
+ LogReferer bool
+ // LogUserAgent instructs logger to extract request user agent values.
+ LogUserAgent bool
+ // LogStatus instructs logger to extract response status code. If handler chain returns an echo.HTTPError,
+ // the status code is extracted from the echo.HTTPError returned
+ LogStatus bool
+ // LogContentLength instructs logger to extract content length header value. Note: this value could be different from
+ // actual request body size as it could be spoofed etc.
+ LogContentLength bool
+ // LogResponseSize instructs logger to extract response content length value. Note: when used with Gzip middleware
+ // this value may not be always correct.
+ LogResponseSize bool
+ // LogHeaders instructs logger to extract given list of headers from request. Note: request can contain more than
+ // one header with same value so slice of values is been logger for each given header.
+ //
+ // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header
+ // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding".
+ LogHeaders []string
+ // LogQueryParams instructs logger to extract given list of query parameters from request URI. Note: request can
+ // contain more than one query parameter with same name so slice of values is been logger for each given query param name.
+ LogQueryParams []string
+ // LogFormValues instructs logger to extract given list of form values from request body+URI. Note: request can
+ // contain more than one form value with same name so slice of values is been logger for each given form value name.
+ LogFormValues []string
+
+ timeNow func() time.Time
+}
+
+// RequestLoggerValues contains extracted values from logger.
+type RequestLoggerValues struct {
+ // StartTime is time recorded before next middleware/handler is executed.
+ StartTime time.Time
+ // Latency is duration it took to execute rest of the handler chain (next(c) call).
+ Latency time.Duration
+ // Protocol is request protocol (i.e. `HTTP/1.1` or `HTTP/2`)
+ Protocol string
+ // RemoteIP is request remote IP. See `echo.Context.RealIP()` for implementation details.
+ RemoteIP string
+ // Host is request host value (i.e. `example.com`)
+ Host string
+ // Method is request method value (i.e. `GET` etc)
+ Method string
+ // URI is request URI (i.e. `/list?lang=en&page=1`)
+ URI string
+ // URIPath is request URI path part (i.e. `/list`)
+ URIPath string
+ // RoutePath is route path part to which request was matched to (i.e. `/user/:id`)
+ RoutePath string
+ // RequestID is request ID from request `X-Request-ID` header or response if request did not have value.
+ RequestID string
+ // Referer is request referer values.
+ Referer string
+ // UserAgent is request user agent values.
+ UserAgent string
+ // Status is response status code. Then handler returns an echo.HTTPError then code from there.
+ Status int
+ // Error is error returned from executed handler chain.
+ Error error
+ // ContentLength is content length header value. Note: this value could be different from actual request body size
+ // as it could be spoofed etc.
+ ContentLength string
+ // ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct.
+ ResponseSize int64
+ // Headers are list of headers from request. Note: request can contain more than one header with same value so slice
+ // of values is what will be returned/logged for each given header.
+ // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header
+ // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding".
+ Headers map[string][]string
+ // QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter
+ // with same name so slice of values is what will be returned/logged for each given query param name.
+ QueryParams map[string][]string
+ // FormValues are list of form values from request body+URI. Note: request can contain more than one form value with
+ // same name so slice of values is what will be returned/logged for each given form value name.
+ FormValues map[string][]string
+}
+
+// RequestLoggerWithConfig returns a RequestLogger middleware with config.
+func RequestLoggerWithConfig(config RequestLoggerConfig) echo.MiddlewareFunc {
+ mw, err := config.ToMiddleware()
+ if err != nil {
+ panic(err)
+ }
+ return mw
+}
+
+// ToMiddleware converts RequestLoggerConfig into middleware or returns an error for invalid configuration.
+func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Skipper == nil {
+ config.Skipper = DefaultSkipper
+ }
+ now := time.Now
+ if config.timeNow != nil {
+ now = config.timeNow
+ }
+
+ if config.LogValuesFunc == nil {
+ return nil, errors.New("missing LogValuesFunc callback function for request logger middleware")
+ }
+
+ logHeaders := len(config.LogHeaders) > 0
+ headers := append([]string(nil), config.LogHeaders...)
+ for i, v := range headers {
+ headers[i] = http.CanonicalHeaderKey(v)
+ }
+
+ logQueryParams := len(config.LogQueryParams) > 0
+ logFormValues := len(config.LogFormValues) > 0
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) error {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ req := c.Request()
+ res := c.Response()
+ start := now()
+
+ if config.BeforeNextFunc != nil {
+ config.BeforeNextFunc(c)
+ }
+ err := next(c)
+ if err != nil && config.HandleError {
+ // When global error handler writes the error to the client the Response gets "committed". This state can be
+ // checked with `c.Response().Committed` field.
+ c.Echo().HTTPErrorHandler(c, err)
+ }
+
+ v := RequestLoggerValues{
+ StartTime: start,
+ }
+ if config.LogLatency {
+ v.Latency = now().Sub(start)
+ }
+ if config.LogProtocol {
+ v.Protocol = req.Proto
+ }
+ if config.LogRemoteIP {
+ v.RemoteIP = c.RealIP()
+ }
+ if config.LogHost {
+ v.Host = req.Host
+ }
+ if config.LogMethod {
+ v.Method = req.Method
+ }
+ if config.LogURI {
+ v.URI = req.RequestURI
+ }
+ if config.LogURIPath {
+ p := req.URL.Path
+ if p == "" {
+ p = "/"
+ }
+ v.URIPath = p
+ }
+ if config.LogRoutePath {
+ v.RoutePath = c.Path()
+ }
+ if config.LogRequestID {
+ id := req.Header.Get(echo.HeaderXRequestID)
+ if id == "" {
+ id = res.Header().Get(echo.HeaderXRequestID)
+ }
+ v.RequestID = id
+ }
+ if config.LogReferer {
+ v.Referer = req.Referer()
+ }
+ if config.LogUserAgent {
+ v.UserAgent = req.UserAgent()
+ }
+
+ var resp *echo.Response
+ if config.LogStatus || config.LogResponseSize {
+ if r, err := echo.UnwrapResponse(res); err != nil {
+ c.Logger().Error("can not determine response status and/or size. ResponseWriter in context does not implement unwrapper interface")
+ } else {
+ resp = r
+ }
+ }
+
+ if config.LogStatus {
+ v.Status = -1
+ if resp != nil {
+ v.Status = resp.Status
+ }
+ if err != nil && !config.HandleError {
+ // this block should not be executed in case of HandleError=true as the global error handler will decide
+ // the status code. In that case status code could be different from what err contains.
+ var hsc echo.HTTPStatusCoder
+ if errors.As(err, &hsc) {
+ v.Status = hsc.StatusCode()
+ }
+ }
+ }
+ if err != nil {
+ v.Error = err
+ }
+ if config.LogContentLength {
+ v.ContentLength = req.Header.Get(echo.HeaderContentLength)
+ }
+ if config.LogResponseSize {
+ v.ResponseSize = -1
+ if resp != nil {
+ v.ResponseSize = resp.Size
+ }
+ }
+ if logHeaders {
+ v.Headers = map[string][]string{}
+ for _, header := range headers {
+ if values, ok := req.Header[header]; ok {
+ v.Headers[header] = values
+ }
+ }
+ }
+ if logQueryParams {
+ queryParams := c.QueryParams()
+ v.QueryParams = map[string][]string{}
+ for _, param := range config.LogQueryParams {
+ if values, ok := queryParams[param]; ok {
+ v.QueryParams[param] = values
+ }
+ }
+ }
+ if logFormValues {
+ v.FormValues = map[string][]string{}
+ for _, formValue := range config.LogFormValues {
+ if values, ok := req.Form[formValue]; ok {
+ v.FormValues[formValue] = values
+ }
+ }
+ }
+
+ if errOnLog := config.LogValuesFunc(c, v); errOnLog != nil {
+ return errOnLog
+ }
+ // in case of HandleError=true we are returning the error that we already have handled with global error handler
+ // this is deliberate as this error could be useful for upstream middlewares and default global error handler
+ // will ignore that error when it bubbles up in middleware chain.
+ // Committed response can be checked in custom error handler with following logic
+ //
+ // if r, _ := echo.UnwrapResponse(c.Response()); r != nil && r.Committed {
+ // return
+ // }
+ return err
+ }
+ }, nil
+}
+
+// RequestLogger creates Request Logger middleware with Echo default settings that uses Context.Logger() as logger.
+func RequestLogger() echo.MiddlewareFunc {
+ return RequestLoggerWithConfig(RequestLoggerConfig{
+ LogLatency: true,
+ LogRemoteIP: true,
+ LogHost: true,
+ LogMethod: true,
+ LogURI: true,
+ LogRequestID: true,
+ LogUserAgent: true,
+ LogStatus: true,
+ LogContentLength: true,
+ LogResponseSize: true,
+ // forwards error to the global error handler, so it can decide appropriate status code.
+ // NB: side-effect of that is - request is now "commited" written to the client. Middlewares up in chain can not
+ // change Response status code or response body.
+ HandleError: true,
+ LogValuesFunc: func(c *echo.Context, v RequestLoggerValues) error {
+ logger := c.Logger()
+ if v.Error == nil {
+ logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST",
+ slog.String("method", v.Method),
+ slog.String("uri", v.URI),
+ slog.Int("status", v.Status),
+ slog.Duration("latency", v.Latency),
+ slog.String("host", v.Host),
+ slog.String("bytes_in", v.ContentLength),
+ slog.Int64("bytes_out", v.ResponseSize),
+ slog.String("user_agent", v.UserAgent),
+ slog.String("remote_ip", v.RemoteIP),
+ slog.String("request_id", v.RequestID),
+ )
+ return nil
+ }
+
+ logger.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR",
+ slog.String("method", v.Method),
+ slog.String("uri", v.URI),
+ slog.Int("status", v.Status),
+ slog.Duration("latency", v.Latency),
+ slog.String("host", v.Host),
+ slog.String("bytes_in", v.ContentLength),
+ slog.Int64("bytes_out", v.ResponseSize),
+ slog.String("user_agent", v.UserAgent),
+ slog.String("remote_ip", v.RemoteIP),
+ slog.String("request_id", v.RequestID),
+
+ slog.String("error", v.Error.Error()),
+ )
+ return nil
+ },
+ })
+}
diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go
new file mode 100644
index 000000000..af39eb32a
--- /dev/null
+++ b/middleware/request_logger_test.go
@@ -0,0 +1,630 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "log/slog"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestRequestLoggerOK(t *testing.T) {
+ old := slog.Default()
+ t.Cleanup(func() {
+ slog.SetDefault(old)
+ })
+
+ e := echo.New()
+ buf := new(bytes.Buffer)
+ e.Logger = slog.New(slog.NewJSONHandler(buf, nil))
+ e.Use(RequestLogger())
+
+ e.POST("/test", func(c *echo.Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ reader := strings.NewReader(`{"foo":"bar"}`)
+ req := httptest.NewRequest(http.MethodPost, "/test", reader)
+ req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size())))
+ req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
+ req.Header.Set(echo.HeaderXRealIP, "8.8.8.8")
+ req.Header.Set("User-Agent", "curl/7.68.0")
+
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ logAttrs := map[string]interface{}{}
+ assert.NoError(t, json.Unmarshal(buf.Bytes(), &logAttrs))
+ logAttrs["latency"] = 123
+ logAttrs["time"] = "x"
+
+ expect := map[string]interface{}{
+ "level": "INFO",
+ "msg": "REQUEST",
+ "method": "POST",
+ "uri": "/test",
+ "status": float64(418),
+ "bytes_in": "13",
+ "host": "example.com",
+ "bytes_out": float64(2),
+ "user_agent": "curl/7.68.0",
+ "remote_ip": "8.8.8.8",
+ "request_id": "",
+
+ "time": "x",
+ "latency": 123,
+ }
+ assert.Equal(t, expect, logAttrs)
+}
+
+func TestRequestLoggerError(t *testing.T) {
+ old := slog.Default()
+ t.Cleanup(func() {
+ slog.SetDefault(old)
+ })
+
+ e := echo.New()
+ buf := new(bytes.Buffer)
+ e.Logger = slog.New(slog.NewJSONHandler(buf, nil))
+ e.Use(RequestLogger())
+
+ e.GET("/test", func(c *echo.Context) error {
+ return errors.New("nope")
+ })
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+
+ logAttrs := map[string]interface{}{}
+ assert.NoError(t, json.Unmarshal(buf.Bytes(), &logAttrs))
+ logAttrs["latency"] = 123
+ logAttrs["time"] = "x"
+
+ expect := map[string]interface{}{
+ "level": "ERROR",
+ "msg": "REQUEST_ERROR",
+ "method": "GET",
+ "uri": "/test",
+ "status": float64(500),
+ "bytes_in": "",
+ "host": "example.com",
+ "bytes_out": float64(36.0),
+ "user_agent": "",
+ "remote_ip": "192.0.2.1",
+ "request_id": "",
+ "error": "nope",
+
+ "latency": 123,
+ "time": "x",
+ }
+ assert.Equal(t, expect, logAttrs)
+}
+
+func TestRequestLoggerWithConfig(t *testing.T) {
+ e := echo.New()
+
+ var expect RequestLoggerValues
+ e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
+ LogRoutePath: true,
+ LogURI: true,
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ expect = values
+ return nil
+ },
+ }))
+
+ e.GET("/test", func(c *echo.Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusTeapot, rec.Code)
+ assert.Equal(t, "/test", expect.RoutePath)
+}
+
+func TestRequestLoggerWithConfig_missingOnLogValuesPanics(t *testing.T) {
+ assert.Panics(t, func() {
+ RequestLoggerWithConfig(RequestLoggerConfig{
+ LogValuesFunc: nil,
+ })
+ })
+}
+
+func TestRequestLogger_skipper(t *testing.T) {
+ e := echo.New()
+
+ loggerCalled := false
+ e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
+ Skipper: func(c *echo.Context) bool {
+ return true
+ },
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ loggerCalled = true
+ return nil
+ },
+ }))
+
+ e.GET("/test", func(c *echo.Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusTeapot, rec.Code)
+ assert.False(t, loggerCalled)
+}
+
+func TestRequestLogger_beforeNextFunc(t *testing.T) {
+ e := echo.New()
+
+ var myLoggerInstance int
+ e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
+ BeforeNextFunc: func(c *echo.Context) {
+ c.Set("myLoggerInstance", 42)
+ },
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ myLoggerInstance = c.Get("myLoggerInstance").(int)
+ return nil
+ },
+ }))
+
+ e.GET("/test", func(c *echo.Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusTeapot, rec.Code)
+ assert.Equal(t, 42, myLoggerInstance)
+}
+
+func TestRequestLogger_logError(t *testing.T) {
+ e := echo.New()
+
+ var actual RequestLoggerValues
+ e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
+ LogStatus: true,
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ actual = values
+ return nil
+ },
+ }))
+
+ e.GET("/test", func(c *echo.Context) error {
+ return echo.NewHTTPError(http.StatusNotAcceptable, "nope")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusNotAcceptable, rec.Code)
+ assert.Equal(t, http.StatusNotAcceptable, actual.Status)
+ assert.EqualError(t, actual.Error, "code=406, message=nope")
+}
+
+func TestRequestLogger_HandleError(t *testing.T) {
+ e := echo.New()
+
+ var actual RequestLoggerValues
+ e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
+ timeNow: func() time.Time {
+ return time.Unix(1631045377, 0).UTC()
+ },
+ HandleError: true,
+ LogStatus: true,
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ actual = values
+ return nil
+ },
+ }))
+
+ // to see if "HandleError" works we create custom error handler that uses its own status codes
+ e.HTTPErrorHandler = func(c *echo.Context, err error) {
+ if r, _ := echo.UnwrapResponse(c.Response()); r != nil && r.Committed {
+ return
+ }
+ c.JSON(http.StatusTeapot, "custom error handler")
+ }
+
+ e.GET("/test", func(c *echo.Context) error {
+ return echo.NewHTTPError(http.StatusForbidden, "nope")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusTeapot, rec.Code)
+
+ expect := RequestLoggerValues{
+ StartTime: time.Unix(1631045377, 0).UTC(),
+ Status: http.StatusTeapot,
+ Error: echo.NewHTTPError(http.StatusForbidden, "nope"),
+ }
+ assert.Equal(t, expect, actual)
+}
+
+func TestRequestLogger_LogValuesFuncError(t *testing.T) {
+ e := echo.New()
+
+ var expect RequestLoggerValues
+ e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
+ LogStatus: true,
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ expect = values
+ return echo.NewHTTPError(http.StatusNotAcceptable, "LogValuesFuncError")
+ },
+ }))
+
+ e.GET("/test", func(c *echo.Context) error {
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ // NOTE: when global error handler received error returned from middleware the status has already
+ // been written to the client and response has been "committed" therefore global error handler does not do anything
+ // and error that bubbled up in middleware chain will not be reflected in response code.
+ assert.Equal(t, http.StatusTeapot, rec.Code)
+ assert.Equal(t, http.StatusTeapot, expect.Status)
+}
+
+func TestRequestLogger_ID(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenFromRequest bool
+ expect string
+ }{
+ {
+ name: "ok, ID is provided from request headers",
+ whenFromRequest: true,
+ expect: "123",
+ },
+ {
+ name: "ok, ID is from response headers",
+ whenFromRequest: false,
+ expect: "321",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ var expect RequestLoggerValues
+ e.Use(RequestLoggerWithConfig(RequestLoggerConfig{
+ LogRequestID: true,
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ expect = values
+ return nil
+ },
+ }))
+
+ e.GET("/test", func(c *echo.Context) error {
+ c.Response().Header().Set(echo.HeaderXRequestID, "321")
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ if tc.whenFromRequest {
+ req.Header.Set(echo.HeaderXRequestID, "123")
+ }
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusTeapot, rec.Code)
+ assert.Equal(t, tc.expect, expect.RequestID)
+ })
+ }
+}
+
+func TestRequestLogger_headerIsCaseInsensitive(t *testing.T) {
+ e := echo.New()
+
+ var expect RequestLoggerValues
+ mw := RequestLoggerWithConfig(RequestLoggerConfig{
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ expect = values
+ return nil
+ },
+ LogHeaders: []string{"referer", "User-Agent"},
+ })(func(c *echo.Context) error {
+ c.Request().Header.Set(echo.HeaderXRequestID, "123")
+ c.FormValue("to force parse form")
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test?lang=en&checked=1&checked=2", nil)
+ req.Header.Set("referer", "https://echo.labstack.com/")
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ err := mw(c)
+
+ assert.NoError(t, err)
+ assert.Len(t, expect.Headers, 1)
+ assert.Equal(t, []string{"https://echo.labstack.com/"}, expect.Headers["Referer"])
+}
+
+func TestRequestLogger_allFields(t *testing.T) {
+ e := echo.New()
+
+ isFirstNowCall := true
+ var expect RequestLoggerValues
+ mw := RequestLoggerWithConfig(RequestLoggerConfig{
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ expect = values
+ return nil
+ },
+ LogLatency: true,
+ LogProtocol: true,
+ LogRemoteIP: true,
+ LogHost: true,
+ LogMethod: true,
+ LogURI: true,
+ LogURIPath: true,
+ LogRoutePath: true,
+ LogRequestID: true,
+ LogReferer: true,
+ LogUserAgent: true,
+ LogStatus: true,
+ LogContentLength: true,
+ LogResponseSize: true,
+ LogHeaders: []string{"accept-encoding", "User-Agent"},
+ LogQueryParams: []string{"lang", "checked"},
+ LogFormValues: []string{"csrf", "multiple"},
+ timeNow: func() time.Time {
+ if isFirstNowCall {
+ isFirstNowCall = false
+ return time.Unix(1631045377, 0)
+ }
+ return time.Unix(1631045377+10, 0)
+ },
+ })(func(c *echo.Context) error {
+ c.Request().Header.Set(echo.HeaderXRequestID, "123")
+ c.FormValue("to force parse form")
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ f := make(url.Values)
+ f.Set("csrf", "token")
+ f.Set("multiple", "1")
+ f.Add("multiple", "2")
+ reader := strings.NewReader(f.Encode())
+ req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", reader)
+ req.Header.Set("Referer", "https://echo.labstack.com/")
+ req.Header.Set("User-Agent", "curl/7.68.0")
+ req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size())))
+ req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
+ req.Header.Set(echo.HeaderXRealIP, "8.8.8.8")
+
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+
+ c.SetPath("/test*")
+
+ err := mw(c)
+
+ assert.NoError(t, err)
+ assert.Equal(t, time.Unix(1631045377, 0), expect.StartTime)
+ assert.Equal(t, 10*time.Second, expect.Latency)
+ assert.Equal(t, "HTTP/1.1", expect.Protocol)
+ assert.Equal(t, "8.8.8.8", expect.RemoteIP)
+ assert.Equal(t, "example.com", expect.Host)
+ assert.Equal(t, http.MethodPost, expect.Method)
+ assert.Equal(t, "/test?lang=en&checked=1&checked=2", expect.URI)
+ assert.Equal(t, "/test", expect.URIPath)
+ assert.Equal(t, "/test*", expect.RoutePath)
+ assert.Equal(t, "123", expect.RequestID)
+ assert.Equal(t, "https://echo.labstack.com/", expect.Referer)
+ assert.Equal(t, "curl/7.68.0", expect.UserAgent)
+ assert.Equal(t, 418, expect.Status)
+ assert.Equal(t, nil, expect.Error)
+ assert.Equal(t, "32", expect.ContentLength)
+ assert.Equal(t, int64(2), expect.ResponseSize)
+
+ assert.Len(t, expect.Headers, 1)
+ assert.Equal(t, []string{"curl/7.68.0"}, expect.Headers["User-Agent"])
+
+ assert.Len(t, expect.QueryParams, 2)
+ assert.Equal(t, []string{"en"}, expect.QueryParams["lang"])
+ assert.Equal(t, []string{"1", "2"}, expect.QueryParams["checked"])
+
+ assert.Len(t, expect.FormValues, 2)
+ assert.Equal(t, []string{"token"}, expect.FormValues["csrf"])
+ assert.Equal(t, []string{"1", "2"}, expect.FormValues["multiple"])
+}
+
+func TestTestRequestLogger(t *testing.T) {
+ var testCases = []struct {
+ name string
+ whenStatus int
+ whenError error
+ expectStatus string
+ expectError string
+ }{
+ {
+ name: "ok",
+ whenStatus: http.StatusTeapot,
+ expectStatus: "418",
+ },
+ {
+ name: "error",
+ whenError: echo.NewHTTPError(http.StatusBadGateway, "bad gw"),
+ expectStatus: "502",
+ expectError: `"error":"code=502, message=bad gw"`,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+ buf := new(bytes.Buffer)
+ e.Logger = slog.New(slog.NewJSONHandler(buf, nil))
+
+ e.Use(RequestLogger())
+ e.POST("/test", func(c *echo.Context) error {
+ if tc.whenError != nil {
+ return tc.whenError
+ }
+ return c.String(tc.whenStatus, "OK")
+ })
+
+ f := make(url.Values)
+ f.Set("csrf", "token")
+ f.Set("multiple", "1")
+ f.Add("multiple", "2")
+ reader := strings.NewReader(f.Encode())
+ req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", reader)
+ req.Header.Set("Referer", "https://echo.labstack.com/")
+ req.Header.Set("User-Agent", "curl/7.68.0")
+ req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size())))
+ req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
+ req.Header.Set(echo.HeaderXRealIP, "8.8.8.8")
+ req.Header.Set(echo.HeaderXRequestID, "MY_ID")
+
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ rawlog := buf.Bytes()
+ if tc.expectError != "" {
+ assert.Contains(t, string(rawlog), `"level":"ERROR"`)
+ assert.Contains(t, string(rawlog), `"msg":"REQUEST_ERROR"`)
+ assert.Contains(t, string(rawlog), tc.expectError)
+ } else {
+ assert.Contains(t, string(rawlog), `"level":"INFO"`)
+ assert.Contains(t, string(rawlog), `"msg":"REQUEST"`)
+ }
+ assert.Contains(t, string(rawlog), `"status":`+tc.expectStatus)
+ assert.Contains(t, string(rawlog), `"method":"POST"`)
+ assert.Contains(t, string(rawlog), `"uri":"/test?lang=en&checked=1&checked=2"`)
+ assert.Contains(t, string(rawlog), `"latency":`) // this value varies
+ assert.Contains(t, string(rawlog), `"request_id":"MY_ID"`)
+ assert.Contains(t, string(rawlog), `"remote_ip":"8.8.8.8"`)
+ assert.Contains(t, string(rawlog), `"host":"example.com"`)
+ assert.Contains(t, string(rawlog), `"user_agent":"curl/7.68.0"`)
+ assert.Contains(t, string(rawlog), `"bytes_in":"32"`)
+ assert.Contains(t, string(rawlog), `"bytes_out":2`)
+ })
+ }
+}
+
+func BenchmarkRequestLogger_withoutMapFields(b *testing.B) {
+ e := echo.New()
+
+ mw := RequestLoggerWithConfig(RequestLoggerConfig{
+ Skipper: nil,
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ return nil
+ },
+ LogLatency: true,
+ LogProtocol: true,
+ LogRemoteIP: true,
+ LogHost: true,
+ LogMethod: true,
+ LogURI: true,
+ LogURIPath: true,
+ LogRoutePath: true,
+ LogRequestID: true,
+ LogReferer: true,
+ LogUserAgent: true,
+ LogStatus: true,
+ LogContentLength: true,
+ LogResponseSize: true,
+ })(func(c *echo.Context) error {
+ c.Request().Header.Set(echo.HeaderXRequestID, "123")
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test?lang=en", nil)
+ req.Header.Set("Referer", "https://echo.labstack.com/")
+ req.Header.Set("User-Agent", "curl/7.68.0")
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ mw(c)
+ }
+}
+
+func BenchmarkRequestLogger_withMapFields(b *testing.B) {
+ e := echo.New()
+
+ mw := RequestLoggerWithConfig(RequestLoggerConfig{
+ LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error {
+ return nil
+ },
+ LogLatency: true,
+ LogProtocol: true,
+ LogRemoteIP: true,
+ LogHost: true,
+ LogMethod: true,
+ LogURI: true,
+ LogURIPath: true,
+ LogRoutePath: true,
+ LogRequestID: true,
+ LogReferer: true,
+ LogUserAgent: true,
+ LogStatus: true,
+ LogContentLength: true,
+ LogResponseSize: true,
+ LogHeaders: []string{"accept-encoding", "User-Agent"},
+ LogQueryParams: []string{"lang", "checked"},
+ LogFormValues: []string{"csrf", "multiple"},
+ })(func(c *echo.Context) error {
+ c.Request().Header.Set(echo.HeaderXRequestID, "123")
+ c.FormValue("to force parse form")
+ return c.String(http.StatusTeapot, "OK")
+ })
+
+ f := make(url.Values)
+ f.Set("csrf", "token")
+ f.Add("multiple", "1")
+ f.Add("multiple", "2")
+ req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", strings.NewReader(f.Encode()))
+ req.Header.Set("Referer", "https://echo.labstack.com/")
+ req.Header.Set("User-Agent", "curl/7.68.0")
+ req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ rec := httptest.NewRecorder()
+ c := e.NewContext(req, rec)
+ mw(c)
+ }
+}
diff --git a/middleware/rewrite.go b/middleware/rewrite.go
new file mode 100644
index 000000000..ea58091b0
--- /dev/null
+++ b/middleware/rewrite.go
@@ -0,0 +1,80 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "errors"
+ "regexp"
+
+ "github.com/labstack/echo/v5"
+)
+
+// RewriteConfig defines the config for Rewrite middleware.
+type RewriteConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // Rules defines the URL path rewrite rules. The values captured in asterisk can be
+ // retrieved by index e.g. $1, $2 and so on.
+ // Example:
+ // "/old": "/new",
+ // "/api/*": "/$1",
+ // "/js/*": "/public/javascripts/$1",
+ // "/users/*/orders/*": "/user/$1/order/$2",
+ // Required.
+ Rules map[string]string
+
+ // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures
+ // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
+ // Example:
+ // "^/old/[0.9]+/": "/new",
+ // "^/api/.+?/(.*)": "/v2/$1",
+ RegexRules map[*regexp.Regexp]string
+}
+
+// Rewrite returns a Rewrite middleware.
+//
+// Rewrite middleware rewrites the URL path based on the provided rules.
+func Rewrite(rules map[string]string) echo.MiddlewareFunc {
+ c := RewriteConfig{}
+ c.Rules = rules
+ return RewriteWithConfig(c)
+}
+
+// RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration.
+//
+// Rewrite middleware rewrites the URL path based on the provided rules.
+func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
+ return toMiddlewareOrPanic(config)
+}
+
+// ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration
+func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
+ if config.Skipper == nil {
+ config.Skipper = DefaultSkipper
+ }
+ if config.Rules == nil && config.RegexRules == nil {
+ return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules")
+ }
+
+ if config.RegexRules == nil {
+ config.RegexRules = make(map[*regexp.Regexp]string)
+ }
+ for k, v := range rewriteRulesRegex(config.Rules) {
+ config.RegexRules[k] = v
+ }
+
+ return func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c *echo.Context) (err error) {
+ if config.Skipper(c) {
+ return next(c)
+ }
+
+ if err := rewriteURL(config.RegexRules, c.Request()); err != nil {
+ return err
+ }
+ return next(c)
+ }
+ }, nil
+}
diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go
new file mode 100644
index 000000000..f45b8d98a
--- /dev/null
+++ b/middleware/rewrite_test.go
@@ -0,0 +1,317 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "regexp"
+ "testing"
+
+ "github.com/labstack/echo/v5"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestRewriteAfterRouting(t *testing.T) {
+ e := echo.New()
+ // middlewares added with `Use()` are executed after routing is done and do not affect which route handler is matched
+ e.Use(RewriteWithConfig(RewriteConfig{
+ Rules: map[string]string{
+ "/old": "/new",
+ "/api/*": "/$1",
+ "/js/*": "/public/javascripts/$1",
+ "/users/*/orders/*": "/user/$1/order/$2",
+ },
+ }))
+ e.GET("/public/*", func(c *echo.Context) error {
+ return c.String(http.StatusOK, c.Param("*"))
+ })
+ e.GET("/*", func(c *echo.Context) error {
+ return c.String(http.StatusOK, c.Param("*"))
+ })
+
+ var testCases = []struct {
+ whenPath string
+ expectRoutePath string
+ expectRequestPath string
+ expectRequestRawPath string
+ }{
+ {
+ whenPath: "/api/users",
+ expectRoutePath: "api/users",
+ expectRequestPath: "/users",
+ expectRequestRawPath: "",
+ },
+ {
+ whenPath: "/js/main.js",
+ expectRoutePath: "js/main.js",
+ expectRequestPath: "/public/javascripts/main.js",
+ expectRequestRawPath: "",
+ },
+ {
+ whenPath: "/users/jack/orders/1",
+ expectRoutePath: "users/jack/orders/1",
+ expectRequestPath: "/user/jack/order/1",
+ expectRequestRawPath: "",
+ },
+ { // no rewrite rule matched. already encoded URL should not be double encoded or changed in any way
+ whenPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ expectRoutePath: "user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ expectRequestPath: "/user/jill/order/T/cO4lW/t/Vp/", // this is equal to `url.Parse(tc.whenPath)` result
+ expectRequestRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ },
+ { // just rewrite but do not touch encoding. already encoded URL should not be double encoded
+ whenPath: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F",
+ expectRoutePath: "users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F",
+ expectRequestPath: "/user/jill/order/T/cO4lW/t/Vp/", // this is equal to `url.Parse(tc.whenPath)` result
+ expectRequestRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F",
+ },
+ { // ` ` (space) is encoded by httpClient to `%20` when doing request to Echo. `%20` should not be double escaped or changed in any way when rewriting request
+ whenPath: "/api/new users",
+ expectRoutePath: "api/new users",
+ expectRequestPath: "/new users",
+ expectRequestRawPath: "",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.whenPath, func(t *testing.T) {
+ target, _ := url.Parse(tc.whenPath)
+ req := httptest.NewRequest(http.MethodGet, target.String(), nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, tc.expectRoutePath, rec.Body.String())
+ assert.Equal(t, tc.expectRequestPath, req.URL.Path)
+ assert.Equal(t, tc.expectRequestRawPath, req.URL.RawPath)
+ })
+ }
+}
+
+func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) {
+ assert.Panics(t, func() {
+ RewriteWithConfig(RewriteConfig{})
+ })
+}
+
+func TestMustRewriteWithConfig_skipper(t *testing.T) {
+ var testCases = []struct {
+ name string
+ givenSkipper func(c *echo.Context) bool
+ whenURL string
+ expectURL string
+ expectStatus int
+ }{
+ {
+ name: "not skipped",
+ whenURL: "/old",
+ expectURL: "/new",
+ expectStatus: http.StatusOK,
+ },
+ {
+ name: "skipped",
+ givenSkipper: func(c *echo.Context) bool {
+ return true
+ },
+ whenURL: "/old",
+ expectURL: "/old",
+ expectStatus: http.StatusNotFound,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ e := echo.New()
+
+ e.Pre(RewriteWithConfig(
+ RewriteConfig{
+ Skipper: tc.givenSkipper,
+ Rules: map[string]string{"/old": "/new"}},
+ ))
+
+ e.GET("/new", func(c *echo.Context) error {
+ return c.NoContent(http.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil)
+ rec := httptest.NewRecorder()
+
+ e.ServeHTTP(rec, req)
+
+ assert.Equal(t, tc.expectURL, req.URL.EscapedPath())
+ assert.Equal(t, tc.expectStatus, rec.Code)
+ })
+ }
+}
+
+// Issue #1086
+func TestEchoRewritePreMiddleware(t *testing.T) {
+ e := echo.New()
+
+ // Rewrite old url to new one
+ // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
+ e.Pre(RewriteWithConfig(RewriteConfig{
+ Rules: map[string]string{"/old": "/new"}}),
+ )
+
+ // Route
+ e.Add(http.MethodGet, "/new", func(c *echo.Context) error {
+ return c.NoContent(http.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/old", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, "/new", req.URL.EscapedPath())
+ assert.Equal(t, http.StatusOK, rec.Code)
+}
+
+// Issue #1143
+func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) {
+ e := echo.New()
+
+ // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches
+ e.Pre(RewriteWithConfig(RewriteConfig{
+ Rules: map[string]string{
+ "/api/*/mgmt/proj/*/agt": "/api/$1/hosts/$2",
+ "/api/*/mgmt/proj": "/api/$1/eng",
+ },
+ }))
+
+ e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c *echo.Context) error {
+ return c.String(http.StatusOK, "hosts")
+ })
+ e.Add(http.MethodGet, "/api/:version/eng", func(c *echo.Context) error {
+ return c.String(http.StatusOK, "eng")
+ })
+
+ for i := 0; i < 100; i++ {
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/mgmt/proj/test/agt", nil)
+ rec := httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, "/api/v1/hosts/test", req.URL.EscapedPath())
+ assert.Equal(t, http.StatusOK, rec.Code)
+
+ defer rec.Result().Body.Close()
+ bodyBytes, _ := io.ReadAll(rec.Result().Body)
+ assert.Equal(t, "hosts", string(bodyBytes))
+ }
+}
+
+// Issue #1573
+func TestEchoRewriteWithCaret(t *testing.T) {
+ e := echo.New()
+
+ e.Pre(RewriteWithConfig(RewriteConfig{
+ Rules: map[string]string{
+ "^/abc/*": "/v1/abc/$1",
+ },
+ }))
+
+ rec := httptest.NewRecorder()
+
+ var req *http.Request
+
+ req = httptest.NewRequest(http.MethodGet, "/abc/test", nil)
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, "/v1/abc/test", req.URL.Path)
+
+ req = httptest.NewRequest(http.MethodGet, "/v1/abc/test", nil)
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, "/v1/abc/test", req.URL.Path)
+
+ req = httptest.NewRequest(http.MethodGet, "/v2/abc/test", nil)
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, "/v2/abc/test", req.URL.Path)
+}
+
+// Verify regex used with rewrite
+func TestEchoRewriteWithRegexRules(t *testing.T) {
+ e := echo.New()
+
+ e.Pre(RewriteWithConfig(RewriteConfig{
+ Rules: map[string]string{
+ "^/a/*": "/v1/$1",
+ "^/b/*/c/*": "/v2/$2/$1",
+ "^/c/*/*": "/v3/$2",
+ },
+ RegexRules: map[*regexp.Regexp]string{
+ regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1",
+ regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1",
+ },
+ }))
+
+ var rec *httptest.ResponseRecorder
+ var req *http.Request
+
+ testCases := []struct {
+ requestPath string
+ expectPath string
+ }{
+ {"/unmatched", "/unmatched"},
+ {"/a/test", "/v1/test"},
+ {"/b/foo/c/bar/baz", "/v2/bar/baz/foo"},
+ {"/c/ignore/test", "/v3/test"},
+ {"/c/ignore1/test/this", "/v3/test/this"},
+ {"/x/ignore/test", "/v4/test"},
+ {"/y/foo/bar", "/v5/bar/foo"},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.requestPath, func(t *testing.T) {
+ req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil)
+ rec = httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, tc.expectPath, req.URL.EscapedPath())
+ })
+ }
+}
+
+// Ensure correct escaping as defined in replacement (issue #1798)
+func TestEchoRewriteReplacementEscaping(t *testing.T) {
+ e := echo.New()
+
+ // NOTE: these are incorrect regexps as they do not factor in that URI we are replacing could contain ? (query) and # (fragment) parts
+ // so in reality they append query and fragment part as `$1` matches everything after that prefix
+ e.Pre(RewriteWithConfig(RewriteConfig{
+ Rules: map[string]string{
+ "^/a/*": "/$1?query=param",
+ "^/b/*": "/$1;part#one",
+ },
+ RegexRules: map[*regexp.Regexp]string{
+ regexp.MustCompile("^/x/(.*)"): "/$1?query=param",
+ regexp.MustCompile("^/y/(.*)"): "/$1;part#one",
+ regexp.MustCompile("^/z/(.*)"): "/$1?test=1#escaped%20test",
+ },
+ }))
+
+ var rec *httptest.ResponseRecorder
+ var req *http.Request
+
+ testCases := []struct {
+ requestPath string
+ expect string
+ }{
+ {"/unmatched", "/unmatched"},
+ {"/a/test", "/test?query=param"},
+ {"/b/foo/bar", "/foo/bar;part#one"},
+ {"/x/test", "/test?query=param"},
+ {"/y/foo/bar", "/foo/bar;part#one"},
+ {"/z/foo/b%20ar", "/foo/b%20ar?test=1#escaped%20test"},
+ {"/z/foo/b%20ar?nope=1#yes", "/foo/b%20ar?nope=1#yes?test=1%23escaped%20test"}, // example of appending
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.requestPath, func(t *testing.T) {
+ req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil)
+ rec = httptest.NewRecorder()
+ e.ServeHTTP(rec, req)
+ assert.Equal(t, tc.expect, req.URL.String())
+ })
+ }
+}
diff --git a/middleware/secure.go b/middleware/secure.go
new file mode 100644
index 000000000..bd389f7ae
--- /dev/null
+++ b/middleware/secure.go
@@ -0,0 +1,148 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
+
+package middleware
+
+import (
+ "fmt"
+
+ "github.com/labstack/echo/v5"
+)
+
+// SecureConfig defines the config for Secure middleware.
+type SecureConfig struct {
+ // Skipper defines a function to skip middleware.
+ Skipper Skipper
+
+ // XSSProtection provides protection against cross-site scripting attack (XSS)
+ // by setting the `X-XSS-Protection` header.
+ // Optional. Default value "1; mode=block".
+ XSSProtection string
+
+ // ContentTypeNosniff provides protection against overriding Content-Type
+ // header by setting the `X-Content-Type-Options` header.
+ // Optional. Default value "nosniff".
+ ContentTypeNosniff string
+
+ // XFrameOptions can be used to indicate whether or not a browser should
+ // be allowed to render a page in a ,