diff --git a/.gitattributes b/.gitattributes index 49b63e526..28981b84a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -13,8 +13,3 @@ *.js text eol=lf *.json text eol=lf LICENSE text eol=lf - -# Exclude `website` and `cookbook` from GitHub's language statistics -# https://github.com/github/linguist#using-gitattributes -cookbook/* linguist-documentation -website/* linguist-documentation diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..af410716d --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,12 @@ +# These are supported funding model platforms + +github: [labstack] +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index ee6f33ef8..1a76adca7 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,23 +1,32 @@ ### Issue Description -### Checklist +### Working code to debug -- [ ] Dependencies installed -- [ ] No typos -- [ ] Searched existing issues and docs +```go +package main -### Expected behaviour +import ( + "github.com/labstack/echo/v5" + "net/http" + "net/http/httptest" + "testing" +) -### Actual behaviour +func TestExample(t *testing.T) { + e := echo.New() -### Steps to reproduce + e.GET("/", func(c *echo.Context) error { + return c.String(http.StatusOK, "Hello, World!") + }) -### Working code to debug + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() -```go -package main + e.ServeHTTP(rec, req) -func main() { + if rec.Code != http.StatusOK { + t.Errorf("got %d, want %d", rec.Code, http.StatusOK) + } } ``` diff --git a/.github/stale.yml b/.github/stale.yml index d9f656321..04dd169cd 100644 --- a/.github/stale.yml +++ b/.github/stale.yml @@ -1,17 +1,19 @@ # Number of days of inactivity before an issue becomes stale daysUntilStale: 60 # Number of days of inactivity before a stale issue is closed -daysUntilClose: 7 +daysUntilClose: 30 # Issues with these labels will never be considered stale exemptLabels: - pinned - security + - bug + - enhancement # Label to use when marking an issue as stale -staleLabel: wontfix +staleLabel: stale # Comment to post when marking an issue as stale. Set to `false` to disable markComment: > This issue has been automatically marked as stale because it has not had - recent activity. It will be closed if no further activity occurs. Thank you - for your contributions. + recent activity. It will be closed within a month if no further activity occurs. + Thank you for your contributions. # Comment to post when closing a stale issue. Set to `false` to disable -closeComment: false \ No newline at end of file +closeComment: false diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml new file mode 100644 index 000000000..8f4eff96e --- /dev/null +++ b/.github/workflows/checks.yml @@ -0,0 +1,47 @@ +name: Run checks + +on: + push: + branches: + - master + pull_request: + branches: + - master + workflow_dispatch: + +permissions: + contents: read # to fetch code (actions/checkout) + +env: + # run static analysis only with the latest Go version + LATEST_GO_VERSION: "1.26" + +jobs: + check: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@v5 + + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v5 + with: + go-version: ${{ env.LATEST_GO_VERSION }} + check-latest: true + + - name: Run golint + run: | + go install golang.org/x/lint/golint@latest + golint -set_exit_status ./... + + - name: Run staticcheck + run: | + go install honnef.co/go/tools/cmd/staticcheck@latest + staticcheck ./... + + - name: Run govulncheck + run: | + go version + go install golang.org/x/vuln/cmd/govulncheck@latest + govulncheck ./... + diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml new file mode 100644 index 000000000..b92c70c1b --- /dev/null +++ b/.github/workflows/echo.yml @@ -0,0 +1,86 @@ +name: Run Tests + +on: + push: + branches: + - master + pull_request: + branches: + - master + workflow_dispatch: + +permissions: + contents: read # to fetch code (actions/checkout) + +env: + # run coverage and benchmarks only with the latest Go version + LATEST_GO_VERSION: "1.26" + +jobs: + test: + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + # Each major Go release is supported until there are two newer major releases. https://golang.org/doc/devel/release.html#policy + # Echo tests with last four major releases (unless there are pressing vulnerabilities) + # As we depend on `golang.org/x/` libraries which only support the last 2 Go releases, we could have situations when + # we derive from the last four major releases promise. + go: ["1.25", "1.26"] + name: ${{ matrix.os }} @ Go ${{ matrix.go }} + runs-on: ${{ matrix.os }} + steps: + - name: Checkout Code + uses: actions/checkout@v5 + + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go }} + + - name: Run Tests + run: go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./... + + - name: Upload coverage to Codecov + if: success() && matrix.go == env.LATEST_GO_VERSION && matrix.os == 'ubuntu-latest' + uses: codecov/codecov-action@v5 + with: + token: + fail_ci_if_error: false + + benchmark: + needs: test + name: Benchmark comparison + runs-on: ubuntu-latest + steps: + - name: Checkout Code (Previous) + uses: actions/checkout@v5 + with: + ref: ${{ github.base_ref }} + path: previous + + - name: Checkout Code (New) + uses: actions/checkout@v5 + with: + path: new + + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v5 + with: + go-version: ${{ env.LATEST_GO_VERSION }} + + - name: Install Dependencies + run: go install golang.org/x/perf/cmd/benchstat@latest + + - name: Run Benchmark (Previous) + run: | + cd previous + go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt + + - name: Run Benchmark (New) + run: | + cd new + go test -run="-" -bench=".*" -count=8 ./... > benchmark.txt + + - name: Run Benchstat + run: | + benchstat previous/benchmark.txt new/benchmark.txt diff --git a/.gitignore b/.gitignore index dd74acca4..dbadf3bd0 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ vendor .idea *.iml *.out +.vscode diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index a1fc87684..000000000 --- a/.travis.yml +++ /dev/null @@ -1,17 +0,0 @@ -language: go -go: - - 1.12.x - - 1.13.x - - tip -env: - - GO111MODULE=on -install: - - go get -v golang.org/x/lint/golint -script: - - golint -set_exit_status ./... - - go test -race -coverprofile=coverage.txt -covermode=atomic ./... -after_success: - - bash <(curl -s https://codecov.io/bash) -matrix: - allow_failures: - - go: tip diff --git a/API_CHANGES_V5.md b/API_CHANGES_V5.md new file mode 100644 index 000000000..d3ca81560 --- /dev/null +++ b/API_CHANGES_V5.md @@ -0,0 +1,1178 @@ +# Echo v5 Public API Changes + +**Comparison between `master` (v4.15.0) and `v5` (v5.0.0-alpha) branches** + +Generated: 2026-01-01 + +--- + +## Executive Summary (by authors) + +Echo `v5` is maintenance release with **major breaking changes** +- `Context` is now struct instead of interface and we can add method to it in the future in minor versions. +- Adds new `Router` interface for possible new routing implementations. +- Drops old logging interface and uses moderm `log/slog` instead. +- Rearranges alot of methods/function signatures to make them more consistent. + +## Executive Summary (by LLMs) + +Echo v5 represents a **major breaking release** with significant architectural changes focused on: +- **Updated generic helpers** to take `*Context` and rename form helpers to `FormValue*` +- **Simplified API surface** by moving Context from interface to concrete struct +- **Modern Go patterns** including slog.Logger integration +- **Enhanced routing** with explicit RouteInfo and Routes types +- **Better error handling** with simplified HTTPError +- **New test helpers** via the `echotest` package + +### Change Statistics + +- **Major Breaking Changes**: 15+ +- **New Functions Added**: 30+ +- **Type Signature Changes**: 20+ +- **Removed APIs**: 10+ +- **New Packages Added**: 1 (`echotest`) +- **Version Change**: `4.15.0` → `5.0.0-alpha` + +--- + +## Critical Breaking Changes + +### 1. **Context: Interface → Concrete Struct** + +**v4 (master):** +```go +type Context interface { + Request() *http.Request + // ... many methods +} + +// Handler signature +func handler(c echo.Context) error +``` + +**v5:** +```go +type Context struct { + // Has unexported fields +} + +// Handler signature - NOW USES POINTER! +func handler(c *echo.Context) error +``` + +**Impact:** 🔴 **CRITICAL BREAKING CHANGE** +- ALL handlers must change from `echo.Context` to `*echo.Context` +- Context is now a concrete struct, not an interface +- This affects every single handler function in user code + +**Migration:** +```go +// Before (v4) +func MyHandler(c echo.Context) error { + return c.JSON(200, map[string]string{"hello": "world"}) +} + +// After (v5) +func MyHandler(c *echo.Context) error { + return c.JSON(200, map[string]string{"hello": "world"}) +} +``` + +--- + +### 2. **Logger: Custom Interface → slog.Logger** + +**v4:** +```go +type Echo struct { + Logger Logger // Custom interface with Print, Debug, Info, etc. +} + +type Logger interface { + Output() io.Writer + SetOutput(w io.Writer) + Prefix() string + // ... many custom methods +} + +// Context returns Logger interface +func (c Context) Logger() Logger +``` + +**v5:** +```go +type Echo struct { + Logger *slog.Logger // Standard library structured logger +} + +// Context returns slog.Logger +func (c *Context) Logger() *slog.Logger +func (c *Context) SetLogger(logger *slog.Logger) +``` + +**Impact:** 🔴 **BREAKING CHANGE** +- Must use Go's standard `log/slog` package +- Logger interface completely removed +- All logging code needs updating + +--- + +### 3. **Router: From Router to DefaultRouter** + +**v4:** +```go +type Router struct { ... } + +func NewRouter(e *Echo) *Router +func (e *Echo) Router() *Router +``` + +**v5:** +```go +type DefaultRouter struct { ... } + +func NewRouter(config RouterConfig) *DefaultRouter +func (e *Echo) Router() Router // Returns interface +``` + +**Changes:** +- New `Router` interface introduced +- `DefaultRouter` is the concrete implementation +- `NewRouter()` now takes `RouterConfig` instead of `*Echo` +- Added `NewConcurrentRouter(r Router) Router` for thread-safe routing + +--- + +### 4. **Route Return Types Changed** + +**v4:** +```go +func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route +func (e *Echo) Any(path string, h HandlerFunc, m ...MiddlewareFunc) []*Route +func (e *Echo) Routes() []*Route +``` + +**v5:** +```go +func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo +func (e *Echo) Any(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo +func (e *Echo) Match(...) Routes // Returns Routes type +func (e *Echo) Router() Router // Returns interface +``` + +**New Types:** +```go +type RouteInfo struct { + Name string + Method string + Path string + Parameters []string +} + +type Routes []RouteInfo // Collection with helper methods +``` + +**Impact:** 🔴 **BREAKING CHANGE** +- Route registration methods return `RouteInfo` instead of `*Route` +- New `Routes` collection type with filtering methods +- `Route` struct still exists but used differently + +--- + +### 5. **Response Type Changed** + +**v4:** +```go +func (c Context) Response() *Response +type Response struct { + Writer http.ResponseWriter + Status int + Size int64 + Committed bool +} +func NewResponse(w http.ResponseWriter, e *Echo) *Response +``` + +**v5:** +```go +func (c *Context) Response() http.ResponseWriter +type Response struct { + http.ResponseWriter // Embedded + Status int + Size int64 + Committed bool +} +func NewResponse(w http.ResponseWriter, logger *slog.Logger) *Response +func UnwrapResponse(rw http.ResponseWriter) (*Response, error) +``` + +**Changes:** +- Context.Response() returns `http.ResponseWriter` instead of `*Response` +- Response now embeds `http.ResponseWriter` +- NewResponse takes `*slog.Logger` instead of `*Echo` +- New `UnwrapResponse()` helper function + +--- + +### 6. **HTTPError Simplified** + +**v4:** +```go +type HTTPError struct { + Internal error + Message interface{} // Can be any type + Code int +} + +func NewHTTPError(code int, message ...interface{}) *HTTPError +``` + +**v5:** +```go +type HTTPError struct { + Code int + Message string // Now string only + // Has unexported fields (Internal moved) +} + +func NewHTTPError(code int, message string) *HTTPError +func (he HTTPError) Wrap(err error) error // New method +func (he *HTTPError) StatusCode() int // Implements HTTPStatusCoder +``` + +**Changes:** +- `Message` field changed from `interface{}` to `string` +- `NewHTTPError()` now takes `string` instead of `...interface{}` +- Added `HTTPStatusCoder` interface and `StatusCode()` method +- Added `Wrap(err error)` method for error wrapping + +--- + +### 7. **HTTPErrorHandler Signature Changed** + +**v4:** +```go +type HTTPErrorHandler func(err error, c Context) + +func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) +``` + +**v5:** +```go +type HTTPErrorHandler func(c *Context, err error) // Parameters swapped! + +func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler // Now a factory +``` + +**Impact:** 🔴 **BREAKING CHANGE** +- Parameter order reversed: `(c *Context, err error)` instead of `(err error, c Context)` +- DefaultHTTPErrorHandler is now a factory function that returns HTTPErrorHandler +- Takes `exposeError` bool to control error message exposure + +--- + +## Notable API Changes in v5 + +### 1. **Generic Parameter Extraction Functions (Updated Signatures)** + +These helpers keep the same generic API but now accept `*Context`, and the +form helpers are renamed from `FormParam*` to `FormValue*`: + +```go +// Query Parameters +func QueryParam[T any](c *Context, key string, opts ...any) (T, error) +func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) +func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error) +func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) + +// Path Parameters +func PathParam[T any](c *Context, paramName string, opts ...any) (T, error) +func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...any) (T, error) + +// Form Values +func FormValue[T any](c *Context, key string, opts ...any) (T, error) +func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) +func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) +func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) + +// Generic Parsing +func ParseValue[T any](value string, opts ...any) (T, error) +func ParseValueOr[T any](value string, defaultValue T, opts ...any) (T, error) +func ParseValues[T any](values []string, opts ...any) ([]T, error) +func ParseValuesOr[T any](values []string, defaultValue []T, opts ...any) ([]T, error) +``` + +`FormParam*` was renamed to `FormValue*`; the rest keep names but now take `*Context`. + +**Supported Types:** +- bool, string +- int, int8, int16, int32, int64 +- uint, uint8, uint16, uint32, uint64 +- float32, float64 +- time.Time, time.Duration +- BindUnmarshaler, encoding.TextUnmarshaler, json.Unmarshaler + +**Example Usage:** +```go +// v5 - Type-safe parameter binding +id, err := echo.PathParam[int](c, "id") +page, err := echo.QueryParamOr[int](c, "page", 1) +tags, err := echo.QueryParams[string](c, "tags") +``` + +--- + +### 2. **Context Store Helpers Now Use `*Context`** + +```go +// Type-safe context value retrieval +func ContextGet[T any](c *Context, key string) (T, error) +func ContextGetOr[T any](c *Context, key string, defaultValue T) (T, error) + +// Error types +var ErrNonExistentKey = errors.New("non existent key") +var ErrInvalidKeyType = errors.New("invalid key type") +``` + +These helpers existed in v4 with `Context` and now accept `*Context`. + +**Example:** +```go +// v5 +user, err := echo.ContextGet[*User](c, "user") +count, err := echo.ContextGetOr[int](c, "count", 0) +``` + +--- + +### 3. **PathValues Type** + +New structured path parameter handling: + +```go +type PathValue struct { + Name string + Value string +} + +type PathValues []PathValue + +func (p PathValues) Get(name string) (string, bool) +func (p PathValues) GetOr(name string, defaultValue string) string + +// Context methods +func (c *Context) PathValues() PathValues +func (c *Context) SetPathValues(pathValues PathValues) +``` + +--- + +### 4. **Time Parsing Options** + +```go +type TimeLayout string + +const ( + TimeLayoutUnixTime = TimeLayout("UnixTime") + TimeLayoutUnixTimeMilli = TimeLayout("UnixTimeMilli") + TimeLayoutUnixTimeNano = TimeLayout("UnixTimeNano") +) + +type TimeOpts struct { + Layout TimeLayout + ParseInLocation *time.Location + ToInLocation *time.Location +} +``` + +--- + +### 5. **StartConfig for Server Configuration** + +```go +type StartConfig struct { + Address string + HideBanner bool + HidePort bool + CertFilesystem fs.FS + TLSConfig *tls.Config + ListenerNetwork string + ListenerAddrFunc func(addr net.Addr) + GracefulTimeout time.Duration + OnShutdownError func(err error) + BeforeServeFunc func(s *http.Server) error +} + +func (sc StartConfig) Start(ctx context.Context, h http.Handler) error +func (sc StartConfig) StartTLS(ctx context.Context, h http.Handler, certFile, keyFile any) error +``` + +**Example:** +```go +// v5 - More control over server startup +ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) +defer cancel() + +sc := echo.StartConfig{ + Address: ":8080", + GracefulTimeout: 10 * time.Second, +} +if err := sc.Start(ctx, e); err != nil { + log.Fatal(err) +} +``` + +--- + +### 6. **Echo Config and Constructors** + +```go +type Config struct { + // Configuration for Echo (logger, binder, renderer, etc.) +} + +func NewWithConfig(config Config) *Echo +``` + +This adds a configuration struct for creating an `Echo` instance without +mutating fields after `New()`. + +--- + +### 7. **Enhanced Routing Features** + +```go +// New route methods +func (e *Echo) AddRoute(route Route) (RouteInfo, error) +func (e *Echo) Middlewares() []MiddlewareFunc +func (e *Echo) PreMiddlewares() []MiddlewareFunc +type AddRouteError struct{ ... } + +// Routes collection with filters +type Routes []RouteInfo + +func (r Routes) Clone() Routes +func (r Routes) FilterByMethod(method string) (Routes, error) +func (r Routes) FilterByName(name string) (Routes, error) +func (r Routes) FilterByPath(path string) (Routes, error) +func (r Routes) FindByMethodPath(method string, path string) (RouteInfo, error) +func (r Routes) Reverse(routeName string, pathValues ...any) (string, error) + +// RouteInfo operations +func (r RouteInfo) Clone() RouteInfo +func (r RouteInfo) Reverse(pathValues ...any) string +``` + +--- + +### 8. **Middleware Configuration Interface** + +```go +type MiddlewareConfigurator interface { + ToMiddleware() (MiddlewareFunc, error) +} +``` + +Allows middleware configs to be converted to middleware without panicking. + +--- + +### 9. **New Context Methods** + +```go +// v5 additions +func (c *Context) FileFS(file string, filesystem fs.FS) error +func (c *Context) FormValueOr(name, defaultValue string) string +func (c *Context) InitializeRoute(ri *RouteInfo, pathValues *PathValues) +func (c *Context) ParamOr(name, defaultValue string) string +func (c *Context) QueryParamOr(name, defaultValue string) string +func (c *Context) RouteInfo() RouteInfo +``` + +--- + +### 10. **Virtual Host Support** + +```go +func NewVirtualHostHandler(vhosts map[string]*Echo) *Echo +``` + +Creates an Echo instance that routes requests to different Echo instances based on host. + +--- + +### 11. **New Binder Functions** + +```go +func BindBody(c *Context, target any) error +func BindHeaders(c *Context, target any) error +func BindPathValues(c *Context, target any) error // Renamed from BindPathParams +func BindQueryParams(c *Context, target any) error +``` + +Top-level binding functions that work with `*Context`. + +--- + +### 12. **New echotest Package** + +```go +package echotest // import "github.com/labstack/echo/v5/echotest" + +func LoadBytes(t *testing.T, name string, opts ...loadBytesOpts) []byte +func TrimNewlineEnd(bytes []byte) []byte +type ContextConfig struct{ ... } +type MultipartForm struct{ ... } +type MultipartFormFile struct{ ... } +``` + +Helpers for loading fixtures and constructing test contexts. + +--- + +## Removed APIs in v5 + +### Constants + +```go +// v4 - Removed in v5 +const CONNECT = http.MethodConnect // Use http.MethodConnect directly +``` + +**Reason:** Deprecated in v4, use stdlib `http.Method*` constants instead. + +--- + +### Constants Added in v5 + +```go +// v5 additions +const ( + NotFoundRouteName = "echo_route_not_found_name" +) +``` + +--- + +### Error Variable Changes + +**v4 exports:** +```go +ErrBadRequest +ErrInvalidKeyType +ErrNonExistentKey +``` + +**v5 exports:** +```go +ErrBadRequest // Now backed by unexported httpError type +ErrValidatorNotRegistered // New +ErrInvalidKeyType +ErrNonExistentKey +``` + +**Reason:** v5 centralizes on `NewHTTPError(code, message)` rather than a broad set +of predefined HTTP error variables. + +--- + +### Functions Removed + +```go +// v4 - Removed in v5 +func GetPath(r *http.Request) string // Use r.URL.Path or r.URL.RawPath +``` + +### Variables Removed + +```go +// v4 - Removed in v5 +var MethodNotAllowedHandler = func(c Context) error { ... } +var NotFoundHandler = func(c Context) error { ... } +``` + +### Functions Renamed + +```go +// v4 +func FormParam[T any](c Context, key string, opts ...any) (T, error) +func FormParamOr[T any](c Context, key string, defaultValue T, opts ...any) (T, error) +func FormParams[T any](c Context, key string, opts ...any) ([]T, error) +func FormParamsOr[T any](c Context, key string, defaultValue []T, opts ...any) ([]T, error) + +// v5 +func FormValue[T any](c *Context, key string, opts ...any) (T, error) +func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) +func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) +func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) +``` + +--- + +### Type Methods Removed/Changed + +**Echo struct changes:** +```go +// v4 fields removed in v5 +type Echo struct { + StdLogger *stdLog.Logger // Removed + Server *http.Server // Removed (use StartConfig) + TLSServer *http.Server // Removed (use StartConfig) + Listener net.Listener // Removed (use StartConfig) + TLSListener net.Listener // Removed (use StartConfig) + AutoTLSManager autocert.Manager // Removed + ListenerNetwork string // Removed + OnAddRouteHandler func(...) // Changed to OnAddRoute + DisableHTTP2 bool // Removed (use StartConfig) + Debug bool // Removed + HideBanner bool // Removed (use StartConfig) + HidePort bool // Removed (use StartConfig) +} + +// v5 Echo struct (simplified) +type Echo struct { + Binder Binder + Filesystem fs.FS // NEW + Renderer Renderer + Validator Validator + JSONSerializer JSONSerializer + IPExtractor IPExtractor + OnAddRoute func(route Route) error // Simplified + HTTPErrorHandler HTTPErrorHandler + Logger *slog.Logger // Changed from Logger interface +} +``` + +--- + +**Context interface → struct:** +```go +// v4 +type Context interface { + // Had: SetResponse(*Response) + Response() *Response + + // Had: ParamNames(), SetParamNames(), ParamValues(), SetParamValues() + // These are removed in v5 (use PathValues() instead) +} + +// v5 +type Context struct { + // Concrete struct with unexported fields +} + +func (c *Context) Response() http.ResponseWriter // Changed return type +func (c *Context) PathValues() PathValues // Replaces ParamNames/Values +``` + +--- + +**Types removed:** +```go +// v4 +type Map map[string]interface{} +``` + +**Group changes:** +```go +// v4 +func (g *Group) File(path, file string) // No return value +func (g *Group) Static(pathPrefix, fsRoot string) // No return value +func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) // No return value + +// v5 +func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo +func (g *Group) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo +func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo +``` + +Now return `RouteInfo` and accept middleware. + +--- + +### Value Binder Factory Name Changes + +```go +// v4 +func PathParamsBinder(c Context) *ValueBinder +func QueryParamsBinder(c Context) *ValueBinder +func FormFieldBinder(c Context) *ValueBinder + +// v5 +func PathValuesBinder(c *Context) *ValueBinder // Renamed +func QueryParamsBinder(c *Context) *ValueBinder +func FormFieldBinder(c *Context) *ValueBinder +``` + +--- + +## Type Signature Changes + +### Binder Interface + +```go +// v4 +type Binder interface { + Bind(i interface{}, c Context) error +} + +// v5 +type Binder interface { + Bind(c *Context, target any) error // Parameters swapped! +} +``` + +--- + +### DefaultBinder Methods + +```go +// v4 +func (b *DefaultBinder) Bind(i interface{}, c Context) error +func (b *DefaultBinder) BindBody(c Context, i interface{}) error +func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error + +// v5 +func (b *DefaultBinder) Bind(c *Context, target any) error // Swapped params +// BindBody, BindPathParams, etc. are now top-level functions +``` + +--- + +### JSONSerializer Interface + +```go +// v4 +type JSONSerializer interface { + Serialize(c Context, i interface{}, indent string) error + Deserialize(c Context, i interface{}) error +} + +// v5 +type JSONSerializer interface { + Serialize(c *Context, target any, indent string) error + Deserialize(c *Context, target any) error +} +``` + +--- + +### Renderer Interface + +```go +// v4 +type Renderer interface { + Render(io.Writer, string, interface{}, Context) error +} + +// v5 +type Renderer interface { + Render(c *Context, w io.Writer, templateName string, data any) error +} +``` + +Parameters reordered with Context first. + +--- + +### NewBindingError + +```go +// v4 +func NewBindingError(sourceParam string, values []string, message interface{}, internalError error) error + +// v5 +func NewBindingError(sourceParam string, values []string, message string, err error) error +``` + +Message parameter changed from `interface{}` to `string`. + +--- + +### HandlerName + +```go +// v5 only +func HandlerName(h HandlerFunc) string +``` + +New utility function to get handler function name. + +--- + +## Middleware Package Changes + +### Signature and Type Updates + +```go +// CORS now accepts optional allow-origins +func CORS(allowOrigins ...string) echo.MiddlewareFunc + +// BodyLimit now accepts bytes +func BodyLimit(limitBytes int64) echo.MiddlewareFunc + +// DefaultSkipper now uses *echo.Context +func DefaultSkipper(c *echo.Context) bool + +// Trailing slash configs renamed/split +func AddTrailingSlashWithConfig(config AddTrailingSlashConfig) echo.MiddlewareFunc +func RemoveTrailingSlashWithConfig(config RemoveTrailingSlashConfig) echo.MiddlewareFunc +type AddTrailingSlashConfig struct{ ... } +type RemoveTrailingSlashConfig struct{ ... } + +// Auth + extractor signatures now use *echo.Context and add ExtractorSource +type BasicAuthValidator func(c *echo.Context, user string, password string) (bool, error) +type Extractor func(c *echo.Context) (string, error) +type ExtractorSource string +type KeyAuthValidator func(c *echo.Context, key string, source ExtractorSource) (bool, error) +type KeyAuthErrorHandler func(c *echo.Context, err error) error + +// BodyDump handler now includes err +type BodyDumpHandler func(c *echo.Context, reqBody []byte, resBody []byte, err error) + +// ValuesExtractor now returns extractor source and CreateExtractors takes a limit +type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error) +func CreateExtractors(lookups string, limit uint) ([]ValuesExtractor, error) +type ValueExtractorError struct{ ... } + +// New constants +const KB = 1024 + +// Rate limiter store now takes a float64 limit +func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore) +``` + +### Added Middleware Exports + +```go +var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key") +var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key") +var RedirectHTTPSConfig = RedirectConfig{ ... } +var RedirectHTTPSWWWConfig = RedirectConfig{ ... } +var RedirectNonHTTPSWWWConfig = RedirectConfig{ ... } +var RedirectNonWWWConfig = RedirectConfig{ ... } +var RedirectWWWConfig = RedirectConfig{ ... } +``` + +### Removed/Consolidated Middleware Exports + +```go +// Removed in v5 +func Logger() echo.MiddlewareFunc +func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc +func Timeout() echo.MiddlewareFunc +func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc +type ErrKeyAuthMissing struct{ ... } +type CSRFErrorHandler func(err error, c echo.Context) error +type LoggerConfig struct{ ... } +type LogErrorFunc func(c echo.Context, err error, stack []byte) error +type TargetProvider interface{ ... } +type TrailingSlashConfig struct{ ... } +type TimeoutConfig struct{ ... } +``` + +Also removed defaults: `DefaultBasicAuthConfig`, `DefaultBodyDumpConfig`, `DefaultBodyLimitConfig`, +`DefaultCORSConfig`, `DefaultDecompressConfig`, `DefaultGzipConfig`, `DefaultLoggerConfig`, +`DefaultRedirectConfig`, `DefaultRequestIDConfig`, `DefaultRewriteConfig`, `DefaultTimeoutConfig`, +`DefaultTrailingSlashConfig`. + +--- + +## Router Interface Changes + +### v4 Router (Concrete Struct) + +```go +type Router struct { ... } + +func NewRouter(e *Echo) *Router +func (r *Router) Add(method, path string, h HandlerFunc) +func (r *Router) Find(method, path string, c Context) +func (r *Router) Reverse(name string, params ...interface{}) string +func (r *Router) Routes() []*Route +``` + +### v5 Router (Interface + DefaultRouter) + +```go +type Router interface { + Add(routable Route) (RouteInfo, error) + Remove(method string, path string) error + Routes() Routes + Route(c *Context) HandlerFunc +} + +type DefaultRouter struct { ... } + +func NewRouter(config RouterConfig) *DefaultRouter +func NewConcurrentRouter(r Router) Router // NEW + +type RouterConfig struct { + NotFoundHandler HandlerFunc + MethodNotAllowedHandler HandlerFunc + OptionsMethodHandler HandlerFunc + AllowOverwritingRoute bool + UnescapePathParamValues bool + UseEscapedPathForMatching bool +} +``` + +**Key Changes:** +- Router is now an interface +- DefaultRouter is the concrete implementation +- Add() returns `(RouteInfo, error)` instead of being void +- New `Remove()` method +- New `Route()` method replaces `Find()` +- Configuration through `RouterConfig` + +--- + +## Echo Instance Method Changes + +### Route Registration + +```go +// v4 +func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route + +// v5 +func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo +func (e *Echo) AddRoute(route Route) (RouteInfo, error) // NEW +``` + +### Static File Serving + +```go +// v4 +func (e *Echo) Static(pathPrefix, fsRoot string) *Route +func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route +func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route +func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route + +// v5 +func (e *Echo) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo +func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo +func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo +func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo +``` + +Return type changed from `*Route` to `RouteInfo`. + +### Server Management + +```go +// v4 +func (e *Echo) Start(address string) error +func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) error +func (e *Echo) StartAutoTLS(address string) error +func (e *Echo) StartH2CServer(address string, h2s *http2.Server) error +func (e *Echo) StartServer(s *http.Server) error +func (e *Echo) Shutdown(ctx context.Context) error +func (e *Echo) Close() error +func (e *Echo) ListenerAddr() net.Addr +func (e *Echo) TLSListenerAddr() net.Addr +func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) + +// v5 +func (e *Echo) Start(address string) error // Simplified +func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) + +// Removed: StartTLS, StartAutoTLS, StartH2CServer, StartServer +// Use StartConfig instead for advanced server configuration +// Removed: Shutdown, Close, ListenerAddr, TLSListenerAddr +// Removed: DefaultHTTPErrorHandler (now a top-level factory function) +``` + +**v5 provides** `StartConfig` type for all advanced server configuration. + +### Router Access + +```go +// v4 +func (e *Echo) Router() *Router +func (e *Echo) Routers() map[string]*Router // For multi-host +func (e *Echo) Routes() []*Route +func (e *Echo) Reverse(name string, params ...interface{}) string +func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string +func (e *Echo) URL(h HandlerFunc, params ...interface{}) string +func (e *Echo) Host(name string, m ...MiddlewareFunc) *Group + +// v5 +func (e *Echo) Router() Router // Returns interface +// Removed: Routers(), Reverse(), URI(), URL(), Host() +// Use router.Routes() and Routes.Reverse() instead +``` + +--- + +## NewContext Changes + +```go +// v4 +func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context +func NewResponse(w http.ResponseWriter, e *Echo) *Response + +// v5 +func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) *Context +func NewContext(r *http.Request, w http.ResponseWriter, opts ...any) *Context // Standalone +func NewResponse(w http.ResponseWriter, logger *slog.Logger) *Response +``` + +--- + +## Migration Guide Summary + +If you are using Linux you can migrate easier parts like that: +```bash +find . -type f -name "*.go" -exec sed -i 's/ echo.Context/ *echo.Context/g' {} + +find . -type f -name "*.go" -exec sed -i 's/echo\/v4/echo\/v5/g' {} + +``` +or in your favorite IDE + +Replace all: +1. ` echo.Context` -> ` *echo.Context` +2. `echo/v4` -> `echo/v5` + + +### 1. Update All Handler Signatures + +```go +// Before +func MyHandler(c echo.Context) error { ... } + +// After +func MyHandler(c *echo.Context) error { ... } +``` + +### 2. Update Logger Usage + +```go +// Before +e.Logger.Info("Server started") +c.Logger().Error("Something went wrong") + +// After +e.Logger.Info("Server started") +c.Logger().Error("Something went wrong") // Same API, different logger +``` + +### 3. Use Type-Safe Parameter Extraction + +```go +// Before +idStr := c.Param("id") +id, err := strconv.Atoi(idStr) + +// After +id, err := echo.PathParam[int](c, "id") +``` + +### 4. Update Error Handler + +```go +// Before +e.HTTPErrorHandler = func(err error, c echo.Context) { + // handle error +} + +// After +e.HTTPErrorHandler = func(c *echo.Context, err error) { // Swapped! + // handle error +} + +// Or use factory +e.HTTPErrorHandler = echo.DefaultHTTPErrorHandler(true) // exposeError=true +``` + +### 5. Update Server Startup + +```go +// Before +e.Start(":8080") +e.StartTLS(":443", "cert.pem", "key.pem") + +// After +// Simple +e.Start(":8080") + +// Advanced with graceful shutdown +ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) +defer cancel() +sc := echo.StartConfig{Address: ":8080"} +sc.Start(ctx, e) +``` + +### 6. Update Route Info Access + +```go +// Before +routes := e.Routes() +for _, r := range routes { + fmt.Println(r.Method, r.Path) +} + +// After +routes := e.Router().Routes() +for _, r := range routes { + fmt.Println(r.Method, r.Path) +} +``` + +### 7. Update HTTPError Creation + +```go +// Before +return echo.NewHTTPError(400, "invalid request", someDetail) + +// After +return echo.NewHTTPError(400, "invalid request") +``` + +### 8. Update Custom Binder + +```go +// Before +type MyBinder struct{} +func (b *MyBinder) Bind(i interface{}, c echo.Context) error { ... } + +// After +type MyBinder struct{} +func (b *MyBinder) Bind(c *echo.Context, target any) error { ... } // Swapped! +``` + +### 9. Path Parameters + +```go +// Before +names := c.ParamNames() +values := c.ParamValues() + +// After +pathValues := c.PathValues() +for _, pv := range pathValues { + fmt.Println(pv.Name, pv.Value) +} +``` + +### 10. Response Access + +```go +// Before +resp := c.Response() +resp.Header().Set("X-Custom", "value") + +// After +c.Response().Header().Set("X-Custom", "value") // Returns http.ResponseWriter + +// To get *echo.Response +resp, err := echo.UnwrapResponse(c.Response()) +``` + +### Go Version Requirements + +- **v4**: Go 1.24.0 (per `go.mod`) +- **v5**: Go 1.25.0 (per `go.mod`) + +--- + +**Generated by comparing `go doc` output from master (v4.15.0) and v5 (v5.0.0-alpha) branches** diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..37d1adb66 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,838 @@ +# Changelog + +## v5.0.3 - 2026-02-06 + +**Security** + +* Fix directory traversal vulnerability under Windows in Static middleware when default Echo filesystem is used. Reported by @shblue21. + +This applies to cases when: +- Windows is used as OS +- `middleware.StaticConfig.Filesystem` is `nil` (default) +- `echo.Filesystem` is has not been set explicitly (default) + +Exposure is restricted to the active process working directory and its subfolders. + + +## v5.0.2 - 2026-02-02 + +**Security** + +* Fix Static middleware with `config.Browse=true` lists all files/subfolders from `config.Filesystem` root and not starting from `config.Root` in https://github.com/labstack/echo/pull/2887 + + +## v5.0.1 - 2026-01-28 + +* Panic MW: will now return a custom PanicStackError with stack trace by @aldas in https://github.com/labstack/echo/pull/2871 +* Docs: add missing err parameter to DenyHandler example by @cgalibern in https://github.com/labstack/echo/pull/2878 +* improve: improve websocket checks in IsWebSocket() [per RFC 6455] by @raju-mechatronics in https://github.com/labstack/echo/pull/2875 +* fix: Context.Json() should not send status code before serialization is complete by @aldas in https://github.com/labstack/echo/pull/2877 + + +## v5.0.0 - 2026-01-18 + +Echo `v5` is maintenance release with **major breaking changes** +- `Context` is now struct instead of interface and we can add method to it in the future in minor versions. +- Adds new `Router` interface for possible new routing implementations. +- Drops old logging interface and uses moderm `log/slog` instead. +- Rearranges alot of methods/function signatures to make them more consistent. + +Upgrade notes and `v4` support: +- Echo `v4` is supported with **security*** updates and **bug** fixes until **2026-12-31** +- If you are using Echo in a production environment, it is recommended to wait until after 2026-03-31 before upgrading. +- Until 2026-03-31, any critical issues requiring breaking `v5` API changes will be addressed, even if this violates semantic versioning. + +See [API_CHANGES_V5.md](./API_CHANGES_V5.md) for public API changes between `v4` and `v5`, notes on **upgrading**. + +Upgrading TLDR: + +If you are using Linux you can migrate easier parts like that: +```bash +find . -type f -name "*.go" -exec sed -i 's/ echo.Context/ *echo.Context/g' {} + +find . -type f -name "*.go" -exec sed -i 's/echo\/v4/echo\/v5/g' {} + +``` +macOS +```bash +find . -type f -name "*.go" -exec sed -i '' 's/ echo.Context/ *echo.Context/g' {} + +find . -type f -name "*.go" -exec sed -i '' 's/echo\/v4/echo\/v5/g' {} + +``` + +or in your favorite IDE + +Replace all: +1. ` echo.Context` -> ` *echo.Context` +2. `echo/v4` -> `echo/v5` + +This should solve most of the issues. Probably the hardest part is updating all the tests. + + +## v4.15.0 - 2026-01-01 + + +**Security** + +NB: **If your application relies on cross-origin or same-site (same subdomain) requests do not blindly push this version to production** + + +The CSRF middleware now supports the [**Sec-Fetch-Site**](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Sec-Fetch-Site) header as a modern, defense-in-depth approach to [CSRF +protection](https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers), implementing the OWASP-recommended Fetch Metadata API alongside the traditional token-based mechanism. + +**How it works:** + +Modern browsers automatically send the `Sec-Fetch-Site` header with all requests, indicating the relationship +between the request origin and the target. The middleware uses this to make security decisions: + +- **`same-origin`** or **`none`**: Requests are allowed (exact origin match or direct user navigation) +- **`same-site`**: Falls back to token validation (e.g., subdomain to main domain) +- **`cross-site`**: Blocked by default with 403 error for unsafe methods (POST, PUT, DELETE, PATCH) + +For browsers that don't send this header (older browsers), the middleware seamlessly falls back to +traditional token-based CSRF protection. + +**New Configuration Options:** +- `TrustedOrigins []string`: Allowlist specific origins for cross-site requests (useful for OAuth callbacks, webhooks) +- `AllowSecFetchSiteFunc func(echo.Context) (bool, error)`: Custom logic for same-site/cross-site request validation + +**Example:** + ```go + e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{ + // Allow OAuth callbacks from trusted provider + TrustedOrigins: []string{"https://oauth-provider.com"}, + + // Custom validation for same-site requests + AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { + // Your custom authorization logic here + return validateCustomAuth(c), nil + // return true, err // blocks request with error + // return true, nil // allows CSRF request through + // return false, nil // falls back to legacy token logic + }, + })) + ``` +PR: https://github.com/labstack/echo/pull/2858 + +**Type-Safe Generic Parameter Binding** + +* Added generic functions for type-safe parameter extraction and context access by @aldas in https://github.com/labstack/echo/pull/2856 + + Echo now provides generic functions for extracting path, query, and form parameters with automatic type conversion, + eliminating manual string parsing and type assertions. + + **New Functions:** + - Path parameters: `PathParam[T]`, `PathParamOr[T]` + - Query parameters: `QueryParam[T]`, `QueryParamOr[T]`, `QueryParams[T]`, `QueryParamsOr[T]` + - Form values: `FormParam[T]`, `FormParamOr[T]`, `FormParams[T]`, `FormParamsOr[T]` + - Context store: `ContextGet[T]`, `ContextGetOr[T]` + + **Supported Types:** + Primitives (`bool`, `string`, `int`/`uint` variants, `float32`/`float64`), `time.Duration`, `time.Time` + (with custom layouts and Unix timestamp support), and custom types implementing `BindUnmarshaler`, + `TextUnmarshaler`, or `JSONUnmarshaler`. + + **Example:** + ```go + // Before: Manual parsing + idStr := c.Param("id") + id, err := strconv.Atoi(idStr) + + // After: Type-safe with automatic parsing + id, err := echo.PathParam[int](c, "id") + + // With default values + page, err := echo.QueryParamOr[int](c, "page", 1) + limit, err := echo.QueryParamOr[int](c, "limit", 20) + + // Type-safe context access (no more panics from type assertions) + user, err := echo.ContextGet[*User](c, "user") + ``` + +PR: https://github.com/labstack/echo/pull/2856 + + + +**DEPRECATION NOTICE** Timeout Middleware Deprecated - Use ContextTimeout Instead + +The `middleware.Timeout` middleware has been **deprecated** due to fundamental architectural issues that cause +data races. Use `middleware.ContextTimeout` or `middleware.ContextTimeoutWithConfig` instead. + +**Why is this being deprecated?** + +The Timeout middleware manipulates response writers across goroutine boundaries, which causes data races that +cannot be reliably fixed without a complete architectural redesign. The middleware: + +- Swaps the response writer using `http.TimeoutHandler` +- Must be the first middleware in the chain (fragile constraint) +- Can cause races with other middleware (Logger, metrics, custom middleware) +- Has been the source of multiple race condition fixes over the years + +**What should you use instead?** + +The `ContextTimeout` middleware (available since v4.12.0) provides timeout functionality using Go's standard +context mechanism. It is: + +- Race-free by design +- Can be placed anywhere in the middleware chain +- Simpler and more maintainable +- Compatible with all other middleware + +**Migration Guide:** + +```go +// Before (deprecated): +e.Use(middleware.Timeout()) + +// After (recommended): +e.Use(middleware.ContextTimeout(30 * time.Second)) +``` + +**Important Behavioral Differences:** + +1. **Handler cooperation required**: With ContextTimeout, your handlers must check `context.Done()` for cooperative + cancellation. The old Timeout middleware would send a 503 response regardless of handler cooperation, but had + data race issues. + +2. **Error handling**: ContextTimeout returns errors through the standard error handling flow. Handlers that receive + `context.DeadlineExceeded` should handle it appropriately: + +```go +e.GET("/long-task", func(c echo.Context) error { + ctx := c.Request().Context() + + // Example: database query with context + result, err := db.QueryContext(ctx, "SELECT * FROM large_table") + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + // Handle timeout + return echo.NewHTTPError(http.StatusServiceUnavailable, "Request timeout") + } + return err + } + + return c.JSON(http.StatusOK, result) +}) +``` + +3. **Background tasks**: For long-running background tasks, use goroutines with context: + +```go +e.GET("/async-task", func(c echo.Context) error { + ctx := c.Request().Context() + + resultCh := make(chan Result, 1) + errCh := make(chan error, 1) + + go func() { + result, err := performLongTask(ctx) + if err != nil { + errCh <- err + return + } + resultCh <- result + }() + + select { + case result := <-resultCh: + return c.JSON(http.StatusOK, result) + case err := <-errCh: + return err + case <-ctx.Done(): + return echo.NewHTTPError(http.StatusServiceUnavailable, "Request timeout") + } +}) +``` + +**Enhancements** + +* Fixes by @aldas in https://github.com/labstack/echo/pull/2852 +* Generic functions by @aldas in https://github.com/labstack/echo/pull/2856 +* CRSF with Sec-Fetch-Site checks by @aldas in https://github.com/labstack/echo/pull/2858 + + +## v4.14.0 - 2025-12-11 + +`middleware.Logger` has been deprecated. For request logging, use `middleware.RequestLogger` or +`middleware.RequestLoggerWithConfig`. + +`middleware.RequestLogger` replaces `middleware.Logger`, offering comparable configuration while relying on the +Go standard library’s new `slog` logger. + +The previous default output format was JSON. The new default follows the standard `slog` logger settings. +To continue emitting request logs in JSON, configure `slog` accordingly: +```go +slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, nil))) +e.Use(middleware.RequestLogger()) +``` + + +**Security** + +* Logger middleware json string escaping and deprecation by @aldas in https://github.com/labstack/echo/pull/2849 + + + +**Enhancements** + +* Update deps by @aldas in https://github.com/labstack/echo/pull/2807 +* refactor to use reflect.TypeFor by @cuiweixie in https://github.com/labstack/echo/pull/2812 +* Use Go 1.25 in CI by @aldas in https://github.com/labstack/echo/pull/2810 +* Modernize context.go by replacing interface{} with any by @vishr in https://github.com/labstack/echo/pull/2822 +* Fix typo in SetParamValues comment by @vishr in https://github.com/labstack/echo/pull/2828 +* Fix typo in ContextTimeout middleware comment by @vishr in https://github.com/labstack/echo/pull/2827 +* Improve BasicAuth middleware: use strings.Cut and RFC compliance by @vishr in https://github.com/labstack/echo/pull/2825 +* Fix duplicate plus operator in router backtracking logic by @yuya-morimoto in https://github.com/labstack/echo/pull/2832 +* Replace custom private IP range check with built-in net.IP.IsPrivate by @kumapower17 in https://github.com/labstack/echo/pull/2835 +* Ensure proxy connection is closed in proxyRaw function(#2837) by @kumapower17 in https://github.com/labstack/echo/pull/2838 +* Update deps by @aldas in https://github.com/labstack/echo/pull/2843 +* Update golang.org/x/* deps by @aldas in https://github.com/labstack/echo/pull/2850 + + + +## v4.13.4 - 2025-05-22 + +**Enhancements** + +* chore: fix some typos in comment by @zhuhaicity in https://github.com/labstack/echo/pull/2735 +* CI: test with Go 1.24 by @aldas in https://github.com/labstack/echo/pull/2748 +* Add support for TLS WebSocket proxy by @t-ibayashi-safie in https://github.com/labstack/echo/pull/2762 + +**Security** + +* Update dependencies for [GO-2025-3487](https://pkg.go.dev/vuln/GO-2025-3487), [GO-2025-3503](https://pkg.go.dev/vuln/GO-2025-3503) and [GO-2025-3595](https://pkg.go.dev/vuln/GO-2025-3595) in https://github.com/labstack/echo/pull/2780 + + +## v4.13.3 - 2024-12-19 + +**Security** + +* Update golang.org/x/net dependency [GO-2024-3333](https://pkg.go.dev/vuln/GO-2024-3333) in https://github.com/labstack/echo/pull/2722 + + +## v4.13.2 - 2024-12-12 + +**Security** + +* Update dependencies (dependabot reports [GO-2024-3321](https://pkg.go.dev/vuln/GO-2024-3321)) in https://github.com/labstack/echo/pull/2721 + + +## v4.13.1 - 2024-12-11 + +**Fixes** + +* Fix BindBody ignoring `Transfer-Encoding: chunked` requests by @178inaba in https://github.com/labstack/echo/pull/2717 + + + +## v4.13.0 - 2024-12-04 + +**BREAKING CHANGE** JWT Middleware Removed from Core use [labstack/echo-jwt](https://github.com/labstack/echo-jwt) instead + +The JWT middleware has been **removed from Echo core** due to another security vulnerability, [CVE-2024-51744](https://nvd.nist.gov/vuln/detail/CVE-2024-51744). For more details, refer to issue [#2699](https://github.com/labstack/echo/issues/2699). A drop-in replacement is available in the [labstack/echo-jwt](https://github.com/labstack/echo-jwt) repository. + +**Important**: Direct assignments like `token := c.Get("user").(*jwt.Token)` will now cause a panic due to an invalid cast. Update your code accordingly. Replace the current imports from `"github.com/golang-jwt/jwt"` in your handlers to the new middleware version using `"github.com/golang-jwt/jwt/v5"`. + + +Background: + +The version of `golang-jwt/jwt` (v3.2.2) previously used in Echo core has been in an unmaintained state for some time. This is not the first vulnerability affecting this library; earlier issues were addressed in [PR #1946](https://github.com/labstack/echo/pull/1946). +JWT middleware was marked as deprecated in Echo core as of [v4.10.0](https://github.com/labstack/echo/releases/tag/v4.10.0) on 2022-12-27. If you did not notice that, consider leveraging tools like [Staticcheck](https://staticcheck.dev/) to catch such deprecations earlier in you dev/CI flow. For bonus points - check out [gosec](https://github.com/securego/gosec). + +We sincerely apologize for any inconvenience caused by this change. While we strive to maintain backward compatibility within Echo core, recurring security issues with third-party dependencies have forced this decision. + +**Enhancements** + +* remove jwt middleware by @stevenwhitehead in https://github.com/labstack/echo/pull/2701 +* optimization: struct alignment by @behnambm in https://github.com/labstack/echo/pull/2636 +* bind: Maintain backwards compatibility for map[string]interface{} binding by @thesaltree in https://github.com/labstack/echo/pull/2656 +* Add Go 1.23 to CI by @aldas in https://github.com/labstack/echo/pull/2675 +* improve `MultipartForm` test by @martinyonatann in https://github.com/labstack/echo/pull/2682 +* `bind` : add support of multipart multi files by @martinyonatann in https://github.com/labstack/echo/pull/2684 +* Add TemplateRenderer struct to ease creating renderers for `html/template` and `text/template` packages. by @aldas in https://github.com/labstack/echo/pull/2690 +* Refactor TestBasicAuth to utilize table-driven test format by @ErikOlson in https://github.com/labstack/echo/pull/2688 +* Remove broken header by @aldas in https://github.com/labstack/echo/pull/2705 +* fix(bind body): content-length can be -1 by @phamvinhdat in https://github.com/labstack/echo/pull/2710 +* CORS middleware should compile allowOrigin regexp at creation by @aldas in https://github.com/labstack/echo/pull/2709 +* Shorten Github issue template and add test example by @aldas in https://github.com/labstack/echo/pull/2711 + + +## v4.12.0 - 2024-04-15 + +**Security** + +* Update golang.org/x/net dep because of [GO-2024-2687](https://pkg.go.dev/vuln/GO-2024-2687) by @aldas in https://github.com/labstack/echo/pull/2625 + + +**Enhancements** + +* binder: make binding to Map work better with string destinations by @aldas in https://github.com/labstack/echo/pull/2554 +* README.md: add Encore as sponsor by @marcuskohlberg in https://github.com/labstack/echo/pull/2579 +* Reorder paragraphs in README.md by @aldas in https://github.com/labstack/echo/pull/2581 +* CI: upgrade actions/checkout to v4 by @aldas in https://github.com/labstack/echo/pull/2584 +* Remove default charset from 'application/json' Content-Type header by @doortts in https://github.com/labstack/echo/pull/2568 +* CI: Use Go 1.22 by @aldas in https://github.com/labstack/echo/pull/2588 +* binder: allow binding to a nil map by @georgmu in https://github.com/labstack/echo/pull/2574 +* Add Skipper Unit Test In BasicBasicAuthConfig and Add More Detail Explanation regarding BasicAuthValidator by @RyoKusnadi in https://github.com/labstack/echo/pull/2461 +* fix some typos by @teslaedison in https://github.com/labstack/echo/pull/2603 +* fix: some typos by @pomadev in https://github.com/labstack/echo/pull/2596 +* Allow ResponseWriters to unwrap writers when flushing/hijacking by @aldas in https://github.com/labstack/echo/pull/2595 +* Add SPDX licence comments to files. by @aldas in https://github.com/labstack/echo/pull/2604 +* Upgrade deps by @aldas in https://github.com/labstack/echo/pull/2605 +* Change type definition blocks to single declarations. This helps copy… by @aldas in https://github.com/labstack/echo/pull/2606 +* Fix Real IP logic by @cl-bvl in https://github.com/labstack/echo/pull/2550 +* Default binder can use `UnmarshalParams(params []string) error` inter… by @aldas in https://github.com/labstack/echo/pull/2607 +* Default binder can bind pointer to slice as struct field. For example `*[]string` by @aldas in https://github.com/labstack/echo/pull/2608 +* Remove maxparam dependence from Context by @aldas in https://github.com/labstack/echo/pull/2611 +* When route is registered with empty path it is normalized to `/`. by @aldas in https://github.com/labstack/echo/pull/2616 +* proxy middleware should use httputil.ReverseProxy for SSE requests by @aldas in https://github.com/labstack/echo/pull/2624 + + +## v4.11.4 - 2023-12-20 + +**Security** + +* Upgrade golang.org/x/crypto to v0.17.0 to fix vulnerability [issue](https://pkg.go.dev/vuln/GO-2023-2402) [#2562](https://github.com/labstack/echo/pull/2562) + +**Enhancements** + +* Update deps and mark Go version to 1.18 as this is what golang.org/x/* use [#2563](https://github.com/labstack/echo/pull/2563) +* Request logger: add example for Slog https://pkg.go.dev/log/slog [#2543](https://github.com/labstack/echo/pull/2543) + + +## v4.11.3 - 2023-11-07 + +**Security** + +* 'c.Attachment' and 'c.Inline' should escape filename in 'Content-Disposition' header to avoid 'Reflect File Download' vulnerability. [#2541](https://github.com/labstack/echo/pull/2541) + +**Enhancements** + +* Tests: refactor context tests to be separate functions [#2540](https://github.com/labstack/echo/pull/2540) +* Proxy middleware: reuse echo request context [#2537](https://github.com/labstack/echo/pull/2537) +* Mark unmarshallable yaml struct tags as ignored [#2536](https://github.com/labstack/echo/pull/2536) + + +## v4.11.2 - 2023-10-11 + +**Security** + +* Bump golang.org/x/net to prevent CVE-2023-39325 / CVE-2023-44487 HTTP/2 Rapid Reset Attack [#2527](https://github.com/labstack/echo/pull/2527) +* fix(sec): randomString bias introduced by #2490 [#2492](https://github.com/labstack/echo/pull/2492) +* CSRF/RequestID mw: switch math/random usage to crypto/random [#2490](https://github.com/labstack/echo/pull/2490) + +**Enhancements** + +* Delete unused context in body_limit.go [#2483](https://github.com/labstack/echo/pull/2483) +* Use Go 1.21 in CI [#2505](https://github.com/labstack/echo/pull/2505) +* Fix some typos [#2511](https://github.com/labstack/echo/pull/2511) +* Allow CORS middleware to send Access-Control-Max-Age: 0 [#2518](https://github.com/labstack/echo/pull/2518) +* Bump dependancies [#2522](https://github.com/labstack/echo/pull/2522) + +## v4.11.1 - 2023-07-16 + +**Fixes** + +* Fix `Gzip` middleware not sending response code for no content responses (404, 301/302 redirects etc) [#2481](https://github.com/labstack/echo/pull/2481) + + +## v4.11.0 - 2023-07-14 + + +**Fixes** + +* Fixes the proxy middleware concurrency issue of calling the Next() proxy target on Round Robin Balancer [#2409](https://github.com/labstack/echo/pull/2409) +* Fix `group.RouteNotFound` not working when group has attached middlewares [#2411](https://github.com/labstack/echo/pull/2411) +* Fix global error handler return error message when message is an error [#2456](https://github.com/labstack/echo/pull/2456) +* Do not use global timeNow variables [#2477](https://github.com/labstack/echo/pull/2477) + + +**Enhancements** + +* Added a optional config variable to disable centralized error handler in recovery middleware [#2410](https://github.com/labstack/echo/pull/2410) +* refactor: use `strings.ReplaceAll` directly [#2424](https://github.com/labstack/echo/pull/2424) +* Add support for Go1.20 `http.rwUnwrapper` to Response struct [#2425](https://github.com/labstack/echo/pull/2425) +* Check whether is nil before invoking centralized error handling [#2429](https://github.com/labstack/echo/pull/2429) +* Proper colon support in `echo.Reverse` method [#2416](https://github.com/labstack/echo/pull/2416) +* Fix misuses of a vs an in documentation comments [#2436](https://github.com/labstack/echo/pull/2436) +* Add link to slog.Handler library for Echo logging into README.md [#2444](https://github.com/labstack/echo/pull/2444) +* In proxy middleware Support retries of failed proxy requests [#2414](https://github.com/labstack/echo/pull/2414) +* gofmt fixes to comments [#2452](https://github.com/labstack/echo/pull/2452) +* gzip response only if it exceeds a minimal length [#2267](https://github.com/labstack/echo/pull/2267) +* Upgrade packages [#2475](https://github.com/labstack/echo/pull/2475) + + +## v4.10.2 - 2023-02-22 + +**Security** + +* `filepath.Clean` behaviour has changed in Go 1.20 - adapt to it [#2406](https://github.com/labstack/echo/pull/2406) +* Add `middleware.CORSConfig.UnsafeWildcardOriginWithAllowCredentials` to make UNSAFE usages of wildcard origin + allow cretentials less likely [#2405](https://github.com/labstack/echo/pull/2405) + +**Enhancements** + +* Add more HTTP error values [#2277](https://github.com/labstack/echo/pull/2277) + + +## v4.10.1 - 2023-02-19 + +**Security** + +* Upgrade deps due to the latest golang.org/x/net vulnerability [#2402](https://github.com/labstack/echo/pull/2402) + + +**Enhancements** + +* Add new JWT repository to the README [#2377](https://github.com/labstack/echo/pull/2377) +* Return an empty string for ctx.path if there is no registered path [#2385](https://github.com/labstack/echo/pull/2385) +* Add context timeout middleware [#2380](https://github.com/labstack/echo/pull/2380) +* Update link to jaegertracing [#2394](https://github.com/labstack/echo/pull/2394) + + +## v4.10.0 - 2022-12-27 + +**Security** + +* We are deprecating JWT middleware in this repository. Please use https://github.com/labstack/echo-jwt instead. + + JWT middleware is moved to separate repository to allow us to bump/upgrade version of JWT implementation (`github.com/golang-jwt/jwt`) we are using +which we can not do in Echo core because this would break backwards compatibility guarantees we try to maintain. + +* This minor version bumps minimum Go version to 1.17 (from 1.16) due `golang.org/x/` packages we depend on. There are + several vulnerabilities fixed in these libraries. + + Echo still tries to support last 4 Go versions but there are occasions we can not guarantee this promise. + + +**Enhancements** + +* Bump x/text to 0.3.8 [#2305](https://github.com/labstack/echo/pull/2305) +* Bump dependencies and add notes about Go releases we support [#2336](https://github.com/labstack/echo/pull/2336) +* Add helper interface for ProxyBalancer interface [#2316](https://github.com/labstack/echo/pull/2316) +* Expose `middleware.CreateExtractors` function so we can use it from echo-contrib repository [#2338](https://github.com/labstack/echo/pull/2338) +* Refactor func(Context) error to HandlerFunc [#2315](https://github.com/labstack/echo/pull/2315) +* Improve function comments [#2329](https://github.com/labstack/echo/pull/2329) +* Add new method HTTPError.WithInternal [#2340](https://github.com/labstack/echo/pull/2340) +* Replace io/ioutil package usages [#2342](https://github.com/labstack/echo/pull/2342) +* Add staticcheck to CI flow [#2343](https://github.com/labstack/echo/pull/2343) +* Replace relative path determination from proprietary to std [#2345](https://github.com/labstack/echo/pull/2345) +* Remove square brackets from ipv6 addresses in XFF (X-Forwarded-For header) [#2182](https://github.com/labstack/echo/pull/2182) +* Add testcases for some BodyLimit middleware configuration options [#2350](https://github.com/labstack/echo/pull/2350) +* Additional configuration options for RequestLogger and Logger middleware [#2341](https://github.com/labstack/echo/pull/2341) +* Add route to request log [#2162](https://github.com/labstack/echo/pull/2162) +* GitHub Workflows security hardening [#2358](https://github.com/labstack/echo/pull/2358) +* Add govulncheck to CI and bump dependencies [#2362](https://github.com/labstack/echo/pull/2362) +* Fix rate limiter docs [#2366](https://github.com/labstack/echo/pull/2366) +* Refactor how `e.Routes()` work and introduce `e.OnAddRouteHandler` callback [#2337](https://github.com/labstack/echo/pull/2337) + + +## v4.9.1 - 2022-10-12 + +**Fixes** + +* Fix logger panicing (when template is set to empty) by bumping dependency version [#2295](https://github.com/labstack/echo/issues/2295) + +**Enhancements** + +* Improve CORS documentation [#2272](https://github.com/labstack/echo/pull/2272) +* Update readme about supported Go versions [#2291](https://github.com/labstack/echo/pull/2291) +* Tests: improve error handling on closing body [#2254](https://github.com/labstack/echo/pull/2254) +* Tests: refactor some of the assertions in tests [#2275](https://github.com/labstack/echo/pull/2275) +* Tests: refactor assertions [#2301](https://github.com/labstack/echo/pull/2301) + +## v4.9.0 - 2022-09-04 + +**Security** + +* Fix open redirect vulnerability in handlers serving static directories (e.Static, e.StaticFs, echo.StaticDirectoryHandler) [#2260](https://github.com/labstack/echo/pull/2260) + +**Enhancements** + +* Allow configuring ErrorHandler in CSRF middleware [#2257](https://github.com/labstack/echo/pull/2257) +* Replace HTTP method constants in tests with stdlib constants [#2247](https://github.com/labstack/echo/pull/2247) + + +## v4.8.0 - 2022-08-10 + +**Most notable things** + +You can now add any arbitrary HTTP method type as a route [#2237](https://github.com/labstack/echo/pull/2237) +```go +e.Add("COPY", "/*", func(c echo.Context) error + return c.String(http.StatusOK, "OK COPY") +}) +``` + +You can add custom 404 handler for specific paths [#2217](https://github.com/labstack/echo/pull/2217) +```go +e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) }) + +g := e.Group("/images") +g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) }) +``` + +**Enhancements** + +* Add new value binding methods (UnixTimeMilli,TextUnmarshaler,JSONUnmarshaler) to Valuebinder [#2127](https://github.com/labstack/echo/pull/2127) +* Refactor: body_limit middleware unit test [#2145](https://github.com/labstack/echo/pull/2145) +* Refactor: Timeout mw: rework how test waits for timeout. [#2187](https://github.com/labstack/echo/pull/2187) +* BasicAuth middleware returns 500 InternalServerError on invalid base64 strings but should return 400 [#2191](https://github.com/labstack/echo/pull/2191) +* Refactor: duplicated findStaticChild process at findChildWithLabel [#2176](https://github.com/labstack/echo/pull/2176) +* Allow different param names in different methods with same path scheme [#2209](https://github.com/labstack/echo/pull/2209) +* Add support for registering handlers for different 404 routes [#2217](https://github.com/labstack/echo/pull/2217) +* Middlewares should use errors.As() instead of type assertion on HTTPError [#2227](https://github.com/labstack/echo/pull/2227) +* Allow arbitrary HTTP method types to be added as routes [#2237](https://github.com/labstack/echo/pull/2237) + +## v4.7.2 - 2022-03-16 + +**Fixes** + +* Fix nil pointer exception when calling Start again after address binding error [#2131](https://github.com/labstack/echo/pull/2131) +* Fix CSRF middleware not being able to extract token from multipart/form-data form [#2136](https://github.com/labstack/echo/pull/2136) +* Fix Timeout middleware write race [#2126](https://github.com/labstack/echo/pull/2126) + +**Enhancements** + +* Recover middleware should not log panic for aborted handler [#2134](https://github.com/labstack/echo/pull/2134) + + +## v4.7.1 - 2022-03-13 + +**Fixes** + +* Fix `e.Static`, `.File()`, `c.Attachment()` being picky with paths starting with `./`, `../` and `/` after 4.7.0 introduced echo.Filesystem support (Go1.16+) [#2123](https://github.com/labstack/echo/pull/2123) + +**Enhancements** + +* Remove some unused code [#2116](https://github.com/labstack/echo/pull/2116) + + +## v4.7.0 - 2022-03-01 + +**Enhancements** + +* Add JWT, KeyAuth, CSRF multivalue extractors [#2060](https://github.com/labstack/echo/pull/2060) +* Add LogErrorFunc to recover middleware [#2072](https://github.com/labstack/echo/pull/2072) +* Add support for HEAD method query params binding [#2027](https://github.com/labstack/echo/pull/2027) +* Improve filesystem support with echo.FileFS, echo.StaticFS, group.FileFS, group.StaticFS [#2064](https://github.com/labstack/echo/pull/2064) + +**Fixes** + +* Fix X-Real-IP bug, improve tests [#2007](https://github.com/labstack/echo/pull/2007) +* Minor syntax fixes [#1994](https://github.com/labstack/echo/pull/1994), [#2102](https://github.com/labstack/echo/pull/2102), [#2102](https://github.com/labstack/echo/pull/2102) + +**General** + +* Add cache-control and connection headers [#2103](https://github.com/labstack/echo/pull/2103) +* Add Retry-After header constant [#2078](https://github.com/labstack/echo/pull/2078) +* Upgrade `go` directive in `go.mod` to 1.17 [#2049](https://github.com/labstack/echo/pull/2049) +* Add Pagoda [#2077](https://github.com/labstack/echo/pull/2077) and Souin [#2069](https://github.com/labstack/echo/pull/2069) to 3rd-party middlewares in README + +## v4.6.3 - 2022-01-10 + +**Fixes** + +* Fixed Echo version number in greeting message which was not incremented to `4.6.2` [#2066](https://github.com/labstack/echo/issues/2066) + + +## v4.6.2 - 2022-01-08 + +**Fixes** + +* Fixed route containing escaped colon should be matchable but is not matched to request path [#2047](https://github.com/labstack/echo/pull/2047) +* Fixed a problem that returned wrong content-encoding when the gzip compressed content was empty. [#1921](https://github.com/labstack/echo/pull/1921) +* Update (test) dependencies [#2021](https://github.com/labstack/echo/pull/2021) + + +**Enhancements** + +* Add support for configurable target header for the request_id middleware [#2040](https://github.com/labstack/echo/pull/2040) +* Change decompress middleware to use stream decompression instead of buffering [#2018](https://github.com/labstack/echo/pull/2018) +* Documentation updates + + +## v4.6.1 - 2021-09-26 + +**Enhancements** + +* Add start time to request logger middleware values [#1991](https://github.com/labstack/echo/pull/1991) + +## v4.6.0 - 2021-09-20 + +Introduced a new [request logger](https://github.com/labstack/echo/blob/master/middleware/request_logger.go) middleware +to help with cases when you want to use some other logging library in your application. + +**Fixes** + +* fix timeout middleware warning: superfluous response.WriteHeader [#1905](https://github.com/labstack/echo/issues/1905) + +**Enhancements** + +* Add Cookie to KeyAuth middleware's KeyLookup [#1929](https://github.com/labstack/echo/pull/1929) +* JWT middleware should ignore case of auth scheme in request header [#1951](https://github.com/labstack/echo/pull/1951) +* Refactor default error handler to return first if response is already committed [#1956](https://github.com/labstack/echo/pull/1956) +* Added request logger middleware which helps to use custom logger library for logging requests. [#1980](https://github.com/labstack/echo/pull/1980) +* Allow escaping of colon in route path so Google Cloud API "custom methods" could be implemented [#1988](https://github.com/labstack/echo/pull/1988) + +## v4.5.0 - 2021-08-01 + +**Important notes** + +A **BREAKING CHANGE** is introduced for JWT middleware users. +The JWT library used for the JWT middleware had to be changed from [github.com/dgrijalva/jwt-go](https://github.com/dgrijalva/jwt-go) to +[github.com/golang-jwt/jwt](https://github.com/golang-jwt/jwt) due former library being unmaintained and affected by security +issues. +The [github.com/golang-jwt/jwt](https://github.com/golang-jwt/jwt) project is a drop-in replacement, but supports only the latest 2 Go versions. +So for JWT middleware users Go 1.15+ is required. For detailed information please read [#1940](https://github.com/labstack/echo/discussions/) + +To change the library imports in all .go files in your project replace all occurrences of `dgrijalva/jwt-go` with `golang-jwt/jwt`. + +For Linux CLI you can use: +```bash +find -type f -name "*.go" -exec sed -i "s/dgrijalva\/jwt-go/golang-jwt\/jwt/g" {} \; +go mod tidy +``` + +**Fixes** + +* Change JWT library to `github.com/golang-jwt/jwt` [#1946](https://github.com/labstack/echo/pull/1946) + +## v4.4.0 - 2021-07-12 + +**Fixes** + +* Split HeaderXForwardedFor header only by comma [#1878](https://github.com/labstack/echo/pull/1878) +* Fix Timeout middleware Context propagation [#1910](https://github.com/labstack/echo/pull/1910) + +**Enhancements** + +* Bind data using headers as source [#1866](https://github.com/labstack/echo/pull/1866) +* Adds JWTConfig.ParseTokenFunc to JWT middleware to allow different libraries implementing JWT parsing. [#1887](https://github.com/labstack/echo/pull/1887) +* Adding tests for Echo#Host [#1895](https://github.com/labstack/echo/pull/1895) +* Adds RequestIDHandler function to RequestID middleware [#1898](https://github.com/labstack/echo/pull/1898) +* Allow for custom JSON encoding implementations [#1880](https://github.com/labstack/echo/pull/1880) + +## v4.3.0 - 2021-05-08 + +**Important notes** + +* Route matching has improvements for following cases: + 1. Correctly match routes with parameter part as last part of route (with trailing backslash) + 2. Considering handlers when resolving routes and search for matching http method handler +* Echo minimal Go version is now 1.13. + +**Fixes** + +* When url ends with slash first param route is the match [#1804](https://github.com/labstack/echo/pull/1812) +* Router should check if node is suitable as matching route by path+method and if not then continue search in tree [#1808](https://github.com/labstack/echo/issues/1808) +* Fix timeout middleware not writing response correctly when handler panics [#1864](https://github.com/labstack/echo/pull/1864) +* Fix binder not working with embedded pointer structs [#1861](https://github.com/labstack/echo/pull/1861) +* Add Go 1.16 to CI and drop 1.12 specific code [#1850](https://github.com/labstack/echo/pull/1850) + +**Enhancements** + +* Make KeyFunc public in JWT middleware [#1756](https://github.com/labstack/echo/pull/1756) +* Add support for optional filesystem to the static middleware [#1797](https://github.com/labstack/echo/pull/1797) +* Add a custom error handler to key-auth middleware [#1847](https://github.com/labstack/echo/pull/1847) +* Allow JWT token to be looked up from multiple sources [#1845](https://github.com/labstack/echo/pull/1845) + +## v4.2.2 - 2021-04-07 + +**Fixes** + +* Allow proxy middleware to use query part in rewrite (#1802) +* Fix timeout middleware not sending status code when handler returns an error (#1805) +* Fix Bind() when target is array/slice and path/query params complains bind target not being struct (#1835) +* Fix panic in redirect middleware on short host name (#1813) +* Fix timeout middleware docs (#1836) + +## v4.2.1 - 2021-03-08 + +**Important notes** + +Due to a datarace the config parameters for the newly added timeout middleware required a change. +See the [docs](https://echo.labstack.com/middleware/timeout). +A performance regression has been fixed, even bringing better performance than before for some routing scenarios. + +**Fixes** + +* Fix performance regression caused by path escaping (#1777, #1798, #1799, aldas) +* Avoid context canceled errors (#1789, clwluvw) +* Improve router to use on stack backtracking (#1791, aldas, stffabi) +* Fix panic in timeout middleware not being not recovered and cause application crash (#1794, aldas) +* Fix Echo.Serve() not serving on HTTP port correctly when TLSListener is used (#1785, #1793, aldas) +* Apply go fmt (#1788, Le0tk0k) +* Uses strings.Equalfold (#1790, rkilingr) +* Improve code quality (#1792, withshubh) + +This release was made possible by our **contributors**: +aldas, clwluvw, lammel, Le0tk0k, maciej-jezierski, rkilingr, stffabi, withshubh + +## v4.2.0 - 2021-02-11 + +**Important notes** + +The behaviour for binding data has been reworked for compatibility with echo before v4.1.11 by +enforcing `explicit tagging` for processing parameters. This **may break** your code if you +expect combined handling of query/path/form params. +Please see the updated documentation for [request](https://echo.labstack.com/guide/request) and [binding](https://echo.labstack.com/guide/request) + +The handling for rewrite rules has been slightly adjusted to expand `*` to a non-greedy `(.*?)` capture group. This is only relevant if multiple asterisks are used in your rules. +Please see [rewrite](https://echo.labstack.com/middleware/rewrite) and [proxy](https://echo.labstack.com/middleware/proxy) for details. + +**Security** + +* Fix directory traversal vulnerability for Windows (#1718, little-cui) +* Fix open redirect vulnerability with trailing slash (#1771,#1775 aldas,GeoffreyFrogeye) + +**Enhancements** + +* Add Echo#ListenerNetwork as configuration (#1667, pafuent) +* Add ability to change the status code using response beforeFuncs (#1706, RashadAnsari) +* Echo server startup to allow data race free access to listener address +* Binder: Restore pre v4.1.11 behaviour for c.Bind() to use query params only for GET or DELETE methods (#1727, aldas) +* Binder: Add separate methods to bind only query params, path params or request body (#1681, aldas) +* Binder: New fluent binder for query/path/form parameter binding (#1717, #1736, aldas) +* Router: Performance improvements for missed routes (#1689, pafuent) +* Router: Improve performance for Real-IP detection using IndexByte instead of Split (#1640, imxyb) +* Middleware: Support real regex rules for rewrite and proxy middleware (#1767) +* Middleware: New rate limiting middleware (#1724, iambenkay) +* Middleware: New timeout middleware implementation for go1.13+ (#1743, ) +* Middleware: Allow regex pattern for CORS middleware (#1623, KlotzAndrew) +* Middleware: Add IgnoreBase parameter to static middleware (#1701, lnenad, iambenkay) +* Middleware: Add an optional custom function to CORS middleware to validate origin (#1651, curvegrid) +* Middleware: Support form fields in JWT middleware (#1704, rkfg) +* Middleware: Use sync.Pool for (de)compress middleware to improve performance (#1699, #1672, pafuent) +* Middleware: Add decompress middleware to support gzip compressed requests (#1687, arun0009) +* Middleware: Add ErrJWTInvalid for JWT middleware (#1627, juanbelieni) +* Middleware: Add SameSite mode for CSRF cookies to support iframes (#1524, pr0head) + +**Fixes** + +* Fix handling of special trailing slash case for partial prefix (#1741, stffabi) +* Fix handling of static routes with trailing slash (#1747) +* Fix Static files route not working (#1671, pwli0755, lammel) +* Fix use of caret(^) in regex for rewrite middleware (#1588, chotow) +* Fix Echo#Reverse for Any type routes (#1695, pafuent) +* Fix Router#Find panic with infinite loop (#1661, pafuent) +* Fix Router#Find panic fails on Param paths (#1659, pafuent) +* Fix DefaultHTTPErrorHandler with Debug=true (#1477, lammel) +* Fix incorrect CORS headers (#1669, ulasakdeniz) +* Fix proxy middleware rewritePath to use url with updated tests (#1630, arun0009) +* Fix rewritePath for proxy middleware to use escaped path in (#1628, arun0009) +* Remove unless defer (#1656, imxyb) + +**General** + +* New maintainers for Echo: Roland Lammel (@lammel) and Pablo Andres Fuente (@pafuent) +* Add GitHub action to compare benchmarks (#1702, pafuent) +* Binding query/path params and form fields to struct only works for explicit tags (#1729,#1734, aldas) +* Add support for Go 1.15 in CI (#1683, asahasrabuddhe) +* Add test for request id to remain unchanged if provided (#1719, iambenkay) +* Refactor echo instance listener access and startup to speed up testing (#1735, aldas) +* Refactor and improve various tests for binding and routing +* Run test workflow only for relevant changes (#1637, #1636, pofl) +* Update .travis.yml (#1662, santosh653) +* Update README.md with an recents framework benchmark (#1679, pafuent) + +This release was made possible by **over 100 commits** from more than **20 contributors**: +asahasrabuddhe, aldas, AndrewKlotz, arun0009, chotow, curvegrid, iambenkay, imxyb, +juanbelieni, lammel, little-cui, lnenad, pafuent, pofl, pr0head, pwli, RashadAnsari, +rkfg, santosh653, segfiner, stffabi, ulasakdeniz diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..decbf0792 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,99 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## About This Project + +Echo is a high performance, minimalist Go web framework. This is the main repository for Echo v4, which is available as a Go module at `github.com/labstack/echo/v4`. + +## Development Commands + +The project uses a Makefile for common development tasks: + +- `make check` - Run linting, vetting, and race condition tests (default target) +- `make init` - Install required linting tools (golint, staticcheck) +- `make lint` - Run staticcheck and golint +- `make vet` - Run go vet +- `make test` - Run short tests +- `make race` - Run tests with race detector +- `make benchmark` - Run benchmarks + +Example commands for development: +```bash +# Setup development environment +make init + +# Run all checks (lint, vet, race) +make check + +# Run specific tests +go test ./middleware/... +go test -race ./... + +# Run benchmarks +make benchmark +``` + +## Code Architecture + +### Core Components + +**Echo Instance (`echo.go`)** +- The `Echo` struct is the top-level framework instance +- Contains router, middleware stacks, and server configuration +- Not goroutine-safe for mutations after server start + +**Context (`context.go`)** +- The `Context` interface represents HTTP request/response context +- Provides methods for request/response handling, path parameters, data binding +- Core abstraction for request processing + +**Router (`router.go`)** +- Radix tree-based HTTP router with smart route prioritization +- Supports static routes, parameterized routes (`/users/:id`), and wildcard routes (`/static/*`) +- Each HTTP method has its own routing tree + +**Middleware (`middleware/`)** +- Extensive middleware system with 50+ built-in middlewares +- Middleware can be applied at Echo, Group, or individual route level +- Common middleware: Logger, Recover, CORS, JWT, Rate Limiting, etc. + +### Key Patterns + +**Middleware Chain** +- Pre-middleware runs before routing +- Regular middleware runs after routing but before handlers +- Middleware functions have signature `func(next echo.HandlerFunc) echo.HandlerFunc` + +**Route Groups** +- Routes can be grouped with common prefixes and middleware +- Groups support nested sub-groups +- Defined in `group.go` + +**Data Binding** +- Automatic binding of request data (JSON, XML, form) to Go structs +- Implemented in `binder.go` with support for custom binders + +**Error Handling** +- Centralized error handling via `HTTPErrorHandler` +- Automatic panic recovery with stack traces + +## File Organization + +- Root directory: Core Echo functionality (echo.go, context.go, router.go, etc.) +- `middleware/`: All built-in middleware implementations +- `_test/`: Test fixtures and utilities +- `_fixture/`: Test data files + +## Code Style + +- Go code uses tabs for indentation (per .editorconfig) +- Follows standard Go conventions and formatting +- Uses gofmt, golint, and staticcheck for code quality + +## Testing + +- Standard Go testing with `testing` package +- Tests include unit tests, integration tests, and benchmarks +- Race condition testing is required (`make race`) +- Test files follow `*_test.go` naming convention \ No newline at end of file diff --git a/LICENSE b/LICENSE index b5b006b4e..2f18411bd 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2017 LabStack +Copyright (c) 2022 LabStack Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/Makefile b/Makefile index dfcb6c02b..bd075bbae 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,32 @@ -tag: - @git tag `grep -P '^\tversion = ' echo.go|cut -f2 -d'"'` - @git tag|grep -v ^v +PKG := "github.com/labstack/echo" +PKG_LIST := $(shell go list ${PKG}/...) + +.DEFAULT_GOAL := check +check: lint vet race ## Check project + +init: + @go install golang.org/x/lint/golint@latest + @go install honnef.co/go/tools/cmd/staticcheck@latest + +lint: ## Lint the files + @staticcheck ${PKG_LIST} + @golint -set_exit_status ${PKG_LIST} + +vet: ## Vet the files + @go vet ${PKG_LIST} + +test: ## Run tests + @go test -short ${PKG_LIST} + +race: ## Run tests with data race detector + @go test -race ${PKG_LIST} + +benchmark: ## Run benchmarks + @go test -run="-" -benchmem -bench=".*" ${PKG_LIST} + +help: ## Display this help screen + @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +goversion ?= "1.25" +test_version: ## Run tests inside Docker with given version (defaults to 1.25 oldest supported). Example: make test_version goversion=1.25 + @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" diff --git a/README.md b/README.md index 0da031225..ca6dfbf5d 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,24 @@ - - [![Sourcegraph](https://sourcegraph.com/github.com/labstack/echo/-/badge.svg?style=flat-square)](https://sourcegraph.com/github.com/labstack/echo?badge) -[![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](http://godoc.org/github.com/labstack/echo) +[![GoDoc](http://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://pkg.go.dev/github.com/labstack/echo/v4) [![Go Report Card](https://goreportcard.com/badge/github.com/labstack/echo?style=flat-square)](https://goreportcard.com/report/github.com/labstack/echo) -[![Build Status](http://img.shields.io/travis/labstack/echo.svg?style=flat-square)](https://travis-ci.org/labstack/echo) +[![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/labstack/echo/echo.yml?style=flat-square)](https://github.com/labstack/echo/actions) [![Codecov](https://img.shields.io/codecov/c/github/labstack/echo.svg?style=flat-square)](https://codecov.io/gh/labstack/echo) -[![Join the chat at https://gitter.im/labstack/echo](https://img.shields.io/badge/gitter-join%20chat-brightgreen.svg?style=flat-square)](https://gitter.im/labstack/echo) -[![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://forum.labstack.com) +[![Forum](https://img.shields.io/badge/community-forum-00afd1.svg?style=flat-square)](https://github.com/labstack/echo/discussions) [![Twitter](https://img.shields.io/badge/twitter-@labstack-55acee.svg?style=flat-square)](https://twitter.com/labstack) [![License](http://img.shields.io/badge/license-mit-blue.svg?style=flat-square)](https://raw.githubusercontent.com/labstack/echo/master/LICENSE) -## Supported Go versions +## Echo -As of version 4.0.0, Echo is available as a [Go module](https://github.com/golang/go/wiki/Modules). -Therefore a Go version capable of understanding /vN suffixed imports is required: +High performance, extensible, minimalist Go web framework. -- 1.9.7+ -- 1.10.3+ -- 1.11+ +* [Official website](https://echo.labstack.com) +* [Quick start](https://echo.labstack.com/docs/quick-start) +* [Middlewares](https://echo.labstack.com/docs/category/middleware) -Any of these versions will allow you to import Echo as `github.com/labstack/echo/v4` which is the recommended -way of using Echo going forward. +Help and questions: [Github Discussions](https://github.com/labstack/echo/discussions) -For older versions, please use the latest v3 tag. -## Feature Overview +### Feature Overview - Optimized HTTP router which smartly prioritize routes - Build robust and scalable RESTful APIs @@ -40,25 +34,48 @@ For older versions, please use the latest v3 tag. - Automatic TLS via Let’s Encrypt - HTTP/2 support -## Benchmarks +## Sponsors -Date: 2018/03/15
-Source: https://github.com/vishr/web-framework-benchmark
-Lower is better! +
+ + encore icon + Encore – the platform for building Go-based cloud backends + +
+
- +Click [here](https://github.com/sponsors/labstack) for more information on sponsorship. ## [Guide](https://echo.labstack.com/guide) +### Supported Echo versions + +- Latest major version of Echo is `v5` as of 2026-01-18. + - Until 2026-03-31, any critical issues requiring breaking API changes will be addressed, even if this violates semantic versioning. + - See [API_CHANGES_V5.md](./API_CHANGES_V5.md) for public API changes between `v4` and `v5`, notes on upgrading. + - If you are using Echo in a production environment, it is recommended to wait until after 2026-03-31 before upgrading. +- Echo `v4` is supported with **security*** updates and **bug** fixes until **2026-12-31** + + +### Installation + +```sh +// go get github.com/labstack/echo/{version} +go get github.com/labstack/echo/v5 +``` +Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions. + + ### Example ```go package main import ( + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" + "log/slog" "net/http" - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" ) func main() { @@ -66,26 +83,50 @@ func main() { e := echo.New() // Middleware - e.Use(middleware.Logger()) - e.Use(middleware.Recover()) + e.Use(middleware.RequestLogger()) // use the RequestLogger middleware with slog logger + e.Use(middleware.Recover()) // recover panics as errors for proper error handling // Routes e.GET("/", hello) // Start server - e.Logger.Fatal(e.Start(":1323")) + if err := e.Start(":8080"); err != nil { + slog.Error("failed to start server", "error", err) + } } // Handler -func hello(c echo.Context) error { +func hello(c *echo.Context) error { return c.String(http.StatusOK, "Hello, World!") } ``` -## Help +# Official middleware repositories + +Following list of middleware is maintained by Echo team. + +| Repository | Description | +|------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [github.com/labstack/echo-jwt](https://github.com/labstack/echo-jwt) | [JWT](https://github.com/golang-jwt/jwt) middleware | +| [github.com/labstack/echo-contrib](https://github.com/labstack/echo-contrib) | [casbin](https://github.com/casbin/casbin), [gorilla/sessions](https://github.com/gorilla/sessions), [jaegertracing](https://github.com/uber/jaeger-client-go), [prometheus](https://github.com/prometheus/client_golang/), [pprof](https://pkg.go.dev/net/http/pprof), [zipkin](https://github.com/openzipkin/zipkin-go) middlewares | + +# Third-party middleware repositories + +Be careful when adding 3rd party middleware. Echo teams does not have time or manpower to guarantee safety and quality +of middlewares in this list. + +| Repository | Description | +|------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [oapi-codegen/oapi-codegen](https://github.com/oapi-codegen/oapi-codegen) | Automatically generate RESTful API documentation with [OpenAPI](https://swagger.io/specification/) Client and Server Code Generator | +| [github.com/swaggo/echo-swagger](https://github.com/swaggo/echo-swagger) | Automatically generate RESTful API documentation with [Swagger](https://swagger.io/) 2.0. | +| [github.com/ziflex/lecho](https://github.com/ziflex/lecho) | [Zerolog](https://github.com/rs/zerolog) logging library wrapper for Echo logger interface. | +| [github.com/brpaz/echozap](https://github.com/brpaz/echozap) | Uber´s [Zap](https://github.com/uber-go/zap) logging library wrapper for Echo logger interface. | +| [github.com/samber/slog-echo](https://github.com/samber/slog-echo) | Go [slog](https://pkg.go.dev/golang.org/x/exp/slog) logging library wrapper for Echo logger interface. | +| [github.com/darkweak/souin/plugins/echo](https://github.com/darkweak/souin/tree/master/plugins/echo) | HTTP cache system based on [Souin](https://github.com/darkweak/souin) to automatically get your endpoints cached. It supports some distributed and non-distributed storage systems depending your needs. | +| [github.com/mikestefanello/pagoda](https://github.com/mikestefanello/pagoda) | Rapid, easy full-stack web development starter kit built with Echo. | +| [github.com/go-woo/protoc-gen-echo](https://github.com/go-woo/protoc-gen-echo) | ProtoBuf generate Echo server side code | -- [Forum](https://forum.labstack.com) -- [Chat](https://gitter.im/labstack/echo) +Please send a PR to add your own library here. ## Contribute @@ -104,8 +145,11 @@ func hello(c echo.Context) error { ## Credits -- [Vishal Rana](https://github.com/vishr) - Author -- [Nitin Rana](https://github.com/nr17) - Consultant +- [Vishal Rana](https://github.com/vishr) (Author) +- [Nitin Rana](https://github.com/nr17) (Consultant) +- [Roland Lammel](https://github.com/lammel) (Maintainer) +- [Martti T.](https://github.com/aldas) (Maintainer) +- [Pablo Andres Fuente](https://github.com/pafuent) (Maintainer) - [Contributors](https://github.com/labstack/echo/graphs/contributors) ## License diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..efb618697 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,15 @@ +# Security Policy + +## Supported Versions + +| Version | Supported | +|-----------|-------------------------------------| +| 5.x.x | :white_check_mark: | +| >= 4.15.x | :white_check_mark: until 2026.12.31 | +| < 4.15 | :x: | + +## Reporting a Vulnerability + +https://github.com/labstack/echo/security/advisories/new + +or look for maintainers email(s) in commits and email them. diff --git a/_fixture/_fixture/README.md b/_fixture/_fixture/README.md new file mode 100644 index 000000000..21a785851 --- /dev/null +++ b/_fixture/_fixture/README.md @@ -0,0 +1 @@ +This directory is used for the static middleware test \ No newline at end of file diff --git a/_fixture/certs/README.md b/_fixture/certs/README.md new file mode 100644 index 000000000..e27d4b139 --- /dev/null +++ b/_fixture/certs/README.md @@ -0,0 +1,13 @@ +To generate a valid certificate and private key use the following command: + +```bash +# In OpenSSL ≥ 1.1.1 +openssl req -x509 -newkey rsa:4096 -sha256 -days 9999 -nodes \ + -keyout key.pem -out cert.pem -subj "/CN=localhost" \ + -addext "subjectAltName=DNS:localhost,IP:127.0.0.1,IP:::1" +``` + +To check a certificate use the following command: +```bash +openssl x509 -in cert.pem -text +``` diff --git a/_fixture/certs/cert.pem b/_fixture/certs/cert.pem index c58f13fa6..d88cf3fec 100644 --- a/_fixture/certs/cert.pem +++ b/_fixture/certs/cert.pem @@ -1,18 +1,30 @@ -----BEGIN CERTIFICATE----- -MIIC+TCCAeGgAwIBAgIQe/dw9alKTWAPhsHoLdkn+TANBgkqhkiG9w0BAQsFADAS -MRAwDgYDVQQKEwdBY21lIENvMB4XDTE2MDkyNTAwNDcxN1oXDTE3MDkyNTAwNDcx -N1owEjEQMA4GA1UEChMHQWNtZSBDbzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC -AQoCggEBAL8WwhLGbK8HkiEDKV0JbjtWp3/EWKhKFW3YtKtPfPOgoZejdNn9VE0B -IlQ4rwa1wmsM9NDKC0m60oiNOYeyugx9PoFI3RXzuKVX2x7E5LTW0sv0LC9PCggZ -MZTih1AiYtwJIZl+aK6s4dTb/PUOLDdcRTZTF2egkdAicbUlQT4Kn+A3jHiE+ATC -h3MlV2BHarhAhWb0FrOg2bEtFrMyFDaLbHI7xbj+vB9CkGB9L5tObP2M9lQCxH8d -ElWkJjxg7vdkhJ5+sWNaY80utNipUdVO845tIERwRXRRviFYpOcuNfnJYC9kwRjv -CRanh3epWhG0cFQVV5d45sHf6t5F+jsCAwEAAaNLMEkwDgYDVR0PAQH/BAQDAgWg -MBMGA1UdJQQMMAoGCCsGAQUFBwMBMAwGA1UdEwEB/wQCMAAwFAYDVR0RBA0wC4IJ -bG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IBAQAdd3ZW6R4cImmxIzfoz7Ttq862 -oOiyzFnisCxgNdA78epit49zg0CgF7q9guTEArXJLI+/qnjPPObPOlTlsEyomb2F -UOS+2hn/ZyU5/tUxhkeOBYqdEaryk6zF6vPLUJ5IphJgOg00uIQGL0UvupBLEyIG -Rsa/lKEtW5Z9PbIi9GeVn51U+9VMCYft/T7SDziKl7OcE/qoVh1G0/tTRkAqOqpZ -bzc8ssEhJVNZ/DO+uYHNYf/waB6NjfXQuTegU/SyxnawvQ4oBHIzyuWplGCcTlfT -IXsOQdJo2xuu8807d+rO1FpN8yWi5OF/0sif0RrocSskLAIL/PI1qfWuuPck +MIIFODCCAyCgAwIBAgIUaTvDluaMf+VJgYHQ0HFTS3yuCHYwDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIxMDIyNzIxMzQ0MVoXDTQ4MDcx +NDIxMzQ0MVowFDESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEF +AAOCAg8AMIICCgKCAgEAnqyyAAnWFH2TH7Epj5yfZxYrBvizydZe1Wo/1WpGR2IK +QT+qIul5sEKX/ERqEOXsawSrL3fw9cuSM8Z2vD/57ZZdoSR7XIdVaMDEQenJ968a +HObu4D27uBQwIwrM5ELgnd+fC4gis64nIu+2GSfHumZXi7lLW7DbNm8oWkMqI6tY +2s2wx2hwGYNVJrwSn4WGnkzhQ5U5mkcsLELMx7GR0Qnv6P7sNGZVeqMU7awkcSpR +crKR1OUP7XCJkEq83WLHSx50+QZv7LiyDmGnujHevRbdSHlcFfHZtaufYat+qICe +S3XADwRQe/0VSsmja6u3DAHy7VmL8PNisAdkopQZrhiI9OvGrpGZffs9zn+s/jeX +N1bqVDihCMiEjqXMlHx2oj3AXrZTFxb7y7Ap9C07nf70lpxQWW9SjMYRF98JBiHF +eJbQkNVkmz6T8ielQbX0l46F2SGK98oyFCGNIAZBUdj5CcS1E6w/lk4t58/em0k7 +3wFC5qg0g0wfIbNSmxljBNxnaBYUqyaaAJJhpaEoOebm4RYV58hQ0FbMfpnLnSh4 +dYStsk6i1PumWoa7D45DTtxF3kH7TB3YOB5aWaNGAPQC1m4Qcd23YB5Rd/ABirSp +ux6/cFGosjSfJ/G+G0RhNUpmcbDJvFSOhD2WCuieVhCTAzp+VPIA9bSqD+InlT0C +AwEAAaOBgTB/MB0GA1UdDgQWBBQZyM//SvzYKokQZI/0MVGb6PkH+zAfBgNVHSME +GDAWgBQZyM//SvzYKokQZI/0MVGb6PkH+zAPBgNVHRMBAf8EBTADAQH/MCwGA1Ud +EQQlMCOCCWxvY2FsaG9zdIcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG +9w0BAQsFAAOCAgEAKGAJQmQ/KLw8iMb5QsyxxAonVjJ1eDAhNM3GWdHpM0/GFamO +vVtATLQQldwDiZJvrsCQPEc8ctZ2Utvg/StLQ3+rZpsvt0+gcUlLJK61qguwYqb2 ++T7VK5s7V/OyI/tsuboOW50Pka9vQHV+Z0aM06Yu+HNDAq/UTpEOb/3MQvZd6Ooy +PTpZtFb/+5jIQa1dIsfFWmpBxF0+wUd9GEkX3j7nekwoZfJ8Ze4GWYERZbOFpDAQ +rIHdthH5VJztnpQJmaKqzgIOF+Rurwlp5ecSC33xNNjDaYtuf/fiWnoKGhHVSBhT +61+0yxn3rTgh/Dsm95xY00rSX6lmcvI+kRNTUc8GGPz0ajBH6xyY7bNhfMjmnSW/ +C/XTEDbTAhT7ndWC5vvzp7ZU0TvN+WY6A0f2kxSnnrEk6QRUvRtKkjAkmAFz8exi +ttBBW0I3E5HNIC5CYRimq/9z+3clM/P1KbNblwuC65bL+PZ+nzFnn5hFaK9eLPol +OwZQXv7IvAw8GfgLTrEUT7eBCQwe1IqesA7NTxF1BVwmNUb2XamvQZ7ly67QybRw +0uJq80XjpVjBWYTTQy1dsnC2OTKdqGsV9TVIDR+UGfIG9cxL70pEbiSH2AX+IDCy +i3kNIvpXgBliAyOjW6Hj1fv6dNfAat/hqEfnquWkfvcs3HNrG/InwpwNAUs= -----END CERTIFICATE----- diff --git a/_fixture/certs/key.pem b/_fixture/certs/key.pem index 9c75e7ca8..0276c224e 100644 --- a/_fixture/certs/key.pem +++ b/_fixture/certs/key.pem @@ -1,27 +1,52 @@ ------BEGIN RSA PRIVATE KEY----- -MIIEpAIBAAKCAQEAvxbCEsZsrweSIQMpXQluO1anf8RYqEoVbdi0q09886Chl6N0 -2f1UTQEiVDivBrXCawz00MoLSbrSiI05h7K6DH0+gUjdFfO4pVfbHsTktNbSy/Qs -L08KCBkxlOKHUCJi3AkhmX5orqzh1Nv89Q4sN1xFNlMXZ6CR0CJxtSVBPgqf4DeM -eIT4BMKHcyVXYEdquECFZvQWs6DZsS0WszIUNotscjvFuP68H0KQYH0vm05s/Yz2 -VALEfx0SVaQmPGDu92SEnn6xY1pjzS602KlR1U7zjm0gRHBFdFG+IVik5y41+clg -L2TBGO8JFqeHd6laEbRwVBVXl3jmwd/q3kX6OwIDAQABAoIBAQCR69EcAUZxinh+ -mSl3EIKK8atLGCcTrC8dCQU+ZJ7odFuxrnLHHHrJqvoKEpclqprioKw63G8uSGoJ -OL8b7tHAQ8v9ciTSZKE2Mhb0MirsJbgnYzhykAr7EDIanbny6a9Qk/CChFNwQDjc -EXnjsIT3aZC44U7YJXfz1rm6OM7Pjn6z8H4vYGRDOsYkhXvPfnPW8C2LFJVr9nvE -0gIAOVoGejEJrsJVK3Uj/nPcqSQYXmwEmtjtzOw7u6yp1b2VZEK7tR47HwJt6ltG -Z9zhpwhpvdOuXNMqMOYRf9bLBWnSqIlTHOO0UlAnyRCY1HxluZB7ZSg9VnoJDrD7 -w+JqAGnBAoGBAO5qyIzjldwR004YjepmZfuX3PnGLZhzhmTTC7Pl9gqv1TvxfxvD -6yBFL2GrN1IcnrX9Qk2xncUAbpM989MF+EC7I4++1t1I6akUKFEDkfvQwQjCXfPS -Jv2rkwIVSkt8F0X/tOb13OeIiHuFVI/Bb9VoJSP/k4DfPV+/HnwBxvzLAoGBAM0u -b/rYfm5rb20/PKClUs154s0eKSokVogqiJkf+5qLsV+TD50JVZBVw8s4XM79iwQI -PyGY9nI1AvqG7yIzxSy5/Qk1+ZVdVYpmWIO5PnJ8TVraDVhCQ3fVz1uWtcyaqPVr -3QzdyvsEgFUGFItmRdhSvA8RGrpVCHTBzrDj3jpRAoGBAKNaSLS3jkstb3D3w+yR -YliisYX1cfIdXTyhmUgWTKD/3oLmsSdt8iC3JoKt1AaPk3Kv5ojjJG0BIcIC1ZeF -ZJW9Yt0vbXpKZcYyCHmRj6lQW6JLwiG3oH133A62VaQojq2oSONiG4wL8S9oqAqj -B6PZanEiwIaw7hU3FoTylstHAoGAFYvE0pCdZjb98njrgusZcN5VxLhgFj7On2no -AjxrjWUR8TleMF1kkM2Qy+xVQp85U+kRyBNp/cA3WduFjQ/mqrW1LpxuYxL0Ap6Q -uPRg7GDFNr8jG5uJvjHDnpiK6rtq9qqnAczgnc9xMnx699B7kSXO/b4MEnkPdENN -0yF6mqECgYA88UELxbhqMSdG24DX0zHXvkXLIml2JNVb54glFByIIem+acff9oG9 -X5GajlBroPoKk7FgA9ouqcQMH66UnFi6qh07l0J2xb0aXP8yzLAGauVGTTNIQCR4 -VpqyDpjlc1ZqfZWOrvwSrUH1mEkxbeVvQsOUja2Jvu+lc3Zo099ILw== ------END RSA PRIVATE KEY----- +-----BEGIN PRIVATE KEY----- +MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQCerLIACdYUfZMf +sSmPnJ9nFisG+LPJ1l7Vaj/VakZHYgpBP6oi6XmwQpf8RGoQ5exrBKsvd/D1y5Iz +xna8P/ntll2hJHtch1VowMRB6cn3rxoc5u7gPbu4FDAjCszkQuCd358LiCKzrici +77YZJ8e6ZleLuUtbsNs2byhaQyojq1jazbDHaHAZg1UmvBKfhYaeTOFDlTmaRyws +QszHsZHRCe/o/uw0ZlV6oxTtrCRxKlFyspHU5Q/tcImQSrzdYsdLHnT5Bm/suLIO +Yae6Md69Ft1IeVwV8dm1q59hq36ogJ5LdcAPBFB7/RVKyaNrq7cMAfLtWYvw82Kw +B2SilBmuGIj068aukZl9+z3Of6z+N5c3VupUOKEIyISOpcyUfHaiPcBetlMXFvvL +sCn0LTud/vSWnFBZb1KMxhEX3wkGIcV4ltCQ1WSbPpPyJ6VBtfSXjoXZIYr3yjIU +IY0gBkFR2PkJxLUTrD+WTi3nz96bSTvfAULmqDSDTB8hs1KbGWME3GdoFhSrJpoA +kmGloSg55ubhFhXnyFDQVsx+mcudKHh1hK2yTqLU+6ZahrsPjkNO3EXeQftMHdg4 +HlpZo0YA9ALWbhBx3bdgHlF38AGKtKm7Hr9wUaiyNJ8n8b4bRGE1SmZxsMm8VI6E +PZYK6J5WEJMDOn5U8gD1tKoP4ieVPQIDAQABAoICAEHF2CsH6MOpofi7GT08cR7s +I33KTcxWngzc9ATk/qjMTO/rEf1Sxmx3zkR1n3nNtQhPcR5GG43nin0HwWQbKOCB +OeJ4GuKp/o9jiHbCEEQpQyvD1jUBofSV+bYs3e2ogy8t6OGA1tGgWPy0XMlkoff0 +QEnczw3864FO5m0z9h2/Ax//r02ZTw5kUEG0KAwT709jEuVO0AfRhM/8CKKmSola +EyaDtSmrWbdyLlSuzJRUNFrVBno3UTjdM0iqkks6jN3ojBhFwNNhY/1uIXafAXNk +LOnD1JYMIHCb6X809VWnqvYgozIWWb5rlA3iM2mITmId1LLqMYX5fWj2R5LUzSek +H+XG+F9FIouTaL1ACoXr0zyeY5N5YJdyXYa1tThdW+axX9ZrnPgeiQrmxzKPIyb7 +LLlVtNBQUg/t5tX80KyYjkNUu4j3oq/uBYPi0m//ovwMyi9bSbbyPT+cDXuXX5Bc +oY7wyn3evXX0c1R7vdJLZLkLu+ctVex/9hvMjeW/mMasDjLnqY7pF3Skct1SX5N2 +U8YVU9bGvFpLEwM9lmi/T7bcv+zbmGPlfTsZiFrCsixPLn7sX7y5M4L8au8O0jh0 +nHm/8rWVg1Qw0Hobg3tA8FjeMa8Sr2fYmkNLVKFzhuJLxknTJLaUbX5CymNqWP4H +OctvfSY0nSZ1eQpBkQaJAoIBAQDTb/NhYCfaJBLXHVMy/VYd7kWGZ+I87artcE/l +8u0pJ8XOP4kp0otFIumpHUFodysAeP6HrI79MuJB40fy91HzWZC+NrPufFFFuZ0z +Ld1o3Y5nAeoZmMlf1F12Oe3OQZy7nm9eNNkfeoVtKqDv4FhAqk+aoMor86HscKsR +C6HlZFdGc7kX0ylrQAXPq9KLhcvUU9oAUpbqTbhYK83IebRJgFDG45HkVo9SUHpF +dmCFSb91eZpRGpdfNLCuLiSu52TebayaUCnceeAt8SyeiChJ/TwWmRRDJS0QUv6h +s3Wdp+cx9ANoujA4XzAs8Fld5IZ4bcG5jjwD62/tJyWrCC5DAoIBAQDAHfHjrYCK +GHBrMj+MA7cK7fCJUn/iJLSLGgo2ANYF5oq9gaCwHCtKIyB9DN/KiY0JpJ6PWg+Q +9Difq23YXiJjNEBS5EFTu9UwWAr1RhSAegrfHxm0sDbcAx31NtDYvBsADCWQYmzc +KPfBshf5K4g/VCIj2VzC2CE6kNtdhqLU6AV2Pi1Tl1S82xWoAjHy91tDmlFQNWCj +B2ZnZ7tY9zuwDfeBBOVCPHICgl5Q4PrY1KEWEXiNxgbtkNmOPAsY9WSqgOsP9pWK +J924gdCCvovINzZtgRisxKth6Fkhra+VCsheg9SWvgR09Deo6CCoSwYxOSb0cjh2 +oyX5Rb1kJ7Z/AoIBAQCX2iNVoBV/GcFeNXV3fXLH9ESCj0FwuNC1zp/TanDhyerK +gd8k5k2Xzcc66gP73vpHUJ6dGlVni4/r+ivGV9HHkF/f/LGlaiuEhBZel2YY1mZb +nIhg8dZOuNqW+mvMYlsKdHNPmW0GqpwBF0iWfu1jI+4gA7Kvdj6o7RIvH8eaVEJK +GvqoHcP1fvmteJ2yDtmhGMfMy4QPqtnmmS8l+CJ/V2SsMuyorXIpkBsAoFAZ6ilT +WY53CT4F5nWt4v39j7pl9SatfT1TV0SmOjvtb6Rf3zu0jyR6RMzkmHa/839ZRylI +OxPntzDCi7qxy7yjLmlVPJ6RgZGgzwqHrEHlX+65AoIBAQCEzu6d3x5B2N02LZli +eFr8MjqbI64GLiulEY5HgNJzZ8k3cjocJI0Ehj36VIEMaYRXSzbVkIO8SCgwsPiR +n5mUDNX+t441jV62Odbxcc3Qdw226rABieOSupDmKEu92GOt57e8FV5939BOVYhf +FunsJYQoViXbCEAIVYVgJSfBmNfVwuvgonfQyn8xErtm4/pyRGa71PqGGSKAj2Qi +/16CuVUFGtZFsLV76JW8wZqHdI4bTF6TW3cEmaLbwcRGL7W0bMSS13rO8/pBh3QW +PhUxhoGYt6rQHHEBkPa04nXDyZ10QRwgTSGVnBIyMK4KyTpxorm8OI2x7dzdcomX +iCCPAoIBAETwfr2JKPb/AzrKhhbZgU+sLVn3WH/nb68VheNEmGOzsqXaSHCR2NOq +/ow7bawjc8yUIhBRzokR4F/7jGolOmfdq0MYFb6/YokssKfv1ugxBhmvOxpZ6F6E +cERJ8Ex/ffQU053gLR/0ammddVuS1GR5I/jEdP0lJVh0xapoZNUlT5dWYCgo20hY +ZAmKpU+veyUn+5Li0pmm959vnLK5LJzEA5mpz3w1QPPtVwQs05dwmEV3CRAcCeeh +8sXp49WNCSW4I3BxuTZzRV845SGIFhZwgVV42PTp2LPKl2p6E7Bk8xpUCCvBpALp +QmA5yIMx+u2Jpr7fUsXEXEPTEhvjff0= +-----END PRIVATE KEY----- diff --git a/_fixture/dist/private.txt b/_fixture/dist/private.txt new file mode 100644 index 000000000..0f9d2435b --- /dev/null +++ b/_fixture/dist/private.txt @@ -0,0 +1 @@ +private file diff --git a/_fixture/dist/public/assets/readme.md b/_fixture/dist/public/assets/readme.md new file mode 100644 index 000000000..50590f554 --- /dev/null +++ b/_fixture/dist/public/assets/readme.md @@ -0,0 +1 @@ +readme in assets diff --git a/_fixture/dist/public/assets/subfolder/subfolder.md b/_fixture/dist/public/assets/subfolder/subfolder.md new file mode 100644 index 000000000..74c928b2f --- /dev/null +++ b/_fixture/dist/public/assets/subfolder/subfolder.md @@ -0,0 +1 @@ +file inside subfolder diff --git a/_fixture/dist/public/index.html b/_fixture/dist/public/index.html new file mode 100644 index 000000000..df6d9015a --- /dev/null +++ b/_fixture/dist/public/index.html @@ -0,0 +1 @@ +

Hello from index

diff --git a/_fixture/dist/public/test.txt b/_fixture/dist/public/test.txt new file mode 100644 index 000000000..dd937160d --- /dev/null +++ b/_fixture/dist/public/test.txt @@ -0,0 +1 @@ +test.txt contents diff --git a/bind.go b/bind.go index c8c88bb20..050e8973b 100644 --- a/bind.go +++ b/bind.go @@ -1,137 +1,237 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( "encoding" - "encoding/json" "encoding/xml" "errors" - "fmt" + "mime/multipart" "net/http" "reflect" "strconv" "strings" + "time" ) -type ( - // Binder is the interface that wraps the Bind method. - Binder interface { - Bind(i interface{}, c Context) error - } +// Binder is the interface that wraps the Bind method. +type Binder interface { + Bind(c *Context, target any) error +} - // DefaultBinder is the default implementation of the Binder interface. - DefaultBinder struct{} +// DefaultBinder is the default implementation of the Binder interface. +type DefaultBinder struct{} - // BindUnmarshaler is the interface used to wrap the UnmarshalParam method. - // Types that don't implement this, but do implement encoding.TextUnmarshaler - // will use that interface instead. - BindUnmarshaler interface { - // UnmarshalParam decodes and assigns a value from an form or query param. - UnmarshalParam(param string) error - } -) +// BindUnmarshaler is the interface used to wrap the UnmarshalParam method. +// Types that don't implement this, but do implement encoding.TextUnmarshaler +// will use that interface instead. +type BindUnmarshaler interface { + // UnmarshalParam decodes and assigns a value from an form or query param. + UnmarshalParam(param string) error +} -// Bind implements the `Binder#Bind` function. -func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { - req := c.Request() +// bindMultipleUnmarshaler is used by binder to unmarshal multiple values from request at once to +// type implementing this interface. For example request could have multiple query fields `?a=1&a=2&b=test` in that case +// for `a` following slice `["1", "2"] will be passed to unmarshaller. +type bindMultipleUnmarshaler interface { + UnmarshalParams(params []string) error +} - names := c.ParamNames() - values := c.ParamValues() +// BindPathValues binds path parameter values to bindable object +func BindPathValues(c *Context, target any) error { params := map[string][]string{} - for i, name := range names { - params[name] = []string{values[i]} + for _, param := range c.PathValues() { + params[param.Name] = []string{param.Value} } - if err := b.bindData(i, params, "param"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + if err := bindData(target, params, "param", nil); err != nil { + return ErrBadRequest.Wrap(err) } - if err = b.bindData(i, c.QueryParams(), "query"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return nil +} + +// BindQueryParams binds query params to bindable object +func BindQueryParams(c *Context, target any) error { + if err := bindData(target, c.QueryParams(), "query", nil); err != nil { + return ErrBadRequest.Wrap(err) } + return nil +} + +// BindBody binds request body contents to bindable object +// NB: then binding forms take note that this implementation uses standard library form parsing +// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm +// See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm +// See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm +func BindBody(c *Context, target any) (err error) { + req := c.Request() if req.ContentLength == 0 { return } - ctype := req.Header.Get(HeaderContentType) - switch { - case strings.HasPrefix(ctype, MIMEApplicationJSON): - if err = json.NewDecoder(req.Body).Decode(i); err != nil { - if ute, ok := err.(*json.UnmarshalTypeError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err) - } else if se, ok := err.(*json.SyntaxError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err) + + // mediatype is found like `mime.ParseMediaType()` does it + base, _, _ := strings.Cut(req.Header.Get(HeaderContentType), ";") + mediatype := strings.TrimSpace(base) + + switch mediatype { + case MIMEApplicationJSON: + if err = c.Echo().JSONSerializer.Deserialize(c, target); err != nil { + var hErr *HTTPError + if errors.As(err, &hErr) { + return err } - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return ErrBadRequest.Wrap(err) } - case strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML): - if err = xml.NewDecoder(req.Body).Decode(i); err != nil { - if ute, ok := err.(*xml.UnsupportedTypeError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err) - } else if se, ok := err.(*xml.SyntaxError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err) - } - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + case MIMEApplicationXML, MIMETextXML: + if err = xml.NewDecoder(req.Body).Decode(target); err != nil { + return ErrBadRequest.Wrap(err) + } + case MIMEApplicationForm: + params, err := c.FormValues() + if err != nil { + return ErrBadRequest.Wrap(err) + } + if err = bindData(target, params, "form", nil); err != nil { + return ErrBadRequest.Wrap(err) } - case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): - params, err := c.FormParams() + case MIMEMultipartForm: + params, err := c.MultipartForm() if err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return ErrBadRequest.Wrap(err) } - if err = b.bindData(i, params, "form"); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + if err = bindData(target, params.Value, "form", params.File); err != nil { + return ErrBadRequest.Wrap(err) } default: - return ErrUnsupportedMediaType + return &HTTPError{Code: http.StatusUnsupportedMediaType} } - return + return nil +} + +// BindHeaders binds HTTP headers to a bindable object +func BindHeaders(c *Context, target any) error { + if err := bindData(target, c.Request().Header, "header", nil); err != nil { + return ErrBadRequest.Wrap(err) + } + return nil +} + +// Bind implements the `Binder#Bind` function. +// Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous +// step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathValues. +func (b *DefaultBinder) Bind(c *Context, target any) error { + if err := BindPathValues(c, target); err != nil { + return err + } + // Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body. + // For example a request URL `&id=1&lang=en` with body `{"id":100,"lang":"de"}` would lead to precedence issues. + // The HTTP method check restores pre-v4.1.11 behavior to avoid these problems (see issue #1670) + method := c.Request().Method + if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead { + if err := BindQueryParams(c, target); err != nil { + return err + } + } + return BindBody(c, target) } -func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag string) error { - if ptr == nil || len(data) == 0 { +// bindData will bind data ONLY fields in destination struct that have EXPLICIT tag +func bindData(destination any, data map[string][]string, tag string, dataFiles map[string][]*multipart.FileHeader) error { + if destination == nil || (len(data) == 0 && len(dataFiles) == 0) { return nil } - typ := reflect.TypeOf(ptr).Elem() - val := reflect.ValueOf(ptr).Elem() + hasFiles := len(dataFiles) > 0 + typ := reflect.TypeOf(destination).Elem() + val := reflect.ValueOf(destination).Elem() - // Map - if typ.Kind() == reflect.Map { + // Support binding to limited Map destinations: + // - map[string][]string, + // - map[string]string <-- (binds first value from data slice) + // - map[string]any + // You are better off binding to struct but there are user who want this map feature. Source of data for these cases are: + // params,query,header,form as these sources produce string values, most of the time slice of strings, actually. + if typ.Kind() == reflect.Map && typ.Key().Kind() == reflect.String { + k := typ.Elem().Kind() + isElemInterface := k == reflect.Interface + isElemString := k == reflect.String + isElemSliceOfStrings := k == reflect.Slice && typ.Elem().Elem().Kind() == reflect.String + if !(isElemSliceOfStrings || isElemString || isElemInterface) { + return nil + } + if val.IsNil() { + val.Set(reflect.MakeMap(typ)) + } for k, v := range data { - val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0])) + if isElemString { + val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0])) + } else if isElemInterface { + // To maintain backward compatibility, we always bind to the first string value + // and not the slice of strings when dealing with map[string]any{} + val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0])) + } else { + val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v)) + } } return nil } // !struct if typ.Kind() != reflect.Struct { + if tag == "param" || tag == "query" || tag == "header" { + // incompatible type, data is probably to be found in the body + return nil + } return errors.New("binding element must be a struct") } - for i := 0; i < typ.NumField(); i++ { + for i := 0; i < typ.NumField(); i++ { // iterate over all destination fields typeField := typ.Field(i) structField := val.Field(i) + if typeField.Anonymous { + if structField.Kind() == reflect.Ptr { + structField = structField.Elem() + } + } if !structField.CanSet() { continue } structFieldKind := structField.Kind() inputFieldName := typeField.Tag.Get(tag) + if typeField.Anonymous && structFieldKind == reflect.Struct && inputFieldName != "" { + // if anonymous struct with query/param/form tags, report an error + return errors.New("query/param/form tags are not allowed with anonymous struct field") + } if inputFieldName == "" { - inputFieldName = typeField.Name - // If tag is nil, we inspect if the field is a struct. - if _, ok := bindUnmarshaler(structField); !ok && structFieldKind == reflect.Struct { - if err := b.bindData(structField.Addr().Interface(), data, tag); err != nil { + // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contain fields with tags). + // structs that implement BindUnmarshaler are bound only when they have explicit tag + if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { + if err := bindData(structField.Addr().Interface(), data, tag, dataFiles); err != nil { return err } - continue + } + // does not have explicit tag and is not an ordinary struct - so move to next field + continue + } + + if hasFiles { + if ok, err := isFieldMultipartFile(structField.Type()); err != nil { + return err + } else if ok { + if ok := setMultipartFileHeaderTypes(structField, inputFieldName, dataFiles); ok { + continue + } } } inputValue, exists := data[inputFieldName] if !exists { - // Go json.Unmarshal supports case insensitive binding. However the - // url params are bound case sensitive which is inconsistent. To + // Go json.Unmarshal supports case-insensitive binding. However the + // url params are bound case-sensitive which is inconsistent. To // fix this we must check all of the map values in a // case-insensitive search. - inputFieldName = strings.ToLower(inputFieldName) for k, v := range data { - if strings.ToLower(k) == inputFieldName { + if strings.EqualFold(k, inputFieldName) { inputValue = v exists = true break @@ -143,27 +243,47 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag continue } - // Call this first, in case we're dealing with an alias to an array type - if ok, err := unmarshalField(typeField.Type.Kind(), inputValue[0], structField); ok { + // NOTE: algorithm here is not particularly sophisticated. It probably does not work with absurd types like `**[]*int` + // but it is smart enough to handle niche cases like `*int`,`*[]string`,`[]*int` . + + // try unmarshalling first, in case we're dealing with an alias to an array type + if ok, err := unmarshalInputsToField(typeField.Type.Kind(), inputValue, structField); ok { + if err != nil { + return err + } + continue + } + + formatTag := typeField.Tag.Get("format") + if ok, err := unmarshalInputToField(typeField.Type.Kind(), inputValue[0], structField, formatTag); ok { if err != nil { return err } continue } - numElems := len(inputValue) - if structFieldKind == reflect.Slice && numElems > 0 { + // we could be dealing with pointer to slice `*[]string` so dereference it. There are weird OpenAPI generators + // that could create struct fields like that. + if structFieldKind == reflect.Pointer { + structFieldKind = structField.Elem().Kind() + structField = structField.Elem() + } + + if structFieldKind == reflect.Slice { sliceOf := structField.Type().Elem().Kind() + numElems := len(inputValue) slice := reflect.MakeSlice(structField.Type(), numElems, numElems) for j := 0; j < numElems; j++ { if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil { return err } } - val.Field(i).Set(slice) - } else if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil { - return err + structField.Set(slice) + continue + } + if err := setWithProperType(structFieldKind, inputValue[0], structField); err != nil { + return err } } return nil @@ -171,7 +291,8 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error { // But also call it here, in case we're dealing with an array of BindUnmarshalers - if ok, err := unmarshalField(valueKind, val, structField); ok { + // Note: format tag not available in this context, so empty string is passed + if ok, err := unmarshalInputToField(valueKind, val, structField, ""); ok { return err } @@ -212,62 +333,53 @@ func setWithProperType(valueKind reflect.Kind, val string, structField reflect.V return nil } -func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) { - switch valueKind { - case reflect.Ptr: - return unmarshalFieldPtr(val, field) - default: - return unmarshalFieldNonPtr(val, field) +func unmarshalInputsToField(valueKind reflect.Kind, values []string, field reflect.Value) (bool, error) { + if valueKind == reflect.Ptr { + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + field = field.Elem() } -} -// bindUnmarshaler attempts to unmarshal a reflect.Value into a BindUnmarshaler -func bindUnmarshaler(field reflect.Value) (BindUnmarshaler, bool) { - ptr := reflect.New(field.Type()) - if ptr.CanInterface() { - iface := ptr.Interface() - if unmarshaler, ok := iface.(BindUnmarshaler); ok { - return unmarshaler, ok - } + fieldIValue := field.Addr().Interface() + unmarshaler, ok := fieldIValue.(bindMultipleUnmarshaler) + if !ok { + return false, nil } - return nil, false + return true, unmarshaler.UnmarshalParams(values) } -// textUnmarshaler attempts to unmarshal a reflect.Value into a TextUnmarshaler -func textUnmarshaler(field reflect.Value) (encoding.TextUnmarshaler, bool) { - ptr := reflect.New(field.Type()) - if ptr.CanInterface() { - iface := ptr.Interface() - if unmarshaler, ok := iface.(encoding.TextUnmarshaler); ok { - return unmarshaler, ok +func unmarshalInputToField(valueKind reflect.Kind, val string, field reflect.Value, formatTag string) (bool, error) { + if valueKind == reflect.Ptr { + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) } + field = field.Elem() } - return nil, false -} -func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) { - if unmarshaler, ok := bindUnmarshaler(field); ok { - err := unmarshaler.UnmarshalParam(value) - field.Set(reflect.ValueOf(unmarshaler).Elem()) - return true, err + fieldIValue := field.Addr().Interface() + // Handle time.Time with custom format tag + if formatTag != "" { + if _, isTime := fieldIValue.(*time.Time); isTime { + t, err := time.Parse(formatTag, val) + if err != nil { + return true, err + } + field.Set(reflect.ValueOf(t)) + return true, nil + } } - if unmarshaler, ok := textUnmarshaler(field); ok { - err := unmarshaler.UnmarshalText([]byte(value)) - field.Set(reflect.ValueOf(unmarshaler).Elem()) - return true, err + + switch unmarshaler := fieldIValue.(type) { + case BindUnmarshaler: + return true, unmarshaler.UnmarshalParam(val) + case encoding.TextUnmarshaler: + return true, unmarshaler.UnmarshalText([]byte(val)) } return false, nil } -func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) { - if field.IsNil() { - // Initialize the pointer to a nil value - field.Set(reflect.New(field.Type().Elem())) - } - return unmarshalFieldNonPtr(value, field.Elem()) -} - func setIntField(value string, bitSize int, field reflect.Value) error { if value == "" { value = "0" @@ -311,3 +423,50 @@ func setFloatField(value string, bitSize int, field reflect.Value) error { } return err } + +var ( + // NOT supported by bind as you can NOT check easily empty struct being actual file or not + multipartFileHeaderType = reflect.TypeFor[multipart.FileHeader]() + // supported by bind as you can check by nil value if file existed or not + multipartFileHeaderPointerType = reflect.TypeFor[*multipart.FileHeader]() + multipartFileHeaderSliceType = reflect.TypeFor[[]multipart.FileHeader]() + multipartFileHeaderPointerSliceType = reflect.TypeFor[[]*multipart.FileHeader]() +) + +func isFieldMultipartFile(field reflect.Type) (bool, error) { + switch field { + case multipartFileHeaderPointerType, + multipartFileHeaderSliceType, + multipartFileHeaderPointerSliceType: + return true, nil + case multipartFileHeaderType: + return true, errors.New("binding to multipart.FileHeader struct is not supported, use pointer to struct") + default: + return false, nil + } +} + +func setMultipartFileHeaderTypes(structField reflect.Value, inputFieldName string, files map[string][]*multipart.FileHeader) bool { + fileHeaders := files[inputFieldName] + if len(fileHeaders) == 0 { + return false + } + + result := true + switch structField.Type() { + case multipartFileHeaderPointerSliceType: + structField.Set(reflect.ValueOf(fileHeaders)) + case multipartFileHeaderSliceType: + headers := make([]multipart.FileHeader, len(fileHeaders)) + for i, fileHeader := range fileHeaders { + headers[i] = *fileHeader + } + structField.Set(reflect.ValueOf(headers)) + case multipartFileHeaderPointerType: + structField.Set(reflect.ValueOf(fileHeaders[0])) + default: + result = false + } + + return result +} diff --git a/bind_test.go b/bind_test.go index 84ac8710e..1d5f8ca41 100644 --- a/bind_test.go +++ b/bind_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( @@ -5,10 +8,13 @@ import ( "encoding/json" "encoding/xml" "errors" + "fmt" "io" "mime/multipart" "net/http" "net/http/httptest" + "net/http/httputil" + "net/url" "reflect" "strconv" "strings" @@ -18,51 +24,91 @@ import ( "github.com/stretchr/testify/assert" ) -type ( - bindTestStruct struct { - I int - PtrI *int - I8 int8 - PtrI8 *int8 - I16 int16 - PtrI16 *int16 - I32 int32 - PtrI32 *int32 - I64 int64 - PtrI64 *int64 - UI uint - PtrUI *uint - UI8 uint8 - PtrUI8 *uint8 - UI16 uint16 - PtrUI16 *uint16 - UI32 uint32 - PtrUI32 *uint32 - UI64 uint64 - PtrUI64 *uint64 - B bool - PtrB *bool - F32 float32 - PtrF32 *float32 - F64 float64 - PtrF64 *float64 - S string - PtrS *string - cantSet string - DoesntExist string - GoT time.Time - GoTptr *time.Time - T Timestamp - Tptr *Timestamp - SA StringArray - } - Timestamp time.Time - TA []Timestamp - StringArray []string - Struct struct { - Foo string - } -) +type bindTestStruct struct { + T Timestamp + GoT time.Time + PtrI16 *int16 + PtrUI *uint + Tptr *Timestamp + PtrF32 *float32 + PtrB *bool + PtrI32 *int32 + GoTptr *time.Time + PtrI64 *int64 + PtrI *int + PtrI8 *int8 + PtrF64 *float64 + PtrUI8 *uint8 + PtrUI64 *uint64 + PtrUI16 *uint16 + PtrS *string + PtrUI32 *uint32 + S string + cantSet string + DoesntExist string + SA StringArray + F64 float64 + I int + UI64 uint64 + UI uint + I64 int64 + F32 float32 + UI32 uint32 + I32 int32 + UI16 uint16 + I16 int16 + B bool + UI8 uint8 + I8 int8 +} + +type bindTestStructWithTags struct { + T Timestamp `json:"T" form:"T"` + GoT time.Time `json:"GoT" form:"GoT"` + PtrI16 *int16 `json:"PtrI16" form:"PtrI16"` + PtrUI *uint `json:"PtrUI" form:"PtrUI"` + Tptr *Timestamp `json:"Tptr" form:"Tptr"` + PtrF32 *float32 `json:"PtrF32" form:"PtrF32"` + PtrB *bool `json:"PtrB" form:"PtrB"` + PtrI32 *int32 `json:"PtrI32" form:"PtrI32"` + GoTptr *time.Time `json:"GoTptr" form:"GoTptr"` + PtrI64 *int64 `json:"PtrI64" form:"PtrI64"` + PtrI *int `json:"PtrI" form:"PtrI"` + PtrI8 *int8 `json:"PtrI8" form:"PtrI8"` + PtrF64 *float64 `json:"PtrF64" form:"PtrF64"` + PtrUI8 *uint8 `json:"PtrUI8" form:"PtrUI8"` + PtrUI64 *uint64 `json:"PtrUI64" form:"PtrUI64"` + PtrUI16 *uint16 `json:"PtrUI16" form:"PtrUI16"` + PtrS *string `json:"PtrS" form:"PtrS"` + PtrUI32 *uint32 `json:"PtrUI32" form:"PtrUI32"` + S string `json:"S" form:"S"` + cantSet string + DoesntExist string `json:"DoesntExist" form:"DoesntExist"` + SA StringArray `json:"SA" form:"SA"` + F64 float64 `json:"F64" form:"F64"` + I int `json:"I" form:"I"` + UI64 uint64 `json:"UI64" form:"UI64"` + UI uint `json:"UI" form:"UI"` + I64 int64 `json:"I64" form:"I64"` + F32 float32 `json:"F32" form:"F32"` + UI32 uint32 `json:"UI32" form:"UI32"` + I32 int32 `json:"I32" form:"I32"` + UI16 uint16 `json:"UI16" form:"UI16"` + I16 int16 `json:"I16" form:"I16"` + B bool `json:"B" form:"B"` + UI8 uint8 `json:"UI8" form:"UI8"` + I8 int8 `json:"I8" form:"I8"` +} + +type Timestamp time.Time +type TA []Timestamp +type StringArray []string +type Struct struct { + Foo string +} +type Bar struct { + Baz int `json:"baz" query:"baz"` +} func (t *Timestamp) UnmarshalParam(src string) error { ts, err := time.Parse(time.RFC3339, src) @@ -123,37 +169,71 @@ var values = map[string][]string{ "ST": {"bar"}, } +// ptr return pointer to value. This is useful as `v := []*int8{&int8(1)}` will not compile +func ptr[T any](value T) *T { + return &value +} + +func TestToMultipleFields(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?id=1&ID=2", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + type Root struct { + ID int64 `query:"id"` + Child2 struct { + ID int64 + } + Child1 struct { + ID int64 `query:"id"` + } + } + + u := new(Root) + err := c.Bind(u) + if assert.NoError(t, err) { + assert.Equal(t, int64(1), u.ID) // perfectly reasonable + assert.Equal(t, int64(1), u.Child1.ID) // untagged struct containing tagged field gets filled (by tag) + assert.Equal(t, int64(0), u.Child2.ID) // untagged struct containing untagged field should not be bind + } +} + func TestBindJSON(t *testing.T) { - assert := assert.New(t) - testBindOkay(assert, strings.NewReader(userJSON), MIMEApplicationJSON) - testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) - testBindError(assert, strings.NewReader(userJSONInvalidType), MIMEApplicationJSON, &json.UnmarshalTypeError{}) + testBindOkay(t, strings.NewReader(userJSON), nil, MIMEApplicationJSON) + testBindOkay(t, strings.NewReader(userJSON), dummyQuery, MIMEApplicationJSON) + testBindArrayOkay(t, strings.NewReader(usersJSON), nil, MIMEApplicationJSON) + testBindArrayOkay(t, strings.NewReader(usersJSON), dummyQuery, MIMEApplicationJSON) + testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) + testBindError(t, strings.NewReader(userJSONInvalidType), MIMEApplicationJSON, &json.UnmarshalTypeError{}) } func TestBindXML(t *testing.T) { - assert := assert.New(t) - - testBindOkay(assert, strings.NewReader(userXML), MIMEApplicationXML) - testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationXML, errors.New("")) - testBindError(assert, strings.NewReader(userXMLConvertNumberError), MIMEApplicationXML, &strconv.NumError{}) - testBindError(assert, strings.NewReader(userXMLUnsupportedTypeError), MIMEApplicationXML, &xml.SyntaxError{}) - testBindOkay(assert, strings.NewReader(userXML), MIMETextXML) - testBindError(assert, strings.NewReader(invalidContent), MIMETextXML, errors.New("")) - testBindError(assert, strings.NewReader(userXMLConvertNumberError), MIMETextXML, &strconv.NumError{}) - testBindError(assert, strings.NewReader(userXMLUnsupportedTypeError), MIMETextXML, &xml.SyntaxError{}) + testBindOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML) + testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) + testBindArrayOkay(t, strings.NewReader(userXML), nil, MIMEApplicationXML) + testBindArrayOkay(t, strings.NewReader(userXML), dummyQuery, MIMEApplicationXML) + testBindError(t, strings.NewReader(invalidContent), MIMEApplicationXML, errors.New("")) + testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMEApplicationXML, &strconv.NumError{}) + testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMEApplicationXML, &xml.SyntaxError{}) + testBindOkay(t, strings.NewReader(userXML), nil, MIMETextXML) + testBindOkay(t, strings.NewReader(userXML), dummyQuery, MIMETextXML) + testBindError(t, strings.NewReader(invalidContent), MIMETextXML, errors.New("")) + testBindError(t, strings.NewReader(userXMLConvertNumberError), MIMETextXML, &strconv.NumError{}) + testBindError(t, strings.NewReader(userXMLUnsupportedTypeError), MIMETextXML, &xml.SyntaxError{}) } func TestBindForm(t *testing.T) { - assert := assert.New(t) - testBindOkay(assert, strings.NewReader(userForm), MIMEApplicationForm) + testBindOkay(t, strings.NewReader(userForm), nil, MIMEApplicationForm) + testBindOkay(t, strings.NewReader(userForm), dummyQuery, MIMEApplicationForm) e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userForm)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(HeaderContentType, MIMEApplicationForm) err := c.Bind(&[]struct{ Field string }{}) - assert.Error(err) + assert.Error(t, err) } func TestBindQueryParams(t *testing.T) { @@ -195,40 +275,74 @@ func TestBindQueryParamsCaseSensitivePrioritized(t *testing.T) { } } +func TestBindHeaderParam(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Name", "Jon Doe") + req.Header.Set("Id", "2") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + u := new(user) + err := BindHeaders(c, u) + if assert.NoError(t, err) { + assert.Equal(t, 2, u.ID) + assert.Equal(t, "Jon Doe", u.Name) + } +} + +func TestBindHeaderParamBadType(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Id", "salamander") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + u := new(user) + err := BindHeaders(c, u) + assert.Error(t, err) + + httpErr, ok := err.(*HTTPError) + if assert.True(t, ok) { + assert.Equal(t, http.StatusBadRequest, httpErr.Code) + } +} + func TestBindUnmarshalParam(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { - T Timestamp `query:"ts"` + T Timestamp `query:"ts"` + ST Struct + StWithTag struct { + Foo string `query:"st"` + } TA []Timestamp `query:"ta"` SA StringArray `query:"sa"` - ST Struct }{} err := c.Bind(&result) ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)) - assert := assert.New(t) - if assert.NoError(err) { + if assert.NoError(t, err) { // assert.Equal( Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T) - assert.Equal(ts, result.T) - assert.Equal(StringArray([]string{"one", "two", "three"}), result.SA) - assert.Equal([]Timestamp{ts, ts}, result.TA) - assert.Equal(Struct{"baz"}, result.ST) + assert.Equal(t, ts, result.T) + assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA) + assert.Equal(t, []Timestamp{ts, ts}, result.TA) + assert.Equal(t, Struct{""}, result.ST) // child struct does not have a field with matching tag + assert.Equal(t, "baz", result.StWithTag.Foo) // child struct has field with matching tag } } func TestBindUnmarshalText(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) + req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { - T time.Time `query:"ts"` + T time.Time `query:"ts"` + ST Struct TA []time.Time `query:"ta"` SA StringArray `query:"sa"` - ST Struct }{} err := c.Bind(&result) ts := time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC) @@ -237,7 +351,7 @@ func TestBindUnmarshalText(t *testing.T) { assert.Equal(t, ts, result.T) assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA) assert.Equal(t, []time.Time{ts, ts}, result.TA) - assert.Equal(t, Struct{"baz"}, result.ST) + assert.Equal(t, Struct{""}, result.ST) // field in child struct does not have tag } } @@ -255,9 +369,49 @@ func TestBindUnmarshalParamPtr(t *testing.T) { } } +func TestBindUnmarshalParamAnonymousFieldPtr(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?baz=1", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + *Bar + }{&Bar{}} + err := c.Bind(&result) + if assert.NoError(t, err) { + assert.Equal(t, 1, result.Baz) + } +} + +func TestBindUnmarshalParamAnonymousFieldPtrNil(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/?baz=1", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + *Bar + }{} + err := c.Bind(&result) + if assert.NoError(t, err) { + assert.Nil(t, result.Bar) + } +} + +func TestBindUnmarshalParamAnonymousFieldPtrCustomTag(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, `/?bar={"baz":100}&baz=1`, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + result := struct { + *Bar `json:"bar" query:"bar"` + }{&Bar{}} + err := c.Bind(&result) + assert.Contains(t, err.Error(), "query/param/form tags are not allowed with anonymous struct field") +} + func TestBindUnmarshalTextPtr(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil) + req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { @@ -270,37 +424,158 @@ func TestBindUnmarshalTextPtr(t *testing.T) { } func TestBindMultipartForm(t *testing.T) { - body := new(bytes.Buffer) - mw := multipart.NewWriter(body) + bodyBuffer := new(bytes.Buffer) + mw := multipart.NewWriter(bodyBuffer) mw.WriteField("id", "1") mw.WriteField("name", "Jon Snow") mw.Close() + body := bodyBuffer.Bytes() - assert := assert.New(t) - testBindOkay(assert, body, mw.FormDataContentType()) + testBindOkay(t, bytes.NewReader(body), nil, mw.FormDataContentType()) + testBindOkay(t, bytes.NewReader(body), dummyQuery, mw.FormDataContentType()) } func TestBindUnsupportedMediaType(t *testing.T) { - assert := assert.New(t) - testBindError(assert, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) + testBindError(t, strings.NewReader(invalidContent), MIMEApplicationJSON, &json.SyntaxError{}) +} + +func TestDefaultBinder_bindDataToMap(t *testing.T) { + exampleData := map[string][]string{ + "multiple": {"1", "2"}, + "single": {"3"}, + } + + t.Run("ok, bind to map[string]string", func(t *testing.T) { + dest := map[string]string{} + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, + map[string]string{ + "multiple": "1", + "single": "3", + }, + dest, + ) + }) + + t.Run("ok, bind to map[string]string with nil map", func(t *testing.T) { + var dest map[string]string + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, + map[string]string{ + "multiple": "1", + "single": "3", + }, + dest, + ) + }) + + t.Run("ok, bind to map[string][]string", func(t *testing.T) { + dest := map[string][]string{} + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, + map[string][]string{ + "multiple": {"1", "2"}, + "single": {"3"}, + }, + dest, + ) + }) + + t.Run("ok, bind to map[string][]string with nil map", func(t *testing.T) { + var dest map[string][]string + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, + map[string][]string{ + "multiple": {"1", "2"}, + "single": {"3"}, + }, + dest, + ) + }) + + t.Run("ok, bind to map[string]interface", func(t *testing.T) { + dest := map[string]any{} + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, + map[string]any{ + "multiple": "1", + "single": "3", + }, + dest, + ) + }) + + t.Run("ok, bind to map[string]interface with nil map", func(t *testing.T) { + var dest map[string]any + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, + map[string]any{ + "multiple": "1", + "single": "3", + }, + dest, + ) + }) + + t.Run("ok, bind to map[string]int skips", func(t *testing.T) { + dest := map[string]int{} + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, map[string]int{}, dest) + }) + + t.Run("ok, bind to map[string]int skips with nil map", func(t *testing.T) { + var dest map[string]int + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, map[string]int(nil), dest) + }) + + t.Run("ok, bind to map[string][]int skips", func(t *testing.T) { + dest := map[string][]int{} + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, map[string][]int{}, dest) + }) + + t.Run("ok, bind to map[string][]int skips with nil map", func(t *testing.T) { + var dest map[string][]int + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) + assert.Equal(t, map[string][]int(nil), dest) + }) } func TestBindbindData(t *testing.T) { - assert := assert.New(t) ts := new(bindTestStruct) - b := new(DefaultBinder) - b.bindData(ts, values, "form") - assertBindTestStruct(assert, ts) + err := bindData(ts, values, "form", nil) + assert.NoError(t, err) + + assert.Equal(t, 0, ts.I) + assert.Equal(t, int8(0), ts.I8) + assert.Equal(t, int16(0), ts.I16) + assert.Equal(t, int32(0), ts.I32) + assert.Equal(t, int64(0), ts.I64) + assert.Equal(t, uint(0), ts.UI) + assert.Equal(t, uint8(0), ts.UI8) + assert.Equal(t, uint16(0), ts.UI16) + assert.Equal(t, uint32(0), ts.UI32) + assert.Equal(t, uint64(0), ts.UI64) + assert.Equal(t, false, ts.B) + assert.Equal(t, float32(0), ts.F32) + assert.Equal(t, float64(0), ts.F64) + assert.Equal(t, "", ts.S) + assert.Equal(t, "", ts.cantSet) } func TestBindParam(t *testing.T) { e := New() - req := httptest.NewRequest(GET, "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - c.SetPath("/users/:id/:name") - c.SetParamNames("id", "name") - c.SetParamValues("1", "Jon Snow") + c.InitializeRoute( + &RouteInfo{Path: "/users/:id/:name"}, + &PathValues{ + {Name: "id", Value: "1"}, + {Name: "name", Value: "Jon Snow"}, + }, + ) u := new(user) err := c.Bind(u) @@ -311,9 +586,12 @@ func TestBindParam(t *testing.T) { // Second test for the absence of a param c2 := e.NewContext(req, rec) - c2.SetPath("/users/:id") - c2.SetParamNames("id") - c2.SetParamValues("1") + c2.InitializeRoute( + &RouteInfo{Path: "/users/:id"}, + &PathValues{ + {Name: "id", Value: "1"}, + }, + ) u = new(user) err = c2.Bind(u) @@ -325,15 +603,18 @@ func TestBindParam(t *testing.T) { // Bind something with param and post data payload body := bytes.NewBufferString(`{ "name": "Jon Snow" }`) e2 := New() - req2 := httptest.NewRequest(POST, "/", body) + req2 := httptest.NewRequest(http.MethodPost, "/", body) req2.Header.Set(HeaderContentType, MIMEApplicationJSON) rec2 := httptest.NewRecorder() c3 := e2.NewContext(req2, rec2) - c3.SetPath("/users/:id") - c3.SetParamNames("id") - c3.SetParamValues("1") + c3.InitializeRoute( + &RouteInfo{Path: "/users/:id"}, + &PathValues{ + {Name: "id", Value: "1"}, + }, + ) u = new(user) err = c3.Bind(u) @@ -355,13 +636,10 @@ func TestBindUnmarshalTypeError(t *testing.T) { err := c.Bind(u) - he := &HTTPError{Code: http.StatusBadRequest, Message: "Unmarshal type error: expected=int, got=string, field=id, offset=14", Internal: err.(*HTTPError).Internal} - - assert.Equal(t, he, err) + assert.EqualError(t, err, `code=400, message=Bad Request, err=json: cannot unmarshal string into Go struct field user.id of type int`) } func TestBindSetWithProperType(t *testing.T) { - assert := assert.New(t) ts := new(bindTestStruct) typ := reflect.TypeOf(ts).Elem() val := reflect.ValueOf(ts).Elem() @@ -376,9 +654,9 @@ func TestBindSetWithProperType(t *testing.T) { } val := values[typeField.Name][0] err := setWithProperType(typeField.Type.Kind(), val, structField) - assert.NoError(err) + assert.NoError(t, err) } - assertBindTestStruct(assert, ts) + assertBindTestStruct(t, ts) type foo struct { Bar bytes.Buffer @@ -386,86 +664,77 @@ func TestBindSetWithProperType(t *testing.T) { v := &foo{} typ = reflect.TypeOf(v).Elem() val = reflect.ValueOf(v).Elem() - assert.Error(setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) + assert.Error(t, setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) } -func TestBindSetFields(t *testing.T) { - assert := assert.New(t) - - ts := new(bindTestStruct) - val := reflect.ValueOf(ts).Elem() - // Int - if assert.NoError(setIntField("5", 0, val.FieldByName("I"))) { - assert.Equal(5, ts.I) - } - if assert.NoError(setIntField("", 0, val.FieldByName("I"))) { - assert.Equal(0, ts.I) - } - - // Uint - if assert.NoError(setUintField("10", 0, val.FieldByName("UI"))) { - assert.Equal(uint(10), ts.UI) - } - if assert.NoError(setUintField("", 0, val.FieldByName("UI"))) { - assert.Equal(uint(0), ts.UI) - } - - // Float - if assert.NoError(setFloatField("15.5", 0, val.FieldByName("F32"))) { - assert.Equal(float32(15.5), ts.F32) - } - if assert.NoError(setFloatField("", 0, val.FieldByName("F32"))) { - assert.Equal(float32(0.0), ts.F32) - } - - // Bool - if assert.NoError(setBoolField("true", val.FieldByName("B"))) { - assert.Equal(true, ts.B) - } - if assert.NoError(setBoolField("", val.FieldByName("B"))) { - assert.Equal(false, ts.B) - } - - ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T")) - if assert.NoError(err) { - assert.Equal(ok, true) - assert.Equal(Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) +func BenchmarkBindbindDataWithTags(b *testing.B) { + b.ReportAllocs() + ts := new(bindTestStructWithTags) + var err error + b.ResetTimer() + for i := 0; i < b.N; i++ { + err = bindData(ts, values, "form", nil) } + assert.NoError(b, err) + assertBindTestStruct(b, (*bindTestStruct)(ts)) } -func assertBindTestStruct(a *assert.Assertions, ts *bindTestStruct) { - a.Equal(0, ts.I) - a.Equal(int8(8), ts.I8) - a.Equal(int16(16), ts.I16) - a.Equal(int32(32), ts.I32) - a.Equal(int64(64), ts.I64) - a.Equal(uint(0), ts.UI) - a.Equal(uint8(8), ts.UI8) - a.Equal(uint16(16), ts.UI16) - a.Equal(uint32(32), ts.UI32) - a.Equal(uint64(64), ts.UI64) - a.Equal(true, ts.B) - a.Equal(float32(32.5), ts.F32) - a.Equal(float64(64.5), ts.F64) - a.Equal("test", ts.S) - a.Equal("", ts.GetCantSet()) +func assertBindTestStruct(tb testing.TB, ts *bindTestStruct) { + assert.Equal(tb, 0, ts.I) + assert.Equal(tb, int8(8), ts.I8) + assert.Equal(tb, int16(16), ts.I16) + assert.Equal(tb, int32(32), ts.I32) + assert.Equal(tb, int64(64), ts.I64) + assert.Equal(tb, uint(0), ts.UI) + assert.Equal(tb, uint8(8), ts.UI8) + assert.Equal(tb, uint16(16), ts.UI16) + assert.Equal(tb, uint32(32), ts.UI32) + assert.Equal(tb, uint64(64), ts.UI64) + assert.Equal(tb, true, ts.B) + assert.Equal(tb, float32(32.5), ts.F32) + assert.Equal(tb, float64(64.5), ts.F64) + assert.Equal(tb, "test", ts.S) + assert.Equal(tb, "", ts.GetCantSet()) } -func testBindOkay(assert *assert.Assertions, r io.Reader, ctype string) { +func testBindOkay(t *testing.T, r io.Reader, query url.Values, ctype string) { e := New() - req := httptest.NewRequest(http.MethodPost, "/", r) + path := "/" + if len(query) > 0 { + path += "?" + query.Encode() + } + req := httptest.NewRequest(http.MethodPost, path, r) rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(HeaderContentType, ctype) u := new(user) err := c.Bind(u) - if assert.NoError(err) { - assert.Equal(1, u.ID) - assert.Equal("Jon Snow", u.Name) + if assert.Equal(t, nil, err) { + assert.Equal(t, 1, u.ID) + assert.Equal(t, "Jon Snow", u.Name) + } +} + +func testBindArrayOkay(t *testing.T, r io.Reader, query url.Values, ctype string) { + e := New() + path := "/" + if len(query) > 0 { + path += "?" + query.Encode() + } + req := httptest.NewRequest(http.MethodPost, path, r) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + req.Header.Set(HeaderContentType, ctype) + u := []user{} + err := c.Bind(&u) + if assert.NoError(t, err) { + assert.Equal(t, 1, len(u)) + assert.Equal(t, 1, u[0].ID) + assert.Equal(t, "Jon Snow", u[0].Name) } } -func testBindError(assert *assert.Assertions, r io.Reader, ctype string, expectedInternal error) { +func testBindError(t *testing.T, r io.Reader, ctype string, expectedInternal error) { e := New() req := httptest.NewRequest(http.MethodPost, "/", r) rec := httptest.NewRecorder() @@ -477,14 +746,948 @@ func testBindError(assert *assert.Assertions, r io.Reader, ctype string, expecte switch { case strings.HasPrefix(ctype, MIMEApplicationJSON), strings.HasPrefix(ctype, MIMEApplicationXML), strings.HasPrefix(ctype, MIMETextXML), strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): - if assert.IsType(new(HTTPError), err) { - assert.Equal(http.StatusBadRequest, err.(*HTTPError).Code) - assert.IsType(expectedInternal, err.(*HTTPError).Internal) + if assert.IsType(t, new(HTTPError), err) { + assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).Code) + assert.IsType(t, expectedInternal, err.(*HTTPError).Unwrap()) } default: - if assert.IsType(new(HTTPError), err) { - assert.Equal(ErrUnsupportedMediaType, err) - assert.IsType(expectedInternal, err.(*HTTPError).Internal) + if assert.IsType(t, new(HTTPError), err) { + assert.Equal(t, ErrUnsupportedMediaType, err) + assert.IsType(t, expectedInternal, err.(*HTTPError).Unwrap()) } } } + +func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { + // tests to check binding behaviour when multiple sources (path params, query params and request body) are in use + // binding is done in steps and one source could overwrite previous source bound data + // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed + + type Opts struct { + Node string `json:"node" form:"node" query:"node" param:"node"` + Lang string + ID int `json:"id" form:"id" query:"id"` + } + + var testCases = []struct { + givenContent io.Reader + whenBindTarget any + expect any + name string + givenURL string + givenMethod string + expectError string + whenNoPathValues bool + }{ + { + name: "ok, POST bind to struct with: path param + query param + body", + givenMethod: http.MethodPost, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{"id": 1}`), + expect: &Opts{ID: 1, Node: "node_from_path"}, // query params are not used, node is filled from path + }, + { + name: "ok, PUT bind to struct with: path param + query param + body", + givenMethod: http.MethodPut, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{"id": 1}`), + expect: &Opts{ID: 1, Node: "node_from_path"}, // query params are not used + }, + { + name: "ok, GET bind to struct with: path param + query param + body", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{"id": 1}`), + expect: &Opts{ID: 1, Node: "xxx"}, // query overwrites previous path value + }, + { + name: "ok, GET bind to struct with: path param + query param + body", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), + expect: &Opts{ID: 1, Node: "zzz"}, // body is bound last and overwrites previous (path,query) values + }, + { + name: "ok, DELETE bind to struct with: path param + query param + body", + givenMethod: http.MethodDelete, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), + expect: &Opts{ID: 1, Node: "zzz"}, // for DELETE body is bound after query params + }, + { + name: "ok, POST bind to struct with: path param + body", + givenMethod: http.MethodPost, + givenURL: "/api/real_node/endpoint", + givenContent: strings.NewReader(`{"id": 1}`), + expect: &Opts{ID: 1, Node: "node_from_path"}, + }, + { + name: "ok, POST bind to struct with path + query + body = body has priority", + givenMethod: http.MethodPost, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), + expect: &Opts{ID: 1, Node: "zzz"}, // field value from content has higher priority + }, + { + name: "nok, POST body bind failure", + givenMethod: http.MethodPost, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`{`), + expect: &Opts{ID: 0, Node: "node_from_path"}, // query binding has already modified bind target + expectError: "code=400, message=Bad Request, err=unexpected EOF", + }, + { + name: "nok, GET with body bind failure when types are not convertible", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint?id=nope", + givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), + expect: &Opts{ID: 0, Node: "node_from_path"}, // path params binding has already modified bind target + expectError: `code=400, message=Bad Request, err=strconv.ParseInt: parsing "nope": invalid syntax`, + }, + { + name: "nok, GET body bind failure - trying to bind json array to struct", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`[{"id": 1}]`), + expect: &Opts{ID: 0, Node: "xxx"}, // query binding has already modified bind target + expectError: `code=400, message=Bad Request, err=json: cannot unmarshal array into Go value of type echo.Opts`, + }, + { // query param is ignored as we do not know where exactly to bind it in slice + name: "ok, GET bind to struct slice, ignore query param", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`[{"id": 1}]`), + whenNoPathValues: true, + whenBindTarget: &[]Opts{}, + expect: &[]Opts{ + {ID: 1, Node: ""}, + }, + }, + { // binding query params interferes with body. b.BindBody() should be used to bind only body to slice + name: "ok, POST binding to slice should not be affected query params types", + givenMethod: http.MethodPost, + givenURL: "/api/real_node/endpoint?id=nope&node=xxx", + givenContent: strings.NewReader(`[{"id": 1}]`), + whenNoPathValues: true, + whenBindTarget: &[]Opts{}, + expect: &[]Opts{{ID: 1}}, + expectError: "", + }, + { // path param is ignored as we do not know where exactly to bind it in slice + name: "ok, GET bind to struct slice, ignore path param", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint?node=xxx", + givenContent: strings.NewReader(`[{"id": 1}]`), + whenBindTarget: &[]Opts{}, + expect: &[]Opts{ + {ID: 1, Node: ""}, + }, + }, + { + name: "ok, GET body bind json array to slice", + givenMethod: http.MethodGet, + givenURL: "/api/real_node/endpoint", + givenContent: strings.NewReader(`[{"id": 1}]`), + whenNoPathValues: true, + whenBindTarget: &[]Opts{}, + expect: &[]Opts{{ID: 1, Node: ""}}, + expectError: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + // assume route we are testing is "/api/:node/endpoint?some_query_params=here" + req := httptest.NewRequest(tc.givenMethod, tc.givenURL, tc.givenContent) + req.Header.Set(HeaderContentType, MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + if !tc.whenNoPathValues { + c.SetPathValues(PathValues{ + {Name: "node", Value: "node_from_path"}, + }) + } + + var bindTarget any + if tc.whenBindTarget != nil { + bindTarget = tc.whenBindTarget + } else { + bindTarget = &Opts{} + } + b := new(DefaultBinder) + + err := b.Bind(c, bindTarget) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, bindTarget) + }) + } +} + +func TestDefaultBinder_BindBody(t *testing.T) { + // tests to check binding behaviour when multiple sources (path params, query params and request body) are in use + // generally when binding from request body - URL and path params are ignored - unless form is being bound. + // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed + + type Node struct { + Node string `json:"node" xml:"node" form:"node" query:"node" param:"node"` + ID int `json:"id" xml:"id" form:"id" query:"id"` + } + type Nodes struct { + Nodes []Node `xml:"node" form:"node"` + } + + var testCases = []struct { + givenContent io.Reader + whenBindTarget any + expect any + name string + givenURL string + givenMethod string + givenContentType string + expectError string + whenNoPathValues bool + whenChunkedBody bool + }{ + { + name: "ok, JSON POST bind to struct with: path + query + empty field in body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationJSON, + givenContent: strings.NewReader(`{"id": 1}`), + expect: &Node{ID: 1, Node: ""}, // path params or query params should not interfere with body + }, + { + name: "ok, JSON POST bind to struct with: path + query + body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationJSON, + givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), + expect: &Node{ID: 1, Node: "zzz"}, // field value from content has higher priority + }, + { + name: "ok, JSON POST body bind json array to slice (has matching path/query params)", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationJSON, + givenContent: strings.NewReader(`[{"id": 1}]`), + whenNoPathValues: true, + whenBindTarget: &[]Node{}, + expect: &[]Node{{ID: 1, Node: ""}}, + expectError: "", + }, + { // rare case as GET is not usually used to send request body + name: "ok, JSON GET bind to struct with: path + query + empty field in body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodGet, + givenContentType: MIMEApplicationJSON, + givenContent: strings.NewReader(`{"id": 1}`), + expect: &Node{ID: 1, Node: ""}, // path params or query params should not interfere with body + }, + { // rare case as GET is not usually used to send request body + name: "ok, JSON GET bind to struct with: path + query + body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodGet, + givenContentType: MIMEApplicationJSON, + givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), + expect: &Node{ID: 1, Node: "zzz"}, // field value from content has higher priority + }, + { + name: "nok, JSON POST body bind failure", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationJSON, + givenContent: strings.NewReader(`{`), + expect: &Node{ID: 0, Node: ""}, + expectError: "code=400, message=Bad Request, err=unexpected EOF", + }, + { + name: "ok, XML POST bind to struct with: path + query + empty body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationXML, + givenContent: strings.NewReader(`1yyy`), + expect: &Node{ID: 1, Node: "yyy"}, + }, + { + name: "ok, XML POST bind array to slice with: path + query + body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationXML, + givenContent: strings.NewReader(`1yyy`), + whenBindTarget: &Nodes{}, + expect: &Nodes{Nodes: []Node{{ID: 1, Node: "yyy"}}}, + }, + { + name: "nok, XML POST bind failure", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationXML, + givenContent: strings.NewReader(`<`), + expect: &Node{ID: 0, Node: ""}, + expectError: "code=400, message=Bad Request, err=XML syntax error on line 1: unexpected EOF", + }, + { + name: "ok, FORM POST bind to struct with: path + query + body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationForm, + givenContent: strings.NewReader(`id=1&node=yyy`), + expect: &Node{ID: 1, Node: "yyy"}, + }, + { + // NB: form values are taken from BOTH body and query for POST/PUT/PATCH by standard library implementation + // See: https://golang.org/pkg/net/http/#Request.ParseForm + name: "ok, FORM POST bind to struct with: path + query + empty field in body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationForm, + givenContent: strings.NewReader(`id=1`), + expect: &Node{ID: 1, Node: "xxx"}, + }, + { + // NB: form values are taken from query by standard library implementation + // See: https://golang.org/pkg/net/http/#Request.ParseForm + name: "ok, FORM GET bind to struct with: path + query + empty field in body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodGet, + givenContentType: MIMEApplicationForm, + givenContent: strings.NewReader(`id=1`), + expect: &Node{ID: 0, Node: "xxx"}, // 'xxx' is taken from URL and body is not used with GET by implementation + }, + { + name: "nok, unsupported content type", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMETextPlain, + givenContent: strings.NewReader(``), + expect: &Node{ID: 0, Node: ""}, + expectError: "code=415, message=Unsupported Media Type", + }, + // FIXME: REASON in Go 1.24 and earlier http.NoBody would result ContentLength=-1 + // but as of Go 1.25 http.NoBody would result ContentLength=0 + // I am too lazy to bother documenting this as 2 version specific tests. + //{ + // name: "nok, JSON POST with http.NoBody", + // givenURL: "/api/real_node/endpoint?node=xxx", + // givenMethod: http.MethodPost, + // givenContentType: MIMEApplicationJSON, + // givenContent: http.NoBody, + // expect: &Node{ID: 0, Node: ""}, + // expectError: "code=400, message=EOF, internal=EOF", + //}, + { + name: "ok, JSON POST with empty body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationJSON, + givenContent: strings.NewReader(""), + expect: &Node{ID: 0, Node: ""}, + }, + { + name: "ok, JSON POST bind to struct with: path + query + chunked body", + givenURL: "/api/real_node/endpoint?node=xxx", + givenMethod: http.MethodPost, + givenContentType: MIMEApplicationJSON, + givenContent: httputil.NewChunkedReader(strings.NewReader("18\r\n" + `{"id": 1, "node": "zzz"}` + "\r\n0\r\n")), + whenChunkedBody: true, + expect: &Node{ID: 1, Node: "zzz"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + // assume route we are testing is "/api/:node/endpoint?some_query_params=here" + req := httptest.NewRequest(tc.givenMethod, tc.givenURL, tc.givenContent) + switch tc.givenContentType { + case MIMEApplicationXML: + req.Header.Set(HeaderContentType, MIMEApplicationXML) + case MIMEApplicationForm: + req.Header.Set(HeaderContentType, MIMEApplicationForm) + case MIMEApplicationJSON: + req.Header.Set(HeaderContentType, MIMEApplicationJSON) + } + if tc.whenChunkedBody { + req.ContentLength = -1 + req.TransferEncoding = append(req.TransferEncoding, "chunked") + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + if !tc.whenNoPathValues { + c.SetPathValues(PathValues{ + {Name: "node", Value: "real_node"}, + }) + } + + var bindTarget any + if tc.whenBindTarget != nil { + bindTarget = tc.whenBindTarget + } else { + bindTarget = &Node{} + } + + err := BindBody(c, bindTarget) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, bindTarget) + }) + } +} + +func testBindURL(queryString string, target any) error { + e := New() + req := httptest.NewRequest(http.MethodGet, queryString, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + return c.Bind(target) +} + +type unixTimestamp struct { + Time time.Time +} + +func (t *unixTimestamp) UnmarshalParam(param string) error { + n, err := strconv.ParseInt(param, 10, 64) + if err != nil { + return fmt.Errorf("'%s' is not an integer", param) + } + *t = unixTimestamp{Time: time.Unix(n, 0)} + return err +} + +type IntArrayA []int + +// UnmarshalParam converts value to *Int64Slice. This allows the API to accept +// a comma-separated list of integers as a query parameter. +func (i *IntArrayA) UnmarshalParam(value string) error { + var values = strings.Split(value, ",") + var numbers = make([]int, 0, len(values)) + + for _, v := range values { + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return fmt.Errorf("'%s' is not an integer", v) + } + + numbers = append(numbers, int(n)) + } + + *i = append(*i, numbers...) + return nil +} + +func TestBindUnmarshalParamExtras(t *testing.T) { + // this test documents how bind handles `BindUnmarshaler` interface: + // NOTE: BindUnmarshaler chooses first input value to be bound. + + t.Run("nok, unmarshalling fails", func(t *testing.T) { + result := struct { + V unixTimestamp `query:"t"` + }{} + err := testBindURL("/?t=xxxx", &result) + + assert.EqualError(t, err, `code=400, message=Bad Request, err='xxxx' is not an integer`) + }) + + t.Run("ok, target is struct", func(t *testing.T) { + result := struct { + V unixTimestamp `query:"t"` + }{} + err := testBindURL("/?t=1710095540&t=1710095541", &result) + + assert.NoError(t, err) + expect := unixTimestamp{ + Time: time.Unix(1710095540, 0), + } + assert.Equal(t, expect, result.V) + }) + + t.Run("ok, target is an alias to slice and is nil, append only values from first", func(t *testing.T) { + result := struct { + V IntArrayA `query:"a"` + }{} + err := testBindURL("/?a=1,2,3&a=4,5,6", &result) + + assert.NoError(t, err) + assert.Equal(t, IntArrayA([]int{1, 2, 3}), result.V) + }) + + t.Run("ok, target is an alias to slice and is nil, single input", func(t *testing.T) { + result := struct { + V IntArrayA `query:"a"` + }{} + err := testBindURL("/?a=1,2", &result) + + assert.NoError(t, err) + assert.Equal(t, IntArrayA([]int{1, 2}), result.V) + }) + + t.Run("ok, target is pointer an alias to slice and is nil", func(t *testing.T) { + result := struct { + V *IntArrayA `query:"a"` + }{} + err := testBindURL("/?a=1&a=4,5,6", &result) + + assert.NoError(t, err) + var expected = IntArrayA([]int{1}) + assert.Equal(t, &expected, result.V) + }) + + t.Run("ok, target is pointer an alias to slice and is NOT nil", func(t *testing.T) { + result := struct { + V *IntArrayA `query:"a"` + }{} + result.V = new(IntArrayA) // NOT nil + + err := testBindURL("/?a=1&a=4,5,6", &result) + + assert.NoError(t, err) + var expected = IntArrayA([]int{1}) + assert.Equal(t, &expected, result.V) + }) +} + +type unixTimestampLast struct { + Time time.Time +} + +// this is silly example for `bindMultipleUnmarshaler` for type that uses last input value for unmarshalling +func (t *unixTimestampLast) UnmarshalParams(params []string) error { + lastInput := params[len(params)-1] + n, err := strconv.ParseInt(lastInput, 10, 64) + if err != nil { + return fmt.Errorf("'%s' is not an integer", lastInput) + } + *t = unixTimestampLast{Time: time.Unix(n, 0)} + return err +} + +type IntArrayB []int + +func (i *IntArrayB) UnmarshalParams(params []string) error { + var numbers = make([]int, 0, len(params)) + + for _, param := range params { + var values = strings.Split(param, ",") + for _, v := range values { + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return fmt.Errorf("'%s' is not an integer", v) + } + numbers = append(numbers, int(n)) + } + } + + *i = append(*i, numbers...) + return nil +} + +func TestBindUnmarshalParams(t *testing.T) { + // this test documents how bind handles `bindMultipleUnmarshaler` interface: + + t.Run("nok, unmarshalling fails", func(t *testing.T) { + result := struct { + V unixTimestampLast `query:"t"` + }{} + err := testBindURL("/?t=xxxx", &result) + + assert.EqualError(t, err, "code=400, message=Bad Request, err='xxxx' is not an integer") + }) + + t.Run("ok, target is struct", func(t *testing.T) { + result := struct { + V unixTimestampLast `query:"t"` + }{} + err := testBindURL("/?t=1710095540&t=1710095541", &result) + + assert.NoError(t, err) + expect := unixTimestampLast{ + Time: time.Unix(1710095541, 0), + } + assert.Equal(t, expect, result.V) + }) + + t.Run("ok, target is an alias to slice and is nil, append multiple inputs", func(t *testing.T) { + result := struct { + V IntArrayB `query:"a"` + }{} + err := testBindURL("/?a=1,2,3&a=4,5,6", &result) + + assert.NoError(t, err) + assert.Equal(t, IntArrayB([]int{1, 2, 3, 4, 5, 6}), result.V) + }) + + t.Run("ok, target is an alias to slice and is nil, single input", func(t *testing.T) { + result := struct { + V IntArrayB `query:"a"` + }{} + err := testBindURL("/?a=1,2", &result) + + assert.NoError(t, err) + assert.Equal(t, IntArrayB([]int{1, 2}), result.V) + }) + + t.Run("ok, target is pointer an alias to slice and is nil", func(t *testing.T) { + result := struct { + V *IntArrayB `query:"a"` + }{} + err := testBindURL("/?a=1&a=4,5,6", &result) + + assert.NoError(t, err) + var expected = IntArrayB([]int{1, 4, 5, 6}) + assert.Equal(t, &expected, result.V) + }) + + t.Run("ok, target is pointer an alias to slice and is NOT nil", func(t *testing.T) { + result := struct { + V *IntArrayB `query:"a"` + }{} + result.V = new(IntArrayB) // NOT nil + + err := testBindURL("/?a=1&a=4,5,6", &result) + assert.NoError(t, err) + var expected = IntArrayB([]int{1, 4, 5, 6}) + assert.Equal(t, &expected, result.V) + }) +} + +func TestBindInt8(t *testing.T) { + t.Run("nok, binding fails", func(t *testing.T) { + type target struct { + V int8 `query:"v"` + } + p := target{} + err := testBindURL("/?v=x&v=2", &p) + assert.EqualError(t, err, `code=400, message=Bad Request, err=strconv.ParseInt: parsing "x": invalid syntax`) + }) + + t.Run("nok, int8 embedded in struct", func(t *testing.T) { + type target struct { + int8 `query:"v"` // embedded field is `Anonymous`. We can only set public fields + } + p := target{} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + assert.Equal(t, target{0}, p) + }) + + t.Run("nok, pointer to int8 embedded in struct", func(t *testing.T) { + type target struct { + *int8 `query:"v"` // embedded field is `Anonymous`. We can only set public fields + } + p := target{} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + + assert.Equal(t, target{int8: nil}, p) + }) + + t.Run("ok, bind int8 as struct field", func(t *testing.T) { + type target struct { + V int8 `query:"v"` + } + p := target{V: 127} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + assert.Equal(t, target{V: 1}, p) + }) + + t.Run("ok, bind pointer to int8 as struct field, value is nil", func(t *testing.T) { + type target struct { + V *int8 `query:"v"` + } + p := target{} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + assert.Equal(t, target{V: ptr(int8(1))}, p) + }) + + t.Run("ok, bind pointer to int8 as struct field, value is set", func(t *testing.T) { + type target struct { + V *int8 `query:"v"` + } + p := target{V: ptr(int8(127))} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + assert.Equal(t, target{V: ptr(int8(1))}, p) + }) + + t.Run("ok, bind int8 slice as struct field, value is nil", func(t *testing.T) { + type target struct { + V []int8 `query:"v"` + } + p := target{} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + assert.Equal(t, target{V: []int8{1, 2}}, p) + }) + + t.Run("ok, bind slice of int8 as struct field, value is set", func(t *testing.T) { + type target struct { + V []int8 `query:"v"` + } + p := target{V: []int8{111}} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + assert.Equal(t, target{V: []int8{1, 2}}, p) + }) + + t.Run("ok, bind slice of pointer to int8 as struct field, value is set", func(t *testing.T) { + type target struct { + V []*int8 `query:"v"` + } + p := target{V: []*int8{ptr(int8(127))}} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + assert.Equal(t, target{V: []*int8{ptr(int8(1)), ptr(int8(2))}}, p) + }) + + t.Run("ok, bind pointer to slice of int8 as struct field, value is set", func(t *testing.T) { + type target struct { + V *[]int8 `query:"v"` + } + p := target{V: &[]int8{111}} + err := testBindURL("/?v=1&v=2", &p) + assert.NoError(t, err) + assert.Equal(t, target{V: &[]int8{1, 2}}, p) + }) +} + +func TestBindMultipartFormFiles(t *testing.T) { + file1 := createTestFormFile("file", "file1.txt") + file11 := createTestFormFile("file", "file11.txt") + file2 := createTestFormFile("file2", "file2.txt") + filesA := createTestFormFile("files", "filesA.txt") + filesB := createTestFormFile("files", "filesB.txt") + + t.Run("nok, can not bind to multipart file struct", func(t *testing.T) { + var target struct { + File multipart.FileHeader `form:"file"` + } + err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored + + assert.EqualError(t, err, `code=400, message=Bad Request, err=binding to multipart.FileHeader struct is not supported, use pointer to struct`) + }) + + t.Run("ok, bind single multipart file to pointer to multipart file", func(t *testing.T) { + var target struct { + File *multipart.FileHeader `form:"file"` + } + err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored + + assert.NoError(t, err) + assertMultipartFileHeader(t, target.File, file1) + }) + + t.Run("ok, bind multiple multipart files to pointer to multipart file", func(t *testing.T) { + var target struct { + File *multipart.FileHeader `form:"file"` + } + err := bindMultipartFiles(t, &target, file1, file11) + + assert.NoError(t, err) + assertMultipartFileHeader(t, target.File, file1) // should choose first one + }) + + t.Run("ok, bind multiple multipart files to slice of multipart file", func(t *testing.T) { + var target struct { + Files []multipart.FileHeader `form:"files"` + } + err := bindMultipartFiles(t, &target, filesA, filesB, file1) + + assert.NoError(t, err) + + assert.Len(t, target.Files, 2) + assertMultipartFileHeader(t, &target.Files[0], filesA) + assertMultipartFileHeader(t, &target.Files[1], filesB) + }) + + t.Run("ok, bind multiple multipart files to slice of pointer to multipart file", func(t *testing.T) { + var target struct { + Files []*multipart.FileHeader `form:"files"` + } + err := bindMultipartFiles(t, &target, filesA, filesB, file1) + + assert.NoError(t, err) + + assert.Len(t, target.Files, 2) + assertMultipartFileHeader(t, target.Files[0], filesA) + assertMultipartFileHeader(t, target.Files[1], filesB) + }) +} + +type testFormFile struct { + Fieldname string + Filename string + Content []byte +} + +func createTestFormFile(formFieldName string, filename string) testFormFile { + return testFormFile{ + Fieldname: formFieldName, + Filename: filename, + Content: []byte(strings.Repeat(filename, 10)), + } +} + +func bindMultipartFiles(t *testing.T, target any, files ...testFormFile) error { + var body bytes.Buffer + mw := multipart.NewWriter(&body) + + for _, file := range files { + fw, err := mw.CreateFormFile(file.Fieldname, file.Filename) + assert.NoError(t, err) + + n, err := fw.Write(file.Content) + assert.NoError(t, err) + assert.Equal(t, len(file.Content), n) + } + + err := mw.Close() + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, "/", &body) + assert.NoError(t, err) + req.Header.Set("Content-Type", mw.FormDataContentType()) + + rec := httptest.NewRecorder() + + e := New() + c := e.NewContext(req, rec) + return c.Bind(target) +} + +func assertMultipartFileHeader(t *testing.T, fh *multipart.FileHeader, file testFormFile) { + assert.Equal(t, file.Filename, fh.Filename) + assert.Equal(t, int64(len(file.Content)), fh.Size) + fl, err := fh.Open() + assert.NoError(t, err) + body, err := io.ReadAll(fl) + assert.NoError(t, err) + assert.Equal(t, string(file.Content), string(body)) + err = fl.Close() + assert.NoError(t, err) +} + +func TestTimeFormatBinding(t *testing.T) { + type TestStruct struct { + DateTimeLocal time.Time `form:"datetime_local" format:"2006-01-02T15:04"` + Date time.Time `query:"date" format:"2006-01-02"` + CustomFormat time.Time `form:"custom" format:"01/02/2006 15:04:05"` + DefaultTime time.Time `form:"default_time"` // No format tag - should use default parsing + PtrTime *time.Time `query:"ptr_time" format:"2006-01-02"` + } + + testCases := []struct { + name string + contentType string + data string + queryParams string + expect TestStruct + expectError bool + }{ + { + name: "ok, datetime-local format binding", + contentType: MIMEApplicationForm, + data: "datetime_local=2023-12-25T14:30&default_time=2023-12-25T14:30:45Z", + expect: TestStruct{ + DateTimeLocal: time.Date(2023, 12, 25, 14, 30, 0, 0, time.UTC), + DefaultTime: time.Date(2023, 12, 25, 14, 30, 45, 0, time.UTC), + }, + }, + { + name: "ok, date format binding via query params", + queryParams: "?date=2023-01-15&ptr_time=2023-02-20", + expect: TestStruct{ + Date: time.Date(2023, 1, 15, 0, 0, 0, 0, time.UTC), + PtrTime: &time.Time{}, + }, + }, + { + name: "ok, custom format via form data", + contentType: MIMEApplicationForm, + data: "custom=12/25/2023 14:30:45", + expect: TestStruct{ + CustomFormat: time.Date(2023, 12, 25, 14, 30, 45, 0, time.UTC), + }, + }, + { + name: "nok, invalid format should fail", + contentType: MIMEApplicationForm, + data: "datetime_local=invalid-date", + expectError: true, + }, + { + name: "nok, wrong format should fail", + contentType: MIMEApplicationForm, + data: "datetime_local=2023-12-25", // Missing time part + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + var req *http.Request + + if tc.contentType == MIMEApplicationJSON { + req = httptest.NewRequest(http.MethodPost, "/"+tc.queryParams, strings.NewReader(tc.data)) + req.Header.Set(HeaderContentType, tc.contentType) + } else if tc.contentType == MIMEApplicationForm { + req = httptest.NewRequest(http.MethodPost, "/"+tc.queryParams, strings.NewReader(tc.data)) + req.Header.Set(HeaderContentType, tc.contentType) + } else { + req = httptest.NewRequest(http.MethodGet, "/"+tc.queryParams, nil) + } + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + var result TestStruct + err := c.Bind(&result) + + if tc.expectError { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + + // Check individual fields since time comparison can be tricky + if !tc.expect.DateTimeLocal.IsZero() { + assert.True(t, tc.expect.DateTimeLocal.Equal(result.DateTimeLocal), + "DateTimeLocal: expected %v, got %v", tc.expect.DateTimeLocal, result.DateTimeLocal) + } + if !tc.expect.Date.IsZero() { + assert.True(t, tc.expect.Date.Equal(result.Date), + "Date: expected %v, got %v", tc.expect.Date, result.Date) + } + if !tc.expect.CustomFormat.IsZero() { + assert.True(t, tc.expect.CustomFormat.Equal(result.CustomFormat), + "CustomFormat: expected %v, got %v", tc.expect.CustomFormat, result.CustomFormat) + } + if !tc.expect.DefaultTime.IsZero() { + assert.True(t, tc.expect.DefaultTime.Equal(result.DefaultTime), + "DefaultTime: expected %v, got %v", tc.expect.DefaultTime, result.DefaultTime) + } + if tc.expect.PtrTime != nil { + assert.NotNil(t, result.PtrTime) + if result.PtrTime != nil { + expectedPtr := time.Date(2023, 2, 20, 0, 0, 0, 0, time.UTC) + assert.True(t, expectedPtr.Equal(*result.PtrTime), + "PtrTime: expected %v, got %v", expectedPtr, *result.PtrTime) + } + } + }) + } +} diff --git a/binder.go b/binder.go new file mode 100644 index 000000000..32029ec0f --- /dev/null +++ b/binder.go @@ -0,0 +1,1329 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "encoding" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "time" +) + +/** + Following functions provide handful of methods for binding to Go native types from request query or path parameters. + * QueryParamsBinder(c) - binds query parameters (source URL) + * PathValuesBinder(c) - binds path parameters (source URL) + * FormFieldBinder(c) - binds form fields (source URL + body) + + Example: + ```go + var length int64 + err := echo.QueryParamsBinder(c).Int64("length", &length).BindError() + ``` + + For every supported type there are following methods: + * ("param", &destination) - if parameter value exists then binds it to given destination of that type i.e Int64(...). + * Must("param", &destination) - parameter value is required to exist, binds it to given destination of that type i.e MustInt64(...). + * s("param", &destination) - (for slices) if parameter values exists then binds it to given destination of that type i.e Int64s(...). + * Musts("param", &destination) - (for slices) parameter value is required to exist, binds it to given destination of that type i.e MustInt64s(...). + + for some slice types `BindWithDelimiter("param", &dest, ",")` supports splitting parameter values before type conversion is done + i.e. URL `/api/search?id=1,2,3&id=1` can be bind to `[]int64{1,2,3,1}` + + `FailFast` flags binder to stop binding after first bind error during binder call chain. Enabled by default. + `BindError()` returns first bind error from binder and resets errors in binder. Useful along with `FailFast()` method + to do binding and returns on first problem + `BindErrors()` returns all bind errors from binder and resets errors in binder. + + Types that are supported: + * bool + * float32 + * float64 + * int + * int8 + * int16 + * int32 + * int64 + * uint + * uint8/byte (does not support `bytes()`. Use BindUnmarshaler/CustomFunc to convert value from base64 etc to []byte{}) + * uint16 + * uint32 + * uint64 + * string + * time + * duration + * BindUnmarshaler() interface + * TextUnmarshaler() interface + * JSONUnmarshaler() interface + * UnixTime() - converts unix time (integer) to time.Time + * UnixTimeMilli() - converts unix time with millisecond precision (integer) to time.Time + * UnixTimeNano() - converts unix time with nanosecond precision (integer) to time.Time + * CustomFunc() - callback function for your custom conversion logic. Signature `func(values []string) []error` +*/ + +// BindingError represents an error that occurred while binding request data. +type BindingError struct { + // Field is the field name where value binding failed + Field string `json:"field"` + *HTTPError + // Values of parameter that failed to bind. + Values []string `json:"-"` +} + +// NewBindingError creates new instance of binding error +func NewBindingError(sourceParam string, values []string, message string, err error) error { + return &BindingError{ + Field: sourceParam, + Values: values, + HTTPError: &HTTPError{Code: http.StatusBadRequest, Message: message, err: err}, + } +} + +// Error returns error message +func (be *BindingError) Error() string { + return fmt.Sprintf("%s, field=%s", be.HTTPError.Error(), be.Field) +} + +// ValueBinder provides utility methods for binding query or path parameter to various Go built-in types +type ValueBinder struct { + // ValueFunc is used to get single parameter (first) value from request + ValueFunc func(sourceParam string) string + // ValuesFunc is used to get all values for parameter from request. i.e. `/api/search?ids=1&ids=2` + ValuesFunc func(sourceParam string) []string + // ErrorFunc is used to create errors. Allows you to use your own error type, that for example marshals to your specific json response + ErrorFunc func(sourceParam string, values []string, message string, internalError error) error + errors []error + // failFast is flag for binding methods to return without attempting to bind when previous binding already failed + failFast bool +} + +// QueryParamsBinder creates query parameter value binder +func QueryParamsBinder(c *Context) *ValueBinder { + return &ValueBinder{ + failFast: true, + ValueFunc: c.QueryParam, + ValuesFunc: func(sourceParam string) []string { + values, ok := c.QueryParams()[sourceParam] + if !ok { + return nil + } + return values + }, + ErrorFunc: NewBindingError, + } +} + +// PathValuesBinder creates path parameter value binder +func PathValuesBinder(c *Context) *ValueBinder { + return &ValueBinder{ + failFast: true, + ValueFunc: c.Param, + ValuesFunc: func(sourceParam string) []string { + // path parameter should not have multiple values so getting values does not make sense but lets not error out here + value := c.Param(sourceParam) + if value == "" { + return nil + } + return []string{value} + }, + ErrorFunc: NewBindingError, + } +} + +// FormFieldBinder creates form field value binder +// For all requests, FormFieldBinder parses the raw query from the URL and uses query params as form fields +// +// For POST, PUT, and PATCH requests, it also reads the request body, parses it +// as a form and uses query params as form fields. Request body parameters take precedence over URL query +// string values in r.Form. +// +// NB: when binding forms take note that this implementation uses standard library form parsing +// which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm +// See https://golang.org/pkg/net/http/#Request.ParseForm +func FormFieldBinder(c *Context) *ValueBinder { + vb := &ValueBinder{ + failFast: true, + ValueFunc: func(sourceParam string) string { + return c.Request().FormValue(sourceParam) + }, + ErrorFunc: NewBindingError, + } + vb.ValuesFunc = func(sourceParam string) []string { + if c.Request().Form == nil { + // this is same as `Request().FormValue()` does internally + _, _ = c.MultipartForm() // we want to trigger c.request.ParseMultipartForm(c.formParseMaxMemory) + } + values, ok := c.Request().Form[sourceParam] + if !ok { + return nil + } + return values + } + + return vb +} + +// FailFast set internal flag to indicate if binding methods will return early (without binding) when previous bind failed +// NB: call this method before any other binding methods as it modifies binding methods behaviour +func (b *ValueBinder) FailFast(value bool) *ValueBinder { + b.failFast = value + return b +} + +func (b *ValueBinder) setError(err error) { + if b.errors == nil { + b.errors = []error{err} + return + } + b.errors = append(b.errors, err) +} + +// BindError returns first seen bind error and resets/empties binder errors for further calls +func (b *ValueBinder) BindError() error { + if b.errors == nil { + return nil + } + err := b.errors[0] + b.errors = nil // reset errors so next chain will start from zero + return err +} + +// BindErrors returns all bind errors and resets/empties binder errors for further calls +func (b *ValueBinder) BindErrors() []error { + if b.errors == nil { + return nil + } + errors := b.errors + b.errors = nil // reset errors so next chain will start from zero + return errors +} + +// CustomFunc binds parameter values with Func. Func is called only when parameter values exist. +func (b *ValueBinder) CustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder { + return b.customFunc(sourceParam, customFunc, false) +} + +// MustCustomFunc requires parameter values to exist to bind with Func. Returns error when value does not exist. +func (b *ValueBinder) MustCustomFunc(sourceParam string, customFunc func(values []string) []error) *ValueBinder { + return b.customFunc(sourceParam, customFunc, true) +} + +func (b *ValueBinder) customFunc(sourceParam string, customFunc func(values []string) []error, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + values := b.ValuesFunc(sourceParam) + if len(values) == 0 { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + if errs := customFunc(values); errs != nil { + b.errors = append(b.errors, errs...) + } + return b +} + +// String binds parameter to string variable +func (b *ValueBinder) String(sourceParam string, dest *string) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + return b + } + *dest = value + return b +} + +// MustString requires parameter value to exist to bind to string variable. Returns error when value does not exist +func (b *ValueBinder) MustString(sourceParam string, dest *string) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil)) + return b + } + *dest = value + return b +} + +// Strings binds parameter values to slice of string +func (b *ValueBinder) Strings(sourceParam string, dest *[]string) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValuesFunc(sourceParam) + if value == nil { + return b + } + *dest = value + return b +} + +// MustStrings requires parameter values to exist to bind to slice of string variables. Returns error when value does not exist +func (b *ValueBinder) MustStrings(sourceParam string, dest *[]string) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValuesFunc(sourceParam) + if value == nil { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + return b + } + *dest = value + return b +} + +// BindUnmarshaler binds parameter to destination implementing BindUnmarshaler interface +func (b *ValueBinder) BindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + return b + } + + if err := dest.UnmarshalParam(tmp); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to BindUnmarshaler interface", err)) + } + return b +} + +// MustBindUnmarshaler requires parameter value to exist to bind to destination implementing BindUnmarshaler interface. +// Returns error when value does not exist +func (b *ValueBinder) MustBindUnmarshaler(sourceParam string, dest BindUnmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil)) + return b + } + + if err := dest.UnmarshalParam(value); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to BindUnmarshaler interface", err)) + } + return b +} + +// JSONUnmarshaler binds parameter to destination implementing json.Unmarshaler interface +func (b *ValueBinder) JSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + return b + } + + if err := dest.UnmarshalJSON([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err)) + } + return b +} + +// MustJSONUnmarshaler requires parameter value to exist to bind to destination implementing json.Unmarshaler interface. +// Returns error when value does not exist +func (b *ValueBinder) MustJSONUnmarshaler(sourceParam string, dest json.Unmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil)) + return b + } + + if err := dest.UnmarshalJSON([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to json.Unmarshaler interface", err)) + } + return b +} + +// TextUnmarshaler binds parameter to destination implementing encoding.TextUnmarshaler interface +func (b *ValueBinder) TextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + return b + } + + if err := dest.UnmarshalText([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err)) + } + return b +} + +// MustTextUnmarshaler requires parameter value to exist to bind to destination implementing encoding.TextUnmarshaler interface. +// Returns error when value does not exist +func (b *ValueBinder) MustTextUnmarshaler(sourceParam string, dest encoding.TextUnmarshaler) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + tmp := b.ValueFunc(sourceParam) + if tmp == "" { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "required field value is empty", nil)) + return b + } + + if err := dest.UnmarshalText([]byte(tmp)); err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{tmp}, "failed to bind field value to encoding.TextUnmarshaler interface", err)) + } + return b +} + +// BindWithDelimiter binds parameter to destination by suitable conversion function. +// Delimiter is used before conversion to split parameter value to separate values +func (b *ValueBinder) BindWithDelimiter(sourceParam string, dest any, delimiter string) *ValueBinder { + return b.bindWithDelimiter(sourceParam, dest, delimiter, false) +} + +// MustBindWithDelimiter requires parameter value to exist to bind destination by suitable conversion function. +// Delimiter is used before conversion to split parameter value to separate values +func (b *ValueBinder) MustBindWithDelimiter(sourceParam string, dest any, delimiter string) *ValueBinder { + return b.bindWithDelimiter(sourceParam, dest, delimiter, true) +} + +func (b *ValueBinder) bindWithDelimiter(sourceParam string, dest any, delimiter string, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + values := b.ValuesFunc(sourceParam) + if len(values) == 0 { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + tmpValues := make([]string, 0, len(values)) + for _, v := range values { + tmpValues = append(tmpValues, strings.Split(v, delimiter)...) + } + + switch d := dest.(type) { + case *[]string: + *d = tmpValues + return b + case *[]bool: + return b.bools(sourceParam, tmpValues, d) + case *[]int64, *[]int32, *[]int16, *[]int8, *[]int: + return b.ints(sourceParam, tmpValues, d) + case *[]uint64, *[]uint32, *[]uint16, *[]uint8, *[]uint: // *[]byte is same as *[]uint8 + return b.uints(sourceParam, tmpValues, d) + case *[]float64, *[]float32: + return b.floats(sourceParam, tmpValues, d) + case *[]time.Duration: + return b.durations(sourceParam, tmpValues, d) + default: + // support only cases when destination is slice + // does not support time.Time as it needs argument (layout) for parsing or BindUnmarshaler + b.setError(b.ErrorFunc(sourceParam, []string{}, "unsupported bind type", nil)) + return b + } +} + +// Int64 binds parameter to int64 variable +func (b *ValueBinder) Int64(sourceParam string, dest *int64) *ValueBinder { + return b.intValue(sourceParam, dest, 64, false) +} + +// MustInt64 requires parameter value to exist to bind to int64 variable. Returns error when value does not exist +func (b *ValueBinder) MustInt64(sourceParam string, dest *int64) *ValueBinder { + return b.intValue(sourceParam, dest, 64, true) +} + +// Int32 binds parameter to int32 variable +func (b *ValueBinder) Int32(sourceParam string, dest *int32) *ValueBinder { + return b.intValue(sourceParam, dest, 32, false) +} + +// MustInt32 requires parameter value to exist to bind to int32 variable. Returns error when value does not exist +func (b *ValueBinder) MustInt32(sourceParam string, dest *int32) *ValueBinder { + return b.intValue(sourceParam, dest, 32, true) +} + +// Int16 binds parameter to int16 variable +func (b *ValueBinder) Int16(sourceParam string, dest *int16) *ValueBinder { + return b.intValue(sourceParam, dest, 16, false) +} + +// MustInt16 requires parameter value to exist to bind to int16 variable. Returns error when value does not exist +func (b *ValueBinder) MustInt16(sourceParam string, dest *int16) *ValueBinder { + return b.intValue(sourceParam, dest, 16, true) +} + +// Int8 binds parameter to int8 variable +func (b *ValueBinder) Int8(sourceParam string, dest *int8) *ValueBinder { + return b.intValue(sourceParam, dest, 8, false) +} + +// MustInt8 requires parameter value to exist to bind to int8 variable. Returns error when value does not exist +func (b *ValueBinder) MustInt8(sourceParam string, dest *int8) *ValueBinder { + return b.intValue(sourceParam, dest, 8, true) +} + +// Int binds parameter to int variable +func (b *ValueBinder) Int(sourceParam string, dest *int) *ValueBinder { + return b.intValue(sourceParam, dest, 0, false) +} + +// MustInt requires parameter value to exist to bind to int variable. Returns error when value does not exist +func (b *ValueBinder) MustInt(sourceParam string, dest *int) *ValueBinder { + return b.intValue(sourceParam, dest, 0, true) +} + +func (b *ValueBinder) intValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + + return b.int(sourceParam, value, dest, bitSize) +} + +func (b *ValueBinder) int(sourceParam string, value string, dest any, bitSize int) *ValueBinder { + n, err := strconv.ParseInt(value, 10, bitSize) + if err != nil { + if bitSize == 0 { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to int", err)) + } else { + b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to int%v", bitSize), err)) + } + return b + } + + switch d := dest.(type) { + case *int64: + *d = n + case *int32: + *d = int32(n) // #nosec G115 + case *int16: + *d = int16(n) // #nosec G115 + case *int8: + *d = int8(n) // #nosec G115 + case *int: + *d = int(n) + } + return b +} + +func (b *ValueBinder) intsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + values := b.ValuesFunc(sourceParam) + if len(values) == 0 { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, values, "required field value is empty", nil)) + } + return b + } + return b.ints(sourceParam, values, dest) +} + +func (b *ValueBinder) ints(sourceParam string, values []string, dest any) *ValueBinder { + switch d := dest.(type) { + case *[]int64: + tmp := make([]int64, len(values)) + for i, v := range values { + b.int(sourceParam, v, &tmp[i], 64) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]int32: + tmp := make([]int32, len(values)) + for i, v := range values { + b.int(sourceParam, v, &tmp[i], 32) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]int16: + tmp := make([]int16, len(values)) + for i, v := range values { + b.int(sourceParam, v, &tmp[i], 16) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]int8: + tmp := make([]int8, len(values)) + for i, v := range values { + b.int(sourceParam, v, &tmp[i], 8) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]int: + tmp := make([]int, len(values)) + for i, v := range values { + b.int(sourceParam, v, &tmp[i], 0) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + } + return b +} + +// Int64s binds parameter to slice of int64 +func (b *ValueBinder) Int64s(sourceParam string, dest *[]int64) *ValueBinder { + return b.intsValue(sourceParam, dest, false) +} + +// MustInt64s requires parameter value to exist to bind to int64 slice variable. Returns error when value does not exist +func (b *ValueBinder) MustInt64s(sourceParam string, dest *[]int64) *ValueBinder { + return b.intsValue(sourceParam, dest, true) +} + +// Int32s binds parameter to slice of int32 +func (b *ValueBinder) Int32s(sourceParam string, dest *[]int32) *ValueBinder { + return b.intsValue(sourceParam, dest, false) +} + +// MustInt32s requires parameter value to exist to bind to int32 slice variable. Returns error when value does not exist +func (b *ValueBinder) MustInt32s(sourceParam string, dest *[]int32) *ValueBinder { + return b.intsValue(sourceParam, dest, true) +} + +// Int16s binds parameter to slice of int16 +func (b *ValueBinder) Int16s(sourceParam string, dest *[]int16) *ValueBinder { + return b.intsValue(sourceParam, dest, false) +} + +// MustInt16s requires parameter value to exist to bind to int16 slice variable. Returns error when value does not exist +func (b *ValueBinder) MustInt16s(sourceParam string, dest *[]int16) *ValueBinder { + return b.intsValue(sourceParam, dest, true) +} + +// Int8s binds parameter to slice of int8 +func (b *ValueBinder) Int8s(sourceParam string, dest *[]int8) *ValueBinder { + return b.intsValue(sourceParam, dest, false) +} + +// MustInt8s requires parameter value to exist to bind to int8 slice variable. Returns error when value does not exist +func (b *ValueBinder) MustInt8s(sourceParam string, dest *[]int8) *ValueBinder { + return b.intsValue(sourceParam, dest, true) +} + +// Ints binds parameter to slice of int +func (b *ValueBinder) Ints(sourceParam string, dest *[]int) *ValueBinder { + return b.intsValue(sourceParam, dest, false) +} + +// MustInts requires parameter value to exist to bind to int slice variable. Returns error when value does not exist +func (b *ValueBinder) MustInts(sourceParam string, dest *[]int) *ValueBinder { + return b.intsValue(sourceParam, dest, true) +} + +// Uint64 binds parameter to uint64 variable +func (b *ValueBinder) Uint64(sourceParam string, dest *uint64) *ValueBinder { + return b.uintValue(sourceParam, dest, 64, false) +} + +// MustUint64 requires parameter value to exist to bind to uint64 variable. Returns error when value does not exist +func (b *ValueBinder) MustUint64(sourceParam string, dest *uint64) *ValueBinder { + return b.uintValue(sourceParam, dest, 64, true) +} + +// Uint32 binds parameter to uint32 variable +func (b *ValueBinder) Uint32(sourceParam string, dest *uint32) *ValueBinder { + return b.uintValue(sourceParam, dest, 32, false) +} + +// MustUint32 requires parameter value to exist to bind to uint32 variable. Returns error when value does not exist +func (b *ValueBinder) MustUint32(sourceParam string, dest *uint32) *ValueBinder { + return b.uintValue(sourceParam, dest, 32, true) +} + +// Uint16 binds parameter to uint16 variable +func (b *ValueBinder) Uint16(sourceParam string, dest *uint16) *ValueBinder { + return b.uintValue(sourceParam, dest, 16, false) +} + +// MustUint16 requires parameter value to exist to bind to uint16 variable. Returns error when value does not exist +func (b *ValueBinder) MustUint16(sourceParam string, dest *uint16) *ValueBinder { + return b.uintValue(sourceParam, dest, 16, true) +} + +// Uint8 binds parameter to uint8 variable +func (b *ValueBinder) Uint8(sourceParam string, dest *uint8) *ValueBinder { + return b.uintValue(sourceParam, dest, 8, false) +} + +// MustUint8 requires parameter value to exist to bind to uint8 variable. Returns error when value does not exist +func (b *ValueBinder) MustUint8(sourceParam string, dest *uint8) *ValueBinder { + return b.uintValue(sourceParam, dest, 8, true) +} + +// Byte binds parameter to byte variable +func (b *ValueBinder) Byte(sourceParam string, dest *byte) *ValueBinder { + return b.uintValue(sourceParam, dest, 8, false) +} + +// MustByte requires parameter value to exist to bind to byte variable. Returns error when value does not exist +func (b *ValueBinder) MustByte(sourceParam string, dest *byte) *ValueBinder { + return b.uintValue(sourceParam, dest, 8, true) +} + +// Uint binds parameter to uint variable +func (b *ValueBinder) Uint(sourceParam string, dest *uint) *ValueBinder { + return b.uintValue(sourceParam, dest, 0, false) +} + +// MustUint requires parameter value to exist to bind to uint variable. Returns error when value does not exist +func (b *ValueBinder) MustUint(sourceParam string, dest *uint) *ValueBinder { + return b.uintValue(sourceParam, dest, 0, true) +} + +func (b *ValueBinder) uintValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + + return b.uint(sourceParam, value, dest, bitSize) +} + +func (b *ValueBinder) uint(sourceParam string, value string, dest any, bitSize int) *ValueBinder { + n, err := strconv.ParseUint(value, 10, bitSize) + if err != nil { + if bitSize == 0 { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to uint", err)) + } else { + b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to uint%v", bitSize), err)) + } + return b + } + + switch d := dest.(type) { + case *uint64: + *d = n + case *uint32: + *d = uint32(n) // #nosec G115 + case *uint16: + *d = uint16(n) // #nosec G115 + case *uint8: // byte is alias to uint8 + *d = uint8(n) // #nosec G115 + case *uint: + *d = uint(n) // #nosec G115 + } + return b +} + +func (b *ValueBinder) uintsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + values := b.ValuesFunc(sourceParam) + if len(values) == 0 { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, values, "required field value is empty", nil)) + } + return b + } + return b.uints(sourceParam, values, dest) +} + +func (b *ValueBinder) uints(sourceParam string, values []string, dest any) *ValueBinder { + switch d := dest.(type) { + case *[]uint64: + tmp := make([]uint64, len(values)) + for i, v := range values { + b.uint(sourceParam, v, &tmp[i], 64) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]uint32: + tmp := make([]uint32, len(values)) + for i, v := range values { + b.uint(sourceParam, v, &tmp[i], 32) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]uint16: + tmp := make([]uint16, len(values)) + for i, v := range values { + b.uint(sourceParam, v, &tmp[i], 16) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]uint8: // byte is alias to uint8 + tmp := make([]uint8, len(values)) + for i, v := range values { + b.uint(sourceParam, v, &tmp[i], 8) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]uint: + tmp := make([]uint, len(values)) + for i, v := range values { + b.uint(sourceParam, v, &tmp[i], 0) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + } + return b +} + +// Uint64s binds parameter to slice of uint64 +func (b *ValueBinder) Uint64s(sourceParam string, dest *[]uint64) *ValueBinder { + return b.uintsValue(sourceParam, dest, false) +} + +// MustUint64s requires parameter value to exist to bind to uint64 slice variable. Returns error when value does not exist +func (b *ValueBinder) MustUint64s(sourceParam string, dest *[]uint64) *ValueBinder { + return b.uintsValue(sourceParam, dest, true) +} + +// Uint32s binds parameter to slice of uint32 +func (b *ValueBinder) Uint32s(sourceParam string, dest *[]uint32) *ValueBinder { + return b.uintsValue(sourceParam, dest, false) +} + +// MustUint32s requires parameter value to exist to bind to uint32 slice variable. Returns error when value does not exist +func (b *ValueBinder) MustUint32s(sourceParam string, dest *[]uint32) *ValueBinder { + return b.uintsValue(sourceParam, dest, true) +} + +// Uint16s binds parameter to slice of uint16 +func (b *ValueBinder) Uint16s(sourceParam string, dest *[]uint16) *ValueBinder { + return b.uintsValue(sourceParam, dest, false) +} + +// MustUint16s requires parameter value to exist to bind to uint16 slice variable. Returns error when value does not exist +func (b *ValueBinder) MustUint16s(sourceParam string, dest *[]uint16) *ValueBinder { + return b.uintsValue(sourceParam, dest, true) +} + +// Uint8s binds parameter to slice of uint8 +func (b *ValueBinder) Uint8s(sourceParam string, dest *[]uint8) *ValueBinder { + return b.uintsValue(sourceParam, dest, false) +} + +// MustUint8s requires parameter value to exist to bind to uint8 slice variable. Returns error when value does not exist +func (b *ValueBinder) MustUint8s(sourceParam string, dest *[]uint8) *ValueBinder { + return b.uintsValue(sourceParam, dest, true) +} + +// Uints binds parameter to slice of uint +func (b *ValueBinder) Uints(sourceParam string, dest *[]uint) *ValueBinder { + return b.uintsValue(sourceParam, dest, false) +} + +// MustUints requires parameter value to exist to bind to uint slice variable. Returns error when value does not exist +func (b *ValueBinder) MustUints(sourceParam string, dest *[]uint) *ValueBinder { + return b.uintsValue(sourceParam, dest, true) +} + +// Bool binds parameter to bool variable +func (b *ValueBinder) Bool(sourceParam string, dest *bool) *ValueBinder { + return b.boolValue(sourceParam, dest, false) +} + +// MustBool requires parameter value to exist to bind to bool variable. Returns error when value does not exist +func (b *ValueBinder) MustBool(sourceParam string, dest *bool) *ValueBinder { + return b.boolValue(sourceParam, dest, true) +} + +func (b *ValueBinder) boolValue(sourceParam string, dest *bool, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + return b.bool(sourceParam, value, dest) +} + +func (b *ValueBinder) bool(sourceParam string, value string, dest *bool) *ValueBinder { + n, err := strconv.ParseBool(value) + if err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to bool", err)) + return b + } + + *dest = n + return b +} + +func (b *ValueBinder) boolsValue(sourceParam string, dest *[]bool, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + values := b.ValuesFunc(sourceParam) + if len(values) == 0 { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + return b.bools(sourceParam, values, dest) +} + +func (b *ValueBinder) bools(sourceParam string, values []string, dest *[]bool) *ValueBinder { + tmp := make([]bool, len(values)) + for i, v := range values { + b.bool(sourceParam, v, &tmp[i]) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *dest = tmp + } + return b +} + +// Bools binds parameter values to slice of bool variables +func (b *ValueBinder) Bools(sourceParam string, dest *[]bool) *ValueBinder { + return b.boolsValue(sourceParam, dest, false) +} + +// MustBools requires parameter values to exist to bind to slice of bool variables. Returns error when values does not exist +func (b *ValueBinder) MustBools(sourceParam string, dest *[]bool) *ValueBinder { + return b.boolsValue(sourceParam, dest, true) +} + +// Float64 binds parameter to float64 variable +func (b *ValueBinder) Float64(sourceParam string, dest *float64) *ValueBinder { + return b.floatValue(sourceParam, dest, 64, false) +} + +// MustFloat64 requires parameter value to exist to bind to float64 variable. Returns error when value does not exist +func (b *ValueBinder) MustFloat64(sourceParam string, dest *float64) *ValueBinder { + return b.floatValue(sourceParam, dest, 64, true) +} + +// Float32 binds parameter to float32 variable +func (b *ValueBinder) Float32(sourceParam string, dest *float32) *ValueBinder { + return b.floatValue(sourceParam, dest, 32, false) +} + +// MustFloat32 requires parameter value to exist to bind to float32 variable. Returns error when value does not exist +func (b *ValueBinder) MustFloat32(sourceParam string, dest *float32) *ValueBinder { + return b.floatValue(sourceParam, dest, 32, true) +} + +func (b *ValueBinder) floatValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + + return b.float(sourceParam, value, dest, bitSize) +} + +func (b *ValueBinder) float(sourceParam string, value string, dest any, bitSize int) *ValueBinder { + n, err := strconv.ParseFloat(value, bitSize) + if err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to float%v", bitSize), err)) + return b + } + + switch d := dest.(type) { + case *float64: + *d = n + case *float32: + *d = float32(n) + } + return b +} + +func (b *ValueBinder) floatsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + values := b.ValuesFunc(sourceParam) + if len(values) == 0 { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + return b.floats(sourceParam, values, dest) +} + +func (b *ValueBinder) floats(sourceParam string, values []string, dest any) *ValueBinder { + switch d := dest.(type) { + case *[]float64: + tmp := make([]float64, len(values)) + for i, v := range values { + b.float(sourceParam, v, &tmp[i], 64) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + case *[]float32: + tmp := make([]float32, len(values)) + for i, v := range values { + b.float(sourceParam, v, &tmp[i], 32) + if b.failFast && b.errors != nil { + return b + } + } + if b.errors == nil { + *d = tmp + } + } + return b +} + +// Float64s binds parameter values to slice of float64 variables +func (b *ValueBinder) Float64s(sourceParam string, dest *[]float64) *ValueBinder { + return b.floatsValue(sourceParam, dest, false) +} + +// MustFloat64s requires parameter values to exist to bind to slice of float64 variables. Returns error when values does not exist +func (b *ValueBinder) MustFloat64s(sourceParam string, dest *[]float64) *ValueBinder { + return b.floatsValue(sourceParam, dest, true) +} + +// Float32s binds parameter values to slice of float32 variables +func (b *ValueBinder) Float32s(sourceParam string, dest *[]float32) *ValueBinder { + return b.floatsValue(sourceParam, dest, false) +} + +// MustFloat32s requires parameter values to exist to bind to slice of float32 variables. Returns error when values does not exist +func (b *ValueBinder) MustFloat32s(sourceParam string, dest *[]float32) *ValueBinder { + return b.floatsValue(sourceParam, dest, true) +} + +// Time binds parameter to time.Time variable +func (b *ValueBinder) Time(sourceParam string, dest *time.Time, layout string) *ValueBinder { + return b.time(sourceParam, dest, layout, false) +} + +// MustTime requires parameter value to exist to bind to time.Time variable. Returns error when value does not exist +func (b *ValueBinder) MustTime(sourceParam string, dest *time.Time, layout string) *ValueBinder { + return b.time(sourceParam, dest, layout, true) +} + +func (b *ValueBinder) time(sourceParam string, dest *time.Time, layout string, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil)) + } + return b + } + t, err := time.Parse(layout, value) + if err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Time", err)) + return b + } + *dest = t + return b +} + +// Times binds parameter values to slice of time.Time variables +func (b *ValueBinder) Times(sourceParam string, dest *[]time.Time, layout string) *ValueBinder { + return b.times(sourceParam, dest, layout, false) +} + +// MustTimes requires parameter values to exist to bind to slice of time.Time variables. Returns error when values does not exist +func (b *ValueBinder) MustTimes(sourceParam string, dest *[]time.Time, layout string) *ValueBinder { + return b.times(sourceParam, dest, layout, true) +} + +func (b *ValueBinder) times(sourceParam string, dest *[]time.Time, layout string, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + values := b.ValuesFunc(sourceParam) + if len(values) == 0 { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + + tmp := make([]time.Time, len(values)) + for i, v := range values { + t, err := time.Parse(layout, v) + if err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{v}, "failed to bind field value to Time", err)) + if b.failFast { + return b + } + continue + } + tmp[i] = t + } + if b.errors == nil { + *dest = tmp + } + return b +} + +// Duration binds parameter to time.Duration variable +func (b *ValueBinder) Duration(sourceParam string, dest *time.Duration) *ValueBinder { + return b.duration(sourceParam, dest, false) +} + +// MustDuration requires parameter value to exist to bind to time.Duration variable. Returns error when value does not exist +func (b *ValueBinder) MustDuration(sourceParam string, dest *time.Duration) *ValueBinder { + return b.duration(sourceParam, dest, true) +} + +func (b *ValueBinder) duration(sourceParam string, dest *time.Duration, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil)) + } + return b + } + t, err := time.ParseDuration(value) + if err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Duration", err)) + return b + } + *dest = t + return b +} + +// Durations binds parameter values to slice of time.Duration variables +func (b *ValueBinder) Durations(sourceParam string, dest *[]time.Duration) *ValueBinder { + return b.durationsValue(sourceParam, dest, false) +} + +// MustDurations requires parameter values to exist to bind to slice of time.Duration variables. Returns error when values does not exist +func (b *ValueBinder) MustDurations(sourceParam string, dest *[]time.Duration) *ValueBinder { + return b.durationsValue(sourceParam, dest, true) +} + +func (b *ValueBinder) durationsValue(sourceParam string, dest *[]time.Duration, valueMustExist bool) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + values := b.ValuesFunc(sourceParam) + if len(values) == 0 { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{}, "required field value is empty", nil)) + } + return b + } + return b.durations(sourceParam, values, dest) +} + +func (b *ValueBinder) durations(sourceParam string, values []string, dest *[]time.Duration) *ValueBinder { + tmp := make([]time.Duration, len(values)) + for i, v := range values { + t, err := time.ParseDuration(v) + if err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{v}, "failed to bind field value to Duration", err)) + if b.failFast { + return b + } + continue + } + tmp[i] = t + } + if b.errors == nil { + *dest = tmp + } + return b +} + +// UnixTime binds parameter to time.Time variable (in local Time corresponding to the given Unix time). +// +// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 +// +// Note: +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +func (b *ValueBinder) UnixTime(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, false, time.Second) +} + +// MustUnixTime requires parameter value to exist to bind to time.Duration variable (in local time corresponding +// to the given Unix time). Returns error when value does not exist. +// +// Example: 1609180603 bind to 2020-12-28T18:36:43.000000000+00:00 +// +// Note: +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +func (b *ValueBinder) MustUnixTime(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, true, time.Second) +} + +// UnixTimeMilli binds parameter to time.Time variable (in local time corresponding to the given Unix time in millisecond precision). +// +// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 +// +// Note: +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +func (b *ValueBinder) UnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, false, time.Millisecond) +} + +// MustUnixTimeMilli requires parameter value to exist to bind to time.Duration variable (in local time corresponding +// to the given Unix time in millisecond precision). Returns error when value does not exist. +// +// Example: 1647184410140 bind to 2022-03-13T15:13:30.140000000+00:00 +// +// Note: +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +func (b *ValueBinder) MustUnixTimeMilli(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, true, time.Millisecond) +} + +// UnixTimeNano binds parameter to time.Time variable (in local time corresponding to the given Unix time in nanosecond precision). +// +// Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 +// Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00 +// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00 +// +// Note: +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. +func (b *ValueBinder) UnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, false, time.Nanosecond) +} + +// MustUnixTimeNano requires parameter value to exist to bind to time.Duration variable (in local Time corresponding +// to the given Unix time value in nano second precision). Returns error when value does not exist. +// +// Example: 1609180603123456789 binds to 2020-12-28T18:36:43.123456789+00:00 +// Example: 1000000000 binds to 1970-01-01T00:00:01.000000000+00:00 +// Example: 999999999 binds to 1970-01-01T00:00:00.999999999+00:00 +// +// Note: +// - time.Time{} (param is empty) and time.Unix(0,0) (param = "0") are not equal +// - Javascript's Number type only has about 53 bits of precision (Number.MAX_SAFE_INTEGER = 9007199254740991). Compare it to 1609180603123456789 in example. +func (b *ValueBinder) MustUnixTimeNano(sourceParam string, dest *time.Time) *ValueBinder { + return b.unixTime(sourceParam, dest, true, time.Nanosecond) +} + +func (b *ValueBinder) unixTime(sourceParam string, dest *time.Time, valueMustExist bool, precision time.Duration) *ValueBinder { + if b.failFast && b.errors != nil { + return b + } + + value := b.ValueFunc(sourceParam) + if value == "" { + if valueMustExist { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "required field value is empty", nil)) + } + return b + } + + n, err := strconv.ParseInt(value, 10, 64) + if err != nil { + b.setError(b.ErrorFunc(sourceParam, []string{value}, "failed to bind field value to Time", err)) + return b + } + + switch precision { + case time.Second: + *dest = time.Unix(n, 0) + case time.Millisecond: + *dest = time.UnixMilli(n) + case time.Nanosecond: + *dest = time.Unix(0, n) + } + return b +} diff --git a/binder_external_test.go b/binder_external_test.go new file mode 100644 index 000000000..d83c891b3 --- /dev/null +++ b/binder_external_test.go @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +// run tests as external package to get real feel for API +package echo_test + +import ( + "encoding/base64" + "fmt" + "log" + "net/http" + "net/http/httptest" + + "github.com/labstack/echo/v5" +) + +func ExampleValueBinder_BindErrors() { + // example route function that binds query params to different destinations and returns all bind errors in one go + routeFunc := func(c *echo.Context) error { + var opts struct { + IDs []int64 + Active bool + } + length := int64(50) // default length is 50 + + b := echo.QueryParamsBinder(c) + + errs := b.Int64("length", &length). + Int64s("ids", &opts.IDs). + Bool("active", &opts.Active). + BindErrors() // returns all errors + if errs != nil { + for _, err := range errs { + bErr := err.(*echo.BindingError) + log.Printf("in case you want to access what field: %s values: %v failed", bErr.Field, bErr.Values) + } + return fmt.Errorf("%v fields failed to bind", len(errs)) + } + fmt.Printf("active = %v, length = %v, ids = %v", opts.Active, length, opts.IDs) + + return c.JSON(http.StatusOK, opts) + } + + e := echo.New() + c := e.NewContext( + httptest.NewRequest(http.MethodGet, "/api/endpoint?active=true&length=25&ids=1&ids=2&ids=3", nil), + httptest.NewRecorder(), + ) + + _ = routeFunc(c) + + // Output: active = true, length = 25, ids = [1 2 3] +} + +func ExampleValueBinder_BindError() { + // example route function that binds query params to different destinations and stops binding on first bind error + failFastRouteFunc := func(c *echo.Context) error { + var opts struct { + IDs []int64 + Active bool + } + length := int64(50) // default length is 50 + + // create binder that stops binding at first error + b := echo.QueryParamsBinder(c) + + err := b.Int64("length", &length). + Int64s("ids", &opts.IDs). + Bool("active", &opts.Active). + BindError() // returns first binding error + if err != nil { + bErr := err.(*echo.BindingError) + return fmt.Errorf("my own custom error for field: %s values: %v", bErr.Field, bErr.Values) + } + fmt.Printf("active = %v, length = %v, ids = %v\n", opts.Active, length, opts.IDs) + + return c.JSON(http.StatusOK, opts) + } + + e := echo.New() + c := e.NewContext( + httptest.NewRequest(http.MethodGet, "/api/endpoint?active=true&length=25&ids=1&ids=2&ids=3", nil), + httptest.NewRecorder(), + ) + + _ = failFastRouteFunc(c) + + // Output: active = true, length = 25, ids = [1 2 3] +} + +func ExampleValueBinder_CustomFunc() { + // example route function that binds query params using custom function closure + routeFunc := func(c *echo.Context) error { + length := int64(50) // default length is 50 + var binary []byte + + b := echo.QueryParamsBinder(c) + errs := b.Int64("length", &length). + CustomFunc("base64", func(values []string) []error { + if len(values) == 0 { + return nil + } + decoded, err := base64.URLEncoding.DecodeString(values[0]) + if err != nil { + // in this example we use only first param value but url could contain multiple params in reality and + // therefore in theory produce multiple binding errors + return []error{echo.NewBindingError("base64", values[0:1], "failed to decode base64", err)} + } + binary = decoded + return nil + }). + BindErrors() // returns all errors + + if errs != nil { + for _, err := range errs { + bErr := err.(*echo.BindingError) + log.Printf("in case you want to access what field: %s values: %v failed", bErr.Field, bErr.Values) + } + return fmt.Errorf("%v fields failed to bind", len(errs)) + } + fmt.Printf("length = %v, base64 = %s", length, binary) + + return c.JSON(http.StatusOK, "ok") + } + + e := echo.New() + c := e.NewContext( + httptest.NewRequest(http.MethodGet, "/api/endpoint?length=25&base64=SGVsbG8gV29ybGQ%3D", nil), + httptest.NewRecorder(), + ) + _ = routeFunc(c) + + // Output: length = 25, base64 = Hello World +} diff --git a/binder_generic.go b/binder_generic.go new file mode 100644 index 000000000..0c0eb9089 --- /dev/null +++ b/binder_generic.go @@ -0,0 +1,563 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "encoding" + "encoding/json" + "fmt" + "strconv" + "time" +) + +// TimeLayout specifies the format for parsing time values in request parameters. +// It can be a standard Go time layout string or one of the special Unix time layouts. +type TimeLayout string + +// TimeOpts is options for parsing time.Time values +type TimeOpts struct { + // Layout specifies the format for parsing time values in request parameters. + // It can be a standard Go time layout string or one of the special Unix time layouts. + // + // Parsing layout defaults to: echo.TimeLayout(time.RFC3339Nano) + // - To convert to custom layout use `echo.TimeLayout("2006-01-02")` + // - To convert unix timestamp (integer) to time.Time use `echo.TimeLayoutUnixTime` + // - To convert unix timestamp in milliseconds to time.Time use `echo.TimeLayoutUnixTimeMilli` + // - To convert unix timestamp in nanoseconds to time.Time use `echo.TimeLayoutUnixTimeNano` + Layout TimeLayout + + // ParseInLocation is location used with time.ParseInLocation for layout that do not contain + // timezone information to set output time in given location. + // Defaults to time.UTC + ParseInLocation *time.Location + + // ToInLocation is location to which parsed time is converted to after parsing. + // The parsed time will be converted using time.In(ToInLocation). + // Defaults to time.UTC + ToInLocation *time.Location +} + +// TimeLayout constants for parsing Unix timestamps in different precisions. +const ( + TimeLayoutUnixTime = TimeLayout("UnixTime") // Unix timestamp in seconds + TimeLayoutUnixTimeMilli = TimeLayout("UnixTimeMilli") // Unix timestamp in milliseconds + TimeLayoutUnixTimeNano = TimeLayout("UnixTimeNano") // Unix timestamp in nanoseconds +) + +// PathParam extracts and parses a path parameter from the context by name. +// It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found. +// +// Empty String Handling: +// If the parameter exists but has an empty value, the zero value of type T is returned +// with no error. For example, a path parameter with value "" returns (0, nil) for int types. +// This differs from standard library behavior where parsing empty strings returns errors. +// To treat empty values as errors, validate the result separately or check the raw value. +// +// See ParseValue for supported types and options +func PathParam[T any](c *Context, paramName string, opts ...any) (T, error) { + for _, pv := range c.PathValues() { + if pv.Name == paramName { + v, err := ParseValue[T](pv.Value, opts...) + if err != nil { + return v, NewBindingError(paramName, []string{pv.Value}, "path value", err) + } + return v, nil + } + } + var zero T + return zero, ErrNonExistentKey +} + +// PathParamOr extracts and parses a path parameter from the context by name. +// Returns defaultValue if the parameter is not found or has an empty value. +// Returns an error only if parsing fails (e.g., "abc" for int type). +// +// Example: +// id, err := echo.PathParamOr[int](c, "id", 0) +// // If "id" is missing: returns (0, nil) +// // If "id" is "123": returns (123, nil) +// // If "id" is "abc": returns (0, BindingError) +// +// See ParseValue for supported types and options +func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...any) (T, error) { + for _, pv := range c.PathValues() { + if pv.Name == paramName { + v, err := ParseValueOr[T](pv.Value, defaultValue, opts...) + if err != nil { + return v, NewBindingError(paramName, []string{pv.Value}, "path value", err) + } + return v, nil + } + } + return defaultValue, nil +} + +// QueryParam extracts and parses a single query parameter from the request by key. +// It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found. +// +// Empty String Handling: +// If the parameter exists but has an empty value (?key=), the zero value of type T is returned +// with no error. For example, "?count=" returns (0, nil) for int types. +// This differs from standard library behavior where parsing empty strings returns errors. +// To treat empty values as errors, validate the result separately or check the raw value. +// +// Behavior Summary: +// - Missing key (?other=value): returns (zero, ErrNonExistentKey) +// - Empty value (?key=): returns (zero, nil) +// - Invalid value (?key=abc for int): returns (zero, BindingError) +// +// See ParseValue for supported types and options +func QueryParam[T any](c *Context, key string, opts ...any) (T, error) { + values, ok := c.QueryParams()[key] + if !ok { + var zero T + return zero, ErrNonExistentKey + } + if len(values) == 0 { + var zero T + return zero, nil + } + value := values[0] + v, err := ParseValue[T](value, opts...) + if err != nil { + return v, NewBindingError(key, []string{value}, "query param", err) + } + return v, nil +} + +// QueryParamOr extracts and parses a single query parameter from the request by key. +// Returns defaultValue if the parameter is not found or has an empty value. +// Returns an error only if parsing fails (e.g., "abc" for int type). +// +// Example: +// page, err := echo.QueryParamOr[int](c, "page", 1) +// // If "page" is missing: returns (1, nil) +// // If "page" is "5": returns (5, nil) +// // If "page" is "abc": returns (1, BindingError) +// +// See ParseValue for supported types and options +func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) { + values, ok := c.QueryParams()[key] + if !ok { + return defaultValue, nil + } + if len(values) == 0 { + return defaultValue, nil + } + value := values[0] + v, err := ParseValueOr[T](value, defaultValue, opts...) + if err != nil { + return v, NewBindingError(key, []string{value}, "query param", err) + } + return v, nil +} + +// QueryParams extracts and parses all values for a query parameter key as a slice. +// It returns the typed slice and an error if binding any value fails. Returns ErrNonExistentKey if parameter not found. +// +// See ParseValues for supported types and options +func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error) { + values, ok := c.QueryParams()[key] + if !ok { + return nil, ErrNonExistentKey + } + + result, err := ParseValues[T](values, opts...) + if err != nil { + return nil, NewBindingError(key, values, "query params", err) + } + return result, nil +} + +// QueryParamsOr extracts and parses all values for a query parameter key as a slice. +// Returns defaultValue if the parameter is not found. +// Returns an error only if parsing any value fails. +// +// Example: +// ids, err := echo.QueryParamsOr[int](c, "ids", []int{}) +// // If "ids" is missing: returns ([], nil) +// // If "ids" is "1&ids=2": returns ([1, 2], nil) +// // If "ids" contains "abc": returns ([], BindingError) +// +// See ParseValues for supported types and options +func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) { + values, ok := c.QueryParams()[key] + if !ok { + return defaultValue, nil + } + + result, err := ParseValuesOr[T](values, defaultValue, opts...) + if err != nil { + return nil, NewBindingError(key, values, "query params", err) + } + return result, nil +} + +// FormValue extracts and parses a single form value from the request by key. +// It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found. +// +// Empty String Handling: +// If the form field exists but has an empty value, the zero value of type T is returned +// with no error. For example, an empty form field returns (0, nil) for int types. +// This differs from standard library behavior where parsing empty strings returns errors. +// To treat empty values as errors, validate the result separately or check the raw value. +// +// See ParseValue for supported types and options +func FormValue[T any](c *Context, key string, opts ...any) (T, error) { + formValues, err := c.FormValues() + if err != nil { + var zero T + return zero, fmt.Errorf("failed to parse form value, key: %s, err: %w", key, err) + } + values, ok := formValues[key] + if !ok { + var zero T + return zero, ErrNonExistentKey + } + if len(values) == 0 { + var zero T + return zero, nil + } + value := values[0] + v, err := ParseValue[T](value, opts...) + if err != nil { + return v, NewBindingError(key, []string{value}, "form value", err) + } + return v, nil +} + +// FormValueOr extracts and parses a single form value from the request by key. +// Returns defaultValue if the parameter is not found or has an empty value. +// Returns an error only if parsing fails or form parsing errors occur. +// +// Example: +// limit, err := echo.FormValueOr[int](c, "limit", 100) +// // If "limit" is missing: returns (100, nil) +// // If "limit" is "50": returns (50, nil) +// // If "limit" is "abc": returns (100, BindingError) +// +// See ParseValue for supported types and options +func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) { + formValues, err := c.FormValues() + if err != nil { + var zero T + return zero, fmt.Errorf("failed to parse form value, key: %s, err: %w", key, err) + } + values, ok := formValues[key] + if !ok { + return defaultValue, nil + } + if len(values) == 0 { + return defaultValue, nil + } + value := values[0] + v, err := ParseValueOr[T](value, defaultValue, opts...) + if err != nil { + return v, NewBindingError(key, []string{value}, "form value", err) + } + return v, nil +} + +// FormValues extracts and parses all values for a form values key as a slice. +// It returns the typed slice and an error if binding any value fails. Returns ErrNonExistentKey if parameter not found. +// +// See ParseValues for supported types and options +func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) { + formValues, err := c.FormValues() + if err != nil { + return nil, fmt.Errorf("failed to parse form values, key: %s, err: %w", key, err) + } + values, ok := formValues[key] + if !ok { + return nil, ErrNonExistentKey + } + result, err := ParseValues[T](values, opts...) + if err != nil { + return nil, NewBindingError(key, values, "form values", err) + } + return result, nil +} + +// FormValuesOr extracts and parses all values for a form values key as a slice. +// Returns defaultValue if the parameter is not found. +// Returns an error only if parsing any value fails or form parsing errors occur. +// +// Example: +// tags, err := echo.FormValuesOr[string](c, "tags", []string{}) +// // If "tags" is missing: returns ([], nil) +// // If form parsing fails: returns (nil, error) +// +// See ParseValues for supported types and options +func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) { + formValues, err := c.FormValues() + if err != nil { + return nil, fmt.Errorf("failed to parse form values, key: %s, err: %w", key, err) + } + values, ok := formValues[key] + if !ok { + return defaultValue, nil + } + result, err := ParseValuesOr[T](values, defaultValue, opts...) + if err != nil { + return nil, NewBindingError(key, values, "form values", err) + } + return result, nil +} + +// ParseValues parses value to generic type slice. Same types are supported as ParseValue +// function but the result type is slice instead of scalar value. +// +// See ParseValue for supported types and options +func ParseValues[T any](values []string, opts ...any) ([]T, error) { + var zero []T + return ParseValuesOr(values, zero, opts...) +} + +// ParseValuesOr parses value to generic type slice, when value is empty defaultValue is returned. +// Same types are supported as ParseValue function but the result type is slice instead of scalar value. +// +// See ParseValue for supported types and options +func ParseValuesOr[T any](values []string, defaultValue []T, opts ...any) ([]T, error) { + if len(values) == 0 { + return defaultValue, nil + } + result := make([]T, 0, len(values)) + for _, v := range values { + tmp, err := ParseValue[T](v, opts...) + if err != nil { + return nil, err + } + result = append(result, tmp) + } + return result, nil +} + +// ParseValue parses value to generic type +// +// Types that are supported: +// - bool +// - float32 +// - float64 +// - int +// - int8 +// - int16 +// - int32 +// - int64 +// - uint +// - uint8/byte +// - uint16 +// - uint32 +// - uint64 +// - string +// - echo.BindUnmarshaler interface +// - encoding.TextUnmarshaler interface +// - json.Unmarshaler interface +// - time.Duration +// - time.Time use echo.TimeOpts or echo.TimeLayout to set time parsing configuration +func ParseValue[T any](value string, opts ...any) (T, error) { + var zero T + return ParseValueOr(value, zero, opts...) +} + +// ParseValueOr parses value to generic type, when value is empty defaultValue is returned. +// +// Types that are supported: +// - bool +// - float32 +// - float64 +// - int +// - int8 +// - int16 +// - int32 +// - int64 +// - uint +// - uint8/byte +// - uint16 +// - uint32 +// - uint64 +// - string +// - echo.BindUnmarshaler interface +// - encoding.TextUnmarshaler interface +// - json.Unmarshaler interface +// - time.Duration +// - time.Time use echo.TimeOpts or echo.TimeLayout to set time parsing configuration +func ParseValueOr[T any](value string, defaultValue T, opts ...any) (T, error) { + if len(value) == 0 { + return defaultValue, nil + } + var tmp T + if err := bindValue(value, &tmp, opts...); err != nil { + var zero T + return zero, fmt.Errorf("failed to parse value, err: %w", err) + } + return tmp, nil +} + +func bindValue(value string, dest any, opts ...any) error { + // NOTE: if this function is ever made public the dest should be checked for nil + // values when dealing with interfaces + if len(opts) > 0 { + if _, isTime := dest.(*time.Time); !isTime { + return fmt.Errorf("options are only supported for time.Time, got %T", dest) + } + } + + switch d := dest.(type) { + case *bool: + n, err := strconv.ParseBool(value) + if err != nil { + return err + } + *d = n + case *float32: + n, err := strconv.ParseFloat(value, 32) + if err != nil { + return err + } + *d = float32(n) + case *float64: + n, err := strconv.ParseFloat(value, 64) + if err != nil { + return err + } + *d = n + case *int: + n, err := strconv.ParseInt(value, 10, 0) + if err != nil { + return err + } + *d = int(n) + case *int8: + n, err := strconv.ParseInt(value, 10, 8) + if err != nil { + return err + } + *d = int8(n) + case *int16: + n, err := strconv.ParseInt(value, 10, 16) + if err != nil { + return err + } + *d = int16(n) + case *int32: + n, err := strconv.ParseInt(value, 10, 32) + if err != nil { + return err + } + *d = int32(n) + case *int64: + n, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + *d = n + case *uint: + n, err := strconv.ParseUint(value, 10, 0) + if err != nil { + return err + } + *d = uint(n) + case *uint8: + n, err := strconv.ParseUint(value, 10, 8) + if err != nil { + return err + } + *d = uint8(n) + case *uint16: + n, err := strconv.ParseUint(value, 10, 16) + if err != nil { + return err + } + *d = uint16(n) + case *uint32: + n, err := strconv.ParseUint(value, 10, 32) + if err != nil { + return err + } + *d = uint32(n) + case *uint64: + n, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return err + } + *d = n + case *string: + *d = value + case *time.Duration: + t, err := time.ParseDuration(value) + if err != nil { + return err + } + *d = t + case *time.Time: + to := TimeOpts{ + Layout: TimeLayout(time.RFC3339Nano), + ParseInLocation: time.UTC, + ToInLocation: time.UTC, + } + for _, o := range opts { + switch v := o.(type) { + case TimeOpts: + if v.Layout != "" { + to.Layout = v.Layout + } + if v.ParseInLocation != nil { + to.ParseInLocation = v.ParseInLocation + } + if v.ToInLocation != nil { + to.ToInLocation = v.ToInLocation + } + case TimeLayout: + to.Layout = v + } + } + var t time.Time + var err error + switch to.Layout { + case TimeLayoutUnixTime: + n, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + t = time.Unix(n, 0) + case TimeLayoutUnixTimeMilli: + n, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + t = time.UnixMilli(n) + case TimeLayoutUnixTimeNano: + n, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return err + } + t = time.Unix(0, n) + default: + if to.ParseInLocation != nil { + t, err = time.ParseInLocation(string(to.Layout), value, to.ParseInLocation) + } else { + t, err = time.Parse(string(to.Layout), value) + } + if err != nil { + return err + } + } + *d = t.In(to.ToInLocation) + case BindUnmarshaler: + if err := d.UnmarshalParam(value); err != nil { + return err + } + case encoding.TextUnmarshaler: + if err := d.UnmarshalText([]byte(value)); err != nil { + return err + } + case json.Unmarshaler: + if err := d.UnmarshalJSON([]byte(value)); err != nil { + return err + } + default: + return fmt.Errorf("unsupported value type: %T", dest) + } + return nil +} diff --git a/binder_generic_test.go b/binder_generic_test.go new file mode 100644 index 000000000..849d75962 --- /dev/null +++ b/binder_generic_test.go @@ -0,0 +1,1616 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "cmp" + "encoding/json" + "fmt" + "math" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// TextUnmarshalerType implements encoding.TextUnmarshaler but NOT BindUnmarshaler +type TextUnmarshalerType struct { + Value string +} + +func (t *TextUnmarshalerType) UnmarshalText(data []byte) error { + s := string(data) + if s == "invalid" { + return fmt.Errorf("invalid value: %s", s) + } + t.Value = strings.ToUpper(s) + return nil +} + +// JSONUnmarshalerType implements json.Unmarshaler but NOT BindUnmarshaler or TextUnmarshaler +type JSONUnmarshalerType struct { + Value string +} + +func (j *JSONUnmarshalerType) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &j.Value) +} + +func TestPathParam(t *testing.T) { + var testCases = []struct { + name string + givenKey string + givenValue string + expect bool + expectErr string + }{ + { + name: "ok", + givenValue: "true", + expect: true, + }, + { + name: "nok, non existent key", + givenKey: "missing", + givenValue: "true", + expect: false, + expectErr: ErrNonExistentKey.Error(), + }, + { + name: "nok, invalid value", + givenValue: "can_parse_me", + expect: false, + expectErr: `code=400, message=path value, err=failed to parse value, err: strconv.ParseBool: parsing "can_parse_me": invalid syntax, field=key`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := NewContext(nil, nil) + c.SetPathValues(PathValues{{ + Name: cmp.Or(tc.givenKey, "key"), + Value: tc.givenValue, + }}) + + v, err := PathParam[bool](c, "key") + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestPathParam_UnsupportedType(t *testing.T) { + c := NewContext(nil, nil) + c.SetPathValues(PathValues{{Name: "key", Value: "true"}}) + + v, err := PathParam[[]bool](c, "key") + + expectErr := "code=400, message=path value, err=failed to parse value, err: unsupported value type: *[]bool, field=key" + assert.EqualError(t, err, expectErr) + assert.Equal(t, []bool(nil), v) +} + +func TestQueryParam(t *testing.T) { + var testCases = []struct { + name string + givenURL string + expect bool + expectErr string + }{ + { + name: "ok", + givenURL: "/?key=true", + expect: true, + }, + { + name: "nok, non existent key", + givenURL: "/?different=true", + expect: false, + expectErr: ErrNonExistentKey.Error(), + }, + { + name: "nok, invalid value", + givenURL: "/?key=invalidbool", + expect: false, + expectErr: `code=400, message=query param, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) + c := NewContext(req, nil) + + v, err := QueryParam[bool](c, "key") + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestQueryParam_UnsupportedType(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil) + c := NewContext(req, nil) + + v, err := QueryParam[[]bool](c, "key") + + expectErr := "code=400, message=query param, err=failed to parse value, err: unsupported value type: *[]bool, field=key" + assert.EqualError(t, err, expectErr) + assert.Equal(t, []bool(nil), v) +} + +func TestQueryParams(t *testing.T) { + var testCases = []struct { + name string + givenURL string + expect []bool + expectErr string + }{ + { + name: "ok", + givenURL: "/?key=true&key=false", + expect: []bool{true, false}, + }, + { + name: "nok, non existent key", + givenURL: "/?different=true", + expect: []bool(nil), + expectErr: ErrNonExistentKey.Error(), + }, + { + name: "nok, invalid value", + givenURL: "/?key=true&key=invalidbool", + expect: []bool(nil), + expectErr: `code=400, message=query params, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) + c := NewContext(req, nil) + + v, err := QueryParams[bool](c, "key") + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestQueryParams_UnsupportedType(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil) + c := NewContext(req, nil) + + v, err := QueryParams[[]bool](c, "key") + + expectErr := "code=400, message=query params, err=failed to parse value, err: unsupported value type: *[]bool, field=key" + assert.EqualError(t, err, expectErr) + assert.Equal(t, [][]bool(nil), v) +} + +func TestFormValue(t *testing.T) { + var testCases = []struct { + name string + givenURL string + expect bool + expectErr string + }{ + { + name: "ok", + givenURL: "/?key=true", + expect: true, + }, + { + name: "nok, non existent key", + givenURL: "/?different=true", + expect: false, + expectErr: ErrNonExistentKey.Error(), + }, + { + name: "nok, invalid value", + givenURL: "/?key=invalidbool", + expect: false, + expectErr: `code=400, message=form value, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) + c := NewContext(req, nil) + + v, err := FormValue[bool](c, "key") + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestFormValue_UnsupportedType(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil) + c := NewContext(req, nil) + + v, err := FormValue[[]bool](c, "key") + + expectErr := "code=400, message=form value, err=failed to parse value, err: unsupported value type: *[]bool, field=key" + assert.EqualError(t, err, expectErr) + assert.Equal(t, []bool(nil), v) +} + +func TestFormValues(t *testing.T) { + var testCases = []struct { + name string + givenURL string + expect []bool + expectErr string + }{ + { + name: "ok", + givenURL: "/?key=true&key=false", + expect: []bool{true, false}, + }, + { + name: "nok, non existent key", + givenURL: "/?different=true", + expect: []bool(nil), + expectErr: ErrNonExistentKey.Error(), + }, + { + name: "nok, invalid value", + givenURL: "/?key=true&key=invalidbool", + expect: []bool(nil), + expectErr: `code=400, message=form values, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) + c := NewContext(req, nil) + + v, err := FormValues[bool](c, "key") + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestFormValues_UnsupportedType(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil) + c := NewContext(req, nil) + + v, err := FormValues[[]bool](c, "key") + + expectErr := "code=400, message=form values, err=failed to parse value, err: unsupported value type: *[]bool, field=key" + assert.EqualError(t, err, expectErr) + assert.Equal(t, [][]bool(nil), v) +} + +func TestParseValue_bool(t *testing.T) { + var testCases = []struct { + name string + when string + expect bool + expectErr error + }{ + { + name: "ok, true", + when: "true", + expect: true, + }, + { + name: "ok, false", + when: "false", + expect: false, + }, + { + name: "ok, 1", + when: "1", + expect: true, + }, + { + name: "ok, 0", + when: "0", + expect: false, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[bool](tc.when) + if tc.expectErr != nil { + assert.ErrorIs(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValue_float32(t *testing.T) { + var testCases = []struct { + name string + when string + expect float32 + expectErr string + }{ + { + name: "ok, 123.345", + when: "123.345", + expect: 123.345, + }, + { + name: "ok, 0", + when: "0", + expect: 0, + }, + { + name: "ok, Inf", + when: "+Inf", + expect: float32(math.Inf(1)), + }, + { + name: "ok, Inf", + when: "-Inf", + expect: float32(math.Inf(-1)), + }, + { + name: "ok, NaN", + when: "NaN", + expect: float32(math.NaN()), + }, + { + name: "ok, invalid value", + when: "X", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseFloat: parsing "X": invalid syntax`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[float32](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + if math.IsNaN(float64(tc.expect)) { + if !math.IsNaN(float64(v)) { + t.Fatal("expected NaN but got non NaN") + } + } else { + assert.Equal(t, tc.expect, v) + } + }) + } +} + +func TestParseValue_float64(t *testing.T) { + var testCases = []struct { + name string + when string + expect float64 + expectErr string + }{ + { + name: "ok, 123.345", + when: "123.345", + expect: 123.345, + }, + { + name: "ok, 0", + when: "0", + expect: 0, + }, + { + name: "ok, Inf", + when: "+Inf", + expect: math.Inf(1), + }, + { + name: "ok, Inf", + when: "-Inf", + expect: math.Inf(-1), + }, + { + name: "ok, NaN", + when: "NaN", + expect: math.NaN(), + }, + { + name: "ok, invalid value", + when: "X", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseFloat: parsing "X": invalid syntax`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[float64](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + if math.IsNaN(tc.expect) { + if !math.IsNaN(v) { + t.Fatal("expected NaN but got non NaN") + } + } else { + assert.Equal(t, tc.expect, v) + } + }) + } +} + +func TestParseValue_int(t *testing.T) { + var testCases = []struct { + name string + when string + expect int + expectErr string + }{ + { + name: "ok, 0", + when: "0", + expect: 0, + }, + { + name: "ok, 1", + when: "1", + expect: 1, + }, + { + name: "ok, -1", + when: "-1", + expect: -1, + }, + { + name: "ok, max int (64bit)", + when: "9223372036854775807", + expect: 9223372036854775807, + }, + { + name: "ok, min int (64bit)", + when: "-9223372036854775808", + expect: -9223372036854775808, + }, + { + name: "ok, overflow max int (64bit)", + when: "9223372036854775808", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "9223372036854775808": value out of range`, + }, + { + name: "ok, underflow min int (64bit)", + when: "-9223372036854775809", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-9223372036854775809": value out of range`, + }, + { + name: "ok, invalid value", + when: "X", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[int](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValue_uint(t *testing.T) { + var testCases = []struct { + name string + when string + expect uint + expectErr string + }{ + { + name: "ok, 0", + when: "0", + expect: 0, + }, + { + name: "ok, 1", + when: "1", + expect: 1, + }, + { + name: "ok, max uint (64bit)", + when: "18446744073709551615", + expect: 18446744073709551615, + }, + { + name: "nok, overflow max uint (64bit)", + when: "18446744073709551616", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseUint: parsing "18446744073709551616": value out of range`, + }, + { + name: "nok, negative value", + when: "-1", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`, + }, + { + name: "nok, invalid value", + when: "X", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[uint](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValue_int8(t *testing.T) { + var testCases = []struct { + name string + when string + expect int8 + expectErr string + }{ + { + name: "ok, 0", + when: "0", + expect: 0, + }, + { + name: "ok, 1", + when: "1", + expect: 1, + }, + { + name: "ok, -1", + when: "-1", + expect: -1, + }, + { + name: "ok, max int8", + when: "127", + expect: 127, + }, + { + name: "ok, min int8", + when: "-128", + expect: -128, + }, + { + name: "nok, overflow max int8", + when: "128", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "128": value out of range`, + }, + { + name: "nok, underflow min int8", + when: "-129", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-129": value out of range`, + }, + { + name: "nok, invalid value", + when: "X", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[int8](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValue_int16(t *testing.T) { + var testCases = []struct { + name string + when string + expect int16 + expectErr string + }{ + { + name: "ok, 0", + when: "0", + expect: 0, + }, + { + name: "ok, 1", + when: "1", + expect: 1, + }, + { + name: "ok, -1", + when: "-1", + expect: -1, + }, + { + name: "ok, max int16", + when: "32767", + expect: 32767, + }, + { + name: "ok, min int16", + when: "-32768", + expect: -32768, + }, + { + name: "nok, overflow max int16", + when: "32768", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "32768": value out of range`, + }, + { + name: "nok, underflow min int16", + when: "-32769", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-32769": value out of range`, + }, + { + name: "nok, invalid value", + when: "X", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[int16](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValue_int32(t *testing.T) { + var testCases = []struct { + name string + when string + expect int32 + expectErr string + }{ + { + name: "ok, 0", + when: "0", + expect: 0, + }, + { + name: "ok, 1", + when: "1", + expect: 1, + }, + { + name: "ok, -1", + when: "-1", + expect: -1, + }, + { + name: "ok, max int32", + when: "2147483647", + expect: 2147483647, + }, + { + name: "ok, min int32", + when: "-2147483648", + expect: -2147483648, + }, + { + name: "nok, overflow max int32", + when: "2147483648", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "2147483648": value out of range`, + }, + { + name: "nok, underflow min int32", + when: "-2147483649", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-2147483649": value out of range`, + }, + { + name: "nok, invalid value", + when: "X", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[int32](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValue_int64(t *testing.T) { + var testCases = []struct { + name string + when string + expect int64 + expectErr string + }{ + { + name: "ok, 0", + when: "0", + expect: 0, + }, + { + name: "ok, 1", + when: "1", + expect: 1, + }, + { + name: "ok, -1", + when: "-1", + expect: -1, + }, + { + name: "ok, max int64", + when: "9223372036854775807", + expect: 9223372036854775807, + }, + { + name: "ok, min int64", + when: "-9223372036854775808", + expect: -9223372036854775808, + }, + { + name: "nok, overflow max int64", + when: "9223372036854775808", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "9223372036854775808": value out of range`, + }, + { + name: "nok, underflow min int64", + when: "-9223372036854775809", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "-9223372036854775809": value out of range`, + }, + { + name: "nok, invalid value", + when: "X", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "X": invalid syntax`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[int64](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValue_uint8(t *testing.T) { + var testCases = []struct { + name string + when string + expect uint8 + expectErr string + }{ + { + name: "ok, 0", + when: "0", + expect: 0, + }, + { + name: "ok, 1", + when: "1", + expect: 1, + }, + { + name: "ok, max uint8", + when: "255", + expect: 255, + }, + { + name: "nok, overflow max uint8", + when: "256", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseUint: parsing "256": value out of range`, + }, + { + name: "nok, negative value", + when: "-1", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`, + }, + { + name: "nok, invalid value", + when: "X", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[uint8](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValue_uint16(t *testing.T) { + var testCases = []struct { + name string + when string + expect uint16 + expectErr string + }{ + { + name: "ok, 0", + when: "0", + expect: 0, + }, + { + name: "ok, 1", + when: "1", + expect: 1, + }, + { + name: "ok, max uint16", + when: "65535", + expect: 65535, + }, + { + name: "nok, overflow max uint16", + when: "65536", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseUint: parsing "65536": value out of range`, + }, + { + name: "nok, negative value", + when: "-1", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`, + }, + { + name: "nok, invalid value", + when: "X", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[uint16](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValue_uint32(t *testing.T) { + var testCases = []struct { + name string + when string + expect uint32 + expectErr string + }{ + { + name: "ok, 0", + when: "0", + expect: 0, + }, + { + name: "ok, 1", + when: "1", + expect: 1, + }, + { + name: "ok, max uint32", + when: "4294967295", + expect: 4294967295, + }, + { + name: "nok, overflow max uint32", + when: "4294967296", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseUint: parsing "4294967296": value out of range`, + }, + { + name: "nok, negative value", + when: "-1", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`, + }, + { + name: "nok, invalid value", + when: "X", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[uint32](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValue_uint64(t *testing.T) { + var testCases = []struct { + name string + when string + expect uint64 + expectErr string + }{ + { + name: "ok, 0", + when: "0", + expect: 0, + }, + { + name: "ok, 1", + when: "1", + expect: 1, + }, + { + name: "ok, max uint64", + when: "18446744073709551615", + expect: 18446744073709551615, + }, + { + name: "nok, overflow max uint64", + when: "18446744073709551616", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseUint: parsing "18446744073709551616": value out of range`, + }, + { + name: "nok, negative value", + when: "-1", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseUint: parsing "-1": invalid syntax`, + }, + { + name: "nok, invalid value", + when: "X", + expect: 0, + expectErr: `failed to parse value, err: strconv.ParseUint: parsing "X": invalid syntax`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[uint64](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValue_string(t *testing.T) { + var testCases = []struct { + name string + when string + expect string + expectErr string + }{ + { + name: "ok, my", + when: "my", + expect: "my", + }, + { + name: "ok, empty", + when: "", + expect: "", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[string](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValue_Duration(t *testing.T) { + var testCases = []struct { + name string + when string + expect time.Duration + expectErr string + }{ + { + name: "ok, 10h11m01s", + when: "10h11m01s", + expect: 10*time.Hour + 11*time.Minute + 1*time.Second, + }, + { + name: "ok, empty", + when: "", + expect: 0, + }, + { + name: "ok, invalid", + when: "0x0", + expect: 0, + expectErr: `failed to parse value, err: time: unknown unit "x" in duration "0x0"`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[time.Duration](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValue_Time(t *testing.T) { + tallinn, err := time.LoadLocation("Europe/Tallinn") + if err != nil { + t.Fatal(err) + } + berlin, err := time.LoadLocation("Europe/Berlin") + if err != nil { + t.Fatal(err) + } + + parse := func(t *testing.T, layout string, s string) time.Time { + result, err := time.Parse(layout, s) + if err != nil { + t.Fatal(err) + } + return result + } + + parseInLoc := func(t *testing.T, layout string, s string, loc *time.Location) time.Time { + result, err := time.ParseInLocation(layout, s, loc) + if err != nil { + t.Fatal(err) + } + return result + } + + var testCases = []struct { + name string + when string + whenLayout TimeLayout + whenTimeOpts *TimeOpts + expect time.Time + expectErr string + }{ + { + name: "ok, defaults to RFC3339Nano", + when: "2006-01-02T15:04:05.999999999Z", + expect: parse(t, time.RFC3339Nano, "2006-01-02T15:04:05.999999999Z"), + }, + { + name: "ok, custom TimeOpt", + when: "2006-01-02", + whenTimeOpts: &TimeOpts{ + Layout: time.DateOnly, + ParseInLocation: tallinn, + ToInLocation: berlin, + }, + expect: parseInLoc(t, time.DateTime, "2006-01-01 23:00:00", berlin), + }, + { + name: "ok, custom layout", + when: "2006-01-02", + whenLayout: TimeLayout(time.DateOnly), + expect: parse(t, time.DateOnly, "2006-01-02"), + }, + { + name: "ok, TimeLayoutUnixTime", + when: "1766604665", + whenLayout: TimeLayoutUnixTime, + expect: parse(t, time.RFC3339Nano, "2025-12-24T19:31:05Z"), + }, + { + name: "nok, TimeLayoutUnixTime, invalid value", + when: "176x6604665", + whenLayout: TimeLayoutUnixTime, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "176x6604665": invalid syntax`, + }, + { + name: "ok, TimeLayoutUnixTimeMilli", + when: "1766604665123", + whenLayout: TimeLayoutUnixTimeMilli, + expect: parse(t, time.RFC3339Nano, "2025-12-24T19:31:05.123Z"), + }, + { + name: "nok, TimeLayoutUnixTimeMilli, invalid value", + when: "1x766604665123", + whenLayout: TimeLayoutUnixTimeMilli, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "1x766604665123": invalid syntax`, + }, + { + name: "ok, TimeLayoutUnixTimeMilli", + when: "1766604665999999999", + whenLayout: TimeLayoutUnixTimeNano, + expect: parse(t, time.RFC3339Nano, "2025-12-24T19:31:05.999999999Z"), + }, + { + name: "nok, TimeLayoutUnixTimeMilli, invalid value", + when: "1x766604665999999999", + whenLayout: TimeLayoutUnixTimeNano, + expectErr: `failed to parse value, err: strconv.ParseInt: parsing "1x766604665999999999": invalid syntax`, + }, + { + name: "ok, invalid", + when: "xx", + expect: time.Time{}, + expectErr: `failed to parse value, err: parsing time "xx" as "2006-01-02T15:04:05.999999999Z07:00": cannot parse "xx" as "2006"`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var opts []any + if tc.whenLayout != "" { + opts = append(opts, tc.whenLayout) + } + if tc.whenTimeOpts != nil { + opts = append(opts, *tc.whenTimeOpts) + } + v, err := ParseValue[time.Time](tc.when, opts...) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValue_OptionsOnlyForTime(t *testing.T) { + _, err := ParseValue[int]("test", TimeLayoutUnixTime) + assert.EqualError(t, err, `failed to parse value, err: options are only supported for time.Time, got *int`) +} + +func TestParseValue_BindUnmarshaler(t *testing.T) { + exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00") + + var testCases = []struct { + name string + when string + expect Timestamp + expectErr string + }{ + { + name: "ok", + when: "2020-12-23T09:45:31+02:00", + expect: Timestamp(exampleTime), + }, + { + name: "nok, invalid value", + when: "2020-12-23T09:45:3102:00", + expect: Timestamp{}, + expectErr: `failed to parse value, err: parsing time "2020-12-23T09:45:3102:00" as "2006-01-02T15:04:05Z07:00": cannot parse "02:00" as "Z07:00"`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[Timestamp](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValue_TextUnmarshaler(t *testing.T) { + var testCases = []struct { + name string + when string + expect TextUnmarshalerType + expectErr string + }{ + { + name: "ok, converts to uppercase", + when: "hello", + expect: TextUnmarshalerType{Value: "HELLO"}, + }, + { + name: "ok, empty string", + when: "", + expect: TextUnmarshalerType{Value: ""}, + }, + { + name: "nok, invalid value", + when: "invalid", + expect: TextUnmarshalerType{}, + expectErr: "failed to parse value, err: invalid value: invalid", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[TextUnmarshalerType](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValue_JSONUnmarshaler(t *testing.T) { + var testCases = []struct { + name string + when string + expect JSONUnmarshalerType + expectErr string + }{ + { + name: "ok, valid JSON string", + when: `"hello"`, + expect: JSONUnmarshalerType{Value: "hello"}, + }, + { + name: "ok, empty JSON string", + when: `""`, + expect: JSONUnmarshalerType{Value: ""}, + }, + { + name: "nok, invalid JSON", + when: "not-json", + expect: JSONUnmarshalerType{}, + expectErr: "failed to parse value, err: invalid character 'o' in literal null (expecting 'u')", + }, + { + name: "nok, unquoted string", + when: "hello", + expect: JSONUnmarshalerType{}, + expectErr: "failed to parse value, err: invalid character 'h' looking for beginning of value", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValue[JSONUnmarshalerType](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestParseValues_bools(t *testing.T) { + var testCases = []struct { + name string + when []string + expect []bool + expectErr string + }{ + { + name: "ok", + when: []string{"true", "0", "false", "1"}, + expect: []bool{true, false, false, true}, + }, + { + name: "nok", + when: []string{"true", "10"}, + expect: nil, + expectErr: `failed to parse value, err: strconv.ParseBool: parsing "10": invalid syntax`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v, err := ParseValues[bool](tc.when) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestPathParamOr(t *testing.T) { + var testCases = []struct { + name string + givenKey string + givenValue string + defaultValue int + expect int + expectErr string + }{ + { + name: "ok, param exists", + givenKey: "id", + givenValue: "123", + defaultValue: 999, + expect: 123, + }, + { + name: "ok, param missing - returns default", + givenKey: "other", + givenValue: "123", + defaultValue: 999, + expect: 999, + }, + { + name: "ok, param exists but empty - returns default", + givenKey: "id", + givenValue: "", + defaultValue: 999, + expect: 999, + }, + { + name: "nok, invalid value", + givenKey: "id", + givenValue: "invalid", + defaultValue: 999, + expectErr: "code=400, message=path value, err=failed to parse value", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := NewContext(nil, nil) + c.SetPathValues(PathValues{{Name: tc.givenKey, Value: tc.givenValue}}) + + v, err := PathParamOr[int](c, "id", tc.defaultValue) + if tc.expectErr != "" { + assert.ErrorContains(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestQueryParamOr(t *testing.T) { + var testCases = []struct { + name string + givenURL string + defaultValue int + expect int + expectErr string + }{ + { + name: "ok, param exists", + givenURL: "/?key=42", + defaultValue: 999, + expect: 42, + }, + { + name: "ok, param missing - returns default", + givenURL: "/?other=42", + defaultValue: 999, + expect: 999, + }, + { + name: "ok, param exists but empty - returns default", + givenURL: "/?key=", + defaultValue: 999, + expect: 999, + }, + { + name: "nok, invalid value", + givenURL: "/?key=invalid", + defaultValue: 999, + expectErr: "code=400, message=query param", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + c := NewContext(req, nil) + + v, err := QueryParamOr[int](c, "key", tc.defaultValue) + if tc.expectErr != "" { + assert.ErrorContains(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestQueryParamsOr(t *testing.T) { + var testCases = []struct { + name string + givenURL string + defaultValue []int + expect []int + expectErr string + }{ + { + name: "ok, params exist", + givenURL: "/?key=1&key=2&key=3", + defaultValue: []int{999}, + expect: []int{1, 2, 3}, + }, + { + name: "ok, params missing - returns default", + givenURL: "/?other=1", + defaultValue: []int{7, 8, 9}, + expect: []int{7, 8, 9}, + }, + { + name: "nok, invalid value", + givenURL: "/?key=1&key=invalid", + defaultValue: []int{999}, + expectErr: "code=400, message=query params", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + c := NewContext(req, nil) + + v, err := QueryParamsOr[int](c, "key", tc.defaultValue) + if tc.expectErr != "" { + assert.ErrorContains(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestFormValueOr(t *testing.T) { + var testCases = []struct { + name string + givenURL string + defaultValue string + expect string + expectErr string + }{ + { + name: "ok, value exists", + givenURL: "/?name=john", + defaultValue: "default", + expect: "john", + }, + { + name: "ok, value missing - returns default", + givenURL: "/?other=john", + defaultValue: "default", + expect: "default", + }, + { + name: "ok, value exists but empty - returns default", + givenURL: "/?name=", + defaultValue: "default", + expect: "default", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) + c := NewContext(req, nil) + + v, err := FormValueOr[string](c, "name", tc.defaultValue) + if tc.expectErr != "" { + assert.ErrorContains(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} + +func TestFormValuesOr(t *testing.T) { + var testCases = []struct { + name string + givenURL string + defaultValue []string + expect []string + expectErr string + }{ + { + name: "ok, values exist", + givenURL: "/?tags=go&tags=rust&tags=python", + defaultValue: []string{"default"}, + expect: []string{"go", "rust", "python"}, + }, + { + name: "ok, values missing - returns default", + givenURL: "/?other=value", + defaultValue: []string{"a", "b"}, + expect: []string{"a", "b"}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) + c := NewContext(req, nil) + + v, err := FormValuesOr[string](c, "tags", tc.defaultValue) + if tc.expectErr != "" { + assert.ErrorContains(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expect, v) + }) + } +} diff --git a/binder_test.go b/binder_test.go new file mode 100644 index 000000000..8eced8208 --- /dev/null +++ b/binder_test.go @@ -0,0 +1,3252 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/stretchr/testify/assert" + "io" + "math/big" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + "time" +) + +func createTestContext(URL string, body io.Reader, pathValues map[string]string) *Context { + e := New() + req := httptest.NewRequest(http.MethodGet, URL, body) + if body != nil { + req.Header.Set(HeaderContentType, MIMEApplicationJSON) + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + if len(pathValues) > 0 { + params := make(PathValues, 0) + for name, value := range pathValues { + params = append(params, PathValue{ + Name: name, + Value: value, + }) + } + c.SetPathValues(params) + } + + return c +} + +func TestBindingError_Error(t *testing.T) { + err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error")) + assert.EqualError(t, err, `code=400, message=bind failed, err=internal error, field=id`) + + bErr := err.(*BindingError) + assert.Equal(t, 400, bErr.Code) + assert.Equal(t, "bind failed", bErr.Message) + assert.Equal(t, errors.New("internal error"), bErr.err) + + assert.Equal(t, "id", bErr.Field) + assert.Equal(t, []string{"1", "nope"}, bErr.Values) +} + +func TestBindingError_ErrorJSON(t *testing.T) { + err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error")) + + resp, _ := json.Marshal(err) + + assert.Equal(t, `{"field":"id","message":"bind failed"}`, string(resp)) +} + +func TestPathValuesBinder(t *testing.T) { + c := createTestContext("/api/user/999", nil, map[string]string{ + "id": "1", + "nr": "2", + "slice": "3", + }) + b := PathValuesBinder(c) + + id := int64(99) + nr := int64(88) + var slice = make([]int64, 0) + var notExisting = make([]int64, 0) + err := b.Int64("id", &id). + Int64("nr", &nr). + Int64s("slice", &slice). + Int64s("not_existing", ¬Existing). + BindError() + + assert.NoError(t, err) + assert.Equal(t, int64(1), id) + assert.Equal(t, int64(2), nr) + assert.Equal(t, []int64{3}, slice) // binding params to slice does not make sense but it should not panic either + assert.Equal(t, []int64{}, notExisting) // binding params to slice does not make sense but it should not panic either +} + +func TestQueryParamsBinder_FailFast(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectError []string + givenFailFast bool + }{ + { + name: "ok, FailFast=true stops at first error", + whenURL: "/api/user/999?nr=en&id=nope", + givenFailFast: true, + expectError: []string{ + `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "nope": invalid syntax, field=id`, + }, + }, + { + name: "ok, FailFast=false encounters all errors", + whenURL: "/api/user/999?nr=en&id=nope", + givenFailFast: false, + expectError: []string{ + `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "nope": invalid syntax, field=id`, + `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "en": invalid syntax, field=nr`, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, map[string]string{"id": "999"}) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + id := int64(99) + nr := int64(88) + errs := b.Int64("id", &id). + Int64("nr", &nr). + BindErrors() + + assert.Len(t, errs, len(tc.expectError)) + for _, err := range errs { + assert.Contains(t, tc.expectError, err.Error()) + } + }) + } +} + +func TestFormFieldBinder(t *testing.T) { + e := New() + body := `texta=foo&slice=5` + req := httptest.NewRequest(http.MethodPost, "/api/search?id=1&nr=2&slice=3&slice=4", strings.NewReader(body)) + req.Header.Set(HeaderContentLength, strconv.Itoa(len(body))) + req.Header.Set(HeaderContentType, MIMEApplicationForm) + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + b := FormFieldBinder(c) + + var texta string + id := int64(99) + nr := int64(88) + var slice = make([]int64, 0) + var notExisting = make([]int64, 0) + err := b. + Int64s("slice", &slice). + Int64("id", &id). + Int64("nr", &nr). + String("texta", &texta). + Int64s("notExisting", ¬Existing). + BindError() + + assert.NoError(t, err) + assert.Equal(t, "foo", texta) + assert.Equal(t, int64(1), id) + assert.Equal(t, int64(2), nr) + assert.Equal(t, []int64{5, 3, 4}, slice) + assert.Equal(t, []int64{}, notExisting) +} + +func TestValueBinder_errorStopsBinding(t *testing.T) { + // this test documents "feature" that binding multiple params can change destination if it was bound before + // failing parameter binding + + c := createTestContext("/api/user/999?id=1&nr=nope", nil, nil) + b := QueryParamsBinder(c) + + id := int64(99) // will be changed before nr binding fails + nr := int64(88) // will not be changed + err := b.Int64("id", &id). + Int64("nr", &nr). + BindError() + + assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=nr") + assert.Equal(t, int64(1), id) + assert.Equal(t, int64(88), nr) +} + +func TestValueBinder_BindError(t *testing.T) { + c := createTestContext("/api/user/999?nr=en&id=nope", nil, nil) + b := QueryParamsBinder(c) + + id := int64(99) + nr := int64(88) + err := b.Int64("id", &id). + Int64("nr", &nr). + BindError() + + assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=id") + assert.Nil(t, b.errors) + assert.Nil(t, b.BindError()) +} + +func TestValueBinder_GetValues(t *testing.T) { + var testCases = []struct { + whenValuesFunc func(sourceParam string) []string + name string + expectError string + expect []int64 + }{ + { + name: "ok, default implementation", + expect: []int64{1, 101}, + }, + { + name: "ok, values returns nil", + whenValuesFunc: func(sourceParam string) []string { + return nil + }, + expect: []int64(nil), + }, + { + name: "ok, values returns empty slice", + whenValuesFunc: func(sourceParam string) []string { + return []string{} + }, + expect: []int64(nil), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext("/search?nr=en&id=1&id=101", nil, nil) + b := QueryParamsBinder(c) + if tc.whenValuesFunc != nil { + b.ValuesFunc = tc.whenValuesFunc + } + + var IDs []int64 + err := b.Int64s("id", &IDs).BindError() + + assert.Equal(t, tc.expect, IDs) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_CustomFuncWithError(t *testing.T) { + c := createTestContext("/search?nr=en&id=1&id=101", nil, nil) + b := QueryParamsBinder(c) + + id := int64(99) + givenCustomFunc := func(values []string) []error { + assert.Equal(t, []string{"1", "101"}, values) + + return []error{ + errors.New("first error"), + errors.New("second error"), + } + } + err := b.CustomFunc("id", givenCustomFunc).BindError() + + assert.Equal(t, int64(99), id) + assert.EqualError(t, err, "first error") +} + +func TestValueBinder_CustomFunc(t *testing.T) { + var testCases = []struct { + expectValue any + name string + whenURL string + givenFuncErrors []error + expectParamValues []string + expectErrors []string + givenFailFast bool + }{ + { + name: "ok, binds value", + whenURL: "/search?nr=en&id=1&id=100", + expectParamValues: []string{"1", "100"}, + expectValue: int64(1000), + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nr=en", + expectParamValues: []string{}, + expectValue: int64(99), + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?nr=en&id=1&id=100", + expectParamValues: []string{"1", "100"}, + expectValue: int64(99), + expectErrors: []string{"previous error"}, + }, + { + name: "nok, func returns errors", + givenFuncErrors: []error{ + errors.New("first error"), + errors.New("second error"), + }, + whenURL: "/search?nr=en&id=1&id=100", + expectParamValues: []string{"1", "100"}, + expectValue: int64(99), + expectErrors: []string{"first error", "second error"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + id := int64(99) + givenCustomFunc := func(values []string) []error { + assert.Equal(t, tc.expectParamValues, values) + if tc.givenFuncErrors == nil { + id = 1000 // emulated conversion and setting value + return nil + } + return tc.givenFuncErrors + } + errs := b.CustomFunc("id", givenCustomFunc).BindErrors() + + assert.Equal(t, tc.expectValue, id) + if tc.expectErrors != nil { + assert.Len(t, errs, len(tc.expectErrors)) + for _, err := range errs { + assert.Contains(t, tc.expectErrors, err.Error()) + } + } else { + assert.Nil(t, errs) + } + }) + } +} + +func TestValueBinder_MustCustomFunc(t *testing.T) { + var testCases = []struct { + expectValue any + name string + whenURL string + givenFuncErrors []error + expectParamValues []string + expectErrors []string + givenFailFast bool + }{ + { + name: "ok, binds value", + whenURL: "/search?nr=en&id=1&id=100", + expectParamValues: []string{"1", "100"}, + expectValue: int64(1000), + }, + { + name: "nok, params values empty, returns error, value is not changed", + whenURL: "/search?nr=en", + expectParamValues: []string{}, + expectValue: int64(99), + expectErrors: []string{"code=400, message=required field value is empty, field=id"}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?nr=en&id=1&id=100", + expectParamValues: []string{"1", "100"}, + expectValue: int64(99), + expectErrors: []string{"previous error"}, + }, + { + name: "nok, func returns errors", + givenFuncErrors: []error{ + errors.New("first error"), + errors.New("second error"), + }, + whenURL: "/search?nr=en&id=1&id=100", + expectParamValues: []string{"1", "100"}, + expectValue: int64(99), + expectErrors: []string{"first error", "second error"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + id := int64(99) + givenCustomFunc := func(values []string) []error { + assert.Equal(t, tc.expectParamValues, values) + if tc.givenFuncErrors == nil { + id = 1000 // emulated conversion and setting value + return nil + } + return tc.givenFuncErrors + } + errs := b.MustCustomFunc("id", givenCustomFunc).BindErrors() + + assert.Equal(t, tc.expectValue, id) + if tc.expectErrors != nil { + assert.Len(t, errs, len(tc.expectErrors)) + for _, err := range errs { + assert.Contains(t, tc.expectErrors, err.Error()) + } + } else { + assert.Nil(t, errs) + } + }) + } +} + +func TestValueBinder_String(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectValue string + expectError string + givenBindErrors []error + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=en¶m=de", + expectValue: "en", + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nr=en", + expectValue: "default", + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?nr=en&id=1&id=100", + expectValue: "default", + expectError: "previous error", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=en¶m=de", + expectValue: "en", + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nr=en", + expectValue: "default", + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?nr=en&id=1&id=100", + expectValue: "default", + expectError: "previous error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := "default" + var err error + if tc.whenMust { + err = b.MustString("param", &dest).BindError() + } else { + err = b.String("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Strings(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectError string + givenBindErrors []error + expectValue []string + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=en¶m=de", + expectValue: []string{"en", "de"}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nr=en", + expectValue: []string{"default"}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?nr=en&id=1&id=100", + expectValue: []string{"default"}, + expectError: "previous error", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=en¶m=de", + expectValue: []string{"en", "de"}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nr=en", + expectValue: []string{"default"}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?nr=en&id=1&id=100", + expectValue: []string{"default"}, + expectError: "previous error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := []string{"default"} + var err error + if tc.whenMust { + err = b.MustStrings("param", &dest).BindError() + } else { + err = b.Strings("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Int64_intValue(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectError string + givenBindErrors []error + expectValue int64 + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=1¶m=100", + expectValue: 1, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: 99, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: 99, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: 99, + expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: 1, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: 99, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: 99, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: 99, + expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := int64(99) + var err error + if tc.whenMust { + err = b.MustInt64("param", &dest).BindError() + } else { + err = b.Int64("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Int_errorMessage(t *testing.T) { + // int/uint (without byte size) has a little bit different error message so test these separately + c := createTestContext("/search?param=nope", nil, nil) + b := QueryParamsBinder(c).FailFast(false) + + destInt := 99 + destUint := uint(98) + errs := b.Int("param", &destInt).Uint("param", &destUint).BindErrors() + + assert.Equal(t, 99, destInt) + assert.Equal(t, uint(98), destUint) + assert.EqualError(t, errs[0], `code=400, message=failed to bind field value to int, err=strconv.ParseInt: parsing "nope": invalid syntax, field=param`) + assert.EqualError(t, errs[1], `code=400, message=failed to bind field value to uint, err=strconv.ParseUint: parsing "nope": invalid syntax, field=param`) +} + +func TestValueBinder_Uint64_uintValue(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectError string + givenBindErrors []error + expectValue uint64 + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=1¶m=100", + expectValue: 1, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: 99, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: 99, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: 99, + expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: 1, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: 99, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: 99, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: 99, + expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := uint64(99) + var err error + if tc.whenMust { + err = b.MustUint64("param", &dest).BindError() + } else { + err = b.Uint64("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Int_Types(t *testing.T) { + type target struct { + int64 int64 + mustInt64 int64 + uint64 uint64 + mustUint64 uint64 + + int32 int32 + mustInt32 int32 + uint32 uint32 + mustUint32 uint32 + + int16 int16 + mustInt16 int16 + uint16 uint16 + mustUint16 uint16 + + int8 int8 + mustInt8 int8 + uint8 uint8 + mustUint8 uint8 + + byte byte + mustByte byte + + int int + mustInt int + uint uint + mustUint uint + } + types := []string{ + "int64=1", + "mustInt64=2", + "uint64=3", + "mustUint64=4", + + "int32=5", + "mustInt32=6", + "uint32=7", + "mustUint32=8", + + "int16=9", + "mustInt16=10", + "uint16=11", + "mustUint16=12", + + "int8=13", + "mustInt8=14", + "uint8=15", + "mustUint8=16", + + "byte=17", + "mustByte=18", + + "int=19", + "mustInt=20", + "uint=21", + "mustUint=22", + } + c := createTestContext("/search?"+strings.Join(types, "&"), nil, nil) + b := QueryParamsBinder(c) + + dest := target{} + err := b. + Int64("int64", &dest.int64). + MustInt64("mustInt64", &dest.mustInt64). + Uint64("uint64", &dest.uint64). + MustUint64("mustUint64", &dest.mustUint64). + Int32("int32", &dest.int32). + MustInt32("mustInt32", &dest.mustInt32). + Uint32("uint32", &dest.uint32). + MustUint32("mustUint32", &dest.mustUint32). + Int16("int16", &dest.int16). + MustInt16("mustInt16", &dest.mustInt16). + Uint16("uint16", &dest.uint16). + MustUint16("mustUint16", &dest.mustUint16). + Int8("int8", &dest.int8). + MustInt8("mustInt8", &dest.mustInt8). + Uint8("uint8", &dest.uint8). + MustUint8("mustUint8", &dest.mustUint8). + Byte("byte", &dest.byte). + MustByte("mustByte", &dest.mustByte). + Int("int", &dest.int). + MustInt("mustInt", &dest.mustInt). + Uint("uint", &dest.uint). + MustUint("mustUint", &dest.mustUint). + BindError() + + assert.NoError(t, err) + assert.Equal(t, int64(1), dest.int64) + assert.Equal(t, int64(2), dest.mustInt64) + assert.Equal(t, uint64(3), dest.uint64) + assert.Equal(t, uint64(4), dest.mustUint64) + + assert.Equal(t, int32(5), dest.int32) + assert.Equal(t, int32(6), dest.mustInt32) + assert.Equal(t, uint32(7), dest.uint32) + assert.Equal(t, uint32(8), dest.mustUint32) + + assert.Equal(t, int16(9), dest.int16) + assert.Equal(t, int16(10), dest.mustInt16) + assert.Equal(t, uint16(11), dest.uint16) + assert.Equal(t, uint16(12), dest.mustUint16) + + assert.Equal(t, int8(13), dest.int8) + assert.Equal(t, int8(14), dest.mustInt8) + assert.Equal(t, uint8(15), dest.uint8) + assert.Equal(t, uint8(16), dest.mustUint8) + + assert.Equal(t, uint8(17), dest.byte) + assert.Equal(t, uint8(18), dest.mustByte) + + assert.Equal(t, 19, dest.int) + assert.Equal(t, 20, dest.mustInt) + assert.Equal(t, uint(21), dest.uint) + assert.Equal(t, uint(22), dest.mustUint) +} + +func TestValueBinder_Int64s_intsValue(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectError string + givenBindErrors []error + expectValue []int64 + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=1¶m=2¶m=1", + expectValue: []int64{1, 2, 1}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: []int64{99}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []int64{99}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []int64{99}, + expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1¶m=2¶m=1", + expectValue: []int64{1, 2, 1}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: []int64{99}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: []int64{99}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []int64{99}, + expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := []int64{99} // when values are set with bind - contents before bind is gone + var err error + if tc.whenMust { + err = b.MustInt64s("param", &dest).BindError() + } else { + err = b.Int64s("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Uint64s_uintsValue(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectError string + givenBindErrors []error + expectValue []uint64 + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=1¶m=2¶m=1", + expectValue: []uint64{1, 2, 1}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: []uint64{99}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []uint64{99}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []uint64{99}, + expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1¶m=2¶m=1", + expectValue: []uint64{1, 2, 1}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: []uint64{99}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: []uint64{99}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []uint64{99}, + expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := []uint64{99} // when values are set with bind - contents before bind is gone + var err error + if tc.whenMust { + err = b.MustUint64s("param", &dest).BindError() + } else { + err = b.Uint64s("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Ints_Types(t *testing.T) { + type target struct { + int64 []int64 + mustInt64 []int64 + uint64 []uint64 + mustUint64 []uint64 + + int32 []int32 + mustInt32 []int32 + uint32 []uint32 + mustUint32 []uint32 + + int16 []int16 + mustInt16 []int16 + uint16 []uint16 + mustUint16 []uint16 + + int8 []int8 + mustInt8 []int8 + uint8 []uint8 + mustUint8 []uint8 + + int []int + mustInt []int + uint []uint + mustUint []uint + } + types := []string{ + "int64=1", + "mustInt64=2", + "uint64=3", + "mustUint64=4", + + "int32=5", + "mustInt32=6", + "uint32=7", + "mustUint32=8", + + "int16=9", + "mustInt16=10", + "uint16=11", + "mustUint16=12", + + "int8=13", + "mustInt8=14", + "uint8=15", + "mustUint8=16", + + "int=19", + "mustInt=20", + "uint=21", + "mustUint=22", + } + url := "/search?" + for _, v := range types { + url = url + "&" + v + "&" + v + } + c := createTestContext(url, nil, nil) + b := QueryParamsBinder(c) + + dest := target{} + err := b. + Int64s("int64", &dest.int64). + MustInt64s("mustInt64", &dest.mustInt64). + Uint64s("uint64", &dest.uint64). + MustUint64s("mustUint64", &dest.mustUint64). + Int32s("int32", &dest.int32). + MustInt32s("mustInt32", &dest.mustInt32). + Uint32s("uint32", &dest.uint32). + MustUint32s("mustUint32", &dest.mustUint32). + Int16s("int16", &dest.int16). + MustInt16s("mustInt16", &dest.mustInt16). + Uint16s("uint16", &dest.uint16). + MustUint16s("mustUint16", &dest.mustUint16). + Int8s("int8", &dest.int8). + MustInt8s("mustInt8", &dest.mustInt8). + Uint8s("uint8", &dest.uint8). + MustUint8s("mustUint8", &dest.mustUint8). + Ints("int", &dest.int). + MustInts("mustInt", &dest.mustInt). + Uints("uint", &dest.uint). + MustUints("mustUint", &dest.mustUint). + BindError() + + assert.NoError(t, err) + assert.Equal(t, []int64{1, 1}, dest.int64) + assert.Equal(t, []int64{2, 2}, dest.mustInt64) + assert.Equal(t, []uint64{3, 3}, dest.uint64) + assert.Equal(t, []uint64{4, 4}, dest.mustUint64) + + assert.Equal(t, []int32{5, 5}, dest.int32) + assert.Equal(t, []int32{6, 6}, dest.mustInt32) + assert.Equal(t, []uint32{7, 7}, dest.uint32) + assert.Equal(t, []uint32{8, 8}, dest.mustUint32) + + assert.Equal(t, []int16{9, 9}, dest.int16) + assert.Equal(t, []int16{10, 10}, dest.mustInt16) + assert.Equal(t, []uint16{11, 11}, dest.uint16) + assert.Equal(t, []uint16{12, 12}, dest.mustUint16) + + assert.Equal(t, []int8{13, 13}, dest.int8) + assert.Equal(t, []int8{14, 14}, dest.mustInt8) + assert.Equal(t, []uint8{15, 15}, dest.uint8) + assert.Equal(t, []uint8{16, 16}, dest.mustUint8) + + assert.Equal(t, []int{19, 19}, dest.int) + assert.Equal(t, []int{20, 20}, dest.mustInt) + assert.Equal(t, []uint{21, 21}, dest.uint) + assert.Equal(t, []uint{22, 22}, dest.mustUint) +} + +func TestValueBinder_Ints_Types_FailFast(t *testing.T) { + // FailFast() should stop parsing and return early + errTmpl := "code=400, message=failed to bind field value to %v, err=strconv.Parse%v: parsing \"nope\": invalid syntax, field=param" + c := createTestContext("/search?param=1¶m=nope¶m=2", nil, nil) + + var dest64 []int64 + err := QueryParamsBinder(c).FailFast(true).Int64s("param", &dest64).BindError() + assert.Equal(t, []int64(nil), dest64) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int64", "Int")) + + var dest32 []int32 + err = QueryParamsBinder(c).FailFast(true).Int32s("param", &dest32).BindError() + assert.Equal(t, []int32(nil), dest32) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int32", "Int")) + + var dest16 []int16 + err = QueryParamsBinder(c).FailFast(true).Int16s("param", &dest16).BindError() + assert.Equal(t, []int16(nil), dest16) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int16", "Int")) + + var dest8 []int8 + err = QueryParamsBinder(c).FailFast(true).Int8s("param", &dest8).BindError() + assert.Equal(t, []int8(nil), dest8) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int8", "Int")) + + var dest []int + err = QueryParamsBinder(c).FailFast(true).Ints("param", &dest).BindError() + assert.Equal(t, []int(nil), dest) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "int", "Int")) + + var destu64 []uint64 + err = QueryParamsBinder(c).FailFast(true).Uint64s("param", &destu64).BindError() + assert.Equal(t, []uint64(nil), destu64) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint64", "Uint")) + + var destu32 []uint32 + err = QueryParamsBinder(c).FailFast(true).Uint32s("param", &destu32).BindError() + assert.Equal(t, []uint32(nil), destu32) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint32", "Uint")) + + var destu16 []uint16 + err = QueryParamsBinder(c).FailFast(true).Uint16s("param", &destu16).BindError() + assert.Equal(t, []uint16(nil), destu16) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint16", "Uint")) + + var destu8 []uint8 + err = QueryParamsBinder(c).FailFast(true).Uint8s("param", &destu8).BindError() + assert.Equal(t, []uint8(nil), destu8) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint8", "Uint")) + + var destu []uint + err = QueryParamsBinder(c).FailFast(true).Uints("param", &destu).BindError() + assert.Equal(t, []uint(nil), destu) + assert.EqualError(t, err, fmt.Sprintf(errTmpl, "uint", "Uint")) +} + +func TestValueBinder_Bool(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectError string + givenBindErrors []error + givenFailFast bool + whenMust bool + expectValue bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=true¶m=1", + expectValue: true, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: false, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: false, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: false, + expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: true, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: false, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: false, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: false, + expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := false + var err error + if tc.whenMust { + err = b.MustBool("param", &dest).BindError() + } else { + err = b.Bool("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Bools(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectError string + givenBindErrors []error + expectValue []bool + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=true¶m=false¶m=1¶m=0", + expectValue: []bool{true, false, true, false}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: []bool(nil), + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenURL: "/search?param=1¶m=100", + expectValue: []bool(nil), + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=true¶m=nope¶m=100", + expectValue: []bool(nil), + expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "nok, conversion fails fast, value is not changed", + givenFailFast: true, + whenURL: "/search?param=true¶m=nope¶m=100", + expectValue: []bool(nil), + expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=true¶m=false¶m=1¶m=0", + expectValue: []bool{true, false, true, false}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: []bool(nil), + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: []bool(nil), + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []bool(nil), + expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + var dest []bool + var err error + if tc.whenMust { + err = b.MustBools("param", &dest).BindError() + } else { + err = b.Bools("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Float64(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectError string + givenBindErrors []error + expectValue float64 + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=4.3¶m=1", + expectValue: 4.3, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: 1.123, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: 1.123, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: 1.123, + expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=4.3¶m=100", + expectValue: 4.3, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: 1.123, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: 1.123, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: 1.123, + expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := 1.123 + var err error + if tc.whenMust { + err = b.MustFloat64("param", &dest).BindError() + } else { + err = b.Float64("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Float64s(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectError string + givenBindErrors []error + expectValue []float64 + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=4.3¶m=0", + expectValue: []float64{4.3, 0}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: []float64(nil), + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenURL: "/search?param=1¶m=100", + expectValue: []float64(nil), + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []float64(nil), + expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "nok, conversion fails fast, value is not changed", + givenFailFast: true, + whenURL: "/search?param=0¶m=nope¶m=100", + expectValue: []float64(nil), + expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=4.3¶m=0", + expectValue: []float64{4.3, 0}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: []float64(nil), + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: []float64(nil), + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []float64(nil), + expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + var dest []float64 + var err error + if tc.whenMust { + err = b.MustFloat64s("param", &dest).BindError() + } else { + err = b.Float64s("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Float32(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectError string + givenBindErrors []error + expectValue float32 + givenNoFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=4.3¶m=1", + expectValue: 4.3, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: 1.123, + }, + { + name: "nok, previous errors fail fast without binding value", + givenNoFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: 1.123, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: 1.123, + expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=4.3¶m=100", + expectValue: 4.3, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: 1.123, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenNoFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: 1.123, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: 1.123, + expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenNoFailFast) + if tc.givenNoFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := float32(1.123) + var err error + if tc.whenMust { + err = b.MustFloat32("param", &dest).BindError() + } else { + err = b.Float32("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Float32s(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectError string + givenBindErrors []error + expectValue []float32 + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=4.3¶m=0", + expectValue: []float32{4.3, 0}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: []float32(nil), + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenURL: "/search?param=1¶m=100", + expectValue: []float32(nil), + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []float32(nil), + expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "nok, conversion fails fast, value is not changed", + givenFailFast: true, + whenURL: "/search?param=0¶m=nope¶m=100", + expectValue: []float32(nil), + expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=4.3¶m=0", + expectValue: []float32{4.3, 0}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: []float32(nil), + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: []float32(nil), + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []float32(nil), + expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + var dest []float32 + var err error + if tc.whenMust { + err = b.MustFloat32s("param", &dest).BindError() + } else { + err = b.Float32s("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Time(t *testing.T) { + exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00") + var testCases = []struct { + expectValue time.Time + name string + whenURL string + whenLayout string + expectError string + givenBindErrors []error + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", + whenLayout: time.RFC3339, + expectValue: exampleTime, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: time.Time{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", + whenLayout: time.RFC3339, + expectValue: exampleTime, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: time.Time{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustTime("param", &dest, tc.whenLayout).BindError() + } else { + err = b.Time("param", &dest, tc.whenLayout).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Times(t *testing.T) { + exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00") + exampleTime2, _ := time.Parse(time.RFC3339, "2000-01-02T09:45:31+00:00") + var testCases = []struct { + name string + whenURL string + whenLayout string + expectError string + givenBindErrors []error + expectValue []time.Time + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", + whenLayout: time.RFC3339, + expectValue: []time.Time{exampleTime, exampleTime2}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: []time.Time(nil), + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Time(nil), + expectError: "previous error", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", + whenLayout: time.RFC3339, + expectValue: []time.Time{exampleTime, exampleTime2}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: []time.Time(nil), + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Time(nil), + expectError: "previous error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + layout := time.RFC3339 + if tc.whenLayout != "" { + layout = tc.whenLayout + } + + var dest []time.Time + var err error + if tc.whenMust { + err = b.MustTimes("param", &dest, layout).BindError() + } else { + err = b.Times("param", &dest, layout).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Duration(t *testing.T) { + example := 42 * time.Second + var testCases = []struct { + name string + whenURL string + expectError string + givenBindErrors []error + expectValue time.Duration + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=42s¶m=1ms", + expectValue: example, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: 0, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: 0, + expectError: "previous error", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=42s¶m=1ms", + expectValue: example, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: 0, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: 0, + expectError: "previous error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest time.Duration + var err error + if tc.whenMust { + err = b.MustDuration("param", &dest).BindError() + } else { + err = b.Duration("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_Durations(t *testing.T) { + exampleDuration := 42 * time.Second + exampleDuration2 := 1 * time.Millisecond + var testCases = []struct { + name string + whenURL string + expectError string + givenBindErrors []error + expectValue []time.Duration + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=42s¶m=1ms", + expectValue: []time.Duration{exampleDuration, exampleDuration2}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: []time.Duration(nil), + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Duration(nil), + expectError: "previous error", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=42s¶m=1ms", + expectValue: []time.Duration{exampleDuration, exampleDuration2}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: []time.Duration(nil), + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + givenBindErrors: []error{errors.New("previous error")}, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Duration(nil), + expectError: "previous error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + var dest []time.Duration + var err error + if tc.whenMust { + err = b.MustDurations("param", &dest).BindError() + } else { + err = b.Durations("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_BindUnmarshaler(t *testing.T) { + exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00") + + var testCases = []struct { + expectValue Timestamp + name string + whenURL string + expectError string + givenBindErrors []error + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", + expectValue: Timestamp(exampleTime), + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: Timestamp{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: Timestamp{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: Timestamp{}, + expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=2020-12-23T09:45:31%2B02:00¶m=2000-01-02T09:45:31%2B00:00", + expectValue: Timestamp(exampleTime), + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: Timestamp{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: Timestamp{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: Timestamp{}, + expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest Timestamp + var err error + if tc.whenMust { + err = b.MustBindUnmarshaler("param", &dest).BindError() + } else { + err = b.BindUnmarshaler("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_JSONUnmarshaler(t *testing.T) { + example := big.NewInt(999) + + var testCases = []struct { + name string + whenURL string + expectError string + expectValue big.Int + givenBindErrors []error + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: big.Int{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: big.Int{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=xxx", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest big.Int + var err error + if tc.whenMust { + err = b.MustJSONUnmarshaler("param", &dest).BindError() + } else { + err = b.JSONUnmarshaler("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_TextUnmarshaler(t *testing.T) { + example := big.NewInt(999) + + var testCases = []struct { + name string + whenURL string + expectError string + expectValue big.Int + givenBindErrors []error + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: big.Int{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=999¶m=998", + expectValue: *example, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: big.Int{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=xxx", + expectValue: big.Int{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=xxx", + expectValue: big.Int{}, + expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest big.Int + var err error + if tc.whenMust { + err = b.MustTextUnmarshaler("param", &dest).BindError() + } else { + err = b.TextUnmarshaler("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_BindWithDelimiter_types(t *testing.T) { + var testCases = []struct { + expect any + name string + whenURL string + }{ + { + name: "ok, strings", + expect: []string{"1", "2", "1"}, + }, + { + name: "ok, int64", + expect: []int64{1, 2, 1}, + }, + { + name: "ok, int32", + expect: []int32{1, 2, 1}, + }, + { + name: "ok, int16", + expect: []int16{1, 2, 1}, + }, + { + name: "ok, int8", + expect: []int8{1, 2, 1}, + }, + { + name: "ok, int", + expect: []int{1, 2, 1}, + }, + { + name: "ok, uint64", + expect: []uint64{1, 2, 1}, + }, + { + name: "ok, uint32", + expect: []uint32{1, 2, 1}, + }, + { + name: "ok, uint16", + expect: []uint16{1, 2, 1}, + }, + { + name: "ok, uint8", + expect: []uint8{1, 2, 1}, + }, + { + name: "ok, uint", + expect: []uint{1, 2, 1}, + }, + { + name: "ok, float64", + expect: []float64{1, 2, 1}, + }, + { + name: "ok, float32", + expect: []float32{1, 2, 1}, + }, + { + name: "ok, bool", + whenURL: "/search?param=1,false¶m=true", + expect: []bool{true, false, true}, + }, + { + name: "ok, Duration", + whenURL: "/search?param=1s,42s¶m=1ms", + expect: []time.Duration{1 * time.Second, 42 * time.Second, 1 * time.Millisecond}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + URL := "/search?param=1,2¶m=1" + if tc.whenURL != "" { + URL = tc.whenURL + } + c := createTestContext(URL, nil, nil) + b := QueryParamsBinder(c) + + switch tc.expect.(type) { + case []string: + var dest []string + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []int64: + var dest []int64 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []int32: + var dest []int32 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []int16: + var dest []int16 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []int8: + var dest []int8 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []int: + var dest []int + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []uint64: + var dest []uint64 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []uint32: + var dest []uint32 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []uint16: + var dest []uint16 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []uint8: + var dest []uint8 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []uint: + var dest []uint + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []float64: + var dest []float64 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []float32: + var dest []float32 + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []bool: + var dest []bool + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + case []time.Duration: + var dest []time.Duration + assert.NoError(t, b.BindWithDelimiter("param", &dest, ",").BindError()) + assert.Equal(t, tc.expect, dest) + default: + assert.Fail(t, "invalid type") + } + }) + } +} + +func TestValueBinder_BindWithDelimiter(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectError string + givenBindErrors []error + expectValue []int64 + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value", + whenURL: "/search?param=1,2¶m=1", + expectValue: []int64{1, 2, 1}, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: []int64(nil), + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []int64(nil), + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []int64(nil), + expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1,2¶m=1", + expectValue: []int64{1, 2, 1}, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: []int64(nil), + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: []int64(nil), + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []int64(nil), + expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest []int64 + var err error + if tc.whenMust { + err = b.MustBindWithDelimiter("param", &dest, ",").BindError() + } else { + err = b.BindWithDelimiter("param", &dest, ",").BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestBindWithDelimiter_invalidType(t *testing.T) { + c := createTestContext("/search?param=1¶m=100", nil, nil) + b := QueryParamsBinder(c) + + var dest []BindUnmarshaler + err := b.BindWithDelimiter("param", &dest, ",").BindError() + assert.Equal(t, []BindUnmarshaler(nil), dest) + assert.EqualError(t, err, "code=400, message=unsupported bind type, field=param") +} + +func TestValueBinder_UnixTime(t *testing.T) { + exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43+00:00") // => 1609180603 + var testCases = []struct { + expectValue time.Time + name string + whenURL string + expectError string + givenBindErrors []error + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value, unix time in seconds", + whenURL: "/search?param=1609180603¶m=1609180604", + expectValue: exampleTime, + }, + { + name: "ok, binds value, unix time over int32 value", + whenURL: "/search?param=2147483648¶m=1609180604", + expectValue: time.Unix(2147483648, 0), + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: time.Time{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1609180603¶m=1609180604", + expectValue: exampleTime, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: time.Time{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustUnixTime("param", &dest).BindError() + } else { + err = b.UnixTime("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano()) + assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC)) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_UnixTimeMilli(t *testing.T) { + exampleTime, _ := time.Parse(time.RFC3339Nano, "2022-03-13T15:13:30.140000000+00:00") // => 1647184410140 + var testCases = []struct { + expectValue time.Time + name string + whenURL string + expectError string + givenBindErrors []error + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value, unix time in milliseconds", + whenURL: "/search?param=1647184410140¶m=1647184410199", + expectValue: exampleTime, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: time.Time{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1647184410140¶m=1647184410199", + expectValue: exampleTime, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: time.Time{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustUnixTimeMilli("param", &dest).BindError() + } else { + err = b.UnixTimeMilli("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano()) + assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC)) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_UnixTimeNano(t *testing.T) { + exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43.000000000+00:00") // => 1609180603 + exampleTimeNano, _ := time.Parse(time.RFC3339Nano, "2020-12-28T18:36:43.123456789+00:00") // => 1609180603123456789 + exampleTimeNanoBelowSec, _ := time.Parse(time.RFC3339Nano, "1970-01-01T00:00:00.999999999+00:00") + var testCases = []struct { + expectValue time.Time + name string + whenURL string + expectError string + givenBindErrors []error + givenFailFast bool + whenMust bool + }{ + { + name: "ok, binds value, unix time in nano seconds (sec precision)", + whenURL: "/search?param=1609180603000000000¶m=1609180604", + expectValue: exampleTime, + }, + { + name: "ok, binds value, unix time in nano seconds", + whenURL: "/search?param=1609180603123456789¶m=1609180604", + expectValue: exampleTimeNano, + }, + { + name: "ok, binds value, unix time in nano seconds (below 1 sec)", + whenURL: "/search?param=999999999¶m=1609180604", + expectValue: exampleTimeNanoBelowSec, + }, + { + name: "ok, params values empty, value is not changed", + whenURL: "/search?nope=1", + expectValue: time.Time{}, + }, + { + name: "nok, previous errors fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + { + name: "ok (must), binds value", + whenMust: true, + whenURL: "/search?param=1609180603000000000¶m=1609180604", + expectValue: exampleTime, + }, + { + name: "ok (must), params values empty, returns error, value is not changed", + whenMust: true, + whenURL: "/search?nope=1", + expectValue: time.Time{}, + expectError: "code=400, message=required field value is empty, field=param", + }, + { + name: "nok (must), previous errors fail fast without binding value", + givenFailFast: true, + whenMust: true, + whenURL: "/search?param=1¶m=100", + expectValue: time.Time{}, + expectError: "previous error", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustUnixTimeNano("param", &dest).BindError() + } else { + err = b.UnixTimeNano("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue.UnixNano(), dest.UnixNano()) + assert.Equal(t, tc.expectValue.In(time.UTC), dest.In(time.UTC)) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) { + type Opts struct { + Param int64 `query:"param"` + } + c := createTestContext("/search?param=1¶m=100", nil, nil) + + b.ReportAllocs() + b.ResetTimer() + binder := new(DefaultBinder) + for i := 0; i < b.N; i++ { + var dest Opts + _ = binder.Bind(c, &dest) + } +} + +func BenchmarkValueBinder_BindInt64_single(b *testing.B) { + c := createTestContext("/search?param=1¶m=100", nil, nil) + + b.ReportAllocs() + b.ResetTimer() + type Opts struct { + Param int64 + } + binder := QueryParamsBinder(c) + for i := 0; i < b.N; i++ { + var dest Opts + _ = binder.Int64("param", &dest.Param).BindError() + } +} + +func BenchmarkRawFunc_Int64_single(b *testing.B) { + c := createTestContext("/search?param=1¶m=100", nil, nil) + + rawFunc := func(input string, defaultValue int64) (int64, bool) { + if input == "" { + return defaultValue, true + } + n, err := strconv.Atoi(input) + if err != nil { + return 0, false + } + return int64(n), true + } + + b.ReportAllocs() + b.ResetTimer() + type Opts struct { + Param int64 + } + for i := 0; i < b.N; i++ { + var dest Opts + if n, ok := rawFunc(c.QueryParam("param"), 1); ok { + dest.Param = n + } + } +} + +func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) { + type Opts struct { + String string `query:"string"` + Strings []string `query:"strings"` + Int64 int64 `query:"int64"` + Uint64 uint64 `query:"uint64"` + Int32 int32 `query:"int32"` + Uint32 uint32 `query:"uint32"` + Int16 int16 `query:"int16"` + Uint16 uint16 `query:"uint16"` + Int8 int8 `query:"int8"` + Uint8 uint8 `query:"uint8"` + } + c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil) + + b.ReportAllocs() + b.ResetTimer() + binder := new(DefaultBinder) + for i := 0; i < b.N; i++ { + var dest Opts + _ = binder.Bind(c, &dest) + if dest.Int64 != 1 { + b.Fatalf("int64!=1") + } + } +} + +func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) { + type Opts struct { + String string `query:"string"` + Strings []string `query:"strings"` + Int64 int64 `query:"int64"` + Uint64 uint64 `query:"uint64"` + Int32 int32 `query:"int32"` + Uint32 uint32 `query:"uint32"` + Int16 int16 `query:"int16"` + Uint16 uint16 `query:"uint16"` + Int8 int8 `query:"int8"` + Uint8 uint8 `query:"uint8"` + } + c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil) + + b.ReportAllocs() + b.ResetTimer() + binder := QueryParamsBinder(c) + for i := 0; i < b.N; i++ { + var dest Opts + _ = binder. + Int64("int64", &dest.Int64). + Int32("int32", &dest.Int32). + Int16("int16", &dest.Int16). + Int8("int8", &dest.Int8). + String("string", &dest.String). + Uint64("int64", &dest.Uint64). + Uint32("int32", &dest.Uint32). + Uint16("int16", &dest.Uint16). + Uint8("int8", &dest.Uint8). + Strings("strings", &dest.Strings). + BindError() + if dest.Int64 != 1 { + b.Fatalf("int64!=1") + } + } +} + +func TestValueBinder_TimeError(t *testing.T) { + var testCases = []struct { + expectValue time.Time + name string + whenURL string + whenLayout string + expectError string + givenBindErrors []error + givenFailFast bool + whenMust bool + }{ + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\": extra text: \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: time.Time{}, + expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\": extra text: \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + dest := time.Time{} + var err error + if tc.whenMust { + err = b.MustTime("param", &dest, tc.whenLayout).BindError() + } else { + err = b.Time("param", &dest, tc.whenLayout).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_TimesError(t *testing.T) { + var testCases = []struct { + name string + whenURL string + whenLayout string + expectError string + givenBindErrors []error + expectValue []time.Time + givenFailFast bool + whenMust bool + }{ + { + name: "nok, fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Time(nil), + expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + layout := time.RFC3339 + if tc.whenLayout != "" { + layout = tc.whenLayout + } + + var dest []time.Time + var err error + if tc.whenMust { + err = b.MustTimes("param", &dest, layout).BindError() + } else { + err = b.Times("param", &dest, layout).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_DurationError(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectError string + givenBindErrors []error + expectValue time.Duration + givenFailFast bool + whenMust bool + }{ + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: 0, + expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: 0, + expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + if tc.givenFailFast { + b.errors = []error{errors.New("previous error")} + } + + var dest time.Duration + var err error + if tc.whenMust { + err = b.MustDuration("param", &dest).BindError() + } else { + err = b.Duration("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValueBinder_DurationsError(t *testing.T) { + var testCases = []struct { + name string + whenURL string + expectError string + givenBindErrors []error + expectValue []time.Duration + givenFailFast bool + whenMust bool + }{ + { + name: "nok, fail fast without binding value", + givenFailFast: true, + whenURL: "/search?param=1¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, err=time: missing unit in duration \"1\", field=param", + }, + { + name: "nok, conversion fails, value is not changed", + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param", + }, + { + name: "nok (must), conversion fails, value is not changed", + whenMust: true, + whenURL: "/search?param=nope¶m=100", + expectValue: []time.Duration(nil), + expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := createTestContext(tc.whenURL, nil, nil) + b := QueryParamsBinder(c).FailFast(tc.givenFailFast) + b.errors = tc.givenBindErrors + + var dest []time.Duration + var err error + if tc.whenMust { + err = b.MustDurations("param", &dest).BindError() + } else { + err = b.Durations("param", &dest).BindError() + } + + assert.Equal(t, tc.expectValue, dest) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 000000000..0fa3a3f18 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,11 @@ +coverage: + status: + project: + default: + threshold: 1% + patch: + default: + threshold: 1% + +comment: + require_changes: true \ No newline at end of file diff --git a/context.go b/context.go index 27da5ffe3..f91ea7a60 100644 --- a/context.go +++ b/context.go @@ -1,254 +1,165 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( "bytes" - "encoding/json" "encoding/xml" + "errors" "fmt" "io" + "io/fs" + "log/slog" "mime/multipart" "net" "net/http" "net/url" - "os" + "path" "path/filepath" "strings" "sync" ) -type ( - // Context represents the context of the current HTTP request. It holds request and - // response objects, path, path parameters, data and registered handler. - Context interface { - // Request returns `*http.Request`. - Request() *http.Request - - // SetRequest sets `*http.Request`. - SetRequest(r *http.Request) - - // SetResponse sets `*Response`. - SetResponse(r *Response) - - // Response returns `*Response`. - Response() *Response - - // IsTLS returns true if HTTP connection is TLS otherwise false. - IsTLS() bool - - // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. - IsWebSocket() bool - - // Scheme returns the HTTP protocol scheme, `http` or `https`. - Scheme() string - - // RealIP returns the client's network address based on `X-Forwarded-For` - // or `X-Real-IP` request header. - RealIP() string - - // Path returns the registered path for the handler. - Path() string - - // SetPath sets the registered path for the handler. - SetPath(p string) - - // Param returns path parameter by name. - Param(name string) string - - // ParamNames returns path parameter names. - ParamNames() []string - - // SetParamNames sets path parameter names. - SetParamNames(names ...string) - - // ParamValues returns path parameter values. - ParamValues() []string - - // SetParamValues sets path parameter values. - SetParamValues(values ...string) - - // QueryParam returns the query param for the provided name. - QueryParam(name string) string - - // QueryParams returns the query parameters as `url.Values`. - QueryParams() url.Values - - // QueryString returns the URL query string. - QueryString() string - - // FormValue returns the form field value for the provided name. - FormValue(name string) string - - // FormParams returns the form parameters as `url.Values`. - FormParams() (url.Values, error) - - // FormFile returns the multipart form file for the provided name. - FormFile(name string) (*multipart.FileHeader, error) - - // MultipartForm returns the multipart form. - MultipartForm() (*multipart.Form, error) - - // Cookie returns the named cookie provided in the request. - Cookie(name string) (*http.Cookie, error) - - // SetCookie adds a `Set-Cookie` header in HTTP response. - SetCookie(cookie *http.Cookie) - - // Cookies returns the HTTP cookies sent with the request. - Cookies() []*http.Cookie - - // Get retrieves data from the context. - Get(key string) interface{} - - // Set saves data in the context. - Set(key string, val interface{}) - - // Bind binds the request body into provided type `i`. The default binder - // does it based on Content-Type header. - Bind(i interface{}) error - - // Validate validates provided `i`. It is usually called after `Context#Bind()`. - // Validator must be registered using `Echo#Validator`. - Validate(i interface{}) error - - // Render renders a template with data and sends a text/html response with status - // code. Renderer must be registered using `Echo.Renderer`. - Render(code int, name string, data interface{}) error - - // HTML sends an HTTP response with status code. - HTML(code int, html string) error - - // HTMLBlob sends an HTTP blob response with status code. - HTMLBlob(code int, b []byte) error - - // String sends a string response with status code. - String(code int, s string) error - - // JSON sends a JSON response with status code. - JSON(code int, i interface{}) error - - // JSONPretty sends a pretty-print JSON with status code. - JSONPretty(code int, i interface{}, indent string) error - - // JSONBlob sends a JSON blob response with status code. - JSONBlob(code int, b []byte) error - - // JSONP sends a JSONP response with status code. It uses `callback` to construct - // the JSONP payload. - JSONP(code int, callback string, i interface{}) error - - // JSONPBlob sends a JSONP blob response with status code. It uses `callback` - // to construct the JSONP payload. - JSONPBlob(code int, callback string, b []byte) error - - // XML sends an XML response with status code. - XML(code int, i interface{}) error - - // XMLPretty sends a pretty-print XML with status code. - XMLPretty(code int, i interface{}, indent string) error - - // XMLBlob sends an XML blob response with status code. - XMLBlob(code int, b []byte) error - - // Blob sends a blob response with status code and content type. - Blob(code int, contentType string, b []byte) error - - // Stream sends a streaming response with status code and content type. - Stream(code int, contentType string, r io.Reader) error - - // File sends a response with the content of the file. - File(file string) error - - // Attachment sends a response as attachment, prompting client to save the - // file. - Attachment(file string, name string) error - - // Inline sends a response as inline, opening the file in the browser. - Inline(file string, name string) error - - // NoContent sends a response with no body and a status code. - NoContent(code int) error - - // Redirect redirects the request to a provided URL with status code. - Redirect(code int, url string) error - - // Error invokes the registered HTTP error handler. Generally used by middleware. - Error(err error) - - // Handler returns the matched handler by router. - Handler() HandlerFunc - - // SetHandler sets the matched handler by router. - SetHandler(h HandlerFunc) - - // Logger returns the `Logger` instance. - Logger() Logger - - // Set the logger - SetLogger(l Logger) +const ( + // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain. + // Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests. + // It is added to context only when Router does not find matching method handler for request. + ContextKeyHeaderAllow = "echo_header_allow" +) - // Echo returns the `Echo` instance. - Echo() *Echo +const ( + // defaultMemory is default value for memory limit that is used when + // parsing multipart forms (See (*http.Request).ParseMultipartForm) + defaultMemory int64 = 32 << 20 // 32 MB + indexPage = "index.html" +) - // Reset resets the context after request completes. It must be called along - // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. - // See `Echo#ServeHTTP()` - Reset(r *http.Request, w http.ResponseWriter) +// Context represents the context of the current HTTP request. It holds request and +// response objects, path, path parameters, data and registered handler. +type Context struct { + request *http.Request + orgResponse *Response + response http.ResponseWriter + query url.Values + + // formParseMaxMemory is used for http.Request.ParseMultipartForm + formParseMaxMemory int64 + + route *RouteInfo + pathValues *PathValues + + store map[string]any + echo *Echo + logger *slog.Logger + + path string + lock sync.RWMutex +} + +// NewContext returns a new Context instance. +// +// Note: request,response and e can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway +// these arguments are useful when creating context for tests and cases like that. +func NewContext(r *http.Request, w http.ResponseWriter, opts ...any) *Context { + var e *Echo + for _, opt := range opts { + switch v := opt.(type) { + case *Echo: + e = v + } } + return newContext(r, w, e) +} - context struct { - request *http.Request - response *Response - path string - pnames []string - pvalues []string - query url.Values - handler HandlerFunc - store Map - echo *Echo - logger Logger - lock sync.RWMutex +func newContext(r *http.Request, w http.ResponseWriter, e *Echo) *Context { + c := &Context{ + pathValues: nil, + store: make(map[string]any), + echo: e, + logger: nil, } -) + var logger *slog.Logger + paramLen := int32(0) + formParseMaxMemory := defaultMemory + if e != nil { + paramLen = e.contextPathParamAllocSize.Load() + logger = e.Logger + formParseMaxMemory = e.formParseMaxMemory + } + if logger == nil { + logger = slog.Default() + } + c.logger = logger + p := make(PathValues, 0, paramLen) + c.pathValues = &p -const ( - defaultMemory = 32 << 20 // 32 MB - indexPage = "index.html" - defaultIndent = " " -) + c.SetRequest(r) + c.orgResponse = NewResponse(w, logger) + c.response = c.orgResponse + c.formParseMaxMemory = formParseMaxMemory + return c +} + +// Reset resets the context after request completes. It must be called along +// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. +// See `Echo#ServeHTTP()` +func (c *Context) Reset(r *http.Request, w http.ResponseWriter) { + c.request = r + c.orgResponse.reset(w) + c.response = c.orgResponse + c.query = nil + c.store = nil + c.logger = c.echo.Logger + + c.route = nil + c.path = "" + // NOTE: empty by setting length to 0. PathValues has to have capacity of c.echo.contextPathParamAllocSize at all times + *c.pathValues = (*c.pathValues)[:0] +} -func (c *context) writeContentType(value string) { - header := c.Response().Header() +func (c *Context) writeContentType(value string) { + header := c.response.Header() if header.Get(HeaderContentType) == "" { header.Set(HeaderContentType, value) } } -func (c *context) Request() *http.Request { +// Request returns `*http.Request`. +func (c *Context) Request() *http.Request { return c.request } -func (c *context) SetRequest(r *http.Request) { +// SetRequest sets `*http.Request`. +func (c *Context) SetRequest(r *http.Request) { c.request = r } -func (c *context) Response() *Response { +// Response returns `*Response`. +func (c *Context) Response() http.ResponseWriter { return c.response } -func (c *context) SetResponse(r *Response) { +// SetResponse sets `*http.ResponseWriter`. Some context methods and/or middleware require that given ResponseWriter implements following +// method `Unwrap() http.ResponseWriter` which eventually should return *echo.Response instance. +func (c *Context) SetResponse(r http.ResponseWriter) { c.response = r } -func (c *context) IsTLS() bool { +// IsTLS returns true if HTTP connection is TLS otherwise false. +func (c *Context) IsTLS() bool { return c.request.TLS != nil } -func (c *context) IsWebSocket() bool { +// IsWebSocket returns true if HTTP connection is WebSocket otherwise false. +func (c *Context) IsWebSocket() bool { upgrade := c.request.Header.Get(HeaderUpgrade) - return strings.ToLower(upgrade) == "websocket" + connection := c.request.Header.Get(HeaderConnection) + return strings.EqualFold(upgrade, "websocket") && strings.Contains(strings.ToLower(connection), "upgrade") } -func (c *context) Scheme() string { +// Scheme returns the HTTP protocol scheme, `http` or `https`. +func (c *Context) Scheme() string { // Can't use `r.Request.URL.Scheme` // See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0 if c.IsTLS() { @@ -269,77 +180,161 @@ func (c *context) Scheme() string { return "http" } -func (c *context) RealIP() string { +// RealIP returns the client's network address based on `X-Forwarded-For` +// or `X-Real-IP` request header. +// The behavior can be configured using `Echo#IPExtractor`. +func (c *Context) RealIP() string { + if c.echo != nil && c.echo.IPExtractor != nil { + return c.echo.IPExtractor(c.request) + } + // Fall back to legacy behavior if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" { - return strings.Split(ip, ", ")[0] + i := strings.IndexAny(ip, ",") + if i > 0 { + xffip := strings.TrimSpace(ip[:i]) + xffip = strings.TrimPrefix(xffip, "[") + xffip = strings.TrimSuffix(xffip, "]") + return xffip + } + return ip } if ip := c.request.Header.Get(HeaderXRealIP); ip != "" { + ip = strings.TrimPrefix(ip, "[") + ip = strings.TrimSuffix(ip, "]") return ip } ra, _, _ := net.SplitHostPort(c.request.RemoteAddr) return ra } -func (c *context) Path() string { +// Path returns the registered path for the handler. +func (c *Context) Path() string { return c.path } -func (c *context) SetPath(p string) { +// SetPath sets the registered path for the handler. +func (c *Context) SetPath(p string) { c.path = p } -func (c *context) Param(name string) string { - for i, n := range c.pnames { - if i < len(c.pvalues) { - if n == name { - return c.pvalues[i] - } - } +// RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route. +// +// RouteInfo returns generic "empty" struct for these cases: +// * Context is accessed before Routing is done. For example inside Pre middlewares (`e.Pre()`) +// * Router did not find matching route - 404 (route not found) +// * Router did not find matching route with same method - 405 (method not allowed) +func (c *Context) RouteInfo() RouteInfo { + if c.route != nil { + return c.route.Clone() } - return "" + return RouteInfo{} +} + +// Param returns path parameter by name. +func (c *Context) Param(name string) string { + return c.pathValues.GetOr(name, "") } -func (c *context) ParamNames() []string { - return c.pnames +// ParamOr returns the path parameter or default value for the provided name. +// +// Notes for DefaultRouter implementation: +// Path parameter could be empty for cases like that: +// * route `/release-:version/bin` and request URL is `/release-/bin` +// * route `/api/:version/image.jpg` and request URL is `/api//image.jpg` +// but not when path parameter is last part of route path +// * route `/download/file.:ext` will not match request `/download/file.` +func (c *Context) ParamOr(name, defaultValue string) string { + return c.pathValues.GetOr(name, defaultValue) } -func (c *context) SetParamNames(names ...string) { - c.pnames = names +// PathValues returns path parameter values. +func (c *Context) PathValues() PathValues { + return *c.pathValues } -func (c *context) ParamValues() []string { - return c.pvalues[:len(c.pnames)] +// SetPathValues sets path parameters for current request. +func (c *Context) SetPathValues(pathValues PathValues) { + if pathValues == nil { + panic("context SetPathValues called with nil PathValues") + } + c.setPathValues(&pathValues) } -func (c *context) SetParamValues(values ...string) { - c.pvalues = values +// InitializeRoute sets the route related variables of this request to the context. +func (c *Context) InitializeRoute(ri *RouteInfo, pathValues *PathValues) { + c.route = ri + c.path = ri.Path + c.setPathValues(pathValues) } -func (c *context) QueryParam(name string) string { +func (c *Context) setPathValues(pv *PathValues) { + // Router accesses c.pathValues by index and may resize it to full capacity during routing + // for that to work without going out-of-bounds we must make sure that c.pathValues slice is not replaced with smaller + // slice than Router can set when routing Route with maximum amount of parameters. + pathValues := c.pathValues + if cap(*c.pathValues) < len(*pv) { + // normally we should not end up here. pathValues is normally sized to Echo.contextPathParamAllocSize which should not + // be smaller than anything router knows as maximum path parameter count to be. + tmp := make(PathValues, len(*pv)) + c.pathValues = &tmp + pathValues = c.pathValues + } else if len(*c.pathValues) != len(*pv) { + *pathValues = (*pathValues)[0:len(*pv)] // resize slice to given params length for copy to work + } + copy(*pathValues, *pv) +} + +// QueryParam returns the query param for the provided name. +func (c *Context) QueryParam(name string) string { if c.query == nil { c.query = c.request.URL.Query() } return c.query.Get(name) } -func (c *context) QueryParams() url.Values { +// QueryParamOr returns the query param or default value for the provided name. +// Note: QueryParamOr does not distinguish if query had no value by that name or value was empty string +// This means URLs `/test?search=` and `/test` would both return `1` for `c.QueryParamOr("search", "1")` +func (c *Context) QueryParamOr(name, defaultValue string) string { + value := c.QueryParam(name) + if value == "" { + value = defaultValue + } + return value +} + +// QueryParams returns the query parameters as `url.Values`. +func (c *Context) QueryParams() url.Values { if c.query == nil { c.query = c.request.URL.Query() } return c.query } -func (c *context) QueryString() string { +// QueryString returns the URL query string. +func (c *Context) QueryString() string { return c.request.URL.RawQuery } -func (c *context) FormValue(name string) string { +// FormValue returns the form field value for the provided name. +func (c *Context) FormValue(name string) string { return c.request.FormValue(name) } -func (c *context) FormParams() (url.Values, error) { +// FormValueOr returns the form field value or default value for the provided name. +// Note: FormValueOr does not distinguish if form had no value by that name or value was empty string +func (c *Context) FormValueOr(name, defaultValue string) string { + value := c.FormValue(name) + if value == "" { + value = defaultValue + } + return value +} + +// FormValues returns the form field values as `url.Values`. +func (c *Context) FormValues() (url.Values, error) { if strings.HasPrefix(c.request.Header.Get(HeaderContentType), MIMEMultipartForm) { - if err := c.request.ParseMultipartForm(defaultMemory); err != nil { + if err := c.request.ParseMultipartForm(c.formParseMaxMemory); err != nil { return nil, err } } else { @@ -350,91 +345,115 @@ func (c *context) FormParams() (url.Values, error) { return c.request.Form, nil } -func (c *context) FormFile(name string) (*multipart.FileHeader, error) { +// FormFile returns the multipart form file for the provided name. +func (c *Context) FormFile(name string) (*multipart.FileHeader, error) { f, fh, err := c.request.FormFile(name) - defer f.Close() - return fh, err + if err != nil { + return nil, err + } + _ = f.Close() + return fh, nil } -func (c *context) MultipartForm() (*multipart.Form, error) { - err := c.request.ParseMultipartForm(defaultMemory) +// MultipartForm returns the multipart form. +func (c *Context) MultipartForm() (*multipart.Form, error) { + err := c.request.ParseMultipartForm(c.formParseMaxMemory) return c.request.MultipartForm, err } -func (c *context) Cookie(name string) (*http.Cookie, error) { +// Cookie returns the named cookie provided in the request. +func (c *Context) Cookie(name string) (*http.Cookie, error) { return c.request.Cookie(name) } -func (c *context) SetCookie(cookie *http.Cookie) { +// SetCookie adds a `Set-Cookie` header in HTTP response. +func (c *Context) SetCookie(cookie *http.Cookie) { http.SetCookie(c.Response(), cookie) } -func (c *context) Cookies() []*http.Cookie { +// Cookies returns the HTTP cookies sent with the request. +func (c *Context) Cookies() []*http.Cookie { return c.request.Cookies() } -func (c *context) Get(key string) interface{} { +// Get retrieves data from the context. +// Method returns any(nil) when key does not exist which is different from typed nil (eg. []byte(nil)). +func (c *Context) Get(key string) any { c.lock.RLock() defer c.lock.RUnlock() return c.store[key] } -func (c *context) Set(key string, val interface{}) { +// Set saves data in the context. +func (c *Context) Set(key string, val any) { c.lock.Lock() defer c.lock.Unlock() if c.store == nil { - c.store = make(Map) + c.store = make(map[string]any) } c.store[key] = val } -func (c *context) Bind(i interface{}) error { - return c.echo.Binder.Bind(i, c) +// Bind binds path params, query params and the request body into provided type `i`. The default binder +// binds body based on Content-Type header. +func (c *Context) Bind(i any) error { + return c.echo.Binder.Bind(c, i) } -func (c *context) Validate(i interface{}) error { +// Validate validates provided `i`. It is usually called after `Context#Bind()`. +// Validator must be registered using `Echo#Validator`. +func (c *Context) Validate(i any) error { if c.echo.Validator == nil { return ErrValidatorNotRegistered } return c.echo.Validator.Validate(i) } -func (c *context) Render(code int, name string, data interface{}) (err error) { +// Render renders a template with data and sends a text/html response with status +// code. Renderer must be registered using `Echo.Renderer`. +func (c *Context) Render(code int, name string, data any) (err error) { if c.echo.Renderer == nil { return ErrRendererNotRegistered } + // as Renderer.Render can fail, and in that case we need to delay sending status code to the client until + // (global) error handler decides the correct status code for the error to be sent to the client, so we need to write + // the rendered template to the buffer first. + // + // html.Template.ExecuteTemplate() documentations writes: + // > If an error occurs executing the template or writing its output, + // > execution stops, but partial results may already have been written to + // > the output writer. + buf := new(bytes.Buffer) - if err = c.echo.Renderer.Render(buf, name, data, c); err != nil { + if err = c.echo.Renderer.Render(c, buf, name, data); err != nil { return } return c.HTMLBlob(code, buf.Bytes()) } -func (c *context) HTML(code int, html string) (err error) { +// HTML sends an HTTP response with status code. +func (c *Context) HTML(code int, html string) (err error) { return c.HTMLBlob(code, []byte(html)) } -func (c *context) HTMLBlob(code int, b []byte) (err error) { +// HTMLBlob sends an HTTP blob response with status code. +func (c *Context) HTMLBlob(code int, b []byte) (err error) { return c.Blob(code, MIMETextHTMLCharsetUTF8, b) } -func (c *context) String(code int, s string) (err error) { +// String sends a string response with status code. +func (c *Context) String(code int, s string) (err error) { return c.Blob(code, MIMETextPlainCharsetUTF8, []byte(s)) } -func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error) { - enc := json.NewEncoder(c.response) - _, pretty := c.QueryParams()["pretty"] - if c.echo.Debug || pretty { - enc.SetIndent("", " ") - } +func (c *Context) jsonPBlob(code int, callback string, i any) (err error) { c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(callback + "(")); err != nil { return } - if err = enc.Encode(i); err != nil { + if err = c.echo.JSONSerializer.Serialize(c, i, ""); err != nil { return } if _, err = c.response.Write([]byte(");")); err != nil { @@ -443,37 +462,47 @@ func (c *context) jsonPBlob(code int, callback string, i interface{}) (err error return } -func (c *context) json(code int, i interface{}, indent string) error { - enc := json.NewEncoder(c.response) - if indent != "" { - enc.SetIndent("", indent) +func (c *Context) json(code int, i any, indent string) error { + c.writeContentType(MIMEApplicationJSON) + + // as JSONSerializer.Serialize can fail, and in that case we need to delay sending status code to the client until + // (global) error handler decides correct status code for the error to be sent to the client. + // For that we need to use writer that can store the proposed status code until the first Write is called. + if r, err := UnwrapResponse(c.response); err == nil { + r.Status = code + } else { + resp := c.Response() + c.SetResponse(&delayedStatusWriter{ResponseWriter: resp, status: code}) + defer c.SetResponse(resp) } - c.writeContentType(MIMEApplicationJSONCharsetUTF8) - c.response.Status = code - return enc.Encode(i) + + return c.echo.JSONSerializer.Serialize(c, i, indent) } -func (c *context) JSON(code int, i interface{}) (err error) { - indent := "" - if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { - indent = defaultIndent - } - return c.json(code, i, indent) +// JSON sends a JSON response with status code. +func (c *Context) JSON(code int, i any) (err error) { + return c.json(code, i, "") } -func (c *context) JSONPretty(code int, i interface{}, indent string) (err error) { +// JSONPretty sends a pretty-print JSON with status code. +func (c *Context) JSONPretty(code int, i any, indent string) (err error) { return c.json(code, i, indent) } -func (c *context) JSONBlob(code int, b []byte) (err error) { - return c.Blob(code, MIMEApplicationJSONCharsetUTF8, b) +// JSONBlob sends a JSON blob response with status code. +func (c *Context) JSONBlob(code int, b []byte) (err error) { + return c.Blob(code, MIMEApplicationJSON, b) } -func (c *context) JSONP(code int, callback string, i interface{}) (err error) { +// JSONP sends a JSONP response with status code. It uses `callback` to construct +// the JSONP payload. +func (c *Context) JSONP(code int, callback string, i any) (err error) { return c.jsonPBlob(code, callback, i) } -func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) { +// JSONPBlob sends a JSONP blob response with status code. It uses `callback` +// to construct the JSONP payload. +func (c *Context) JSONPBlob(code int, callback string, b []byte) (err error) { c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(callback + "(")); err != nil { @@ -486,7 +515,7 @@ func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) { return } -func (c *context) xml(code int, i interface{}, indent string) (err error) { +func (c *Context) xml(code int, i any, indent string) (err error) { c.writeContentType(MIMEApplicationXMLCharsetUTF8) c.response.WriteHeader(code) enc := xml.NewEncoder(c.response) @@ -499,19 +528,18 @@ func (c *context) xml(code int, i interface{}, indent string) (err error) { return enc.Encode(i) } -func (c *context) XML(code int, i interface{}) (err error) { - indent := "" - if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { - indent = defaultIndent - } - return c.xml(code, i, indent) +// XML sends an XML response with status code. +func (c *Context) XML(code int, i any) (err error) { + return c.xml(code, i, "") } -func (c *context) XMLPretty(code int, i interface{}, indent string) (err error) { +// XMLPretty sends a pretty-print XML with status code. +func (c *Context) XMLPretty(code int, i any, indent string) (err error) { return c.xml(code, i, indent) } -func (c *context) XMLBlob(code int, b []byte) (err error) { +// XMLBlob sends an XML blob response with status code. +func (c *Context) XMLBlob(code int, b []byte) (err error) { c.writeContentType(MIMEApplicationXMLCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(xml.Header)); err != nil { @@ -521,62 +549,89 @@ func (c *context) XMLBlob(code int, b []byte) (err error) { return } -func (c *context) Blob(code int, contentType string, b []byte) (err error) { +// Blob sends a blob response with status code and content type. +func (c *Context) Blob(code int, contentType string, b []byte) (err error) { c.writeContentType(contentType) c.response.WriteHeader(code) _, err = c.response.Write(b) return } -func (c *context) Stream(code int, contentType string, r io.Reader) (err error) { +// Stream sends a streaming response with status code and content type. +func (c *Context) Stream(code int, contentType string, r io.Reader) (err error) { c.writeContentType(contentType) c.response.WriteHeader(code) _, err = io.Copy(c.response, r) return } -func (c *context) File(file string) (err error) { - f, err := os.Open(file) +// File sends a response with the content of the file. +func (c *Context) File(file string) error { + return fsFile(c, file, c.echo.Filesystem) +} + +// FileFS serves file from given file system. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (c *Context) FileFS(file string, filesystem fs.FS) error { + return fsFile(c, file, filesystem) +} + +func fsFile(c *Context, file string, filesystem fs.FS) error { + file = path.Clean(file) // `os.Open` and `os.DirFs.Open()` behave differently, later does not like ``, `.`, `..` at all, but we allowed those now need to clean + f, err := filesystem.Open(file) if err != nil { - return NotFoundHandler(c) + return ErrNotFound } defer f.Close() fi, _ := f.Stat() if fi.IsDir() { - file = filepath.Join(file, indexPage) - f, err = os.Open(file) + file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect. + f, err = filesystem.Open(file) if err != nil { - return NotFoundHandler(c) + return ErrNotFound } defer f.Close() if fi, err = f.Stat(); err != nil { - return + return err } } - http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f) - return + ff, ok := f.(io.ReadSeeker) + if !ok { + return errors.New("file does not implement io.ReadSeeker") + } + http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) + return nil } -func (c *context) Attachment(file, name string) error { +// Attachment sends a response as attachment, prompting client to save the file. +func (c *Context) Attachment(file, name string) error { return c.contentDisposition(file, name, "attachment") } -func (c *context) Inline(file, name string) error { +// Inline sends a response as inline, opening the file in the browser. +func (c *Context) Inline(file, name string) error { return c.contentDisposition(file, name, "inline") } -func (c *context) contentDisposition(file, name, dispositionType string) error { - c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf("%s; filename=%q", dispositionType, name)) +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +func (c *Context) contentDisposition(file, name, dispositionType string) error { + c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`, dispositionType, quoteEscaper.Replace(name))) return c.File(file) } -func (c *context) NoContent(code int) error { +// NoContent sends a response with no body and a status code. +func (c *Context) NoContent(code int) error { c.response.WriteHeader(code) return nil } -func (c *context) Redirect(code int, url string) error { +// Redirect redirects the request to a provided URL with status code. +func (c *Context) Redirect(code int, url string) error { if code < 300 || code > 308 { return ErrInvalidRedirectCode } @@ -585,45 +640,20 @@ func (c *context) Redirect(code int, url string) error { return nil } -func (c *context) Error(err error) { - c.echo.HTTPErrorHandler(err, c) -} - -func (c *context) Echo() *Echo { - return c.echo -} - -func (c *context) Handler() HandlerFunc { - return c.handler -} - -func (c *context) SetHandler(h HandlerFunc) { - c.handler = h -} - -func (c *context) Logger() Logger { - res := c.logger - if res != nil { - return res +// Logger returns logger in Context +func (c *Context) Logger() *slog.Logger { + if c.logger != nil { + return c.logger } return c.echo.Logger } -func (c *context) SetLogger(l Logger) { - c.logger = l +// SetLogger sets logger in Context +func (c *Context) SetLogger(logger *slog.Logger) { + c.logger = logger } -func (c *context) Reset(r *http.Request, w http.ResponseWriter) { - c.request = r - c.response.reset(w) - c.query = nil - c.handler = NotFoundHandler - c.store = nil - c.path = "" - c.pnames = nil - c.logger = nil - // NOTE: Don't reset because it has to have length c.echo.maxParam at all times - for i := 0; i < *c.echo.maxParam; i++ { - c.pvalues[i] = "" - } +// Echo returns the `Echo` instance. +func (c *Context) Echo() *Echo { + return c.echo } diff --git a/context_generic.go b/context_generic.go new file mode 100644 index 000000000..7cf8b296c --- /dev/null +++ b/context_generic.go @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echo + +import "errors" + +// ErrNonExistentKey is error that is returned when key does not exist +var ErrNonExistentKey = errors.New("non existent key") + +// ErrInvalidKeyType is error that is returned when the value is not castable to expected type. +var ErrInvalidKeyType = errors.New("invalid key type") + +// ContextGet retrieves a value from the context store or ErrNonExistentKey error the key is missing. +// Returns ErrInvalidKeyType error if the value is not castable to type T. +func ContextGet[T any](c *Context, key string) (T, error) { + c.lock.RLock() + defer c.lock.RUnlock() + + val, ok := c.store[key] + if !ok { + var zero T + return zero, ErrNonExistentKey + } + + typed, ok := val.(T) + if !ok { + var zero T + return zero, ErrInvalidKeyType + } + + return typed, nil +} + +// ContextGetOr retrieves a value from the context store or returns a default value when the key +// is missing. Returns ErrInvalidKeyType error if the value is not castable to type T. +func ContextGetOr[T any](c *Context, key string, defaultValue T) (T, error) { + typed, err := ContextGet[T](c, key) + if err == ErrNonExistentKey { + return defaultValue, nil + } + return typed, err +} diff --git a/context_generic_test.go b/context_generic_test.go new file mode 100644 index 000000000..ce468ac3e --- /dev/null +++ b/context_generic_test.go @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestContextGetOK(t *testing.T) { + c := NewContext(nil, nil) + + c.Set("key", int64(123)) + + v, err := ContextGet[int64](c, "key") + assert.NoError(t, err) + assert.Equal(t, int64(123), v) +} + +func TestContextGetNonExistentKey(t *testing.T) { + c := NewContext(nil, nil) + + c.Set("key", int64(123)) + + v, err := ContextGet[int64](c, "nope") + assert.ErrorIs(t, err, ErrNonExistentKey) + assert.Equal(t, int64(0), v) +} + +func TestContextGetInvalidCast(t *testing.T) { + c := NewContext(nil, nil) + + c.Set("key", int64(123)) + + v, err := ContextGet[bool](c, "key") + assert.ErrorIs(t, err, ErrInvalidKeyType) + assert.Equal(t, false, v) +} + +func TestContextGetOrOK(t *testing.T) { + c := NewContext(nil, nil) + + c.Set("key", int64(123)) + + v, err := ContextGetOr[int64](c, "key", 999) + assert.NoError(t, err) + assert.Equal(t, int64(123), v) +} + +func TestContextGetOrNonExistentKey(t *testing.T) { + c := NewContext(nil, nil) + + c.Set("key", int64(123)) + + v, err := ContextGetOr[int64](c, "nope", 999) + assert.NoError(t, err) + assert.Equal(t, int64(999), v) +} + +func TestContextGetOrInvalidCast(t *testing.T) { + c := NewContext(nil, nil) + + c.Set("key", int64(123)) + + v, err := ContextGetOr[float32](c, "key", float32(999)) + assert.ErrorIs(t, err, ErrInvalidKeyType) + assert.Equal(t, float32(0), v) +} diff --git a/context_test.go b/context_test.go index 47be19cce..5945c9ecc 100644 --- a/context_test.go +++ b/context_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( @@ -5,36 +8,36 @@ import ( "crypto/tls" "encoding/json" "encoding/xml" - "errors" "fmt" - "github.com/labstack/gommon/log" "io" + "io/fs" + "log/slog" "math" "mime/multipart" "net/http" "net/http/httptest" "net/url" + "os" "strings" "testing" "text/template" "time" - testify "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/assert" ) -type ( - Template struct { - templates *template.Template - } -) +type Template struct { + templates *template.Template +} -var testUser = user{1, "Jon Snow"} +var testUser = user{ID: 1, Name: "Jon Snow"} func BenchmarkAllocJSONP(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + e.Logger = slog.New(slog.DiscardHandler) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) b.ResetTimer() b.ReportAllocs() @@ -46,9 +49,10 @@ func BenchmarkAllocJSONP(b *testing.B) { func BenchmarkAllocJSON(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + e.Logger = slog.New(slog.DiscardHandler) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) b.ResetTimer() b.ReportAllocs() @@ -60,9 +64,10 @@ func BenchmarkAllocJSON(b *testing.B) { func BenchmarkAllocXML(b *testing.B) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + e.Logger = slog.New(slog.DiscardHandler) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) b.ResetTimer() b.ReportAllocs() @@ -72,338 +77,430 @@ func BenchmarkAllocXML(b *testing.B) { } } -func (t *Template) Render(w io.Writer, name string, data interface{}, c Context) error { - return t.templates.ExecuteTemplate(w, name, data) +func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) { + c := Context{request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, + }} + for i := 0; i < b.N; i++ { + c.RealIP() + } } -type responseWriterErr struct { +func (t *Template) Render(c *Context, w io.Writer, name string, data any) error { + return t.templates.ExecuteTemplate(w, name, data) } -func (responseWriterErr) Header() http.Header { - return http.Header{} -} +func TestContextEcho(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) -func (responseWriterErr) Write([]byte) (int, error) { - return 0, errors.New("err") + assert.Equal(t, e, c.Echo()) } -func (responseWriterErr) WriteHeader(statusCode int) { +func TestContextRequest(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + assert.NotNil(t, c.Request()) + assert.Equal(t, req, c.Request()) } -func TestContext(t *testing.T) { +func TestContextResponse(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) - assert := testify.New(t) - - // Echo - assert.Equal(e, c.Echo()) + c := e.NewContext(req, rec) - // Request - assert.NotNil(c.Request()) + assert.NotNil(t, c.Response()) +} - // Response - assert.NotNil(c.Response()) +func TestContextRenderTemplate(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() - //-------- - // Render - //-------- + c := e.NewContext(req, rec) tmpl := &Template{ templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")), } - c.echo.Renderer = tmpl + c.Echo().Renderer = tmpl err := c.Render(http.StatusOK, "hello", "Jon Snow") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("Hello, Jon Snow!", rec.Body.String()) - } - - c.echo.Renderer = nil - err = c.Render(http.StatusOK, "hello", "Jon Snow") - assert.Error(err) - - // JSON - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON+"\n", rec.Body.String()) - } - - // JSON with "?pretty" - req = httptest.NewRequest(http.MethodGet, "/?pretty", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSONPretty+"\n", rec.Body.String()) - } - req = httptest.NewRequest(http.MethodGet, "/", nil) // reset - - // JSONPretty - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSONPretty+"\n", rec.Body.String()) - } - - // JSON (error) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.JSON(http.StatusOK, make(chan bool)) - assert.Error(err) - - // JSONP - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "Hello, Jon Snow!", rec.Body.String()) + } +} + +func TestContextRenderTemplateError(t *testing.T) { + // we test that when template rendering fails, no response is sent to the client yet, so the global error handler can decide what to do + e := New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + tmpl := &Template{ + templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")), + } + c.Echo().Renderer = tmpl + err := c.Render(http.StatusOK, "not_existing", "Jon Snow") + + assert.EqualError(t, err, `template: no template "not_existing" associated with template "hello"`) + assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client + assert.Empty(t, rec.Body.String()) // body must not be sent to the client +} + +func TestContextRenderErrorsOnNoRenderer(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + c.Echo().Renderer = nil + assert.Error(t, c.Render(http.StatusOK, "hello", "Jon Snow")) +} + +func TestContextStream(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec) + + r, w := io.Pipe() + go func() { + defer w.Close() + for i := 0; i < 3; i++ { + fmt.Fprintf(w, "data: index %v\n\n", i) + time.Sleep(5 * time.Millisecond) + } + }() + + err := c.Stream(http.StatusOK, "text/event-stream", r) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "text/event-stream", rec.Header().Get(HeaderContentType)) + assert.Equal(t, "data: index 0\n\ndata: index 1\n\ndata: index 2\n\n", rec.Body.String()) + } +} + +func TestContextHTML(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := NewContext(req, rec) + + err := c.HTML(http.StatusOK, "Hi, Jon Snow") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hi, Jon Snow", rec.Body.String()) + } +} + +func TestContextHTMLBlob(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := NewContext(req, rec) + + err := c.HTMLBlob(http.StatusOK, []byte("Hi, Jon Snow")) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hi, Jon Snow", rec.Body.String()) + } +} + +func TestContextJSON(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + c := e.NewContext(req, rec) + + err := c.JSON(http.StatusOK, user{ID: 1, Name: "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON+"\n", rec.Body.String()) + } +} + +func TestContextJSONErrorsOut(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + c := e.NewContext(req, rec) + + err := c.JSON(http.StatusOK, make(chan bool)) + assert.EqualError(t, err, "json: unsupported type: chan bool") + + assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client + assert.Empty(t, rec.Body.String()) // body must not be sent to the client +} + +func TestContextJSONWithNotEchoResponse(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + c := e.NewContext(req, rec) + + c.SetResponse(rec) + + err := c.JSON(http.StatusCreated, map[string]float64{"foo": math.NaN()}) + assert.EqualError(t, err, "json: unsupported value: NaN") + + assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client + assert.Empty(t, rec.Body.String()) // body must not be sent to the client +} + +func TestContextJSONPretty(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec) + + err := c.JSONPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, " ") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) + } +} + +func TestContextJSONWithEmptyIntent(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec) + + u := user{ID: 1, Name: "Jon Snow"} + emptyIndent := "" + buf := new(bytes.Buffer) + + enc := json.NewEncoder(buf) + enc.SetIndent(emptyIndent, emptyIndent) + _ = enc.Encode(u) + err := c.JSONPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, emptyIndent) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) + assert.Equal(t, buf.String(), rec.Body.String()) + } +} + +func TestContextJSONP(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec) + callback := "callback" - err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(callback+"("+userJSON+"\n);", rec.Body.String()) - } - - // XML - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.XML(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXML, rec.Body.String()) - } - - // XML with "?pretty" - req = httptest.NewRequest(http.MethodGet, "/?pretty", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.XML(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) - } - req = httptest.NewRequest(http.MethodGet, "/", nil) - - // XML (error) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.XML(http.StatusOK, make(chan bool)) - assert.Error(err) - - // XML response write error - c = e.NewContext(req, rec).(*context) - c.response.Writer = responseWriterErr{} - err = c.XML(0, 0) - testify.Error(t, err) - - // XMLPretty - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXMLPretty, rec.Body.String()) - } - - t.Run("empty indent", func(t *testing.T) { - var ( - u = user{1, "Jon Snow"} - buf = new(bytes.Buffer) - emptyIndent = "" - ) - - t.Run("json", func(t *testing.T) { - buf.Reset() - assert := testify.New(t) - - // New JSONBlob with empty indent - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - enc := json.NewEncoder(buf) - enc.SetIndent(emptyIndent, emptyIndent) - err = enc.Encode(u) - err = c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(buf.String(), rec.Body.String()) - } - }) + err := c.JSONP(http.StatusOK, callback, user{ID: 1, Name: "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, callback+"("+userJSON+"\n);", rec.Body.String()) + } +} - t.Run("xml", func(t *testing.T) { - buf.Reset() - assert := testify.New(t) - - // New XMLBlob with empty indent - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - enc := xml.NewEncoder(buf) - enc.Indent(emptyIndent, emptyIndent) - err = enc.Encode(u) - err = c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+buf.String(), rec.Body.String()) - } - }) - }) +func TestContextJSONBlob(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec) - // Legacy JSONBlob - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - data, err := json.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + data, err := json.Marshal(user{ID: 1, Name: "Jon Snow"}) + assert.NoError(t, err) err = c.JSONBlob(http.StatusOK, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON, rec.Body.String()) - } - - // Legacy JSONPBlob - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - callback = "callback" - data, err = json.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON, rec.Body.String()) + } +} + +func TestContextJSONPBlob(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec) + + callback := "callback" + data, err := json.Marshal(user{ID: 1, Name: "Jon Snow"}) + assert.NoError(t, err) err = c.JSONPBlob(http.StatusOK, callback, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(callback+"("+userJSON+");", rec.Body.String()) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, callback+"("+userJSON+");", rec.Body.String()) } +} - // Legacy XMLBlob - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - data, err = xml.Marshal(user{1, "Jon Snow"}) - assert.NoError(err) - err = c.XMLBlob(http.StatusOK, data) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(xml.Header+userXML, rec.Body.String()) - } - - // String - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.String(http.StatusOK, "Hello, World!") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal("Hello, World!", rec.Body.String()) - } - - // HTML - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.HTML(http.StatusOK, "Hello, World!") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal("Hello, World!", rec.Body.String()) - } - - // Stream - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - r := strings.NewReader("response from a stream") - err = c.Stream(http.StatusOK, "application/octet-stream", r) - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("application/octet-stream", rec.Header().Get(HeaderContentType)) - assert.Equal("response from a stream", rec.Body.String()) - } - - // Attachment - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.Attachment("_fixture/images/walle.png", "walle.png") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("attachment; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(219885, rec.Body.Len()) - } - - // Inline - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = c.Inline("_fixture/images/walle.png", "walle.png") - if assert.NoError(err) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("inline; filename=\"walle.png\"", rec.Header().Get(HeaderContentDisposition)) - assert.Equal(219885, rec.Body.Len()) - } - - // NoContent - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - c.NoContent(http.StatusOK) - assert.Equal(http.StatusOK, rec.Code) - - // Error - rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - c.Error(errors.New("error")) - assert.Equal(http.StatusInternalServerError, rec.Code) - - // Reset - c.SetParamNames("foo") - c.SetParamValues("bar") - c.Set("foe", "ban") - c.query = url.Values(map[string][]string{"fon": {"baz"}}) - c.Reset(req, httptest.NewRecorder()) - assert.Equal(0, len(c.ParamValues())) - assert.Equal(0, len(c.ParamNames())) - assert.Equal(0, len(c.store)) - assert.Equal("", c.Path()) - assert.Equal(0, len(c.QueryParams())) +func TestContextXML(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec) + + err := c.XML(http.StatusOK, user{ID: 1, Name: "Jon Snow"}) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXML, rec.Body.String()) + } } -func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { +func TestContextXMLPretty(t *testing.T) { e := New() + rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec) + + err := c.XMLPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, " ") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String()) + } +} + +func TestContextXMLBlob(t *testing.T) { + e := New() rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) - err := c.JSON(http.StatusCreated, user{1, "Jon Snow"}) + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec) + + data, err := xml.Marshal(user{ID: 1, Name: "Jon Snow"}) + assert.NoError(t, err) + err = c.XMLBlob(http.StatusOK, data) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+userXML, rec.Body.String()) + } +} + +func TestContextXMLWithEmptyIntent(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec) - assert := testify.New(t) - if assert.NoError(err) { - assert.Equal(http.StatusCreated, rec.Code) - assert.Equal(MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(userJSON+"\n", rec.Body.String()) + u := user{ID: 1, Name: "Jon Snow"} + emptyIndent := "" + buf := new(bytes.Buffer) + + enc := xml.NewEncoder(buf) + enc.Indent(emptyIndent, emptyIndent) + _ = enc.Encode(u) + err := c.XMLPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, emptyIndent) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, xml.Header+buf.String(), rec.Body.String()) } } -func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) { +func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) - err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()}) + c := e.NewContext(req, rec) + err := c.JSON(http.StatusCreated, user{ID: 1, Name: "Jon Snow"}) - assert := testify.New(t) - if assert.Error(err) { - assert.False(c.response.Committed) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON+"\n", rec.Body.String()) } } +func TestContextAttachment(t *testing.T) { + var testCases = []struct { + name string + whenName string + expectHeader string + }{ + { + name: "ok", + whenName: "walle.png", + expectHeader: `attachment; filename="walle.png"`, + }, + { + name: "ok, escape quotes in malicious filename", + whenName: `malicious.sh"; \"; dummy=.txt`, + expectHeader: `attachment; filename="malicious.sh\"; \\\"; dummy=.txt"`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec) + + err := c.Attachment("_fixture/images/walle.png", tc.whenName) + if assert.NoError(t, err) { + assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition)) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, 219885, rec.Body.Len()) + } + }) + } +} + +func TestContextInline(t *testing.T) { + var testCases = []struct { + name string + whenName string + expectHeader string + }{ + { + name: "ok", + whenName: "walle.png", + expectHeader: `inline; filename="walle.png"`, + }, + { + name: "ok, escape quotes in malicious filename", + whenName: `malicious.sh"; \"; dummy=.txt`, + expectHeader: `inline; filename="malicious.sh\"; \\\"; dummy=.txt"`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec) + + err := c.Inline("_fixture/images/walle.png", tc.whenName) + if assert.NoError(t, err) { + assert.Equal(t, tc.expectHeader, rec.Header().Get(HeaderContentDisposition)) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, 219885, rec.Body.Len()) + } + }) + } +} + +func TestContextNoContent(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) + c := e.NewContext(req, rec) + + c.NoContent(http.StatusOK) + assert.Equal(t, http.StatusOK, rec.Code) +} + func TestContextCookie(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -412,24 +509,22 @@ func TestContextCookie(t *testing.T) { req.Header.Add(HeaderCookie, theme) req.Header.Add(HeaderCookie, user) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) - - assert := testify.New(t) + c := e.NewContext(req, rec) // Read single cookie, err := c.Cookie("theme") - if assert.NoError(err) { - assert.Equal("theme", cookie.Name) - assert.Equal("light", cookie.Value) + if assert.NoError(t, err) { + assert.Equal(t, "theme", cookie.Name) + assert.Equal(t, "light", cookie.Value) } // Read multiple for _, cookie := range c.Cookies() { switch cookie.Name { case "theme": - assert.Equal("light", cookie.Value) + assert.Equal(t, "light", cookie.Value) case "user": - assert.Equal("Jon Snow", cookie.Value) + assert.Equal(t, "Jon Snow", cookie.Value) } } @@ -444,47 +539,244 @@ func TestContextCookie(t *testing.T) { HttpOnly: true, } c.SetCookie(cookie) - assert.Contains(rec.Header().Get(HeaderSetCookie), "SSID") - assert.Contains(rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") - assert.Contains(rec.Header().Get(HeaderSetCookie), "labstack.com") - assert.Contains(rec.Header().Get(HeaderSetCookie), "Secure") - assert.Contains(rec.Header().Get(HeaderSetCookie), "HttpOnly") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "labstack.com") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Secure") + assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly") } -func TestContextPath(t *testing.T) { - e := New() - r := e.Router() +func TestContext_PathValues(t *testing.T) { + var testCases = []struct { + name string + given PathValues + expect PathValues + }{ + { + name: "param exists", + given: PathValues{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + expect: PathValues{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + }, + { + name: "params is empty", + given: PathValues{}, + expect: PathValues{}, + }, + } - r.Add(http.MethodGet, "/users/:id", nil) - c := e.NewContext(nil, nil) - r.Find(http.MethodGet, "/users/1", c) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, nil) + + c.SetPathValues(tc.given) + + assert.EqualValues(t, tc.expect, c.PathValues()) + }) + } +} + +func TestContext_PathParam(t *testing.T) { + var testCases = []struct { + name string + given PathValues + whenParamName string + expect string + }{ + { + name: "param exists", + given: PathValues{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + expect: "101", + }, + { + name: "multiple same param values exists - return first", + given: PathValues{ + {Name: "uid", Value: "101"}, + {Name: "uid", Value: "202"}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + expect: "101", + }, + { + name: "param does not exists", + given: PathValues{ + {Name: "uid", Value: "101"}, + }, + whenParamName: "nope", + expect: "", + }, + } - assert := testify.New(t) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, nil) - assert.Equal("/users/:id", c.Path()) + c.SetPathValues(tc.given) - r.Add(http.MethodGet, "/users/:uid/files/:fid", nil) - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/users/1/files/1", c) - assert.Equal("/users/:uid/files/:fid", c.Path()) + assert.EqualValues(t, tc.expect, c.Param(tc.whenParamName)) + }) + } } -func TestContextPathParam(t *testing.T) { - e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, nil) +func TestContext_PathParamDefault(t *testing.T) { + var testCases = []struct { + name string + given PathValues + whenParamName string + whenDefaultValue string + expect string + }{ + { + name: "param exists", + given: PathValues{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + whenDefaultValue: "999", + expect: "101", + }, + { + name: "param exists and is empty", + given: PathValues{ + {Name: "uid", Value: ""}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + whenDefaultValue: "999", + expect: "", // <-- this is different from QueryParamOr behaviour + }, + { + name: "param does not exists", + given: PathValues{ + {Name: "uid", Value: "101"}, + }, + whenParamName: "nope", + whenDefaultValue: "999", + expect: "999", + }, + } - // ParamNames - c.SetParamNames("uid", "fid") - testify.EqualValues(t, []string{"uid", "fid"}, c.ParamNames()) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, nil) - // ParamValues - c.SetParamValues("101", "501") - testify.EqualValues(t, []string{"101", "501"}, c.ParamValues()) + c.SetPathValues(tc.given) - // Param - testify.Equal(t, "501", c.Param("fid")) - testify.Equal(t, "", c.Param("undefined")) + assert.EqualValues(t, tc.expect, c.ParamOr(tc.whenParamName, tc.whenDefaultValue)) + }) + } +} + +func TestContextGetAndSetPathValuesMutability(t *testing.T) { + t.Run("c.PathValues() does not return copy and modifying raw slice mutates value in context", func(t *testing.T) { + e := New() + e.contextPathParamAllocSize.Store(1) + + req := httptest.NewRequest(http.MethodGet, "/:foo", nil) + c := e.NewContext(req, nil) + + params := PathValues{{Name: "foo", Value: "101"}} + c.SetPathValues(params) + + // round-trip param values with modification + paramVals := c.PathValues() + assert.Equal(t, params, c.PathValues()) + + // PathValues() does not return copy and modifying raw slice mutates value in context + paramVals[0] = PathValue{Name: "xxx", Value: "yyy"} + assert.Equal(t, PathValues{PathValue{Name: "xxx", Value: "yyy"}}, c.PathValues()) + }) + + t.Run("calling SetPathValues with bigger size changes capacity in context", func(t *testing.T) { + e := New() + e.contextPathParamAllocSize.Store(1) + + req := httptest.NewRequest(http.MethodGet, "/:foo", nil) + c := e.NewContext(req, nil) + // increase path param capacity in context + pathValues := PathValues{ + {Name: "aaa", Value: "bbb"}, + {Name: "ccc", Value: "ddd"}, + } + c.SetPathValues(pathValues) + assert.Equal(t, pathValues, c.PathValues()) + + // shouldn't explode during Reset() afterwards! + assert.NotPanics(t, func() { + c.Reset(nil, nil) + }) + assert.Equal(t, PathValues{}, c.PathValues()) + assert.Len(t, *c.pathValues, 0) + assert.Equal(t, 2, cap(*c.pathValues)) + }) + + t.Run("calling SetPathValues with smaller size slice does not change capacity in context", func(t *testing.T) { + e := New() + + req := httptest.NewRequest(http.MethodGet, "/:foo", nil) + c := e.NewContext(req, nil) + c.pathValues = &PathValues{ + {Name: "aaa", Value: "bbb"}, + {Name: "ccc", Value: "ddd"}, + } + + pathValues := PathValues{ + {Name: "aaa", Value: "bbb"}, + } + // given pathValues slice is smaller. this should not decrease c.pathValues capacity + c.SetPathValues(pathValues) + assert.Equal(t, pathValues, c.PathValues()) + + // shouldn't explode during Reset() afterwards! + assert.NotPanics(t, func() { + c.Reset(nil, nil) + }) + assert.Equal(t, PathValues{}, c.PathValues()) + assert.Len(t, *c.pathValues, 0) + assert.Equal(t, 2, cap(*c.pathValues)) + }) + +} + +// Issue #1655 +func TestContext_SetParamNamesShouldNotModifyPathValuesCapacity(t *testing.T) { + e := New() + c := e.NewContext(nil, nil) + + assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load()) + expectedTwoParams := PathValues{ + {Name: "1", Value: "one"}, + {Name: "2", Value: "two"}, + } + c.SetPathValues(expectedTwoParams) + assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load()) + assert.Equal(t, expectedTwoParams, c.PathValues()) + + expectedThreeParams := PathValues{ + {Name: "1", Value: "one"}, + {Name: "2", Value: "two"}, + {Name: "3", Value: "three"}, + } + c.SetPathValues(expectedThreeParams) + assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load()) + assert.Equal(t, expectedThreeParams, c.PathValues()) } func TestContextFormValue(t *testing.T) { @@ -498,44 +790,154 @@ func TestContextFormValue(t *testing.T) { c := e.NewContext(req, nil) // FormValue - testify.Equal(t, "Jon Snow", c.FormValue("name")) - testify.Equal(t, "jon@labstack.com", c.FormValue("email")) + assert.Equal(t, "Jon Snow", c.FormValue("name")) + assert.Equal(t, "jon@labstack.com", c.FormValue("email")) + + // FormValueOr + assert.Equal(t, "Jon Snow", c.FormValueOr("name", "nope")) + assert.Equal(t, "default", c.FormValueOr("missing", "default")) - // FormParams - params, err := c.FormParams() - if testify.NoError(t, err) { - testify.Equal(t, url.Values{ + // FormValues + values, err := c.FormValues() + if assert.NoError(t, err) { + assert.Equal(t, url.Values{ "name": []string{"Jon Snow"}, "email": []string{"jon@labstack.com"}, - }, params) + }, values) } // Multipart FormParams error req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) req.Header.Add(HeaderContentType, MIMEMultipartForm) c = e.NewContext(req, nil) - params, err = c.FormParams() - testify.Nil(t, params) - testify.Error(t, err) + values, err = c.FormValues() + assert.Nil(t, values) + assert.Error(t, err) } -func TestContextQueryParam(t *testing.T) { - q := make(url.Values) - q.Set("name", "Jon Snow") - q.Set("email", "jon@labstack.com") - req := httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil) - e := New() - c := e.NewContext(req, nil) +func TestContext_QueryParams(t *testing.T) { + var testCases = []struct { + expect url.Values + name string + givenURL string + }{ + { + name: "multiple values in url", + givenURL: "/?test=1&test=2&email=jon%40labstack.com", + expect: url.Values{ + "test": []string{"1", "2"}, + "email": []string{"jon@labstack.com"}, + }, + }, + { + name: "single value in url", + givenURL: "/?nope=1", + expect: url.Values{ + "nope": []string{"1"}, + }, + }, + { + name: "no query params in url", + givenURL: "/?", + expect: url.Values{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + e := New() + c := e.NewContext(req, nil) + + assert.Equal(t, tc.expect, c.QueryParams()) + }) + } +} + +func TestContext_QueryParam(t *testing.T) { + var testCases = []struct { + name string + givenURL string + whenParamName string + expect string + }{ + { + name: "value exists in url", + givenURL: "/?test=1", + whenParamName: "test", + expect: "1", + }, + { + name: "multiple values exists in url", + givenURL: "/?test=9&test=8", + whenParamName: "test", + expect: "9", // <-- first value in returned + }, + { + name: "value does not exists in url", + givenURL: "/?nope=1", + whenParamName: "test", + expect: "", + }, + { + name: "value is empty in url", + givenURL: "/?test=", + whenParamName: "test", + expect: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + e := New() + c := e.NewContext(req, nil) + + assert.Equal(t, tc.expect, c.QueryParam(tc.whenParamName)) + }) + } +} + +func TestContext_QueryParamDefault(t *testing.T) { + var testCases = []struct { + name string + givenURL string + whenParamName string + whenDefaultValue string + expect string + }{ + { + name: "value exists in url", + givenURL: "/?test=1", + whenParamName: "test", + whenDefaultValue: "999", + expect: "1", + }, + { + name: "value does not exists in url", + givenURL: "/?nope=1", + whenParamName: "test", + whenDefaultValue: "999", + expect: "999", + }, + { + name: "value is empty in url", + givenURL: "/?test=", + whenParamName: "test", + whenDefaultValue: "999", + expect: "999", + }, + } - // QueryParam - testify.Equal(t, "Jon Snow", c.QueryParam("name")) - testify.Equal(t, "jon@labstack.com", c.QueryParam("email")) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + e := New() + c := e.NewContext(req, nil) - // QueryParams - testify.Equal(t, url.Values{ - "name": []string{"Jon Snow"}, - "email": []string{"jon@labstack.com"}, - }, c.QueryParams()) + assert.Equal(t, tc.expect, c.QueryParamOr(tc.whenParamName, tc.whenDefaultValue)) + }) + } } func TestContextFormFile(t *testing.T) { @@ -543,7 +945,7 @@ func TestContextFormFile(t *testing.T) { buf := new(bytes.Buffer) mr := multipart.NewWriter(buf) w, err := mr.CreateFormFile("file", "test") - if testify.NoError(t, err) { + if assert.NoError(t, err) { w.Write([]byte("test")) } mr.Close() @@ -552,8 +954,8 @@ func TestContextFormFile(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.FormFile("file") - if testify.NoError(t, err) { - testify.Equal(t, "test", f.Filename) + if assert.NoError(t, err) { + assert.Equal(t, "test", f.Filename) } } @@ -562,14 +964,26 @@ func TestContextMultipartForm(t *testing.T) { buf := new(bytes.Buffer) mw := multipart.NewWriter(buf) mw.WriteField("name", "Jon Snow") + fileContent := "This is a test file" + w, err := mw.CreateFormFile("file", "test.txt") + if assert.NoError(t, err) { + w.Write([]byte(fileContent)) + } mw.Close() req := httptest.NewRequest(http.MethodPost, "/", buf) req.Header.Set(HeaderContentType, mw.FormDataContentType()) rec := httptest.NewRecorder() c := e.NewContext(req, rec) f, err := c.MultipartForm() - if testify.NoError(t, err) { - testify.NotNil(t, f) + if assert.NoError(t, err) { + assert.NotNil(t, f) + + files := f.File["file"] + if assert.Len(t, files, 1) { + file := files[0] + assert.Equal(t, "test.txt", file.Filename) + assert.Equal(t, int64(len(fileContent)), file.Size) + } } } @@ -578,23 +992,53 @@ func TestContextRedirect(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - testify.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) - testify.Equal(t, http.StatusMovedPermanently, rec.Code) - testify.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) - testify.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) + assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) + assert.Equal(t, http.StatusMovedPermanently, rec.Code) + assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) + assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) } -func TestContextStore(t *testing.T) { - var c Context - c = new(context) - c.Set("name", "Jon Snow") - testify.Equal(t, "Jon Snow", c.Get("name")) +func TestContextGet(t *testing.T) { + var testCases = []struct { + name string + given any + whenKey string + expect any + }{ + { + name: "ok, value exist", + given: "Jon Snow", + whenKey: "key", + expect: "Jon Snow", + }, + { + name: "ok, value does not exist", + given: "Jon Snow", + whenKey: "nope", + expect: nil, + }, + { + name: "ok, value is nil value", + given: []byte(nil), + whenKey: "key", + expect: []byte(nil), + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var c = new(Context) + c.Set("key", tc.given) + + v := c.Get(tc.whenKey) + assert.Equal(t, tc.expect, v) + }) + } } func BenchmarkContext_Store(b *testing.B) { e := &Echo{} - c := &context{ + c := &Context{ echo: e, } @@ -606,47 +1050,9 @@ func BenchmarkContext_Store(b *testing.B) { } } -func TestContextHandler(t *testing.T) { - e := New() - r := e.Router() - b := new(bytes.Buffer) - - r.Add(http.MethodGet, "/handler", func(Context) error { - _, err := b.Write([]byte("handler")) - return err - }) - c := e.NewContext(nil, nil) - r.Find(http.MethodGet, "/handler", c) - err := c.Handler()(c) - testify.Equal(t, "handler", b.String()) - testify.NoError(t, err) -} - -func TestContext_SetHandler(t *testing.T) { - var c Context - c = new(context) - - testify.Nil(t, c.Handler()) - - c.SetHandler(func(c Context) error { - return nil - }) - testify.NotNil(t, c.Handler()) -} - -func TestContext_Path(t *testing.T) { - path := "/pa/th" - - var c Context - c = new(context) - - c.SetPath(path) - testify.Equal(t, path, c.Path()) -} - type validator struct{} -func (*validator) Validate(i interface{}) error { +func (*validator) Validate(i any) error { return nil } @@ -654,10 +1060,10 @@ func TestContext_Validate(t *testing.T) { e := New() c := e.NewContext(nil, nil) - testify.Error(t, c.Validate(struct{}{})) + assert.Error(t, c.Validate(struct{}{})) e.Validator = &validator{} - testify.NoError(t, c.Validate(struct{}{})) + assert.NoError(t, c.Validate(struct{}{})) } func TestContext_QueryString(t *testing.T) { @@ -665,31 +1071,30 @@ func TestContext_QueryString(t *testing.T) { queryString := "query=string&var=val" - req := httptest.NewRequest(GET, "/?"+queryString, nil) + req := httptest.NewRequest(http.MethodGet, "/?"+queryString, nil) c := e.NewContext(req, nil) - testify.Equal(t, queryString, c.QueryString()) + assert.Equal(t, queryString, c.QueryString()) } func TestContext_Request(t *testing.T) { - var c Context - c = new(context) + var c = new(Context) - testify.Nil(t, c.Request()) + assert.Nil(t, c.Request()) - req := httptest.NewRequest(GET, "/path", nil) + req := httptest.NewRequest(http.MethodGet, "/path", nil) c.SetRequest(req) - testify.Equal(t, req, c.Request()) + assert.Equal(t, req, c.Request()) } func TestContext_Scheme(t *testing.T) { tests := []struct { - c Context + c *Context s string }{ { - &context{ + &Context{ request: &http.Request{ TLS: &tls.ConnectionState{}, }, @@ -697,7 +1102,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedProto: []string{"https"}}, }, @@ -705,7 +1110,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedProtocol: []string{"http"}}, }, @@ -713,7 +1118,7 @@ func TestContext_Scheme(t *testing.T) { "http", }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedSsl: []string{"on"}}, }, @@ -721,7 +1126,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderXUrlScheme: []string{"https"}}, }, @@ -729,7 +1134,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &Context{ request: &http.Request{}, }, "http", @@ -737,44 +1142,61 @@ func TestContext_Scheme(t *testing.T) { } for _, tt := range tests { - testify.Equal(t, tt.s, tt.c.Scheme()) + assert.Equal(t, tt.s, tt.c.Scheme()) } } func TestContext_IsWebSocket(t *testing.T) { tests := []struct { - c Context - ws testify.BoolAssertionFunc + c *Context + ws assert.BoolAssertionFunc }{ { - &context{ + &Context{ request: &http.Request{ - Header: http.Header{HeaderUpgrade: []string{"websocket"}}, + Header: http.Header{ + HeaderUpgrade: []string{"websocket"}, + HeaderConnection: []string{"upgrade"}, + }, }, }, - testify.True, + assert.True, }, { - &context{ + &Context{ request: &http.Request{ - Header: http.Header{HeaderUpgrade: []string{"Websocket"}}, + Header: http.Header{ + HeaderUpgrade: []string{"Websocket"}, + HeaderConnection: []string{"Upgrade"}, + }, }, }, - testify.True, + assert.True, }, { - &context{ + &Context{ request: &http.Request{}, }, - testify.False, + assert.False, }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderUpgrade: []string{"other"}}, }, }, - testify.False, + assert.False, + }, + { + &Context{ + request: &http.Request{ + Header: http.Header{ + HeaderUpgrade: []string{"websocket"}, + HeaderConnection: []string{"close"}, + }, + }, + }, + assert.False, }, } @@ -787,39 +1209,23 @@ func TestContext_IsWebSocket(t *testing.T) { func TestContext_Bind(t *testing.T) { e := New() - req := httptest.NewRequest(POST, "/", strings.NewReader(userJSON)) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) c := e.NewContext(req, nil) u := new(user) req.Header.Add(HeaderContentType, MIMEApplicationJSON) err := c.Bind(u) - testify.NoError(t, err) - testify.Equal(t, &user{1, "Jon Snow"}, u) -} - -func TestContext_Logger(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) - - log1 := c.Logger() - testify.NotNil(t, log1) - - log2 := log.New("echo2") - c.SetLogger(log2) - testify.Equal(t, log2, c.Logger()) - - // Resetting the context returns the initial logger - c.Reset(nil, nil) - testify.Equal(t, log1, c.Logger()) + assert.NoError(t, err) + assert.Equal(t, &user{ID: 1, Name: "Jon Snow"}, u) } func TestContext_RealIP(t *testing.T) { tests := []struct { - c Context + c *Context s string }{ { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, }, @@ -827,7 +1233,47 @@ func TestContext_RealIP(t *testing.T) { "127.0.0.1", }, { - &context{ + &Context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1,127.0.1.1"}}, + }, + }, + "127.0.0.1", + }, + { + &Context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1"}}, + }, + }, + "127.0.0.1", + }, + { + &Context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348], 2001:db8::1, "}}, + }, + }, + "2001:db8:85a3:8d3:1319:8a2e:370:7348", + }, + { + &Context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348],[2001:db8::1]"}}, + }, + }, + "2001:db8:85a3:8d3:1319:8a2e:370:7348", + }, + { + &Context{ + request: &http.Request{ + Header: http.Header{HeaderXForwardedFor: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}}, + }, + }, + "2001:db8:85a3:8d3:1319:8a2e:370:7348", + }, + { + &Context{ request: &http.Request{ Header: http.Header{ "X-Real-Ip": []string{"192.168.0.1"}, @@ -837,7 +1283,18 @@ func TestContext_RealIP(t *testing.T) { "192.168.0.1", }, { - &context{ + &Context{ + request: &http.Request{ + Header: http.Header{ + "X-Real-Ip": []string{"[2001:db8::1]"}, + }, + }, + }, + "2001:db8::1", + }, + + { + &Context{ request: &http.Request{ RemoteAddr: "89.89.89.89:1654", }, @@ -847,6 +1304,173 @@ func TestContext_RealIP(t *testing.T) { } for _, tt := range tests { - testify.Equal(t, tt.s, tt.c.RealIP()) + assert.Equal(t, tt.s, tt.c.RealIP()) + } +} + +func TestContext_File(t *testing.T) { + var testCases = []struct { + whenFS fs.FS + name string + whenFile string + expectError string + expectStartsWith []byte + expectStatus int + }{ + { + name: "ok, from default file system", + whenFile: "_fixture/images/walle.png", + whenFS: nil, + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "ok, from custom file system", + whenFile: "walle.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, not existent file", + whenFile: "not.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: nil, + expectError: "Not Found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + if tc.whenFS != nil { + e.Filesystem = tc.whenFS + } + + handler := func(ec *Context) error { + return ec.File(tc.whenFile) + } + + req := httptest.NewRequest(http.MethodGet, "/match.png", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + + assert.Equal(t, tc.expectStatus, rec.Code) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestContext_FileFS(t *testing.T) { + var testCases = []struct { + whenFS fs.FS + name string + whenFile string + expectError string + expectStartsWith []byte + expectStatus int + }{ + { + name: "ok", + whenFile: "walle.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, not existent file", + whenFile: "not.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: nil, + expectError: "Not Found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + handler := func(ec *Context) error { + return ec.FileFS(tc.whenFile, tc.whenFS) + } + + req := httptest.NewRequest(http.MethodGet, "/match.png", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + + assert.Equal(t, tc.expectStatus, rec.Code) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestLogger(t *testing.T) { + e := New() + c := e.NewContext(nil, nil) + + log1 := c.Logger() + assert.NotNil(t, log1) + assert.Equal(t, e.Logger, log1) + + customLogger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + c.SetLogger(customLogger) + assert.Equal(t, customLogger, c.Logger()) + + // Resetting the context returns the initial Echo logger + c.Reset(nil, nil) + assert.Equal(t, e.Logger, c.Logger()) +} + +func TestRouteInfo(t *testing.T) { + e := New() + c := e.NewContext(nil, nil) + + orgRI := RouteInfo{ + Name: "root", + Method: http.MethodGet, + Path: "/*", + Parameters: []string{"*"}, } + c.route = &orgRI + ri := c.RouteInfo() + assert.Equal(t, orgRI, ri) + + // Test mutability when middlewares start to change things + + // RouteInfo inside context will not be affected when returned instance is changed + expect := orgRI.Clone() + ri.Path = "changed" + ri.Parameters[0] = "changed" + assert.Equal(t, expect, c.RouteInfo()) + + // RouteInfo inside context will not be affected when returned instance is changed + expect = c.RouteInfo() + orgRI.Name = "changed" + assert.NotEqual(t, expect, c.RouteInfo()) } diff --git a/echo.go b/echo.go index a6ac0fa80..4855e8429 100644 --- a/echo.go +++ b/echo.go @@ -1,153 +1,132 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + /* Package echo implements high performance, minimalist Go web framework. Example: - package main + package main - import ( - "net/http" + import ( + "log/slog" + "net/http" - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" - ) + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" + ) - // Handler - func hello(c echo.Context) error { - return c.String(http.StatusOK, "Hello, World!") - } + // Handler + func hello(c *echo.Context) error { + return c.String(http.StatusOK, "Hello, World!") + } - func main() { - // Echo instance - e := echo.New() + func main() { + // Echo instance + e := echo.New() - // Middleware - e.Use(middleware.Logger()) - e.Use(middleware.Recover()) + // Middleware + e.Use(middleware.RequestLogger()) + e.Use(middleware.Recover()) - // Routes - e.GET("/", hello) + // Routes + e.GET("/", hello) - // Start server - e.Logger.Fatal(e.Start(":1323")) - } + // Start server + if err := e.Start(":8080"); err != nil { + slog.Error("failed to start server", "error", err) + } + } Learn more at https://echo.labstack.com */ package echo import ( - "bytes" stdContext "context" - "crypto/tls" + "encoding/json" "errors" "fmt" - "io" - "io/ioutil" - stdLog "log" - "net" + "io/fs" + "log/slog" "net/http" "net/url" - "path" + "os" + "os/signal" "path/filepath" - "reflect" - "runtime" + "strings" "sync" - "time" - - "github.com/labstack/gommon/color" - "github.com/labstack/gommon/log" - "golang.org/x/crypto/acme" - "golang.org/x/crypto/acme/autocert" + "sync/atomic" + "syscall" ) -type ( - // Echo is the top-level framework instance. - Echo struct { - common - StdLogger *stdLog.Logger - colorer *color.Color - premiddleware []MiddlewareFunc - middleware []MiddlewareFunc - maxParam *int - router *Router - routers map[string]*Router - notFoundHandler HandlerFunc - pool sync.Pool - Server *http.Server - TLSServer *http.Server - Listener net.Listener - TLSListener net.Listener - AutoTLSManager autocert.Manager - DisableHTTP2 bool - Debug bool - HideBanner bool - HidePort bool - HTTPErrorHandler HTTPErrorHandler - Binder Binder - Validator Validator - Renderer Renderer - Logger Logger - } - - // Route contains a handler and information for matching against requests. - Route struct { - Method string `json:"method"` - Path string `json:"path"` - Name string `json:"name"` - } - - // HTTPError represents an error that occurred while handling a request. - HTTPError struct { - Code int `json:"-"` - Message interface{} `json:"message"` - Internal error `json:"-"` // Stores the error returned by an external dependency - } - - // MiddlewareFunc defines a function to process middleware. - MiddlewareFunc func(HandlerFunc) HandlerFunc - - // HandlerFunc defines a function to serve HTTP requests. - HandlerFunc func(Context) error - - // HTTPErrorHandler is a centralized HTTP error handler. - HTTPErrorHandler func(error, Context) - - // Validator is the interface that wraps the Validate function. - Validator interface { - Validate(i interface{}) error - } - - // Renderer is the interface that wraps the Render function. - Renderer interface { - Render(io.Writer, string, interface{}, Context) error - } - - // Map defines a generic map of type `map[string]interface{}`. - Map map[string]interface{} - - // Common struct for Echo & Group. - common struct{} -) +// Echo is the top-level framework instance. +// +// Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these +// fields from handlers/middlewares and changing field values at the same time leads to data-races. +// Same rule applies to adding new routes after server has been started - Adding a route is not Goroutine safe action. +type Echo struct { + serveHTTPFunc func(http.ResponseWriter, *http.Request) -// HTTP methods -// NOTE: Deprecated, please use the stdlib constants directly instead. -const ( - CONNECT = http.MethodConnect - DELETE = http.MethodDelete - GET = http.MethodGet - HEAD = http.MethodHead - OPTIONS = http.MethodOptions - PATCH = http.MethodPatch - POST = http.MethodPost - // PROPFIND = "PROPFIND" - PUT = http.MethodPut - TRACE = http.MethodTrace -) + Binder Binder + Filesystem fs.FS + Renderer Renderer + Validator Validator + JSONSerializer JSONSerializer + IPExtractor IPExtractor + OnAddRoute func(route Route) error + HTTPErrorHandler HTTPErrorHandler + Logger *slog.Logger + + contextPool sync.Pool + + router Router + + // premiddleware are middlewares that are called before routing is done + premiddleware []MiddlewareFunc + + // middleware are middlewares that are called after routing is done and before handler is called + middleware []MiddlewareFunc + + contextPathParamAllocSize atomic.Int32 + + // formParseMaxMemory is passed to Context for multipart form parsing (See http.Request.ParseMultipartForm) + formParseMaxMemory int64 +} + +// JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. +type JSONSerializer interface { + Serialize(c *Context, target any, indent string) error + Deserialize(c *Context, target any) error +} + +// HTTPErrorHandler is a centralized HTTP error handler. +type HTTPErrorHandler func(c *Context, err error) + +// HandlerFunc defines a function to serve HTTP requests. +type HandlerFunc func(c *Context) error + +// MiddlewareFunc defines a function to process middleware. +type MiddlewareFunc func(next HandlerFunc) HandlerFunc + +// MiddlewareConfigurator defines interface for creating middleware handlers with possibility to return configuration errors instead of panicking. +type MiddlewareConfigurator interface { + ToMiddleware() (MiddlewareFunc, error) +} + +// Validator is the interface that wraps the Validate function. +type Validator interface { + Validate(i any) error +} // MIME types const ( - MIMEApplicationJSON = "application/json" + // MIMEApplicationJSON JavaScript Object Notation (JSON) https://www.rfc-editor.org/rfc/rfc8259 + MIMEApplicationJSON = "application/json" + // Deprecated: Please use MIMEApplicationJSON instead. JSON should be encoded using UTF-8 by default. + // No "charset" parameter is defined for this registration. + // Adding one really has no effect on compliant recipients. + // See RFC 8259, section 8.1. https://datatracker.ietf.org/doc/html/rfc8259#section-8.1n" MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8 MIMEApplicationJavaScript = "application/javascript" MIMEApplicationJavaScriptCharsetUTF8 = MIMEApplicationJavaScript + "; " + charsetUTF8 @@ -172,12 +151,21 @@ const ( PROPFIND = "PROPFIND" // REPORT Method can be used to get information about a resource, see rfc 3253 REPORT = "REPORT" + // RouteNotFound is special method type for routes handling "route not found" (404) cases + RouteNotFound = "echo_route_not_found" + // RouteAny is special method type that matches any HTTP method in request. Any has lower + // priority that other methods that have been registered with Router to that path. + RouteAny = "echo_route_any" ) // Headers const ( - HeaderAccept = "Accept" - HeaderAcceptEncoding = "Accept-Encoding" + HeaderAccept = "Accept" + HeaderAcceptEncoding = "Accept-Encoding" + // HeaderAllow is the name of the "Allow" header field used to list the set of methods + // advertised as supported by the target resource. Returning an Allow header is mandatory + // for status 405 (method not found) and useful for the OPTIONS method in responses. + // See RFC 7231: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1 HeaderAllow = "Allow" HeaderAuthorization = "Authorization" HeaderContentDisposition = "Content-Disposition" @@ -189,6 +177,7 @@ const ( HeaderIfModifiedSince = "If-Modified-Since" HeaderLastModified = "Last-Modified" HeaderLocation = "Location" + HeaderRetryAfter = "Retry-After" HeaderUpgrade = "Upgrade" HeaderVary = "Vary" HeaderWWWAuthenticate = "WWW-Authenticate" @@ -198,11 +187,17 @@ const ( HeaderXForwardedSsl = "X-Forwarded-Ssl" HeaderXUrlScheme = "X-Url-Scheme" HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" - HeaderXRealIP = "X-Real-IP" - HeaderXRequestID = "X-Request-ID" + HeaderXRealIP = "X-Real-Ip" + HeaderXRequestID = "X-Request-Id" + HeaderXCorrelationID = "X-Correlation-Id" HeaderXRequestedWith = "X-Requested-With" HeaderServer = "Server" - HeaderOrigin = "Origin" + + // HeaderOrigin request header indicates the origin (scheme, hostname, and port) that caused the request. + // See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin + HeaderOrigin = "Origin" + HeaderCacheControl = "Cache-Control" + HeaderConnection = "Connection" // Access control HeaderAccessControlRequestMethod = "Access-Control-Request-Method" @@ -221,314 +216,425 @@ const ( HeaderXFrameOptions = "X-Frame-Options" HeaderContentSecurityPolicy = "Content-Security-Policy" HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only" - HeaderXCSRFToken = "X-CSRF-Token" + HeaderXCSRFToken = "X-CSRF-Token" // #nosec G101 HeaderReferrerPolicy = "Referrer-Policy" -) -const ( - // Version of Echo - Version = "4.1.13" - website = "https://echo.labstack.com" - // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo - banner = ` - ____ __ - / __/___/ / ___ - / _// __/ _ \/ _ \ -/___/\__/_//_/\___/ %s -High performance, minimalist Go web framework -%s -____________________________________O/_______ - O\ -` + // HeaderSecFetchSite fetch metadata request header indicates the relationship between a request initiator's + // origin and the origin of the requested resource. + // See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Sec-Fetch-Site + HeaderSecFetchSite = "Sec-Fetch-Site" ) -var ( - methods = [...]string{ - http.MethodConnect, - http.MethodDelete, - http.MethodGet, - http.MethodHead, - http.MethodOptions, - http.MethodPatch, - http.MethodPost, - PROPFIND, - http.MethodPut, - http.MethodTrace, - REPORT, - } -) +// Config is configuration for NewWithConfig function +type Config struct { + // Logger is the slog logger instance used for application-wide structured logging. + // If not set, a default TextHandler writing to stdout is created. + Logger *slog.Logger -// Errors -var ( - ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) - ErrNotFound = NewHTTPError(http.StatusNotFound) - ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) - ErrForbidden = NewHTTPError(http.StatusForbidden) - ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) - ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) - ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests) - ErrBadRequest = NewHTTPError(http.StatusBadRequest) - ErrBadGateway = NewHTTPError(http.StatusBadGateway) - ErrInternalServerError = NewHTTPError(http.StatusInternalServerError) - ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout) - ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable) - ErrValidatorNotRegistered = errors.New("validator not registered") - ErrRendererNotRegistered = errors.New("renderer not registered") - ErrInvalidRedirectCode = errors.New("invalid redirect status code") - ErrCookieNotFound = errors.New("cookie not found") - ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") -) + // HTTPErrorHandler is the centralized error handler that processes errors returned + // by handlers and middleware, converting them to appropriate HTTP responses. + // If not set, DefaultHTTPErrorHandler(false) is used. + HTTPErrorHandler HTTPErrorHandler -// Error handlers -var ( - NotFoundHandler = func(c Context) error { - return ErrNotFound - } + // Router is the HTTP request router responsible for matching URLs to handlers + // using a radix tree-based algorithm. + // If not set, NewRouter(RouterConfig{}) is used. + Router Router + + // OnAddRoute is an optional callback hook executed when routes are registered. + // Useful for route validation, logging, or custom route processing. + // If not set, no callback is executed. + OnAddRoute func(route Route) error + + // Filesystem is the fs.FS implementation used for serving static files. + // Supports os.DirFS, embed.FS, and custom implementations. + // If not set, defaults to current working directory. + Filesystem fs.FS + + // Binder handles automatic data binding from HTTP requests to Go structs. + // Supports JSON, XML, form data, query parameters, and path parameters. + // If not set, DefaultBinder is used. + Binder Binder + + // Validator provides optional struct validation after data binding. + // Commonly used with third-party validation libraries. + // If not set, Context.Validate() returns ErrValidatorNotRegistered. + Validator Validator - MethodNotAllowedHandler = func(c Context) error { - return ErrMethodNotAllowed + // Renderer provides template rendering for generating HTML responses. + // Requires integration with a template engine like html/template. + // If not set, Context.Render() returns ErrRendererNotRegistered. + Renderer Renderer + + // JSONSerializer handles JSON encoding and decoding for HTTP requests/responses. + // Can be replaced with faster alternatives like jsoniter or sonic. + // If not set, DefaultJSONSerializer using encoding/json is used. + JSONSerializer JSONSerializer + + // IPExtractor defines the strategy for extracting the real client IP address + // from requests, particularly important when behind proxies or load balancers. + // Used for rate limiting, access control, and logging. + // If not set, falls back to checking X-Forwarded-For and X-Real-IP headers. + IPExtractor IPExtractor + + // FormParseMaxMemory is default value for memory limit that is used + // when parsing multipart forms (See (*http.Request).ParseMultipartForm) + FormParseMaxMemory int64 +} + +// NewWithConfig creates an instance of Echo with given configuration. +func NewWithConfig(config Config) *Echo { + e := New() + if config.Logger != nil { + e.Logger = config.Logger } -) + if config.HTTPErrorHandler != nil { + e.HTTPErrorHandler = config.HTTPErrorHandler + } + if config.Router != nil { + e.router = config.Router + } + if config.OnAddRoute != nil { + e.OnAddRoute = config.OnAddRoute + } + if config.Filesystem != nil { + e.Filesystem = config.Filesystem + } + if config.Binder != nil { + e.Binder = config.Binder + } + if config.Validator != nil { + e.Validator = config.Validator + } + if config.Renderer != nil { + e.Renderer = config.Renderer + } + if config.JSONSerializer != nil { + e.JSONSerializer = config.JSONSerializer + } + if config.IPExtractor != nil { + e.IPExtractor = config.IPExtractor + } + if config.FormParseMaxMemory > 0 { + e.formParseMaxMemory = config.FormParseMaxMemory + } + return e +} // New creates an instance of Echo. -func New() (e *Echo) { - e = &Echo{ - Server: new(http.Server), - TLSServer: new(http.Server), - AutoTLSManager: autocert.Manager{ - Prompt: autocert.AcceptTOS, - }, - Logger: log.New("echo"), - colorer: color.New(), - maxParam: new(int), - } - e.Server.Handler = e - e.TLSServer.Handler = e - e.HTTPErrorHandler = e.DefaultHTTPErrorHandler - e.Binder = &DefaultBinder{} - e.Logger.SetLevel(log.ERROR) - e.StdLogger = stdLog.New(e.Logger.Output(), e.Logger.Prefix()+": ", 0) - e.pool.New = func() interface{} { - return e.NewContext(nil, nil) - } - e.router = NewRouter(e) - e.routers = map[string]*Router{} - return -} +func New() *Echo { + logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) + e := &Echo{ + Logger: logger, + Filesystem: newDefaultFS(), + Binder: &DefaultBinder{}, + JSONSerializer: &DefaultJSONSerializer{}, + formParseMaxMemory: defaultMemory, + } -// NewContext returns a Context instance. -func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context { - return &context{ - request: r, - response: NewResponse(w, e), - store: make(Map), - echo: e, - pvalues: make([]string, *e.maxParam), - handler: NotFoundHandler, + e.serveHTTPFunc = e.serveHTTP + e.router = NewRouter(RouterConfig{}) + e.HTTPErrorHandler = DefaultHTTPErrorHandler(false) + e.contextPool.New = func() any { + return newContext(nil, nil, e) } + return e +} + +// NewContext returns a new Context instance. +// +// Note: both request and response can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway +// these arguments are useful when creating context for tests and cases like that. +func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) *Context { + return newContext(r, w, e) } // Router returns the default router. -func (e *Echo) Router() *Router { +func (e *Echo) Router() Router { return e.router } -// Routers returns the map of host => router. -func (e *Echo) Routers() map[string]*Router { - return e.routers -} +// DefaultHTTPErrorHandler creates new default HTTP error handler implementation. It sends a JSON response +// with status code. `exposeError` parameter decides if returned message will contain also error message or not +// +// Note: DefaultHTTPErrorHandler does not log errors. Use middleware for it if errors need to be logged (separately) +// Note: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). +// When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from +// handler. Then the error that global error handler received will be ignored because we have already "committed" the +// response and status code header has been sent to the client. +func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler { + return func(c *Context, err error) { + if r, _ := UnwrapResponse(c.response); r != nil && r.Committed { + return + } -// DefaultHTTPErrorHandler is the default HTTP error handler. It sends a JSON response -// with status code. -func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { - he, ok := err.(*HTTPError) - if ok { - if he.Internal != nil { - if herr, ok := he.Internal.(*HTTPError); ok { - he = herr + code := http.StatusInternalServerError + var sc HTTPStatusCoder + if errors.As(err, &sc) { + if tmp := sc.StatusCode(); tmp != 0 { + code = tmp } } - } else { - he = &HTTPError{ - Code: http.StatusInternalServerError, - Message: http.StatusText(http.StatusInternalServerError), - } - } - // Issue #1426 - code := he.Code - message := he.Message - if e.Debug { - message = err.Error() - } else if m, ok := message.(string); ok { - message = Map{"message": m} - } + var result any + switch m := sc.(type) { + case json.Marshaler: // this type knows how to format itself to JSON + result = m + case *HTTPError: + sText := m.Message + if sText == "" { + sText = http.StatusText(code) + } + msg := map[string]any{"message": sText} + if exposeError { + if wrappedErr := m.Unwrap(); wrappedErr != nil { + msg["error"] = wrappedErr.Error() + } + } + result = msg + default: + msg := map[string]any{"message": http.StatusText(code)} + if exposeError { + msg["error"] = err.Error() + } + result = msg + } - // Send response - if !c.Response().Committed { + var cErr error if c.Request().Method == http.MethodHead { // Issue #608 - err = c.NoContent(he.Code) + cErr = c.NoContent(code) } else { - err = c.JSON(code, message) + cErr = c.JSON(code, result) } - if err != nil { - e.Logger.Error(err) + if cErr != nil { + c.Logger().Error("echo default error handler failed to send error to client", "error", cErr) // truly rare case. ala client already disconnected } } } -// Pre adds middleware to the chain which is run before router. +// Pre adds middleware to the chain which is run before router tries to find matching route. +// Meaning middleware is executed even for 404 (not found) cases. func (e *Echo) Pre(middleware ...MiddlewareFunc) { e.premiddleware = append(e.premiddleware, middleware...) } -// Use adds middleware to the chain which is run after router. +// Use adds middleware to the chain which is run after router has found matching route and before route/request handler method is executed. func (e *Echo) Use(middleware ...MiddlewareFunc) { e.middleware = append(e.middleware, middleware...) } // CONNECT registers a new CONNECT route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodConnect, path, h, m...) } // DELETE registers a new DELETE route for a path with matching handler in the router -// with optional route-level middleware. -func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// with optional route-level middleware. Panics on error. +func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodDelete, path, h, m...) } // GET registers a new GET route for a path with matching handler in the router -// with optional route-level middleware. -func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// with optional route-level middleware. Panics on error. +func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodGet, path, h, m...) } // HEAD registers a new HEAD route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodHead, path, h, m...) } // OPTIONS registers a new OPTIONS route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodOptions, path, h, m...) } // PATCH registers a new PATCH route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPatch, path, h, m...) } // POST registers a new POST route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPost, path, h, m...) } // PUT registers a new PUT route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPut, path, h, m...) } // TRACE registers a new TRACE route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodTrace, path, h, m...) } -// Any registers a new route for all HTTP methods and path with matching handler +// RouteNotFound registers a special-case route which is executed when no other route is found (i.e. HTTP 404 cases) +// for current request URL. +// Path supports static and named/any parameters just like other http method is defined. Generally path is ended with +// wildcard/match-any character (`/*`, `/download/*` etc). +// +// Example: `e.RouteNotFound("/*", func(c *echo.Context) error { return c.NoContent(http.StatusNotFound) })` +func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { + return e.Add(RouteNotFound, path, h, m...) +} + +// Any registers a new route for all HTTP methods (supported by Echo) and path with matching handler // in the router with optional route-level middleware. -func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = e.Add(m, path, handler, middleware...) - } - return routes +// +// Note: this method only adds specific set of supported HTTP methods as handler and is not true +// "catch-any-arbitrary-method" way of matching requests. +func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { + return e.Add(RouteAny, path, handler, middleware...) } // Match registers a new route for multiple HTTP methods and path with matching -// handler in the router with optional route-level middleware. -func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = e.Add(m, path, handler, middleware...) +// handler in the router with optional route-level middleware. Panics on error. +func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := e.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) } - return routes -} - -// Static registers a new route with path prefix to serve static files from the -// provided root directory. -func (e *Echo) Static(prefix, root string) *Route { - if root == "" { - root = "." // For security we want to restrict to CWD. + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage } - return e.static(prefix, root, e.GET) + return ris } -func (common) static(prefix, root string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route) *Route { - h := func(c Context) error { - p, err := url.PathUnescape(c.Param("*")) +// Static registers a new route with path prefix to serve static files from the provided root directory. +func (e *Echo) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo { + subFs := MustSubFS(e.Filesystem, fsRoot) + return e.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(subFs, false), + middleware..., + ) +} + +// StaticFS registers a new route with path prefix to serve static files from the provided file system. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo { + return e.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(filesystem, false), + middleware..., + ) +} + +// StaticDirectoryHandler creates handler function to serve files from provided file system +// When disablePathUnescaping is set then file name from path is not unescaped and is served as is. +func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc { + return func(c *Context) error { + p := c.Param("*") + if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice + tmpPath, err := url.PathUnescape(p) + if err != nil { + return fmt.Errorf("failed to unescape path variable: %w", err) + } + p = tmpPath + } + + // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid + name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/"))) + fi, err := fs.Stat(fileSystem, name) if err != nil { - return err + return ErrNotFound + } + + // If the request is for a directory and does not end with "/" + p = c.Request().URL.Path // path must not be empty. + if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { + // Redirect to ends with "/" + return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/")) } - name := filepath.Join(root, path.Clean("/"+p)) // "/"+ for security - return c.File(name) + return fsFile(c, name, fileSystem) } - if prefix == "/" { - return get(prefix+"*", h) +} + +// FileFS registers a new route with path to serve file from the provided file system. +func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo { + return e.GET(path, StaticFileHandler(file, filesystem), m...) +} + +// StaticFileHandler creates handler function to serve file from provided file system +func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { + return func(c *Context) error { + return fsFile(c, file, filesystem) } - return get(prefix+"/*", h) } -func (common) file(path, file string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route, - m ...MiddlewareFunc) *Route { - return get(path, func(c Context) error { +// File registers a new route with path to serve a static file with optional route-level middleware. Panics on error. +func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo { + handler := func(c *Context) error { return c.File(file) - }, m...) + } + return e.Add(http.MethodGet, path, handler, middleware...) } -// File registers a new route with path to serve a static file with optional route-level middleware. -func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route { - return e.file(path, file, e.GET, m...) +// AddRoute registers a new Route with default host Router +func (e *Echo) AddRoute(route Route) (RouteInfo, error) { + return e.add(route) } -func (e *Echo) add(host, method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - name := handlerName(handler) - router := e.findRouter(host) - router.Add(method, path, func(c Context) error { - h := handler - // Chain middleware - for i := len(middleware) - 1; i >= 0; i-- { - h = middleware[i](h) +func (e *Echo) add(route Route) (RouteInfo, error) { + if e.OnAddRoute != nil { + if err := e.OnAddRoute(route); err != nil { + return RouteInfo{}, err } - return h(c) - }) - r := &Route{ - Method: method, - Path: path, - Name: name, } - e.router.routes[method+path] = r - return r + + ri, err := e.router.Add(route) + if err != nil { + return RouteInfo{}, err + } + + paramsCount := int32(len(ri.Parameters)) // #nosec G115 + if paramsCount > e.contextPathParamAllocSize.Load() { + e.contextPathParamAllocSize.Store(paramsCount) + } + return ri, nil } // Add registers a new route for an HTTP method and path with matching handler // in the router with optional route-level middleware. -func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - return e.add("", method, path, handler, middleware...) -} - -// Host creates a new router group for the provided host and optional host-level middleware. -func (e *Echo) Host(name string, m ...MiddlewareFunc) (g *Group) { - e.routers[name] = NewRouter(e) - g = &Group{host: name, echo: e} - g.Use(m...) - return +func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { + ri, err := e.add( + Route{ + Method: method, + Path: path, + Handler: handler, + Middlewares: middleware, + Name: "", + }, + ) + if err != nil { + panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ri } // Group creates a new router group with prefix and optional group-level middleware. @@ -538,233 +644,102 @@ func (e *Echo) Group(prefix string, m ...MiddlewareFunc) (g *Group) { return } -// URI generates a URI from handler. -func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string { - name := handlerName(handler) - return e.Reverse(name, params...) -} - -// URL is an alias for `URI` function. -func (e *Echo) URL(h HandlerFunc, params ...interface{}) string { - return e.URI(h, params...) +// PreMiddlewares returns registered pre middlewares. These are middleware to the chain +// which are run before router tries to find matching route. +// Use this method to build your own ServeHTTP method. +// +// NOTE: returned slice is not a copy. Do not mutate. +func (e *Echo) PreMiddlewares() []MiddlewareFunc { + return e.premiddleware } -// Reverse generates an URL from route name and provided parameters. -func (e *Echo) Reverse(name string, params ...interface{}) string { - uri := new(bytes.Buffer) - ln := len(params) - n := 0 - for _, r := range e.router.routes { - if r.Name == name { - for i, l := 0, len(r.Path); i < l; i++ { - if r.Path[i] == ':' && n < ln { - for ; i < l && r.Path[i] != '/'; i++ { - } - uri.WriteString(fmt.Sprintf("%v", params[n])) - n++ - } - if i < l { - uri.WriteByte(r.Path[i]) - } - } - break - } - } - return uri.String() -} - -// Routes returns the registered routes. -func (e *Echo) Routes() []*Route { - routes := make([]*Route, 0, len(e.router.routes)) - for _, v := range e.router.routes { - routes = append(routes, v) - } - return routes +// Middlewares returns registered route level middlewares. Does not contain any group level +// middlewares. Use this method to build your own ServeHTTP method. +// +// NOTE: returned slice is not a copy. Do not mutate. +func (e *Echo) Middlewares() []MiddlewareFunc { + return e.middleware } // AcquireContext returns an empty `Context` instance from the pool. // You must return the context by calling `ReleaseContext()`. -func (e *Echo) AcquireContext() Context { - return e.pool.Get().(Context) +func (e *Echo) AcquireContext() *Context { + return e.contextPool.Get().(*Context) } // ReleaseContext returns the `Context` instance back to the pool. // You must call it after `AcquireContext()`. -func (e *Echo) ReleaseContext(c Context) { - e.pool.Put(c) +func (e *Echo) ReleaseContext(c *Context) { + e.contextPool.Put(c) } // ServeHTTP implements `http.Handler` interface, which serves HTTP requests. func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Acquire context - c := e.pool.Get().(*context) - c.Reset(r, w) + e.serveHTTPFunc(w, r) +} + +// serveHTTP implements `http.Handler` interface, which serves HTTP requests. +func (e *Echo) serveHTTP(w http.ResponseWriter, r *http.Request) { + c := e.contextPool.Get().(*Context) + defer e.contextPool.Put(c) - h := NotFoundHandler + c.Reset(r, w) + var h HandlerFunc if e.premiddleware == nil { - e.findRouter(r.Host).Find(r.Method, getPath(r), c) - h = c.Handler() - h = applyMiddleware(h, e.middleware...) + h = applyMiddleware(e.router.Route(c), e.middleware...) } else { - h = func(c Context) error { - e.findRouter(r.Host).Find(r.Method, getPath(r), c) - h := c.Handler() - h = applyMiddleware(h, e.middleware...) - return h(c) + h = func(cc *Context) error { + h1 := applyMiddleware(e.router.Route(cc), e.middleware...) + return h1(cc) } h = applyMiddleware(h, e.premiddleware...) } // Execute chain if err := h(c); err != nil { - e.HTTPErrorHandler(err, c) - } - - // Release context - e.pool.Put(c) -} - -// Start starts an HTTP server. + e.HTTPErrorHandler(c, err) + } +} + +// Start stars HTTP server on given address with Echo as a handler serving requests. The server can be shutdown by +// sending os.Interrupt signal with `ctrl+c`. Method returns only errors that are not http.ErrServerClosed. +// +// Note: this method is created for use in examples/demos and is deliberately simple without providing configuration +// options. +// +// In need of customization use: +// +// ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) +// defer cancel() +// sc := echo.StartConfig{Address: ":8080"} +// if err := sc.Start(ctx, e); err != nil && !errors.Is(err, http.ErrServerClosed) { +// slog.Error(err.Error()) +// } +// +// // or standard library `http.Server` +// +// s := http.Server{Addr: ":8080", Handler: e} +// if err := s.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { +// slog.Error(err.Error()) +// } func (e *Echo) Start(address string) error { - e.Server.Addr = address - return e.StartServer(e.Server) -} - -// StartTLS starts an HTTPS server. -// If `certFile` or `keyFile` is `string` the values are treated as file paths. -// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is. -func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err error) { - var cert []byte - if cert, err = filepathOrContent(certFile); err != nil { - return - } - - var key []byte - if key, err = filepathOrContent(keyFile); err != nil { - return - } - - s := e.TLSServer - s.TLSConfig = new(tls.Config) - s.TLSConfig.Certificates = make([]tls.Certificate, 1) - if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil { - return - } - - return e.startTLS(address) -} - -func filepathOrContent(fileOrContent interface{}) (content []byte, err error) { - switch v := fileOrContent.(type) { - case string: - return ioutil.ReadFile(v) - case []byte: - return v, nil - default: - return nil, ErrInvalidCertOrKeyType - } -} - -// StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org. -func (e *Echo) StartAutoTLS(address string) error { - s := e.TLSServer - s.TLSConfig = new(tls.Config) - s.TLSConfig.GetCertificate = e.AutoTLSManager.GetCertificate - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, acme.ALPNProto) - return e.startTLS(address) -} - -func (e *Echo) startTLS(address string) error { - s := e.TLSServer - s.Addr = address - if !e.DisableHTTP2 { - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2") - } - return e.StartServer(e.TLSServer) -} - -// StartServer starts a custom http server. -func (e *Echo) StartServer(s *http.Server) (err error) { - // Setup - e.colorer.SetOutput(e.Logger.Output()) - s.ErrorLog = e.StdLogger - s.Handler = e - if e.Debug { - e.Logger.SetLevel(log.DEBUG) - } - - if !e.HideBanner { - e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website)) - } - - if s.TLSConfig == nil { - if e.Listener == nil { - e.Listener, err = newListener(s.Addr) - if err != nil { - return err - } - } - if !e.HidePort { - e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) - } - return s.Serve(e.Listener) - } - if e.TLSListener == nil { - l, err := newListener(s.Addr) - if err != nil { - return err - } - e.TLSListener = tls.NewListener(l, s.TLSConfig) - } - if !e.HidePort { - e.colorer.Printf("⇨ https server started on %s\n", e.colorer.Green(e.TLSListener.Addr())) - } - return s.Serve(e.TLSListener) -} - -// Close immediately stops the server. -// It internally calls `http.Server#Close()`. -func (e *Echo) Close() error { - if err := e.TLSServer.Close(); err != nil { - return err - } - return e.Server.Close() -} - -// Shutdown stops the server gracefully. -// It internally calls `http.Server#Shutdown()`. -func (e *Echo) Shutdown(ctx stdContext.Context) error { - if err := e.TLSServer.Shutdown(ctx); err != nil { - return err - } - return e.Server.Shutdown(ctx) -} - -// NewHTTPError creates a new HTTPError instance. -func NewHTTPError(code int, message ...interface{}) *HTTPError { - he := &HTTPError{Code: code, Message: http.StatusText(code)} - if len(message) > 0 { - he.Message = message[0] - } - return he -} - -// Error makes it compatible with `error` interface. -func (he *HTTPError) Error() string { - return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal) -} - -// SetInternal sets error to HTTPError.Internal -func (he *HTTPError) SetInternal(err error) *HTTPError { - he.Internal = err - return he + sc := StartConfig{Address: address} + ctx, cancel := signal.NotifyContext(stdContext.Background(), os.Interrupt, syscall.SIGTERM) // start shutdown process on ctrl+c + defer cancel() + return sc.Start(ctx, e) } // WrapHandler wraps `http.Handler` into `echo.HandlerFunc`. func WrapHandler(h http.Handler) HandlerFunc { - return func(c Context) error { - h.ServeHTTP(c.Response(), c.Request()) + return func(c *Context) error { + req := c.Request() + req.Pattern = c.Path() + for _, p := range c.PathValues() { + req.SetPathValue(p.Name, p.Value) + } + + h.ServeHTTP(c.Response(), req) return nil } } @@ -772,77 +747,88 @@ func WrapHandler(h http.Handler) HandlerFunc { // WrapMiddleware wraps `func(http.Handler) http.Handler` into `echo.MiddlewareFunc` func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { return func(next HandlerFunc) HandlerFunc { - return func(c Context) (err error) { + return func(c *Context) (err error) { + req := c.Request() + req.Pattern = c.Path() + for _, p := range c.PathValues() { + req.SetPathValue(p.Name, p.Value) + } + m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c.SetRequest(r) - c.SetResponse(NewResponse(w, c.Echo())) + c.SetResponse(NewResponse(w, c.echo.Logger)) err = next(c) - })).ServeHTTP(c.Response(), c.Request()) + })).ServeHTTP(c.Response(), req) return } } } -func getPath(r *http.Request) string { - path := r.URL.RawPath - if path == "" { - path = r.URL.Path +func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc { + for i := len(middleware) - 1; i >= 0; i-- { + h = middleware[i](h) } - return path + return h } -func (e *Echo) findRouter(host string) *Router { - if len(e.routers) > 0 { - if r, ok := e.routers[host]; ok { - return r - } - } - return e.router +// defaultFS emulates os.Open behaviour with filesystem opened by `os.DirFs`. Difference between `os.Open` and `fs.Open` +// is that FS does not allow to open path that start with `..` or `/` etc. For example previously you could have `../images` +// in your application but `fs := os.DirFS("./")` would not allow you to use `fs.Open("../images")` and this would break +// all old applications that rely on being able to traverse up from current executable run path. +// NB: private because you really should use fs.FS implementation instances +type defaultFS struct { + fs fs.FS + prefix string } -func handlerName(h HandlerFunc) string { - t := reflect.ValueOf(h).Type() - if t.Kind() == reflect.Func { - return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name() +func newDefaultFS() *defaultFS { + dir, _ := os.Getwd() + return &defaultFS{ + prefix: dir, + fs: os.DirFS(dir), } - return t.String() -} - -// // PathUnescape is wraps `url.PathUnescape` -// func PathUnescape(s string) (string, error) { -// return url.PathUnescape(s) -// } - -// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted -// connections. It's used by ListenAndServe and ListenAndServeTLS so -// dead TCP connections (e.g. closing laptop mid-download) eventually -// go away. -type tcpKeepAliveListener struct { - *net.TCPListener } -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { - if c, err = ln.AcceptTCP(); err != nil { - return - } else if err = c.(*net.TCPConn).SetKeepAlive(true); err != nil { - return - } else if err = c.(*net.TCPConn).SetKeepAlivePeriod(3 * time.Minute); err != nil { - return - } - return +func (fs defaultFS) Open(name string) (fs.File, error) { + return fs.fs.Open(name) } -func newListener(address string) (*tcpKeepAliveListener, error) { - l, err := net.Listen("tcp", address) +func subFS(currentFs fs.FS, root string) (fs.FS, error) { + root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows + if dFS, ok := currentFs.(*defaultFS); ok { + // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. + // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we + // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs + if !filepath.IsAbs(root) { + root = filepath.Join(dFS.prefix, root) + } + return &defaultFS{ + prefix: root, + fs: os.DirFS(root), + }, nil + } + return fs.Sub(currentFs, root) +} + +// MustSubFS creates sub FS from current filesystem or panic on failure. +// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. +// +// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with +// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to +// create sub fs which uses necessary prefix for directory path. +func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { + subFs, err := subFS(currentFs, fsRoot) if err != nil { - return nil, err + panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err)) } - return &tcpKeepAliveListener{l.(*net.TCPListener)}, nil + return subFs } -func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc { - for i := len(middleware) - 1; i >= 0; i-- { - h = middleware[i](h) +func sanitizeURI(uri string) string { + // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri + // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash + if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') { + uri = "/" + strings.TrimLeft(uri, `/\`) } - return h + return uri } diff --git a/echo_test.go b/echo_test.go index 3f2e48e51..f26eed8e2 100644 --- a/echo_test.go +++ b/echo_test.go @@ -1,30 +1,36 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( "bytes" stdContext "context" "errors" - "io/ioutil" + "fmt" + "io/fs" + "log/slog" + "net" "net/http" "net/http/httptest" - "reflect" + "net/url" + "os" + "runtime" "strings" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -type ( - user struct { - ID int `json:"id" xml:"id" form:"id" query:"id" param:"id"` - Name string `json:"name" xml:"name" form:"name" query:"name" param:"name"` - } -) +type user struct { + ID int `json:"id" xml:"id" form:"id" query:"id" param:"id" header:"id"` + Name string `json:"name" xml:"name" form:"name" query:"name" param:"name" header:"name"` +} const ( userJSON = `{"id":1,"name":"Jon Snow"}` + usersJSON = `[{"id":1,"name":"Jon Snow"}]` userXML = `1Jon Snow` userForm = `id=1&name=Jon Snow` invalidContent = "invalid content" @@ -43,6 +49,8 @@ const userXMLPretty = ` Jon Snow ` +var dummyQuery = url.Values{"dummy": []string{"useless"}} + func TestEcho(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -52,50 +60,354 @@ func TestEcho(t *testing.T) { // Router assert.NotNil(t, e.Router()) - // DefaultHTTPErrorHandler - e.DefaultHTTPErrorHandler(errors.New("error"), c) + e.HTTPErrorHandler(c, errors.New("error")) + assert.Equal(t, http.StatusInternalServerError, rec.Code) } -func TestEchoStatic(t *testing.T) { - e := New() +func TestNewWithConfig(t *testing.T) { + e := NewWithConfig(Config{}) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e.GET("/", func(c *Context) error { + return c.String(http.StatusTeapot, "Hello, World!") + }) + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.Equal(t, `Hello, World!`, rec.Body.String()) +} + +func TestEcho_StaticFS(t *testing.T) { + var testCases = []struct { + givenFs fs.FS + name string + givenPrefix string + givenFsRoot string + whenURL string + expectHeaderLocation string + expectBodyStartsWith string + expectStatus int + }{ + { + name: "ok", + givenPrefix: "/images", + givenFs: os.DirFS("./_fixture/images"), + whenURL: "/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "ok, from sub fs", + givenPrefix: "/images", + givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"), + whenURL: "/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "No file", + givenPrefix: "/images", + givenFs: os.DirFS("_fixture/scripts"), + whenURL: "/images/bolt.png", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory", + givenPrefix: "/images", + givenFs: os.DirFS("_fixture/images"), + whenURL: "/images/", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory Redirect", + givenPrefix: "/", + givenFs: os.DirFS("_fixture/"), + whenURL: "/folder", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory Redirect with non-root path", + givenPrefix: "/static", + givenFs: os.DirFS("_fixture"), + whenURL: "/static", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/static/", + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory 404 (request URL without slash)", + givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" + givenFs: os.DirFS("_fixture"), + whenURL: "/folder", // no trailing slash + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Prefixed directory redirect (without slash redirect to slash)", + givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* + givenFs: os.DirFS("_fixture"), + whenURL: "/folder", // no trailing slash + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory with index.html", + givenPrefix: "/", + givenFs: os.DirFS("_fixture"), + whenURL: "/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending with slash)", + givenPrefix: "/assets/", + givenFs: os.DirFS("_fixture"), + whenURL: "/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending without slash)", + givenPrefix: "/assets", + givenFs: os.DirFS("_fixture"), + whenURL: "/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Sub-directory with index.html", + givenPrefix: "/", + givenFs: os.DirFS("_fixture"), + whenURL: "/folder/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "do not allow directory traversal (backslash - windows separator)", + givenPrefix: "/", + givenFs: os.DirFS("_fixture/"), + whenURL: `/..\\middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "do not allow directory traversal (slash - unix separator)", + givenPrefix: "/", + givenFs: os.DirFS("_fixture/"), + whenURL: `/../middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "open redirect vulnerability", + givenPrefix: "/", + givenFs: os.DirFS("_fixture/"), + whenURL: "/open.redirect.hackercom%2f..", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/open.redirect.hackercom/../", // location starting with `//open` would be very bad + expectBodyStartsWith: "", + }, + } - assert := assert.New(t) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() - // OK - e.Static("/images", "_fixture/images") - c, b := request(http.MethodGet, "/images/walle.png", e) - assert.Equal(http.StatusOK, c) - assert.NotEmpty(b) + tmpFs := tc.givenFs + if tc.givenFsRoot != "" { + tmpFs = MustSubFS(tmpFs, tc.givenFsRoot) + } + e.StaticFS(tc.givenPrefix, tmpFs) - // No file - e.Static("/images", "_fixture/scripts") - c, _ = request(http.MethodGet, "/images/bolt.png", e) - assert.Equal(http.StatusNotFound, c) + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() - // Directory - e.Static("/images", "_fixture/images") - c, _ = request(http.MethodGet, "/images", e) - assert.Equal(http.StatusNotFound, c) + e.ServeHTTP(rec, req) - // Directory with index.html - e.Static("/", "_fixture") - c, r := request(http.MethodGet, "/", e) - assert.Equal(http.StatusOK, c) - assert.Equal(true, strings.HasPrefix(r, "")) + assert.Equal(t, tc.expectStatus, rec.Code) + body := rec.Body.String() + if tc.expectBodyStartsWith != "" { + assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) + } else { + assert.Equal(t, "", body) + } - // Sub-directory with index.html - c, r = request(http.MethodGet, "/folder", e) - assert.Equal(http.StatusOK, c) - assert.Equal(true, strings.HasPrefix(r, "")) + if tc.expectHeaderLocation != "" { + assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) + } else { + _, ok := rec.Result().Header["Location"] + assert.False(t, ok) + } + }) + } } -func TestEchoFile(t *testing.T) { +func TestEcho_FileFS(t *testing.T) { + var testCases = []struct { + whenFS fs.FS + name string + whenPath string + whenFile string + givenURL string + expectStartsWith []byte + expectCode int + }{ + { + name: "ok", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/walle", + expectCode: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, requesting invalid path", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/walle.png", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + { + name: "nok, serving not existent file from filesystem", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "not-existent.png", + givenURL: "/walle", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) + + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestEcho_StaticPanic(t *testing.T) { + var testCases = []struct { + name string + givenRoot string + }{ + { + name: "panics for ../", + givenRoot: "../assets", + }, + { + name: "panics for /", + givenRoot: "/assets", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Filesystem = os.DirFS("./") + + assert.Panics(t, func() { + e.Static("../assets", tc.givenRoot) + }) + }) + } +} + +func TestEchoStaticRedirectIndex(t *testing.T) { e := New() - e.File("/walle", "_fixture/images/walle.png") - c, b := request(http.MethodGet, "/walle", e) - assert.Equal(t, http.StatusOK, c) - assert.NotEmpty(t, b) + + // HandlerFunc + ri := e.Static("/static", "_fixture") + assert.Equal(t, http.MethodGet, ri.Method) + assert.Equal(t, "/static*", ri.Path) + assert.Equal(t, "GET:/static*", ri.Name) + assert.Equal(t, []string{"*"}, ri.Parameters) + + ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) + defer cancel() + addr, err := startOnRandomPort(ctx, e) + if err != nil { + assert.Fail(t, err.Error()) + } + + code, body, err := doGet(fmt.Sprintf("http://%v/static", addr)) + assert.NoError(t, err) + assert.True(t, strings.HasPrefix(body, "")) + assert.Equal(t, http.StatusOK, code) +} + +func TestEchoFile(t *testing.T) { + var testCases = []struct { + name string + givenPath string + givenFile string + whenPath string + expectStartsWith string + expectCode int + }{ + { + name: "ok", + givenPath: "/walle", + givenFile: "_fixture/images/walle.png", + whenPath: "/walle", + expectCode: http.StatusOK, + expectStartsWith: string([]byte{0x89, 0x50, 0x4e}), + }, + { + name: "ok with relative path", + givenPath: "/", + givenFile: "./go.mod", + whenPath: "/", + expectCode: http.StatusOK, + expectStartsWith: "module github.com/labstack/echo/v", + }, + { + name: "nok file does not exist", + givenPath: "/", + givenFile: "./this-file-does-not-exist", + whenPath: "/", + expectCode: http.StatusNotFound, + expectStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() // we are using echo.defaultFS instance + e.File(tc.givenPath, tc.givenFile) + + c, b := request(http.MethodGet, tc.whenPath, e) + assert.Equal(t, tc.expectCode, c) + + if len(b) > len(tc.expectStartsWith) { + b = b[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, b) + }) + } } func TestEchoMiddleware(t *testing.T) { @@ -103,36 +415,37 @@ func TestEchoMiddleware(t *testing.T) { buf := new(bytes.Buffer) e.Pre(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - assert.Empty(t, c.Path()) + return func(c *Context) error { + // before route match is found RouteInfo does not exist + assert.Equal(t, RouteInfo{}, c.RouteInfo()) buf.WriteString("-1") return next(c) } }) e.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { buf.WriteString("1") return next(c) } }) e.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { buf.WriteString("2") return next(c) } }) e.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { buf.WriteString("3") return next(c) } }) // Route - e.GET("/", func(c Context) error { + e.GET("/", func(c *Context) error { return c.String(http.StatusOK, "OK") }) @@ -145,11 +458,11 @@ func TestEchoMiddleware(t *testing.T) { func TestEchoMiddlewareError(t *testing.T) { e := New() e.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { return errors.New("error") } }) - e.GET("/", NotFoundHandler) + e.GET("/", notFoundHandler) c, _ := request(http.MethodGet, "/", e) assert.Equal(t, http.StatusInternalServerError, c) } @@ -158,7 +471,7 @@ func TestEchoHandler(t *testing.T) { e := New() // HandlerFunc - e.GET("/ok", func(c Context) error { + e.GET("/ok", func(c *Context) error { return c.String(http.StatusOK, "OK") }) @@ -169,171 +482,302 @@ func TestEchoHandler(t *testing.T) { func TestEchoWrapHandler(t *testing.T) { e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + var actualID string + var actualPattern string + e.GET("/:id", WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("test")) - })) - if assert.NoError(t, h(c)) { - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "test", rec.Body.String()) - } + actualID = r.PathValue("id") + actualPattern = r.Pattern + }))) + + req := httptest.NewRequest(http.MethodGet, "/123", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "test", rec.Body.String()) + assert.Equal(t, "123", actualID) + assert.Equal(t, "/:id", actualPattern) } func TestEchoWrapMiddleware(t *testing.T) { e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - buf := new(bytes.Buffer) - mw := WrapMiddleware(func(h http.Handler) http.Handler { + + var actualID string + var actualPattern string + e.Use(WrapMiddleware(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - buf.Write([]byte("mw")) + actualID = r.PathValue("id") + actualPattern = r.Pattern h.ServeHTTP(w, r) }) + })) + + e.GET("/:id", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") }) - h := mw(func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - if assert.NoError(t, h(c)) { - assert.Equal(t, "mw", buf.String()) - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "OK", rec.Body.String()) - } + + req := httptest.NewRequest(http.MethodGet, "/123", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.Equal(t, "OK", rec.Body.String()) + assert.Equal(t, "123", actualID) + assert.Equal(t, "/:id", actualPattern) } func TestEchoConnect(t *testing.T) { e := New() - testMethod(t, http.MethodConnect, "/", e) + + ri := e.CONNECT("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodConnect, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodConnect+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodConnect, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoDelete(t *testing.T) { e := New() - testMethod(t, http.MethodDelete, "/", e) + + ri := e.DELETE("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodDelete, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodDelete+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodDelete, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoGet(t *testing.T) { e := New() - testMethod(t, http.MethodGet, "/", e) + + ri := e.GET("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodGet, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodGet+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodGet, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoHead(t *testing.T) { e := New() - testMethod(t, http.MethodHead, "/", e) + + ri := e.HEAD("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodHead, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodHead+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodHead, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoOptions(t *testing.T) { e := New() - testMethod(t, http.MethodOptions, "/", e) + + ri := e.OPTIONS("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodOptions, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodOptions+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodOptions, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPatch(t *testing.T) { e := New() - testMethod(t, http.MethodPatch, "/", e) + + ri := e.PATCH("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPatch, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodPatch+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodPatch, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPost(t *testing.T) { e := New() - testMethod(t, http.MethodPost, "/", e) + + ri := e.POST("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPost, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodPost+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodPost, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPut(t *testing.T) { e := New() - testMethod(t, http.MethodPut, "/", e) + + ri := e.PUT("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPut, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodPut+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodPut, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoTrace(t *testing.T) { e := New() - testMethod(t, http.MethodTrace, "/", e) -} -func TestEchoAny(t *testing.T) { // JFC - e := New() - e.Any("/", func(c Context) error { - return c.String(http.StatusOK, "Any") + ri := e.TRACE("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") }) + + assert.Equal(t, http.MethodTrace, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodTrace+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodTrace, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } -func TestEchoMatch(t *testing.T) { // JFC +func TestEcho_Any(t *testing.T) { e := New() - e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error { - return c.String(http.StatusOK, "Match") + + ri := e.Any("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK from ANY") }) + + assert.Equal(t, RouteAny, ri.Method) + assert.Equal(t, "/activate", ri.Path) + assert.Equal(t, RouteAny+":/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodTrace, "/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK from ANY`, body) } -func TestEchoURL(t *testing.T) { +func TestEcho_Any_hasLowerPriority(t *testing.T) { e := New() - static := func(Context) error { return nil } - getUser := func(Context) error { return nil } - getFile := func(Context) error { return nil } - e.GET("/static/file", static) - e.GET("/users/:id", getUser) - g := e.Group("/group") - g.GET("/users/:uid/files/:fid", getFile) + e.Any("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "ANY") + }) + e.GET("/activate", func(c *Context) error { + return c.String(http.StatusLocked, "GET") + }) - assert := assert.New(t) + status, body := request(http.MethodTrace, "/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `ANY`, body) - assert.Equal("/static/file", e.URL(static)) - assert.Equal("/users/:id", e.URL(getUser)) - assert.Equal("/users/1", e.URL(getUser, "1")) - assert.Equal("/group/users/1/files/:fid", e.URL(getFile, "1")) - assert.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1")) + status, body = request(http.MethodGet, "/activate", e) + assert.Equal(t, http.StatusLocked, status) + assert.Equal(t, `GET`, body) } -func TestEchoRoutes(t *testing.T) { +func TestEchoMatch(t *testing.T) { // JFC e := New() - routes := []*Route{ - {http.MethodGet, "/users/:user/events", ""}, - {http.MethodGet, "/users/:user/events/public", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, - } - for _, r := range routes { - e.Add(r.Method, r.Path, func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - } - - if assert.Equal(t, len(routes), len(e.Routes())) { - for _, r := range e.Routes() { - found := false - for _, rr := range routes { - if r.Method == rr.Method && r.Path == rr.Path { - found = true - break - } - } - if !found { - t.Errorf("Route %s %s not found", r.Method, r.Path) - } - } - } + ris := e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c *Context) error { + return c.String(http.StatusOK, "Match") + }) + assert.Len(t, ris, 2) } -func TestEchoEncodedPath(t *testing.T) { +func TestEchoServeHTTPPathEncoding(t *testing.T) { e := New() - e.GET("/:id", func(c Context) error { - return c.NoContent(http.StatusOK) + e.GET("/with/slash", func(c *Context) error { + return c.String(http.StatusOK, "/with/slash") }) - req := httptest.NewRequest(http.MethodGet, "/with%2Fslash", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) + e.GET("/:id", func(c *Context) error { + return c.String(http.StatusOK, c.Param("id")) + }) + + var testCases = []struct { + name string + whenURL string + expectURL string + expectStatus int + }{ + { + name: "url with encoding is not decoded for routing", + whenURL: "/with%2Fslash", + expectURL: "with%2Fslash", // `%2F` is not decoded to `/` for routing + expectStatus: http.StatusOK, + }, + { + name: "url without encoding is used as is", + whenURL: "/with/slash", + expectURL: "/with/slash", + expectStatus: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectStatus, rec.Code) + assert.Equal(t, tc.expectURL, rec.Body.String()) + }) + } } func TestEchoGroup(t *testing.T) { e := New() buf := new(bytes.Buffer) e.Use(MiddlewareFunc(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { buf.WriteString("0") return next(c) } })) - h := func(c Context) error { + h := func(c *Context) error { return c.NoContent(http.StatusOK) } @@ -346,7 +790,7 @@ func TestEchoGroup(t *testing.T) { // Group g1 := e.Group("/group1") g1.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { buf.WriteString("1") return next(c) } @@ -356,14 +800,14 @@ func TestEchoGroup(t *testing.T) { // Nested groups with middleware g2 := e.Group("/group2") g2.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { buf.WriteString("2") return next(c) } }) g3 := g2.Group("/group3") g3.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { buf.WriteString("3") return next(c) } @@ -382,6 +826,70 @@ func TestEchoGroup(t *testing.T) { assert.Equal(t, "023", buf.String()) } +func TestEcho_RouteNotFound(t *testing.T) { + var testCases = []struct { + expectRoute any + name string + whenURL string + expectCode int + }{ + { + name: "404, route to static not found handler /a/c/xx", + whenURL: "/a/c/xx", + expectRoute: "GET /a/c/xx", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to path param not found handler /a/:file", + whenURL: "/a/echo.exe", + expectRoute: "GET /a/:file", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to any not found handler /*", + whenURL: "/b/echo.exe", + expectRoute: "GET /*", + expectCode: http.StatusNotFound, + }, + { + name: "200, route /a/c/df to /a/c/df", + whenURL: "/a/c/df", + expectRoute: "GET /a/c/df", + expectCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + okHandler := func(c *Context) error { + return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) + } + notFoundHandler := func(c *Context) error { + return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) + } + + e.GET("/", okHandler) + e.GET("/a/c/df", okHandler) + e.GET("/a/b*", okHandler) + e.PUT("/*", okHandler) + + e.RouteNotFound("/a/c/xx", notFoundHandler) // static + e.RouteNotFound("/a/:file", notFoundHandler) // param + e.RouteNotFound("/*", notFoundHandler) // any + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectRoute, rec.Body.String()) + }) + } +} + func TestEchoNotFound(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/files", nil) @@ -392,189 +900,328 @@ func TestEchoNotFound(t *testing.T) { func TestEchoMethodNotAllowed(t *testing.T) { e := New() - e.GET("/", func(c Context) error { + + e.GET("/", func(c *Context) error { return c.String(http.StatusOK, "Echo!") }) req := httptest.NewRequest(http.MethodPost, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusMethodNotAllowed, rec.Code) + assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow)) +} + +func TestEcho_OnAddRoute(t *testing.T) { + exampleRoute := Route{ + Method: http.MethodGet, + Path: "/api/files/:id", + Handler: notFoundHandler, + Middlewares: nil, + Name: "x", + } + + var testCases = []struct { + whenRoute Route + whenError error + name string + expectError string + expectAdded []string + expectLen int + }{ + { + name: "ok", + whenRoute: exampleRoute, + whenError: nil, + expectAdded: []string{"/static", "/api/files/:id"}, + expectError: "", + expectLen: 2, + }, + { + name: "nok, error is returned", + whenRoute: exampleRoute, + whenError: errors.New("nope"), + expectAdded: []string{"/static"}, + expectError: "nope", + expectLen: 1, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + e := New() + + added := make([]string, 0) + cnt := 0 + e.OnAddRoute = func(route Route) error { + if cnt > 0 && tc.whenError != nil { // we want to GET /static to succeed for nok tests + return tc.whenError + } + cnt++ + added = append(added, route.Path) + return nil + } + + e.GET("/static", notFoundHandler) + + var err error + _, err = e.AddRoute(tc.whenRoute) + + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + assert.Len(t, e.Router().Routes(), tc.expectLen) + assert.Equal(t, tc.expectAdded, added) + }) + } } func TestEchoContext(t *testing.T) { e := New() c := e.AcquireContext() - assert.IsType(t, new(context), c) + assert.IsType(t, new(Context), c) e.ReleaseContext(c) } -func TestEchoStart(t *testing.T) { +func TestPreMiddlewares(t *testing.T) { e := New() - go func() { - assert.NoError(t, e.Start(":0")) - }() - time.Sleep(200 * time.Millisecond) + assert.Equal(t, 0, len(e.PreMiddlewares())) + + e.Pre(func(next HandlerFunc) HandlerFunc { + return func(c *Context) error { + return next(c) + } + }) + + assert.Equal(t, 1, len(e.PreMiddlewares())) } -func TestEchoStartTLS(t *testing.T) { +func TestMiddlewares(t *testing.T) { e := New() - go func() { - err := e.StartTLS(":0", "_fixture/certs/cert.pem", "_fixture/certs/key.pem") - // Prevent the test to fail after closing the servers - if err != http.ErrServerClosed { - assert.NoError(t, err) + assert.Equal(t, 0, len(e.Middlewares())) + + e.Use(func(next HandlerFunc) HandlerFunc { + return func(c *Context) error { + return next(c) } + }) + + assert.Equal(t, 1, len(e.Middlewares())) +} + +func TestEcho_Start(t *testing.T) { + e := New() + e.GET("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + rndPort, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + defer rndPort.Close() + errChan := make(chan error, 1) + go func() { + errChan <- e.Start(rndPort.Addr().String()) }() - time.Sleep(200 * time.Millisecond) - e.Close() + select { + case <-time.After(250 * time.Millisecond): + t.Fatal("start did not error out") + case err := <-errChan: + expectContains := "bind: address already in use" + if runtime.GOOS == "windows" { + expectContains = "bind: Only one usage of each socket address" + } + assert.Contains(t, err.Error(), expectContains) + } +} + +func request(method, path string, e *Echo) (int, string) { + req := httptest.NewRequest(method, path, nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + return rec.Code, rec.Body.String() +} + +type customError struct { + Code int + Message string } -func TestEchoStartTLSByteString(t *testing.T) { - cert, err := ioutil.ReadFile("_fixture/certs/cert.pem") - require.NoError(t, err) - key, err := ioutil.ReadFile("_fixture/certs/key.pem") - require.NoError(t, err) +func (ce *customError) StatusCode() int { + return ce.Code +} - testCases := []struct { - cert interface{} - key interface{} - expectedErr error - name string +func (ce *customError) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`{"x":"%v"}`, ce.Message)), nil +} + +func (ce *customError) Error() string { + return ce.Message +} + +func TestDefaultHTTPErrorHandler(t *testing.T) { + var testCases = []struct { + whenError error + name string + whenMethod string + expectBody string + expectLogged string + expectStatus int + givenExposeError bool + givenLoggerFunc bool }{ { - cert: "_fixture/certs/cert.pem", - key: "_fixture/certs/key.pem", - expectedErr: nil, - name: `ValidCertAndKeyFilePath`, + name: "ok, expose error = true, HTTPError, no wrapped err", + givenExposeError: true, + whenError: &HTTPError{Code: http.StatusTeapot, Message: "my_error"}, + expectStatus: http.StatusTeapot, + expectBody: `{"message":"my_error"}` + "\n", + }, + { + name: "ok, expose error = true, HTTPError + wrapped error", + givenExposeError: true, + whenError: HTTPError{Code: http.StatusTeapot, Message: "my_error"}.Wrap(errors.New("internal_error")), + expectStatus: http.StatusTeapot, + expectBody: `{"error":"internal_error","message":"my_error"}` + "\n", + }, + { + name: "ok, expose error = true, HTTPError + wrapped HTTPError", + givenExposeError: true, + whenError: HTTPError{Code: http.StatusTeapot, Message: "my_error"}.Wrap(&HTTPError{Code: http.StatusTeapot, Message: "early_error"}), + expectStatus: http.StatusTeapot, + expectBody: `{"error":"code=418, message=early_error","message":"my_error"}` + "\n", + }, + { + name: "ok, expose error = false, HTTPError", + whenError: &HTTPError{Code: http.StatusTeapot, Message: "my_error"}, + expectStatus: http.StatusTeapot, + expectBody: `{"message":"my_error"}` + "\n", + }, + { + name: "ok, expose error = false, HTTPError, no message", + whenError: &HTTPError{Code: http.StatusTeapot, Message: ""}, + expectStatus: http.StatusTeapot, + expectBody: `{"message":"I'm a teapot"}` + "\n", + }, + { + name: "ok, expose error = false, HTTPError + internal HTTPError", + whenError: HTTPError{Code: http.StatusTooEarly, Message: "my_error"}.Wrap(&HTTPError{Code: http.StatusTeapot, Message: "early_error"}), + expectStatus: http.StatusTooEarly, + expectBody: `{"message":"my_error"}` + "\n", }, { - cert: cert, - key: key, - expectedErr: nil, - name: `ValidCertAndKeyByteString`, + name: "ok, expose error = true, Error", + givenExposeError: true, + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: `{"error":"my errors wraps: internal_error","message":"Internal Server Error"}` + "\n", }, { - cert: cert, - key: 1, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidKeyType`, + name: "ok, expose error = false, Error", + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: `{"message":"Internal Server Error"}` + "\n", }, { - cert: 0, - key: key, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidCertType`, + name: "ok, http.HEAD, expose error = true, Error", + givenExposeError: true, + whenMethod: http.MethodHead, + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: ``, }, { - cert: 0, - key: 1, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidCertAndKeyTypes`, + name: "ok, custom error implement MarshalJSON + HTTPStatusCoder", + whenMethod: http.MethodGet, + whenError: &customError{Code: http.StatusTeapot, Message: "custom error msg"}, + expectStatus: http.StatusTeapot, + expectBody: `{"x":"custom error msg"}` + "\n", }, } - for _, test := range testCases { - test := test - t.Run(test.name, func(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) e := New() - e.HideBanner = true - - go func() { - err := e.StartTLS(":0", test.cert, test.key) - if test.expectedErr != nil { - require.EqualError(t, err, test.expectedErr.Error()) - } else if err != http.ErrServerClosed { // Prevent the test to fail after closing the servers - require.NoError(t, err) - } - }() - time.Sleep(200 * time.Millisecond) + e.Logger = slog.New(slog.DiscardHandler) + e.Any("/path", func(c *Context) error { + return tc.whenError + }) + + e.HTTPErrorHandler = DefaultHTTPErrorHandler(tc.givenExposeError) + + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod + } + c, b := request(method, "/path", e) - require.NoError(t, e.Close()) + assert.Equal(t, tc.expectStatus, c) + assert.Equal(t, tc.expectBody, b) + assert.Equal(t, tc.expectLogged, buf.String()) }) } } -func TestEchoStartAutoTLS(t *testing.T) { +func TestDefaultHTTPErrorHandler_CommitedResponse(t *testing.T) { e := New() - errChan := make(chan error, 0) - - go func() { - errChan <- e.StartAutoTLS(":0") - }() - time.Sleep(200 * time.Millisecond) - - select { - case err := <-errChan: - assert.NoError(t, err) - default: - assert.NoError(t, e.Close()) - } -} + req := httptest.NewRequest(http.MethodGet, "/", nil) + resp := httptest.NewRecorder() + c := e.NewContext(req, resp) -func testMethod(t *testing.T, method, path string, e *Echo) { - p := reflect.ValueOf(path) - h := reflect.ValueOf(func(c Context) error { - return c.String(http.StatusOK, method) - }) - i := interface{}(e) - reflect.ValueOf(i).MethodByName(method).Call([]reflect.Value{p, h}) - _, body := request(method, path, e) - assert.Equal(t, method, body) -} + c.orgResponse.Committed = true + errHandler := DefaultHTTPErrorHandler(false) -func request(method, path string, e *Echo) (int, string) { - req := httptest.NewRequest(method, path, nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - return rec.Code, rec.Body.String() + errHandler(c, errors.New("my_error")) + assert.Equal(t, http.StatusOK, resp.Code) } -func TestHTTPError(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - assert.Equal(t, "code=400, message=map[code:12], internal=", err.Error()) -} - -func TestEchoClose(t *testing.T) { +func benchmarkEchoRoutes(b *testing.B, routes []testRoute) { e := New() - errCh := make(chan error) - - go func() { - errCh <- e.Start(":0") - }() + req := httptest.NewRequest(http.MethodGet, "/", nil) + u := req.URL + w := httptest.NewRecorder() - time.Sleep(200 * time.Millisecond) + b.ReportAllocs() - if err := e.Close(); err != nil { - t.Fatal(err) + // Add routes + for _, route := range routes { + e.Add(route.Method, route.Path, func(c *Context) error { + return nil + }) } - assert.NoError(t, e.Close()) - - err := <-errCh - assert.Equal(t, err.Error(), "http: Server closed") + // Find routes + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, route := range routes { + req.Method = route.Method + u.Path = route.Path + e.ServeHTTP(w, req) + } + } } -func TestEchoShutdown(t *testing.T) { - e := New() - errCh := make(chan error) - - go func() { - errCh <- e.Start(":0") - }() +func BenchmarkEchoStaticRoutes(b *testing.B) { + benchmarkEchoRoutes(b, staticRoutes) +} - time.Sleep(200 * time.Millisecond) +func BenchmarkEchoStaticRoutesMisses(b *testing.B) { + benchmarkEchoRoutes(b, staticRoutes) +} - if err := e.Close(); err != nil { - t.Fatal(err) - } +func BenchmarkEchoGitHubAPI(b *testing.B) { + benchmarkEchoRoutes(b, gitHubAPI) +} - ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 10*time.Second) - defer cancel() - assert.NoError(t, e.Shutdown(ctx)) +func BenchmarkEchoGitHubAPIMisses(b *testing.B) { + benchmarkEchoRoutes(b, gitHubAPI) +} - err := <-errCh - assert.Equal(t, err.Error(), "http: Server closed") +func BenchmarkEchoParseAPI(b *testing.B) { + benchmarkEchoRoutes(b, parseAPI) } diff --git a/echotest/context.go b/echotest/context.go new file mode 100644 index 000000000..2f665705d --- /dev/null +++ b/echotest/context.go @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echotest + +import ( + "bytes" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/labstack/echo/v5" +) + +// ContextConfig is configuration for creating echo.Context for testing purposes. +type ContextConfig struct { + // Request will be used instead of default `httptest.NewRequest(http.MethodGet, "/", nil)` + Request *http.Request + + // Response will be used instead of default `httptest.NewRecorder()` + Response *httptest.ResponseRecorder + + // QueryValues wil be set as Request.URL.RawQuery value + QueryValues url.Values + + // Headers wil be set as Request.Header value + Headers http.Header + + // PathValues initializes context.PathValues with given value. + PathValues echo.PathValues + + // RouteInfo initializes context.RouteInfo() with given value + RouteInfo *echo.RouteInfo + + // FormValues creates form-urlencoded form out of given values. If there is no + // `content-type` header it will be set to `application/x-www-form-urlencoded` + // In case Request was not set the Request.Method is set to `POST` + // + // FormValues, MultipartForm and JSONBody are mutually exclusive. + FormValues url.Values + + // MultipartForm creates multipart form out of given value. If there is no + // `content-type` header it will be set to `multipart/form-data` + // In case Request was not set the Request.Method is set to `POST` + // + // FormValues, MultipartForm and JSONBody are mutually exclusive. + MultipartForm *MultipartForm + + // JSONBody creates JSON body out of given bytes. If there is no + // `content-type` header it will be set to `application/json` + // In case Request was not set the Request.Method is set to `POST` + // + // FormValues, MultipartForm and JSONBody are mutually exclusive. + JSONBody []byte +} + +// MultipartForm is used to create multipart form out of given value +type MultipartForm struct { + Fields map[string]string + Files []MultipartFormFile +} + +// MultipartFormFile is used to create file in multipart form out of given value +type MultipartFormFile struct { + Fieldname string + Filename string + Content []byte +} + +// ToContext converts ContextConfig to echo.Context +func (conf ContextConfig) ToContext(t *testing.T) *echo.Context { + c, _ := conf.ToContextRecorder(t) + return c +} + +// ToContextRecorder converts ContextConfig to echo.Context and httptest.ResponseRecorder +func (conf ContextConfig) ToContextRecorder(t *testing.T) (*echo.Context, *httptest.ResponseRecorder) { + if conf.Response == nil { + conf.Response = httptest.NewRecorder() + } + isDefaultRequest := false + if conf.Request == nil { + isDefaultRequest = true + conf.Request = httptest.NewRequest(http.MethodGet, "/", nil) + } + + if len(conf.QueryValues) > 0 { + conf.Request.URL.RawQuery = conf.QueryValues.Encode() + } + if len(conf.Headers) > 0 { + conf.Request.Header = conf.Headers + } + if len(conf.FormValues) > 0 { + body := strings.NewReader(url.Values(conf.FormValues).Encode()) + conf.Request.Body = io.NopCloser(body) + conf.Request.ContentLength = int64(body.Len()) + + if conf.Request.Header.Get(echo.HeaderContentType) == "" { + conf.Request.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) + } + if isDefaultRequest { + conf.Request.Method = http.MethodPost + } + } else if conf.MultipartForm != nil { + var body bytes.Buffer + mw := multipart.NewWriter(&body) + for field, value := range conf.MultipartForm.Fields { + if err := mw.WriteField(field, value); err != nil { + t.Fatal(err) + } + } + for _, file := range conf.MultipartForm.Files { + fw, err := mw.CreateFormFile(file.Fieldname, file.Filename) + if err != nil { + t.Fatal(err) + } + if _, err = fw.Write(file.Content); err != nil { + t.Fatal(err) + } + } + if err := mw.Close(); err != nil { + t.Fatal(err) + } + + conf.Request.Body = io.NopCloser(&body) + conf.Request.ContentLength = int64(body.Len()) + if conf.Request.Header.Get(echo.HeaderContentType) == "" { + conf.Request.Header.Set(echo.HeaderContentType, mw.FormDataContentType()) + } + if isDefaultRequest { + conf.Request.Method = http.MethodPost + } + } else if conf.JSONBody != nil { + body := bytes.NewReader(conf.JSONBody) + conf.Request.Body = io.NopCloser(body) + conf.Request.ContentLength = int64(body.Len()) + + if conf.Request.Header.Get(echo.HeaderContentType) == "" { + conf.Request.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + } + if isDefaultRequest { + conf.Request.Method = http.MethodPost + } + } + + ec := echo.NewContext(conf.Request, conf.Response, echo.New()) + if conf.RouteInfo == nil { + conf.RouteInfo = &echo.RouteInfo{ + Name: "", + Method: conf.Request.Method, + Path: "/test", + Parameters: []string{}, + } + for _, p := range conf.PathValues { + conf.RouteInfo.Parameters = append(conf.RouteInfo.Parameters, p.Name) + } + } + ec.InitializeRoute(conf.RouteInfo, &conf.PathValues) + return ec, conf.Response +} + +// ServeWithHandler serves ContextConfig with given handler and returns httptest.ResponseRecorder for response checking +func (conf ContextConfig) ServeWithHandler(t *testing.T, handler echo.HandlerFunc, opts ...any) *httptest.ResponseRecorder { + c, rec := conf.ToContextRecorder(t) + + errHandler := echo.DefaultHTTPErrorHandler(false) + for _, opt := range opts { + switch o := opt.(type) { + case echo.HTTPErrorHandler: + errHandler = o + } + } + + err := handler(c) + if err != nil { + errHandler(c, err) + } + return rec +} diff --git a/echotest/context_external_test.go b/echotest/context_external_test.go new file mode 100644 index 000000000..d98257148 --- /dev/null +++ b/echotest/context_external_test.go @@ -0,0 +1,27 @@ +package echotest_test + +import ( + "net/http" + "testing" + + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/echotest" + "github.com/stretchr/testify/assert" +) + +func TestToContext_JSONBody(t *testing.T) { + c := echotest.ContextConfig{ + JSONBody: echotest.LoadBytes(t, "testdata/test.json"), + }.ToContext(t) + + payload := struct { + Field string `json:"field"` + }{} + if err := c.Bind(&payload); err != nil { + t.Fatal(err) + } + + assert.Equal(t, "value", payload.Field) + assert.Equal(t, http.MethodPost, c.Request().Method) + assert.Equal(t, echo.MIMEApplicationJSON, c.Request().Header.Get(echo.HeaderContentType)) +} diff --git a/echotest/context_test.go b/echotest/context_test.go new file mode 100644 index 000000000..66815e4b0 --- /dev/null +++ b/echotest/context_test.go @@ -0,0 +1,157 @@ +package echotest + +import ( + "net/http" + "net/url" + "strings" + "testing" + + "github.com/labstack/echo/v5" + "github.com/stretchr/testify/assert" +) + +func TestServeWithHandler(t *testing.T) { + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, c.QueryParam("key")) + } + testConf := ContextConfig{ + QueryValues: url.Values{"key": []string{"value"}}, + } + + resp := testConf.ServeWithHandler(t, handler) + + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, "value", resp.Body.String()) +} + +func TestServeWithHandler_error(t *testing.T) { + handler := func(c *echo.Context) error { + return echo.NewHTTPError(http.StatusBadRequest, "something went wrong") + } + testConf := ContextConfig{ + QueryValues: url.Values{"key": []string{"value"}}, + } + + customErrHandler := echo.DefaultHTTPErrorHandler(true) + + resp := testConf.ServeWithHandler(t, handler, customErrHandler) + + assert.Equal(t, http.StatusBadRequest, resp.Code) + assert.Equal(t, `{"message":"something went wrong"}`+"\n", resp.Body.String()) +} + +func TestToContext_QueryValues(t *testing.T) { + testConf := ContextConfig{ + QueryValues: url.Values{"t": []string{"2006-01-02"}}, + } + c := testConf.ToContext(t) + + v, err := echo.QueryParam[string](c, "t") + + assert.NoError(t, err) + assert.Equal(t, "2006-01-02", v) +} + +func TestToContext_Headers(t *testing.T) { + testConf := ContextConfig{ + Headers: http.Header{echo.HeaderXRequestID: []string{"ABC"}}, + } + c := testConf.ToContext(t) + + id := c.Request().Header.Get(echo.HeaderXRequestID) + + assert.Equal(t, "ABC", id) +} + +func TestToContext_PathValues(t *testing.T) { + testConf := ContextConfig{ + PathValues: echo.PathValues{{ + Name: "key", + Value: "value", + }}, + } + c := testConf.ToContext(t) + + key := c.Param("key") + + assert.Equal(t, "value", key) +} + +func TestToContext_RouteInfo(t *testing.T) { + testConf := ContextConfig{ + RouteInfo: &echo.RouteInfo{ + Name: "my_route", + Method: http.MethodGet, + Path: "/:id", + Parameters: []string{"id"}, + }, + } + c := testConf.ToContext(t) + + ri := c.RouteInfo() + + assert.Equal(t, echo.RouteInfo{ + Name: "my_route", + Method: http.MethodGet, + Path: "/:id", + Parameters: []string{"id"}, + }, ri) +} + +func TestToContext_FormValues(t *testing.T) { + testConf := ContextConfig{ + FormValues: url.Values{"key": []string{"value"}}, + } + c := testConf.ToContext(t) + + assert.Equal(t, "value", c.FormValue("key")) + assert.Equal(t, http.MethodPost, c.Request().Method) + assert.Equal(t, echo.MIMEApplicationForm, c.Request().Header.Get(echo.HeaderContentType)) +} + +func TestToContext_MultipartForm(t *testing.T) { + testConf := ContextConfig{ + MultipartForm: &MultipartForm{ + Fields: map[string]string{ + "key": "value", + }, + Files: []MultipartFormFile{ + { + Fieldname: "file", + Filename: "test.json", + Content: LoadBytes(t, "testdata/test.json"), + }, + }, + }, + } + c := testConf.ToContext(t) + + assert.Equal(t, "value", c.FormValue("key")) + assert.Equal(t, http.MethodPost, c.Request().Method) + assert.Equal(t, true, strings.HasPrefix(c.Request().Header.Get(echo.HeaderContentType), "multipart/form-data; boundary=")) + + fv, err := c.FormFile("file") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "test.json", fv.Filename) + assert.Equal(t, int64(23), fv.Size) +} + +func TestToContext_JSONBody(t *testing.T) { + testConf := ContextConfig{ + JSONBody: LoadBytes(t, "testdata/test.json"), + } + c := testConf.ToContext(t) + + payload := struct { + Field string `json:"field"` + }{} + if err := c.Bind(&payload); err != nil { + t.Fatal(err) + } + + assert.Equal(t, "value", payload.Field) + assert.Equal(t, http.MethodPost, c.Request().Method) + assert.Equal(t, echo.MIMEApplicationJSON, c.Request().Header.Get(echo.HeaderContentType)) +} diff --git a/echotest/reader.go b/echotest/reader.go new file mode 100644 index 000000000..0caceca02 --- /dev/null +++ b/echotest/reader.go @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echotest + +import ( + "os" + "path/filepath" + "runtime" + "testing" +) + +type loadBytesOpts func([]byte) []byte + +// TrimNewlineEnd instructs LoadBytes to remove `\n` from the end of loaded file. +func TrimNewlineEnd(bytes []byte) []byte { + bLen := len(bytes) + if bLen > 1 && bytes[bLen-1] == '\n' { + bytes = bytes[:bLen-1] + } + return bytes +} + +// LoadBytes is helper to load file contents relative to current (where test file is) package +// directory. +func LoadBytes(t *testing.T, name string, opts ...loadBytesOpts) []byte { + bytes := loadBytes(t, name, 2) + + for _, f := range opts { + bytes = f(bytes) + } + + return bytes +} + +func loadBytes(t *testing.T, name string, callDepth int) []byte { + _, b, _, _ := runtime.Caller(callDepth) + basepath := filepath.Dir(b) + + path := filepath.Join(basepath, name) // relative path + bytes, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + return bytes[:] +} diff --git a/echotest/reader_external_test.go b/echotest/reader_external_test.go new file mode 100644 index 000000000..43fd57416 --- /dev/null +++ b/echotest/reader_external_test.go @@ -0,0 +1,25 @@ +package echotest_test + +import ( + "strings" + "testing" + + "github.com/labstack/echo/v5/echotest" + "github.com/stretchr/testify/assert" +) + +const testJSONContent = `{ + "field": "value" +}` + +func TestLoadBytesOK(t *testing.T) { + data := echotest.LoadBytes(t, "testdata/test.json") + assert.Equal(t, []byte(testJSONContent+"\n"), data) +} + +func TestLoadBytes_custom(t *testing.T) { + data := echotest.LoadBytes(t, "testdata/test.json", func(bytes []byte) []byte { + return []byte(strings.ToUpper(string(bytes))) + }) + assert.Equal(t, []byte(strings.ToUpper(testJSONContent)+"\n"), data) +} diff --git a/echotest/reader_test.go b/echotest/reader_test.go new file mode 100644 index 000000000..23b3c2dd2 --- /dev/null +++ b/echotest/reader_test.go @@ -0,0 +1,21 @@ +package echotest + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +const testJSONContent = `{ + "field": "value" +}` + +func TestLoadBytesOK(t *testing.T) { + data := LoadBytes(t, "testdata/test.json") + assert.Equal(t, []byte(testJSONContent+"\n"), data) +} + +func TestLoadBytesOK_TrimNewlineEnd(t *testing.T) { + data := LoadBytes(t, "testdata/test.json", TrimNewlineEnd) + assert.Equal(t, []byte(testJSONContent), data) +} diff --git a/echotest/testdata/test.json b/echotest/testdata/test.json new file mode 100644 index 000000000..94ae65f17 --- /dev/null +++ b/echotest/testdata/test.json @@ -0,0 +1,3 @@ +{ + "field": "value" +} diff --git a/go.mod b/go.mod index c5db2ae1a..a2480a285 100644 --- a/go.mod +++ b/go.mod @@ -1,17 +1,16 @@ -module github.com/labstack/echo/v4 +module github.com/labstack/echo/v5 -go 1.12 +go 1.25.0 require ( - github.com/dgrijalva/jwt-go v3.2.0+incompatible - github.com/labstack/echo v3.3.10+incompatible // indirect - github.com/labstack/gommon v0.3.0 - github.com/mattn/go-colorable v0.1.4 // indirect - github.com/mattn/go-isatty v0.0.11 // indirect - github.com/stretchr/testify v1.4.0 - github.com/valyala/fasttemplate v1.1.0 - golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 - golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 // indirect - golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8 // indirect - golang.org/x/text v0.3.2 // indirect + github.com/stretchr/testify v1.11.1 + golang.org/x/net v0.49.0 + golang.org/x/time v0.14.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/text v0.33.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 57c79877e..f1e80fc13 100644 --- a/go.sum +++ b/go.sum @@ -1,62 +1,16 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= -github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= -github.com/labstack/echo v3.3.10+incompatible h1:pGRcYk231ExFAyoAjAfD85kQzRJCRI8bbnE7CX5OEgg= -github.com/labstack/echo v3.3.10+incompatible/go.mod h1:0INS7j/VjnFxD4E2wkz67b8cVwCLbBmJyDaka6Cmk1s= -github.com/labstack/gommon v0.3.0 h1:JEeO0bvc78PKdyHxloTKiF8BD5iGrH8T6MSeGvSgob0= -github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= -github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= -github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= -github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.9 h1:d5US/mDsogSGW37IV293h//ZFaeajb69h+EHFsv2xGg= -github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= -github.com/mattn/go-isatty v0.0.10 h1:qxFzApOv4WsAL965uUPIsXzAKCZxN2p9UqdhFS4ZW10= -github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= -github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM= -github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= -github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasttemplate v1.0.1 h1:tY9CJiPnMXf1ERmG2EyK7gNUd+c6RKGD0IfU8WdUSz8= -github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= -github.com/valyala/fasttemplate v1.1.0 h1:RZqt0yGBsps8NGvLSGW804QQqCUYYLsaOjTVHy1Ocw4= -github.com/valyala/fasttemplate v1.1.0/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 h1:HuIa8hRrWRSrqYzx1qI49NNxhdi2PrY7gxVSq1JjLDc= -golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 h1:sKJQZMuxjOAR/Uo2LBfU90onWEf1dF4C+0hPJCc9Mpc= -golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20191021144547-ec77196f6094 h1:5O4U9trLjNpuhpynaDsqwCk+Tw6seqJz1EbqbnzHrc8= -golang.org/x/net v0.0.0-20191021144547-ec77196f6094/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 h1:efeOvDhwQ29Dj3SdAV/MJf8oukgn+8D8WgaCaRMchF8= -golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a h1:aYOabOQFp6Vj6W1F80affTUvO9UxmJRx8K0gsfABByQ= -golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191024172528-b4ff53e7a1cb h1:ZxSglHghKPYD8WDeRUzRJrUJtDF0PxsTUSxyqr9/5BI= -golang.org/x/sys v0.0.0-20191024172528-b4ff53e7a1cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8 h1:JA8d3MPx/IToSyXZG/RhwYEtfrKO1Fxrqe8KrkiLXKM= -golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/group.go b/group.go index 5d9582535..d81cd9163 100644 --- a/group.go +++ b/group.go @@ -1,124 +1,172 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( + "io/fs" "net/http" ) -type ( - // Group is a set of sub-routes for a specified route. It can be used for inner - // routes that share a common middleware or functionality that should be separate - // from the parent echo instance while still inheriting from it. - Group struct { - common - host string - prefix string - middleware []MiddlewareFunc - echo *Echo - } -) +// Group is a set of sub-routes for a specified route. It can be used for inner +// routes that share a common middleware or functionality that should be separate +// from the parent echo instance while still inheriting from it. +type Group struct { + echo *Echo + prefix string + middleware []MiddlewareFunc +} // Use implements `Echo#Use()` for sub-routes within the Group. +// Group middlewares are not executed on request when there is no matching route found. func (g *Group) Use(middleware ...MiddlewareFunc) { g.middleware = append(g.middleware, middleware...) - if len(g.middleware) == 0 { - return - } - // Allow all requests to reach the group as they might get dropped if router - // doesn't find a match, making none of the group middleware process. - g.Any("", NotFoundHandler) - g.Any("/*", NotFoundHandler) } -// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. -func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error. +func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodConnect, path, h, m...) } -// DELETE implements `Echo#DELETE()` for sub-routes within the Group. -func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error. +func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodDelete, path, h, m...) } -// GET implements `Echo#GET()` for sub-routes within the Group. -func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// GET implements `Echo#GET()` for sub-routes within the Group. Panics on error. +func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodGet, path, h, m...) } -// HEAD implements `Echo#HEAD()` for sub-routes within the Group. -func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error. +func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodHead, path, h, m...) } -// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. -func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error. +func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodOptions, path, h, m...) } -// PATCH implements `Echo#PATCH()` for sub-routes within the Group. -func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error. +func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPatch, path, h, m...) } -// POST implements `Echo#POST()` for sub-routes within the Group. -func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// POST implements `Echo#POST()` for sub-routes within the Group. Panics on error. +func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPost, path, h, m...) } -// PUT implements `Echo#PUT()` for sub-routes within the Group. -func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error. +func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPut, path, h, m...) } -// TRACE implements `Echo#TRACE()` for sub-routes within the Group. -func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error. +func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodTrace, path, h, m...) } -// Any implements `Echo#Any()` for sub-routes within the Group. -func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = g.Add(m, path, handler, middleware...) +// Any implements `Echo#Any()` for sub-routes within the Group. Panics on error. +func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { + return g.Add(RouteAny, path, handler, middleware...) +} + +// Match implements `Echo#Match()` for sub-routes within the Group. Panics on error. +func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := g.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) } - return routes -} - -// Match implements `Echo#Match()` for sub-routes within the Group. -func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = g.Add(m, path, handler, middleware...) + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage } - return routes + return ris } // Group creates a new sub-group with prefix and optional sub-group-level middleware. +// Important! Group middlewares are only executed in case there was exact route match and not +// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add +// a catch-all route `/*` for the group which handler returns always 404 func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) { m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) m = append(m, g.middleware...) m = append(m, middleware...) sg = g.echo.Group(g.prefix+prefix, m...) - sg.host = g.host return } // Static implements `Echo#Static()` for sub-routes within the Group. -func (g *Group) Static(prefix, root string) { - g.static(prefix, root, g.GET) -} - -// File implements `Echo#File()` for sub-routes within the Group. -func (g *Group) File(path, file string) { - g.file(g.prefix+path, file, g.GET) +func (g *Group) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo { + subFs := MustSubFS(g.echo.Filesystem, fsRoot) + return g.StaticFS(pathPrefix, subFs, middleware...) +} + +// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo { + return g.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(filesystem, false), + middleware..., + ) +} + +// FileFS implements `Echo#FileFS()` for sub-routes within the Group. +func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo { + return g.GET(path, StaticFileHandler(file, filesystem), m...) +} + +// File implements `Echo#File()` for sub-routes within the Group. Panics on error. +func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo { + handler := func(c *Context) error { + return c.File(file) + } + return g.Add(http.MethodGet, path, handler, middleware...) +} + +// RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group. +// +// Example: `g.RouteNotFound("/*", func(c *echo.Context) error { return c.NoContent(http.StatusNotFound) })` +func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { + return g.Add(RouteNotFound, path, h, m...) +} + +// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error. +func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { + ri, err := g.AddRoute(Route{ + Method: method, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ri } -// Add implements `Echo#Add()` for sub-routes within the Group. -func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - // Combine into a new slice to avoid accidentally passing the same slice for +// AddRoute registers a new Routable with Router +func (g *Group) AddRoute(route Route) (RouteInfo, error) { + // Combine middleware into a new slice to avoid accidentally passing the same slice for // multiple routes, which would lead to later add() calls overwriting the // middleware from earlier calls. - m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) - m = append(m, g.middleware...) - m = append(m, middleware...) - return g.echo.add(g.host, method, g.prefix+path, handler, m...) + groupRoute := route.WithPrefix(g.prefix, append([]MiddlewareFunc{}, g.middleware...)) + return g.echo.add(groupRoute) } diff --git a/group_test.go b/group_test.go index 342cd29e2..7078b6497 100644 --- a/group_test.go +++ b/group_test.go @@ -1,58 +1,115 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package echo import ( + "io/fs" "net/http" + "net/http/httptest" + "os" + "strings" "testing" "github.com/stretchr/testify/assert" ) -// TODO: Fix me -func TestGroup(t *testing.T) { - g := New().Group("/group") - h := func(Context) error { return nil } - g.CONNECT("/", h) - g.DELETE("/", h) - g.GET("/", h) - g.HEAD("/", h) - g.OPTIONS("/", h) - g.PATCH("/", h) - g.POST("/", h) - g.PUT("/", h) - g.TRACE("/", h) - g.Any("/", h) - g.Match([]string{http.MethodGet, http.MethodPost}, "/", h) - g.Static("/static", "/tmp") - g.File("/walle", "_fixture/images//walle.png") +func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) { + e := New() + + called := false + mw := func(next HandlerFunc) HandlerFunc { + return func(c *Context) error { + called = true + return c.NoContent(http.StatusTeapot) + } + } + // even though group has middleware it will not be executed when there are no routes under that group + _ = e.Group("/group", mw) + + status, body := request(http.MethodGet, "/group/nope", e) + assert.Equal(t, http.StatusNotFound, status) + assert.Equal(t, `{"message":"Not Found"}`+"\n", body) + + assert.False(t, called) +} + +func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) { + e := New() + + called := false + mw := func(next HandlerFunc) HandlerFunc { + return func(c *Context) error { + called = true + return c.NoContent(http.StatusTeapot) + } + } + // even though group has middleware and routes when we have no match on some route the middlewares for that + // group will not be executed + g := e.Group("/group", mw) + g.GET("/yes", handlerFunc) + + status, body := request(http.MethodGet, "/group/nope", e) + assert.Equal(t, http.StatusNotFound, status) + assert.Equal(t, `{"message":"Not Found"}`+"\n", body) + + assert.False(t, called) +} + +func TestGroup_multiLevelGroup(t *testing.T) { + e := New() + + api := e.Group("/api") + users := api.Group("/users") + users.GET("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + status, body := request(http.MethodGet, "/api/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroupFile(t *testing.T) { + e := New() + g := e.Group("/group") + g.File("/walle", "_fixture/images/walle.png") + expectedData, err := os.ReadFile("_fixture/images/walle.png") + assert.Nil(t, err) + req := httptest.NewRequest(http.MethodGet, "/group/walle", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, expectedData, rec.Body.Bytes()) } func TestGroupRouteMiddleware(t *testing.T) { // Ensure middleware slices are not re-used e := New() g := e.Group("/group") - h := func(Context) error { return nil } + h := func(*Context) error { return nil } m1 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { return next(c) } } m2 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { return next(c) } } m3 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { return next(c) } } m4 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { return c.NoContent(404) } } m5 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { return c.NoContent(405) } } @@ -71,17 +128,17 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { e := New() g := e.Group("/group") m1 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { return next(c) } } m2 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - return c.String(http.StatusOK, c.Path()) + return func(c *Context) error { + return c.String(http.StatusOK, c.RouteInfo().Path) } } - h := func(c Context) error { - return c.String(http.StatusOK, c.Path()) + h := func(c *Context) error { + return c.String(http.StatusOK, c.RouteInfo().Path) } g.Use(m1) g.GET("/help", h, m2) @@ -104,3 +161,654 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { assert.Equal(t, "/*", m) } + +func TestGroup_CONNECT(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.CONNECT("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodConnect, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodConnect, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_DELETE(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.DELETE("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodDelete, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodDelete, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_HEAD(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.HEAD("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodHead, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, http.MethodHead+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodHead, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_OPTIONS(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.OPTIONS("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodOptions, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodOptions, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_PATCH(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.PATCH("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPatch, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodPatch, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_POST(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.POST("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPost, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, http.MethodPost+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodPost, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_PUT(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.PUT("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPut, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, http.MethodPut+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodPut, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_TRACE(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.TRACE("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodTrace, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodTrace, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_RouteNotFound(t *testing.T) { + var testCases = []struct { + expectRoute any + name string + whenURL string + expectCode int + }{ + { + name: "404, route to static not found handler /group/a/c/xx", + whenURL: "/group/a/c/xx", + expectRoute: "GET /group/a/c/xx", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to path param not found handler /group/a/:file", + whenURL: "/group/a/echo.exe", + expectRoute: "GET /group/a/:file", + expectCode: http.StatusNotFound, + }, + { + name: "404, route to any not found handler /group/*", + whenURL: "/group/b/echo.exe", + expectRoute: "GET /group/*", + expectCode: http.StatusNotFound, + }, + { + name: "200, route /group/a/c/df to /group/a/c/df", + whenURL: "/group/a/c/df", + expectRoute: "GET /group/a/c/df", + expectCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + g := e.Group("/group") + + okHandler := func(c *Context) error { + return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) + } + notFoundHandler := func(c *Context) error { + return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) + } + + g.GET("/", okHandler) + g.GET("/a/c/df", okHandler) + g.GET("/a/b*", okHandler) + g.PUT("/*", okHandler) + + g.RouteNotFound("/a/c/xx", notFoundHandler) // static + g.RouteNotFound("/a/:file", notFoundHandler) // param + g.RouteNotFound("/*", notFoundHandler) // any + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectRoute, rec.Body.String()) + }) + } +} + +func TestGroup_Any(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.Any("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK from ANY") + }) + + assert.Equal(t, RouteAny, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, RouteAny+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodTrace, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK from ANY`, body) +} + +func TestGroup_Match(t *testing.T) { + e := New() + + myMethods := []string{http.MethodGet, http.MethodPost} + users := e.Group("/users") + ris := users.Match(myMethods, "/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + assert.Len(t, ris, 2) + + for _, m := range myMethods { + status, body := request(m, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_MatchWithErrors(t *testing.T) { + e := New() + + users := e.Group("/users") + users.GET("/activate", func(c *Context) error { + return c.String(http.StatusOK, "OK") + }) + myMethods := []string{http.MethodGet, http.MethodPost} + + errs := func() (errs []error) { + defer func() { + if r := recover(); r != nil { + if tmpErr, ok := r.([]error); ok { + errs = tmpErr + return + } + panic(r) + } + }() + + users.Match(myMethods, "/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + return nil + }() + assert.Len(t, errs, 1) + assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed") + + for _, m := range myMethods { + status, body := request(m, "/users/activate", e) + + expect := http.StatusTeapot + if m == http.MethodGet { + expect = http.StatusOK + } + assert.Equal(t, expect, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_Static(t *testing.T) { + e := New() + + g := e.Group("/books") + ri := g.Static("/download", "_fixture") + assert.Equal(t, http.MethodGet, ri.Method) + assert.Equal(t, "/books/download*", ri.Path) + assert.Equal(t, "GET:/books/download*", ri.Name) + assert.Equal(t, []string{"*"}, ri.Parameters) + + req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + assert.True(t, strings.HasPrefix(body, "")) +} + +func TestGroup_StaticMultiTest(t *testing.T) { + var testCases = []struct { + name string + givenPrefix string + givenRoot string + whenURL string + expectHeaderLocation string + expectBodyStartsWith string + expectBodyNotContains string + expectStatus int + }{ + { + name: "ok", + givenPrefix: "/images", + givenRoot: "_fixture/images", + whenURL: "/test/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "ok, without prefix", + givenPrefix: "", + givenRoot: "_fixture/images", + whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png` + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "nok, without prefix does not serve dir index", + givenPrefix: "", + givenRoot: "_fixture/images", + whenURL: "/test/", // `/test` + `*` creates route `/test*` + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "No file", + givenPrefix: "/images", + givenRoot: "_fixture/scripts", + whenURL: "/test/images/bolt.png", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory", + givenPrefix: "/images", + givenRoot: "_fixture/images", + whenURL: "/test/images/", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory Redirect", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/folder", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory Redirect with non-root path", + givenPrefix: "/static", + givenRoot: "_fixture", + whenURL: "/test/static", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/static/", + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory 404 (request URL without slash)", + givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" + givenRoot: "_fixture", + whenURL: "/test/folder", // no trailing slash + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Prefixed directory redirect (without slash redirect to slash)", + givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* + givenRoot: "_fixture", + whenURL: "/test/folder", // no trailing slash + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory with index.html", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending with slash)", + givenPrefix: "/assets/", + givenRoot: "_fixture", + whenURL: "/test/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending without slash)", + givenPrefix: "/assets", + givenRoot: "_fixture", + whenURL: "/test/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Sub-directory with index.html", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/folder/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "nok, URL encoded path traversal (single encoding, slash - unix separator)", + givenRoot: "_fixture/dist/public", + whenURL: "/%2e%2e%2fprivate.txt", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + expectBodyNotContains: `private file`, + }, + { + name: "nok, URL encoded path traversal (single encoding, backslash - windows separator)", + givenRoot: "_fixture/dist/public", + whenURL: "/%2e%2e%5cprivate.txt", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + expectBodyNotContains: `private file`, + }, + { + name: "do not allow directory traversal (backslash - windows separator)", + givenPrefix: "/", + givenRoot: "_fixture/", + whenURL: `/test/..\\middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "do not allow directory traversal (slash - unix separator)", + givenPrefix: "/", + givenRoot: "_fixture/", + whenURL: `/test/../middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + g := e.Group("/test") + g.Static(tc.givenPrefix, tc.givenRoot) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectStatus, rec.Code) + body := rec.Body.String() + if tc.expectBodyStartsWith != "" { + assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) + } else { + assert.Equal(t, "", body) + } + if tc.expectBodyNotContains != "" { + assert.NotContains(t, body, tc.expectBodyNotContains) + } + + if tc.expectHeaderLocation != "" { + assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) + } else { + _, ok := rec.Result().Header["Location"] + assert.False(t, ok) + } + }) + } +} + +func TestGroup_FileFS(t *testing.T) { + var testCases = []struct { + whenFS fs.FS + name string + whenPath string + whenFile string + givenURL string + expectStartsWith []byte + expectCode int + }{ + { + name: "ok", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/assets/walle", + expectCode: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, requesting invalid path", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/assets/walle.png", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + { + name: "nok, serving not existent file from filesystem", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "not-existent.png", + givenURL: "/assets/walle", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + g := e.Group("/assets") + g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) + + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestGroup_StaticPanic(t *testing.T) { + var testCases = []struct { + name string + givenRoot string + }{ + { + name: "panics for ../", + givenRoot: "../images", + }, + { + name: "panics for /", + givenRoot: "/images", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Filesystem = os.DirFS("./") + + g := e.Group("/assets") + + assert.Panics(t, func() { + g.Static("/images", tc.givenRoot) + }) + }) + } +} + +func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) { + var testCases = []struct { + expectBody any + name string + whenURL string + expectCode int + givenCustom404 bool + expectMiddlewareCalled bool + }{ + { + name: "ok, custom 404 handler is called with middleware", + givenCustom404: true, + whenURL: "/group/test3", + expectBody: "404 GET /group/*", + expectCode: http.StatusNotFound, + expectMiddlewareCalled: true, // because RouteNotFound is added after middleware is added + }, + { + name: "ok, default group 404 handler is not called with middleware", + givenCustom404: false, + whenURL: "/group/test3", + expectBody: "404 GET /*", + expectCode: http.StatusNotFound, + expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added + }, + { + name: "ok, (no slash) default group 404 handler is called with middleware", + givenCustom404: false, + whenURL: "/group", + expectBody: "404 GET /*", + expectCode: http.StatusNotFound, + expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + okHandler := func(c *Context) error { + return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) + } + notFoundHandler := func(c *Context) error { + return c.String(http.StatusNotFound, "404 "+c.Request().Method+" "+c.Path()) + } + + e := New() + e.GET("/test1", okHandler) + e.RouteNotFound("/*", notFoundHandler) + + g := e.Group("/group") + g.GET("/test1", okHandler) + + middlewareCalled := false + g.Use(func(next HandlerFunc) HandlerFunc { + return func(c *Context) error { + middlewareCalled = true + return next(c) + } + }) + if tc.givenCustom404 { + g.RouteNotFound("/*", notFoundHandler) + } + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectMiddlewareCalled, middlewareCalled) + assert.Equal(t, tc.expectCode, rec.Code) + assert.Equal(t, tc.expectBody, rec.Body.String()) + }) + } +} diff --git a/httperror.go b/httperror.go new file mode 100644 index 000000000..6e14da3d9 --- /dev/null +++ b/httperror.go @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "errors" + "fmt" + "net/http" +) + +// HTTPStatusCoder is interface that errors can implement to produce status code for HTTP response +type HTTPStatusCoder interface { + StatusCode() int +} + +// StatusCode returns status code from error if it implements HTTPStatusCoder interface. +// If error does not implement the interface it returns 0. +func StatusCode(err error) int { + var sc HTTPStatusCoder + if errors.As(err, &sc) { + return sc.StatusCode() + } + return 0 +} + +// Following errors can produce HTTP status code by implementing HTTPStatusCoder interface +var ( + ErrBadRequest = &httpError{http.StatusBadRequest} // 400 + ErrUnauthorized = &httpError{http.StatusUnauthorized} // 401 + ErrForbidden = &httpError{http.StatusForbidden} // 403 + ErrNotFound = &httpError{http.StatusNotFound} // 404 + ErrMethodNotAllowed = &httpError{http.StatusMethodNotAllowed} // 405 + ErrRequestTimeout = &httpError{http.StatusRequestTimeout} // 408 + ErrStatusRequestEntityTooLarge = &httpError{http.StatusRequestEntityTooLarge} // 413 + ErrUnsupportedMediaType = &httpError{http.StatusUnsupportedMediaType} // 415 + ErrTooManyRequests = &httpError{http.StatusTooManyRequests} // 429 + ErrInternalServerError = &httpError{http.StatusInternalServerError} // 500 + ErrBadGateway = &httpError{http.StatusBadGateway} // 502 + ErrServiceUnavailable = &httpError{http.StatusServiceUnavailable} // 503 +) + +// Following errors fall into 500 (InternalServerError) category +var ( + ErrValidatorNotRegistered = errors.New("validator not registered") + ErrRendererNotRegistered = errors.New("renderer not registered") + ErrInvalidRedirectCode = errors.New("invalid redirect status code") + ErrCookieNotFound = errors.New("cookie not found") + ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") + ErrInvalidListenerNetwork = errors.New("invalid listener network") +) + +// NewHTTPError creates new instance of HTTPError +func NewHTTPError(code int, message string) *HTTPError { + return &HTTPError{ + Code: code, + Message: message, + } +} + +// HTTPError represents an error that occurred while handling a request. +type HTTPError struct { + // Code is status code for HTTP response + Code int `json:"-"` + Message string `json:"message"` + err error +} + +// StatusCode returns status code for HTTP response +func (he *HTTPError) StatusCode() int { + return he.Code +} + +// Error makes it compatible with `error` interface. +func (he *HTTPError) Error() string { + msg := he.Message + if msg == "" { + msg = http.StatusText(he.Code) + } + if he.err == nil { + return fmt.Sprintf("code=%d, message=%v", he.Code, msg) + } + return fmt.Sprintf("code=%d, message=%v, err=%v", he.Code, msg, he.err.Error()) +} + +// Wrap eturns new HTTPError with given errors wrapped inside +func (he HTTPError) Wrap(err error) error { + return &HTTPError{ + Code: he.Code, + Message: he.Message, + err: err, + } +} + +func (he *HTTPError) Unwrap() error { + return he.err +} + +type httpError struct { + code int +} + +func (he httpError) StatusCode() int { + return he.code +} + +func (he httpError) Error() string { + return http.StatusText(he.code) // does not include status code +} + +func (he httpError) Wrap(err error) error { + return &HTTPError{ + Code: he.code, + Message: http.StatusText(he.code), + err: err, + } +} diff --git a/httperror_external_test.go b/httperror_external_test.go new file mode 100644 index 000000000..91acdca25 --- /dev/null +++ b/httperror_external_test.go @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +// run tests as external package to get real feel for API +package echo_test + +import ( + "encoding/json" + "fmt" + "github.com/labstack/echo/v5" + "net/http" + "net/http/httptest" +) + +func ExampleDefaultHTTPErrorHandler() { + e := echo.New() + e.GET("/api/endpoint", func(c *echo.Context) error { + return &apiError{ + Code: http.StatusBadRequest, + Body: map[string]any{"message": "custom error"}, + } + }) + + req := httptest.NewRequest(http.MethodGet, "/api/endpoint?err=1", nil) + resp := httptest.NewRecorder() + + e.ServeHTTP(resp, req) + + fmt.Printf("%d %s", resp.Code, resp.Body.String()) + + // Output: 400 {"error":{"message":"custom error"}} +} + +type apiError struct { + Code int + Body any +} + +func (e *apiError) StatusCode() int { + return e.Code +} + +func (e *apiError) MarshalJSON() ([]byte, error) { + type body struct { + Error any `json:"error"` + } + return json.Marshal(body{Error: e.Body}) +} + +func (e *apiError) Error() string { + return http.StatusText(e.Code) +} diff --git a/httperror_test.go b/httperror_test.go new file mode 100644 index 000000000..0a91bbc9c --- /dev/null +++ b/httperror_test.go @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "errors" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHTTPError_StatusCode(t *testing.T) { + var err error = &HTTPError{Code: http.StatusBadRequest, Message: "my error message"} + + code := 0 + var sc HTTPStatusCoder + if errors.As(err, &sc) { + code = sc.StatusCode() + } + assert.Equal(t, http.StatusBadRequest, code) +} + +func TestHTTPError_Error(t *testing.T) { + var testCases = []struct { + name string + error error + expect string + }{ + { + name: "ok, without message", + error: &HTTPError{Code: http.StatusBadRequest}, + expect: "code=400, message=Bad Request", + }, + { + name: "ok, with message", + error: &HTTPError{Code: http.StatusBadRequest, Message: "my error message"}, + expect: "code=400, message=my error message", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expect, tc.error.Error()) + }) + } +} + +func TestHTTPError_WrapUnwrap(t *testing.T) { + err := &HTTPError{Code: http.StatusBadRequest, Message: "bad"} + wrapped := err.Wrap(errors.New("my_error")).(*HTTPError) + + err.Code = http.StatusOK + err.Message = "changed" + + assert.Equal(t, http.StatusBadRequest, wrapped.Code) + assert.Equal(t, "bad", wrapped.Message) + + assert.Equal(t, errors.New("my_error"), wrapped.Unwrap()) + assert.Equal(t, "code=400, message=bad, err=my_error", wrapped.Error()) +} + +func TestNewHTTPError(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, "bad") + err2 := &HTTPError{Code: http.StatusBadRequest, Message: "bad"} + + assert.Equal(t, err2, err) +} + +func TestStatusCode(t *testing.T) { + var testCases = []struct { + name string + err error + expect int + }{ + { + name: "ok, HTTPError", + err: &HTTPError{Code: http.StatusNotFound}, + expect: http.StatusNotFound, + }, + { + name: "ok, sentinel error", + err: ErrNotFound, + expect: http.StatusNotFound, + }, + { + name: "ok, wrapped HTTPError", + err: fmt.Errorf("wrapped: %w", &HTTPError{Code: http.StatusTeapot}), + expect: http.StatusTeapot, + }, + { + name: "nok, normal error", + err: errors.New("error"), + expect: 0, + }, + { + name: "nok, nil", + err: nil, + expect: 0, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expect, StatusCode(tc.err)) + }) + } +} diff --git a/ip.go b/ip.go new file mode 100644 index 000000000..e2b287bfd --- /dev/null +++ b/ip.go @@ -0,0 +1,267 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "net" + "net/http" + "strings" +) + +/** +By: https://github.com/tmshn (See: https://github.com/labstack/echo/pull/1478 , https://github.com/labstack/echox/pull/134 ) +Source: https://echo.labstack.com/guide/ip-address/ + +IP address plays fundamental role in HTTP; it's used for access control, auditing, geo-based access analysis and more. +Echo provides handy method [`Context#RealIP()`](https://godoc.org/github.com/labstack/echo#Context) for that. + +However, it is not trivial to retrieve the _real_ IP address from requests especially when you put L7 proxies before the application. +In such situation, _real_ IP needs to be relayed on HTTP layer from proxies to your app, but you must not trust HTTP headers unconditionally. +Otherwise, you might give someone a chance of deceiving you. **A security risk!** + +To retrieve IP address reliably/securely, you must let your application be aware of the entire architecture of your infrastructure. +In Echo, this can be done by configuring `Echo#IPExtractor` appropriately. +This guides show you why and how. + +> Note: if you don't set `Echo#IPExtractor` explicitly, Echo fallback to legacy behavior, which is not a good choice. + +Let's start from two questions to know the right direction: + +1. Do you put any HTTP (L7) proxy in front of the application? + - It includes both cloud solutions (such as AWS ALB or GCP HTTP LB) and OSS ones (such as Nginx, Envoy or Istio ingress gateway). +2. If yes, what HTTP header do your proxies use to pass client IP to the application? + +## Case 1. With no proxy + +If you put no proxy (e.g.: directory facing to the internet), all you need to (and have to) see is IP address from network layer. +Any HTTP header is untrustable because the clients have full control what headers to be set. + +In this case, use `echo.ExtractIPDirect()`. + +```go +e.IPExtractor = echo.ExtractIPDirect() +``` + +## Case 2. With proxies using `X-Forwarded-For` header + +[`X-Forwared-For` (XFF)](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For) is the popular header +to relay clients' IP addresses. +At each hop on the proxies, they append the request IP address at the end of the header. + +Following example diagram illustrates this behavior. + +```text +┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ +│ "Origin" │───────────>│ Proxy 1 │───────────>│ Proxy 2 │───────────>│ Your app │ +│ (IP: a) │ │ (IP: b) │ │ (IP: c) │ │ │ +└──────────┘ └──────────┘ └──────────┘ └──────────┘ + +Case 1. +XFF: "" "a" "a, b" + ~~~~~~ +Case 2. +XFF: "x" "x, a" "x, a, b" + ~~~~~~~~~ + ↑ What your app will see +``` + +In this case, use **first _untrustable_ IP reading from right**. Never use first one reading from left, as it is +configurable by client. Here "trustable" means "you are sure the IP address belongs to your infrastructure". +In above example, if `b` and `c` are trustable, the IP address of the client is `a` for both cases, never be `x`. + +In Echo, use `ExtractIPFromXFFHeader(...TrustOption)`. + +```go +e.IPExtractor = echo.ExtractIPFromXFFHeader() +``` + +By default, it trusts internal IP addresses (loopback, link-local unicast, private-use and unique local address +from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and +[RFC4193](https://tools.ietf.org/html/rfc4193)). +To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s. + +E.g.: + +```go +e.IPExtractor = echo.ExtractIPFromXFFHeader( + TrustLinkLocal(false), + TrustIPRanges(lbIPRange), +) +``` + +- Ref: https://godoc.org/github.com/labstack/echo#TrustOption + +## Case 3. With proxies using `X-Real-IP` header + +`X-Real-IP` is another HTTP header to relay clients' IP addresses, but it carries only one address unlike XFF. + +If your proxies set this header, use `ExtractIPFromRealIPHeader(...TrustOption)`. + +```go +e.IPExtractor = echo.ExtractIPFromRealIPHeader() +``` + +Again, it trusts internal IP addresses by default (loopback, link-local unicast, private-use and unique local address +from [RFC6890](https://tools.ietf.org/html/rfc6890), [RFC4291](https://tools.ietf.org/html/rfc4291) and +[RFC4193](https://tools.ietf.org/html/rfc4193)). +To control this behavior, use [`TrustOption`](https://godoc.org/github.com/labstack/echo#TrustOption)s. + +- Ref: https://godoc.org/github.com/labstack/echo#TrustOption + +> **Never forget** to configure the outermost proxy (i.e.; at the edge of your infrastructure) **not to pass through incoming headers**. +> Otherwise there is a chance of fraud, as it is what clients can control. + +## About default behavior + +In default behavior, Echo sees all of first XFF header, X-Real-IP header and IP from network layer. + +As you might already notice, after reading this article, this is not good. +Sole reason this is default is just backward compatibility. + +## Private IP ranges + +See: https://en.wikipedia.org/wiki/Private_network + +Private IPv4 address ranges (RFC 1918): +* 10.0.0.0 – 10.255.255.255 (24-bit block) +* 172.16.0.0 – 172.31.255.255 (20-bit block) +* 192.168.0.0 – 192.168.255.255 (16-bit block) + +Private IPv6 address ranges: +* fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA) + +*/ + +type ipChecker struct { + trustExtraRanges []*net.IPNet + trustLoopback bool + trustLinkLocal bool + trustPrivateNet bool +} + +// TrustOption is config for which IP address to trust +type TrustOption func(*ipChecker) + +// TrustLoopback configures if you trust loopback address (default: true). +func TrustLoopback(v bool) TrustOption { + return func(c *ipChecker) { + c.trustLoopback = v + } +} + +// TrustLinkLocal configures if you trust link-local address (default: true). +func TrustLinkLocal(v bool) TrustOption { + return func(c *ipChecker) { + c.trustLinkLocal = v + } +} + +// TrustPrivateNet configures if you trust private network address (default: true). +func TrustPrivateNet(v bool) TrustOption { + return func(c *ipChecker) { + c.trustPrivateNet = v + } +} + +// TrustIPRange add trustable IP ranges using CIDR notation. +func TrustIPRange(ipRange *net.IPNet) TrustOption { + return func(c *ipChecker) { + c.trustExtraRanges = append(c.trustExtraRanges, ipRange) + } +} + +func newIPChecker(configs []TrustOption) *ipChecker { + checker := &ipChecker{trustLoopback: true, trustLinkLocal: true, trustPrivateNet: true} + for _, configure := range configs { + configure(checker) + } + return checker +} + +func (c *ipChecker) trust(ip net.IP) bool { + if c.trustLoopback && ip.IsLoopback() { + return true + } + if c.trustLinkLocal && ip.IsLinkLocalUnicast() { + return true + } + if c.trustPrivateNet && ip.IsPrivate() { + return true + } + for _, trustedRange := range c.trustExtraRanges { + if trustedRange.Contains(ip) { + return true + } + } + return false +} + +// IPExtractor is a function to extract IP addr from http.Request. +// Set appropriate one to Echo#IPExtractor. +// See https://echo.labstack.com/guide/ip-address for more details. +type IPExtractor func(*http.Request) string + +// ExtractIPDirect extracts IP address using actual IP address. +// Use this if your server faces to internet directory (i.e.: uses no proxy). +func ExtractIPDirect() IPExtractor { + return extractIP +} + +func extractIP(req *http.Request) string { + host, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + if net.ParseIP(req.RemoteAddr) != nil { + return req.RemoteAddr + } + return "" + } + return host +} + +// ExtractIPFromRealIPHeader extracts IP address using x-real-ip header. +// Use this if you put proxy which uses this header. +func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { + checker := newIPChecker(options) + return func(req *http.Request) string { + realIP := req.Header.Get(HeaderXRealIP) + if realIP != "" { + realIP = strings.TrimPrefix(realIP, "[") + realIP = strings.TrimSuffix(realIP, "]") + if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) { + return realIP + } + } + return extractIP(req) + } +} + +// ExtractIPFromXFFHeader extracts IP address using x-forwarded-for header. +// Use this if you put proxy which uses this header. +// This returns nearest untrustable IP. If all IPs are trustable, returns furthest one (i.e.: XFF[0]). +func ExtractIPFromXFFHeader(options ...TrustOption) IPExtractor { + checker := newIPChecker(options) + return func(req *http.Request) string { + directIP := extractIP(req) + xffs := req.Header[HeaderXForwardedFor] + if len(xffs) == 0 { + return directIP + } + ips := append(strings.Split(strings.Join(xffs, ","), ","), directIP) + for i := len(ips) - 1; i >= 0; i-- { + ips[i] = strings.TrimSpace(ips[i]) + ips[i] = strings.TrimPrefix(ips[i], "[") + ips[i] = strings.TrimSuffix(ips[i], "]") + ip := net.ParseIP(ips[i]) + if ip == nil { + // Unable to parse IP; cannot trust entire records + return directIP + } + if !checker.trust(ip) { + return ip.String() + } + } + // All of the IPs are trusted; return first element because it is furthest from server (best effort strategy). + return strings.TrimSpace(ips[0]) + } +} diff --git a/ip_test.go b/ip_test.go new file mode 100644 index 000000000..29bf6afde --- /dev/null +++ b/ip_test.go @@ -0,0 +1,716 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "net" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func mustParseCIDR(s string) *net.IPNet { + _, IPNet, err := net.ParseCIDR(s) + if err != nil { + panic(err) + } + return IPNet +} + +func TestIPChecker_TrustOption(t *testing.T) { + var testCases = []struct { + name string + whenIP string + givenOptions []TrustOption + expect bool + }{ + { + name: "ip is within trust range, trusts additional private IPV6 network", + givenOptions: []TrustOption{ + TrustLoopback(false), + TrustLinkLocal(false), + TrustPrivateNet(false), + // this is private IPv6 ip + // CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48 + // Address: 2001:0db8:0000:0000:0000:0000:0000:0103 + // Range start: 2001:0db8:0000:0000:0000:0000:0000:0000 + // Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff + TrustIPRange(mustParseCIDR("2001:db8::103/48")), + }, + whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", + expect: true, + }, + { + name: "ip is within trust range, trusts additional private IPV6 network", + givenOptions: []TrustOption{ + TrustIPRange(mustParseCIDR("2001:db8::103/48")), + }, + whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", + expect: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker(tc.givenOptions) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustIPRange(t *testing.T) { + var testCases = []struct { + name string + givenRange string + whenIP string + expect bool + }{ + { + name: "ip is within trust range, IPV6 network range", + // CIDR Notation: 2001:0db8:0000:0000:0000:0000:0000:0000/48 + // Address: 2001:0db8:0000:0000:0000:0000:0000:0103 + // Range start: 2001:0db8:0000:0000:0000:0000:0000:0000 + // Range end: 2001:0db8:0000:ffff:ffff:ffff:ffff:ffff + givenRange: "2001:db8::103/48", + whenIP: "2001:0db8:0000:0000:0000:0000:0000:0103", + expect: true, + }, + { + name: "ip is outside (upper bounds) of trust range, IPV6 network range", + givenRange: "2001:db8::103/48", + whenIP: "2001:0db8:0001:0000:0000:0000:0000:0000", + expect: false, + }, + { + name: "ip is outside (lower bounds) of trust range, IPV6 network range", + givenRange: "2001:db8::103/48", + whenIP: "2001:0db7:ffff:ffff:ffff:ffff:ffff:ffff", + expect: false, + }, + { + name: "ip is within trust range, IPV4 network range", + // CIDR Notation: 8.8.8.8/24 + // Address: 8.8.8.8 + // Range start: 8.8.8.0 + // Range end: 8.8.8.255 + givenRange: "8.8.8.0/24", + whenIP: "8.8.8.8", + expect: true, + }, + { + name: "ip is within trust range, IPV4 network range", + // CIDR Notation: 8.8.8.8/24 + // Address: 8.8.8.8 + // Range start: 8.8.8.0 + // Range end: 8.8.8.255 + givenRange: "8.8.8.0/24", + whenIP: "8.8.8.8", + expect: true, + }, + { + name: "ip is outside (upper bounds) of trust range, IPV4 network range", + givenRange: "8.8.8.0/24", + whenIP: "8.8.9.0", + expect: false, + }, + { + name: "ip is outside (lower bounds) of trust range, IPV4 network range", + givenRange: "8.8.8.0/24", + whenIP: "8.8.7.255", + expect: false, + }, + { + name: "public ip, trust everything in IPV4 network range", + givenRange: "0.0.0.0/0", + whenIP: "8.8.8.8", + expect: true, + }, + { + name: "internal ip, trust everything in IPV4 network range", + givenRange: "0.0.0.0/0", + whenIP: "127.0.10.1", + expect: true, + }, + { + name: "public ip, trust everything in IPV6 network range", + givenRange: "::/0", + whenIP: "2a00:1450:4026:805::200e", + expect: true, + }, + { + name: "internal ip, trust everything in IPV6 network range", + givenRange: "::/0", + whenIP: "0:0:0:0:0:0:0:1", + expect: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cidr := mustParseCIDR(tc.givenRange) + checker := newIPChecker([]TrustOption{ + TrustLoopback(false), // disable to avoid interference + TrustLinkLocal(false), // disable to avoid interference + TrustPrivateNet(false), // disable to avoid interference + + TrustIPRange(cidr), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustPrivateNet(t *testing.T) { + var testCases = []struct { + name string + whenIP string + expect bool + }{ + { + name: "do not trust public IPv4 address", + whenIP: "8.8.8.8", + expect: false, + }, + { + name: "do not trust public IPv6 address", + whenIP: "2a00:1450:4026:805::200e", + expect: false, + }, + + { // Class A: 10.0.0.0 — 10.255.255.255 + name: "do not trust IPv4 just outside of class A (lower bounds)", + whenIP: "9.255.255.255", + expect: false, + }, + { + name: "do not trust IPv4 just outside of class A (upper bounds)", + whenIP: "11.0.0.0", + expect: false, + }, + { + name: "trust IPv4 of class A (lower bounds)", + whenIP: "10.0.0.0", + expect: true, + }, + { + name: "trust IPv4 of class A (upper bounds)", + whenIP: "10.255.255.255", + expect: true, + }, + + { // Class B: 172.16.0.0 — 172.31.255.255 + name: "do not trust IPv4 just outside of class B (lower bounds)", + whenIP: "172.15.255.255", + expect: false, + }, + { + name: "do not trust IPv4 just outside of class B (upper bounds)", + whenIP: "172.32.0.0", + expect: false, + }, + { + name: "trust IPv4 of class B (lower bounds)", + whenIP: "172.16.0.0", + expect: true, + }, + { + name: "trust IPv4 of class B (upper bounds)", + whenIP: "172.31.255.255", + expect: true, + }, + + { // Class C: 192.168.0.0 — 192.168.255.255 + name: "do not trust IPv4 just outside of class C (lower bounds)", + whenIP: "192.167.255.255", + expect: false, + }, + { + name: "do not trust IPv4 just outside of class C (upper bounds)", + whenIP: "192.169.0.0", + expect: false, + }, + { + name: "trust IPv4 of class C (lower bounds)", + whenIP: "192.168.0.0", + expect: true, + }, + { + name: "trust IPv4 of class C (upper bounds)", + whenIP: "192.168.255.255", + expect: true, + }, + + { // fc00::/7 address block = RFC 4193 Unique Local Addresses (ULA) + // splits the address block in two equally sized halves, fc00::/8 and fd00::/8. + // https://en.wikipedia.org/wiki/Unique_local_address + name: "trust IPv6 private address", + whenIP: "fdfc:3514:2cb3:4bd5::", + expect: true, + }, + { + name: "do not trust IPv6 just out of /fd (upper bounds)", + whenIP: "/fe00:0000:0000:0000:0000", + expect: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker([]TrustOption{ + TrustLoopback(false), // disable to avoid interference + TrustLinkLocal(false), // disable to avoid interference + + TrustPrivateNet(true), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustLinkLocal(t *testing.T) { + var testCases = []struct { + name string + whenIP string + expect bool + }{ + { + name: "trust link local IPv4 address (lower bounds)", + whenIP: "169.254.0.0", + expect: true, + }, + { + name: "trust link local IPv4 address (upper bounds)", + whenIP: "169.254.255.255", + expect: true, + }, + { + name: "do not trust link local IPv4 address (outside of lower bounds)", + whenIP: "169.253.255.255", + expect: false, + }, + { + name: "do not trust link local IPv4 address (outside of upper bounds)", + whenIP: "169.255.0.0", + expect: false, + }, + { + name: "trust link local IPv6 address ", + whenIP: "fe80::1", + expect: true, + }, + { + name: "do not trust link local IPv6 address ", + whenIP: "fec0::1", + expect: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker([]TrustOption{ + TrustLoopback(false), // disable to avoid interference + TrustPrivateNet(false), // disable to avoid interference + + TrustLinkLocal(true), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestTrustLoopback(t *testing.T) { + var testCases = []struct { + name string + whenIP string + expect bool + }{ + { + name: "trust IPv4 as localhost", + whenIP: "127.0.0.1", + expect: true, + }, + { + name: "trust IPv6 as localhost", + whenIP: "::1", + expect: true, + }, + { + name: "do not trust public ip as localhost", + whenIP: "8.8.8.8", + expect: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + checker := newIPChecker([]TrustOption{ + TrustLinkLocal(false), // disable to avoid interference + TrustPrivateNet(false), // disable to avoid interference + + TrustLoopback(true), + }) + + result := checker.trust(net.ParseIP(tc.whenIP)) + assert.Equal(t, tc.expect, result) + }) + } +} + +func TestExtractIPDirect(t *testing.T) { + var testCases = []struct { + name string + whenRequest http.Request + expectIP string + }{ + { + name: "request has no headers, extracts IP from request remote addr", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "remote addr is IP without port, extracts IP directly", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1", + }, + expectIP: "203.0.113.1", + }, + { + name: "remote addr is IPv6 without port, extracts IP directly", + whenRequest: http.Request{ + RemoteAddr: "2001:db8::1", + }, + expectIP: "2001:db8::1", + }, + { + name: "remote addr is IPv6 with port", + whenRequest: http.Request{ + RemoteAddr: "[2001:db8::1]:8080", + }, + expectIP: "2001:db8::1", + }, + { + name: "remote addr is invalid, returns empty string", + whenRequest: http.Request{ + RemoteAddr: "invalid-ip-format", + }, + expectIP: "", + }, + { + name: "request is from external IP has X-Real-Ip header, extractor still extracts IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.10"}, + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from internal IP and has Real-IP header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.10"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + { + name: "request is from external IP and has XFF + Real-IP header, extractor still extracts external IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.10"}, + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from internal IP and has XFF + Real-IP header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"127.0.0.1"}, + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + { + name: "request is from external IP and has XFF header, extractor still extracts external IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from internal IP and has XFF header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"192.0.2.106, 198.51.100.105, fc00::104, 2001:db8::103, 192.168.0.102, 169.254.0.101"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + { + name: "request is from internal IP and has INVALID XFF header, extractor still extracts internal IP from request remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"this.is.broken.lol, 169.254.0.101"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + extractedIP := ExtractIPDirect()(&tc.whenRequest) + assert.Equal(t, tc.expectIP, extractedIP) + }) + } +} + +func TestExtractIPFromRealIPHeader(t *testing.T) { + _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") + _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64") + + var testCases = []struct { + whenRequest http.Request + name string + expectIP string + givenTrustOptions []TrustOption + }{ + { + name: "request has no headers, extracts IP from request remote addr", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has INVALID external X-Real-Ip header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"xxx.yyy.zzz.ccc"}, // <-- this is invalid + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.199"}, // <-- this is untrusted + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"[2001:db8::113:199]"}, // <-- this is untrusted + }, + RemoteAddr: "[2001:db8::113:1]:8080", + }, + expectIP: "2001:db8::113:1", + }, + { + name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.199"}, + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.199", + }, + { + name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"[2001:db8::113:199]"}, + }, + RemoteAddr: "[2001:db8::113:1]:8080", + }, + expectIP: "2001:db8::113:199", + }, + { + name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"203.0.113.199"}, + HeaderXForwardedFor: []string{"203.0.113.198, 203.0.113.197"}, // <-- should not affect anything + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.199", + }, + { + name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", + givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" + TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64" + }, + whenRequest: http.Request{ + Header: http.Header{ + HeaderXRealIP: []string{"[2001:db8::113:199]"}, + HeaderXForwardedFor: []string{"[2001:db8::113:198], [2001:db8::113:197]"}, // <-- should not affect anything + }, + RemoteAddr: "[2001:db8::113:1]:8080", + }, + expectIP: "2001:db8::113:199", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + extractedIP := ExtractIPFromRealIPHeader(tc.givenTrustOptions...)(&tc.whenRequest) + assert.Equal(t, tc.expectIP, extractedIP) + }) + } +} + +func TestExtractIPFromXFFHeader(t *testing.T) { + _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") + _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64") + + var testCases = []struct { + whenRequest http.Request + name string + expectIP string + givenTrustOptions []TrustOption + }{ + { + name: "request has no headers, extracts IP from request remote addr", + whenRequest: http.Request{ + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request has INVALID external XFF header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"xxx.yyy.zzz.ccc, 127.0.0.2"}, // <-- this is invalid + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.1", + }, + { + name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"127.0.0.3, 127.0.0.2, 127.0.0.1"}, + }, + RemoteAddr: "127.0.0.1:8080", + }, + expectIP: "127.0.0.3", + }, + { + name: "request trusts all IPs in XFF header, extract IP from furthest in XFF chain", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"[fe80::3], [fe80::2], [fe80::1]"}, + }, + RemoteAddr: "[fe80::1]:8080", + }, + expectIP: "fe80::3", + }, + { + name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"203.0.113.199"}, // <-- this is untrusted + }, + RemoteAddr: "203.0.113.1:8080", + }, + expectIP: "203.0.113.1", + }, + { + name: "request is from external IP has valid + UNTRUSTED external XFF header, extract IP from remote addr", + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"[2001:db8::1]"}, // <-- this is untrusted + }, + RemoteAddr: "[2001:db8::2]:8080", + }, + expectIP: "2001:db8::2", + }, + { + name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header", + givenTrustOptions: []TrustOption{ + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" + }, + // from request its seems that request has been proxied through 6 servers. + // 1) 203.0.1.100 (this is external IP set by 203.0.100.100 which we do not trust - could be spoofed) + // 2) 203.0.100.100 (this is outside of our network but set by 203.0.113.199 which we trust to set correct IPs) + // 3) 203.0.113.199 (we trust, for example maybe our proxy from some other office) + // 4) 192.168.1.100 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products) + // 5) 127.0.0.1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing) + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"203.0.1.100, 203.0.100.100, 203.0.113.199, 192.168.1.100"}, + }, + RemoteAddr: "127.0.0.1:8080", // IP of proxy upstream of our APP + }, + expectIP: "203.0.100.100", // this is first trusted IP in XFF chain + }, + { + name: "request is from external IP is valid and has some IPs TRUSTED XFF header, extract IP from XFF header", + givenTrustOptions: []TrustOption{ + TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "2001:db8::/64" + }, + // from request its seems that request has been proxied through 6 servers. + // 1) 2001:db8:1::1:100 (this is external IP set by 2001:db8:2::100:100 which we do not trust - could be spoofed) + // 2) 2001:db8:2::100:100 (this is outside of our network but set by 2001:db8::113:199 which we trust to set correct IPs) + // 3) 2001:db8::113:199 (we trust, for example maybe our proxy from some other office) + // 4) fd12:3456:789a:1::1 (internal IP, some internal upstream loadbalancer ala SSL offloading with F5 products) + // 5) fe80::1 (is proxy on localhost. maybe we have Nginx in front of our Echo instance doing some routing) + whenRequest: http.Request{ + Header: http.Header{ + HeaderXForwardedFor: []string{"[2001:db8:1::1:100], [2001:db8:2::100:100], [2001:db8::113:199], [fd12:3456:789a:1::1]"}, + }, + RemoteAddr: "[fe80::1]:8080", // IP of proxy upstream of our APP + }, + expectIP: "2001:db8:2::100:100", // this is first trusted IP in XFF chain + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + extractedIP := ExtractIPFromXFFHeader(tc.givenTrustOptions...)(&tc.whenRequest) + assert.Equal(t, tc.expectIP, extractedIP) + }) + } +} diff --git a/json.go b/json.go new file mode 100644 index 000000000..a969ccb8c --- /dev/null +++ b/json.go @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "encoding/json" +) + +// DefaultJSONSerializer implements JSON encoding using encoding/json. +type DefaultJSONSerializer struct{} + +// Serialize converts an interface into a json and writes it to the response. +// You can optionally use the indent parameter to produce pretty JSONs. +func (d DefaultJSONSerializer) Serialize(c *Context, target any, indent string) error { + enc := json.NewEncoder(c.Response()) + if indent != "" { + enc.SetIndent("", indent) + } + return enc.Encode(target) +} + +// Deserialize reads a JSON from a request body and converts it into an interface. +func (d DefaultJSONSerializer) Deserialize(c *Context, target any) error { + if err := json.NewDecoder(c.Request().Body).Decode(target); err != nil { + return ErrBadRequest.Wrap(err) + } + return nil +} diff --git a/json_test.go b/json_test.go new file mode 100644 index 000000000..1804b3e82 --- /dev/null +++ b/json_test.go @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// Note this test is deliberately simple as there's not a lot to test. +// Just need to ensure it writes JSONs. The heavy work is done by the context methods. +func TestDefaultJSONCodec_Encode(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodPost, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Echo + assert.Equal(t, e, c.Echo()) + + // Request + assert.NotNil(t, c.Request()) + + // Response + assert.NotNil(t, c.Response()) + + //-------- + // Default JSON encoder + //-------- + + enc := new(DefaultJSONSerializer) + + err := enc.Serialize(c, user{ID: 1, Name: "Jon Snow"}, "") + if assert.NoError(t, err) { + assert.Equal(t, userJSON+"\n", rec.Body.String()) + } + + req = httptest.NewRequest(http.MethodPost, "/", nil) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + err = enc.Serialize(c, user{ID: 1, Name: "Jon Snow"}, " ") + if assert.NoError(t, err) { + assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) + } +} + +// Note this test is deliberately simple as there's not a lot to test. +// Just need to ensure it writes JSONs. The heavy work is done by the context methods. +func TestDefaultJSONCodec_Decode(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Echo + assert.Equal(t, e, c.Echo()) + + // Request + assert.NotNil(t, c.Request()) + + // Response + assert.NotNil(t, c.Response()) + + //-------- + // Default JSON encoder + //-------- + + enc := new(DefaultJSONSerializer) + + var u = user{} + err := enc.Deserialize(c, &u) + if assert.NoError(t, err) { + assert.Equal(t, u, user{ID: 1, Name: "Jon Snow"}) + } + + var userUnmarshalSyntaxError = user{} + req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(invalidContent)) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + err = enc.Deserialize(c, &userUnmarshalSyntaxError) + assert.IsType(t, &HTTPError{}, err) + assert.EqualError(t, err, "code=400, message=Bad Request, err=invalid character 'i' looking for beginning of value") + + var userUnmarshalTypeError = struct { + ID string `json:"id"` + Name string `json:"name"` + }{} + + req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + err = enc.Deserialize(c, &userUnmarshalTypeError) + assert.IsType(t, &HTTPError{}, err) + assert.EqualError(t, err, "code=400, message=Bad Request, err=json: cannot unmarshal number into Go struct field .id of type string") + +} diff --git a/log.go b/log.go deleted file mode 100644 index 3f8de5904..000000000 --- a/log.go +++ /dev/null @@ -1,41 +0,0 @@ -package echo - -import ( - "io" - - "github.com/labstack/gommon/log" -) - -type ( - // Logger defines the logging interface. - Logger interface { - Output() io.Writer - SetOutput(w io.Writer) - Prefix() string - SetPrefix(p string) - Level() log.Lvl - SetLevel(v log.Lvl) - SetHeader(h string) - Print(i ...interface{}) - Printf(format string, args ...interface{}) - Printj(j log.JSON) - Debug(i ...interface{}) - Debugf(format string, args ...interface{}) - Debugj(j log.JSON) - Info(i ...interface{}) - Infof(format string, args ...interface{}) - Infoj(j log.JSON) - Warn(i ...interface{}) - Warnf(format string, args ...interface{}) - Warnj(j log.JSON) - Error(i ...interface{}) - Errorf(format string, args ...interface{}) - Errorj(j log.JSON) - Fatal(i ...interface{}) - Fatalj(j log.JSON) - Fatalf(format string, args ...interface{}) - Panic(i ...interface{}) - Panicj(j log.JSON) - Panicf(format string, args ...interface{}) - } -) diff --git a/middleware/DEVELOPMENT.md b/middleware/DEVELOPMENT.md new file mode 100644 index 000000000..77cb226dd --- /dev/null +++ b/middleware/DEVELOPMENT.md @@ -0,0 +1,11 @@ +# Development Guidelines for middlewares + +## Best practices: + +* Do not use `panic` in middleware creator functions in case of invalid configuration. +* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead + because previous middlewares up in call chain could have logic for dealing with returned errors. +* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they + want to create middleware with panics or with returning errors on configuration errors. +* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same. + diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 76ba24206..e0a284c67 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -1,106 +1,156 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( + "bytes" + "cmp" "encoding/base64" + "errors" "strconv" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // BasicAuthConfig defines the config for BasicAuth middleware. - BasicAuthConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// BasicAuthConfig defines the config for BasicAuthWithConfig middleware. +// +// SECURITY: The Validator function is responsible for securely comparing credentials. +// See BasicAuthValidator documentation for guidance on preventing timing attacks. +type BasicAuthConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Validator is a function to validate BasicAuth credentials. - // Required. - Validator BasicAuthValidator + // Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers + // this function would be called once for each header until first valid result is returned + // Required. + Validator BasicAuthValidator - // Realm is a string to define realm attribute of BasicAuth. - // Default value "Restricted". - Realm string - } + // Realm is a string to define realm attribute of BasicAuthWithConfig. + // Default value "Restricted". + Realm string - // BasicAuthValidator defines a function to validate BasicAuth credentials. - BasicAuthValidator func(string, string, echo.Context) (bool, error) -) + // AllowedCheckLimit set how many headers are allowed to be checked. This is useful + // environments like corporate test environments with application proxies restricting + // access to environment with their own auth scheme. + // Defaults to 1. + AllowedCheckLimit uint +} + +// BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials. +// +// SECURITY WARNING: To prevent timing attacks that could allow attackers to enumerate +// valid usernames or passwords, validator implementations MUST use constant-time +// comparison for credential checking. Use crypto/subtle.ConstantTimeCompare instead +// of standard string equality (==) or switch statements. +// +// Example of SECURE implementation: +// +// import "crypto/subtle" +// +// validator := func(c *echo.Context, username, password string) (bool, error) { +// // Fetch expected credentials from database/config +// expectedUser := "admin" +// expectedPass := "secretpassword" +// +// // Use constant-time comparison to prevent timing attacks +// userMatch := subtle.ConstantTimeCompare([]byte(username), []byte(expectedUser)) == 1 +// passMatch := subtle.ConstantTimeCompare([]byte(password), []byte(expectedPass)) == 1 +// +// if userMatch && passMatch { +// return true, nil +// } +// return false, nil +// } +// +// Example of INSECURE implementation (DO NOT USE): +// +// // VULNERABLE TO TIMING ATTACKS - DO NOT USE +// validator := func(c *echo.Context, username, password string) (bool, error) { +// if username == "admin" && password == "secret" { // Timing leak! +// return true, nil +// } +// return false, nil +// } +type BasicAuthValidator func(c *echo.Context, user string, password string) (bool, error) const ( basic = "basic" defaultRealm = "Restricted" ) -var ( - // DefaultBasicAuthConfig is the default BasicAuth middleware config. - DefaultBasicAuthConfig = BasicAuthConfig{ - Skipper: DefaultSkipper, - Realm: defaultRealm, - } -) - // BasicAuth returns an BasicAuth middleware. // // For valid credentials it calls the next handler. // For missing or invalid credentials, it sends "401 - Unauthorized" response. func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { - c := DefaultBasicAuthConfig - c.Validator = fn - return BasicAuthWithConfig(c) + return BasicAuthWithConfig(BasicAuthConfig{Validator: fn}) } -// BasicAuthWithConfig returns an BasicAuth middleware with config. -// See `BasicAuth()`. +// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config. func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration +func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Validator == nil { - panic("echo: basic-auth middleware requires a validator function") + return nil, errors.New("echo basic-auth middleware requires a validator function") } if config.Skipper == nil { - config.Skipper = DefaultBasicAuthConfig.Skipper + config.Skipper = DefaultSkipper } - if config.Realm == "" { - config.Realm = defaultRealm + realm := defaultRealm + if config.Realm != "" { + realm = config.Realm } + realm = strconv.Quote(realm) + limit := cmp.Or(config.AllowedCheckLimit, 1) return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } - auth := c.Request().Header.Get(echo.HeaderAuthorization) + var lastError error l := len(basic) + i := uint(0) + for _, auth := range c.Request().Header[echo.HeaderAuthorization] { + if i >= limit { + break + } + if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) { + continue + } + i++ - if len(auth) > l+1 && strings.ToLower(auth[:l]) == basic { - b, err := base64.StdEncoding.DecodeString(auth[l+1:]) - if err != nil { - return err + // Invalid base64 shouldn't be treated as error + // instead should be treated as invalid client input + b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:]) + if errDecode != nil { + lastError = echo.ErrBadRequest.Wrap(errDecode) + continue } - cred := string(b) - for i := 0; i < len(cred); i++ { - if cred[i] == ':' { - // Verify credentials - valid, err := config.Validator(cred[:i], cred[i+1:], c) - if err != nil { - return err - } else if valid { - return next(c) - } - break + idx := bytes.IndexByte(b, ':') + if idx >= 0 { + valid, errValidate := config.Validator(c, string(b[:idx]), string(b[idx+1:])) + if errValidate != nil { + lastError = errValidate + } else if valid { + return next(c) } } } - realm := defaultRealm - if config.Realm != defaultRealm { - realm = strconv.Quote(config.Realm) + if lastError != nil { + return lastError } // Need to return `401` for browsers to pop-up login box. c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm) return echo.ErrUnauthorized } - } + }, nil } diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 76039db0a..42386354f 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -1,71 +1,240 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( + "crypto/subtle" "encoding/base64" + "errors" "net/http" "net/http/httptest" "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestBasicAuth(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - c := e.NewContext(req, res) - f := func(u, p string, c echo.Context) (bool, error) { - if u == "joe" && p == "secret" { + validatorFunc := func(c *echo.Context, u, p string) (bool, error) { + // Use constant-time comparison to prevent timing attacks + userMatch := subtle.ConstantTimeCompare([]byte(u), []byte("joe")) == 1 + passMatch := subtle.ConstantTimeCompare([]byte(p), []byte("secret")) == 1 + + if userMatch && passMatch { return true, nil } + + // Special case for testing error handling + if u == "error" { + return false, errors.New(p) + } + return false, nil } - h := BasicAuth(f)(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) + defaultConfig := BasicAuthConfig{Validator: validatorFunc} + + var testCases = []struct { + name string + givenConfig BasicAuthConfig + whenAuth []string + expectHeader string + expectErr string + }{ + { + name: "ok", + givenConfig: defaultConfig, + whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "ok, multiple", + givenConfig: BasicAuthConfig{Validator: validatorFunc, AllowedCheckLimit: 2}, + whenAuth: []string{ + "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), + basic + " NOT_BASE64", + basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + }, + }, + { + name: "nok, multiple, valid out of limit", + givenConfig: BasicAuthConfig{Validator: validatorFunc, AllowedCheckLimit: 1}, + whenAuth: []string{ + "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), + basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid_password")), + // limit only check first and should not check auth below + basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + }, + expectHeader: basic + ` realm="Restricted"`, + expectErr: "Unauthorized", + }, + { + name: "nok, invalid Authorization header", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + expectHeader: basic + ` realm="Restricted"`, + expectErr: "Unauthorized", + }, + { + name: "nok, not base64 Authorization header", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"}, + expectErr: "code=400, message=Bad Request, err=illegal base64 data at input byte 3", + }, + { + name: "nok, missing Authorization header", + givenConfig: defaultConfig, + expectHeader: basic + ` realm="Restricted"`, + expectErr: "Unauthorized", + }, + { + name: "ok, realm", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "ok, realm, case-insensitive header scheme", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "nok, realm, invalid Authorization header", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + expectHeader: basic + ` realm="someRealm"`, + expectErr: "Unauthorized", + }, + { + name: "nok, validator func returns an error", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))}, + expectErr: "my_error", + }, + { + name: "ok, skipped", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c *echo.Context) bool { + return true + }}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + c := e.NewContext(req, res) + + config := tc.givenConfig + + mw, err := config.ToMiddleware() + assert.NoError(t, err) - assert := assert.New(t) + h := mw(func(c *echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) - // Valid credentials - auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + if len(tc.whenAuth) != 0 { + for _, a := range tc.whenAuth { + req.Header.Add(echo.HeaderAuthorization, a) + } + } + err = h(c) - h = BasicAuthWithConfig(BasicAuthConfig{ - Skipper: nil, - Validator: f, - Realm: "someRealm", - })(func(c echo.Context) error { - return c.String(http.StatusOK, "test") + if tc.expectErr != "" { + assert.Equal(t, http.StatusOK, res.Code) + assert.EqualError(t, err, tc.expectErr) + } else { + assert.Equal(t, http.StatusTeapot, res.Code) + assert.NoError(t, err) + } + if tc.expectHeader != "" { + assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate)) + } + }) + } +} + +func TestBasicAuth_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BasicAuth(nil) + assert.NotNil(t, mw) }) - // Valid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) - - // Case-insensitive header scheme - auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) - - // Invalid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")) - req.Header.Set(echo.HeaderAuthorization, auth) - he := h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) - assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate)) - - // Missing Authorization header - req.Header.Del(echo.HeaderAuthorization) - he = h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) - - // Invalid Authorization header - auth = base64.StdEncoding.EncodeToString([]byte("invalid")) - req.Header.Set(echo.HeaderAuthorization, auth) - he = h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) + mw := BasicAuth(func(c *echo.Context, user string, password string) (bool, error) { + return true, nil + }) + assert.NotNil(t, mw) +} + +func TestBasicAuthWithConfig_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil}) + assert.NotNil(t, mw) + }) + + mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c *echo.Context, user string, password string) (bool, error) { + return true, nil + }}) + assert.NotNil(t, mw) +} + +func TestBasicAuthRealm(t *testing.T) { + e := echo.New() + mockValidator := func(c *echo.Context, u, p string) (bool, error) { + return false, nil // Always fail to trigger WWW-Authenticate header + } + + tests := []struct { + name string + realm string + expectedAuth string + }{ + { + name: "Default realm", + realm: "Restricted", + expectedAuth: `basic realm="Restricted"`, + }, + { + name: "Custom realm", + realm: "My API", + expectedAuth: `basic realm="My API"`, + }, + { + name: "Realm with special characters", + realm: `Realm with "quotes" and \backslashes`, + expectedAuth: `basic realm="Realm with \"quotes\" and \\backslashes"`, + }, + { + name: "Empty realm (falls back to default)", + realm: "", + expectedAuth: `basic realm="Restricted"`, + }, + { + name: "Realm with unicode", + realm: "测试领域", + expectedAuth: `basic realm="测试领域"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + c := e.NewContext(req, res) + + h := BasicAuthWithConfig(BasicAuthConfig{ + Validator: mockValidator, + Realm: tt.realm, + })(func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + err := h(c) + + assert.Equal(t, echo.ErrUnauthorized, err) + assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate)) + }) + } } diff --git a/middleware/body_dump.go b/middleware/body_dump.go index ebd0d0ab2..d5c823c9b 100644 --- a/middleware/body_dump.go +++ b/middleware/body_dump.go @@ -1,93 +1,146 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( "bufio" "bytes" + "errors" "io" - "io/ioutil" "net" "net/http" + "sync" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // BodyDumpConfig defines the config for BodyDump middleware. - BodyDumpConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Handler receives request and response payload. - // Required. - Handler BodyDumpHandler - } - - // BodyDumpHandler receives the request and response payload. - BodyDumpHandler func(echo.Context, []byte, []byte) +// BodyDumpConfig defines the config for BodyDump middleware. +type BodyDumpConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Handler receives request, response payloads and handler error if there are any. + // Required. + Handler BodyDumpHandler + + // MaxRequestBytes limits how much of the request body to dump. + // If the request body exceeds this limit, only the first MaxRequestBytes + // are dumped. The handler callback receives truncated data. + // Default: 5 * MB (5,242,880 bytes) + // Set to -1 to disable limits (not recommended in production). + MaxRequestBytes int64 + + // MaxResponseBytes limits how much of the response body to dump. + // If the response body exceeds this limit, only the first MaxResponseBytes + // are dumped. The handler callback receives truncated data. + // Default: 5 * MB (5,242,880 bytes) + // Set to -1 to disable limits (not recommended in production). + MaxResponseBytes int64 +} - bodyDumpResponseWriter struct { - io.Writer - http.ResponseWriter - } -) +// BodyDumpHandler receives the request and response payload. +type BodyDumpHandler func(c *echo.Context, reqBody []byte, resBody []byte, err error) -var ( - // DefaultBodyDumpConfig is the default BodyDump middleware config. - DefaultBodyDumpConfig = BodyDumpConfig{ - Skipper: DefaultSkipper, - } -) +type bodyDumpResponseWriter struct { + io.Writer + http.ResponseWriter +} // BodyDump returns a BodyDump middleware. // // BodyDump middleware captures the request and response payload and calls the // registered handler. +// +// SECURITY: By default, this limits dumped bodies to 5MB to prevent memory exhaustion +// attacks. To customize limits, use BodyDumpWithConfig. To disable limits (not recommended +// in production), explicitly set MaxRequestBytes and MaxResponseBytes to -1. func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc { - c := DefaultBodyDumpConfig - c.Handler = handler - return BodyDumpWithConfig(c) + return BodyDumpWithConfig(BodyDumpConfig{Handler: handler}) } // BodyDumpWithConfig returns a BodyDump middleware with config. // See: `BodyDump()`. +// +// SECURITY: If MaxRequestBytes and MaxResponseBytes are not set (zero values), they default +// to 5MB each to prevent DoS attacks via large payloads. Set them explicitly to -1 to disable +// limits if needed for your use case. func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration +func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Handler == nil { - panic("echo: body-dump middleware requires a handler function") + return nil, errors.New("echo body-dump middleware requires a handler function") } if config.Skipper == nil { - config.Skipper = DefaultBodyDumpConfig.Skipper + config.Skipper = DefaultSkipper + } + if config.MaxRequestBytes == 0 { + config.MaxRequestBytes = 5 * MB + } + if config.MaxResponseBytes == 0 { + config.MaxResponseBytes = 5 * MB } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } - // Request - reqBody := []byte{} - if c.Request().Body != nil { // Read - reqBody, _ = ioutil.ReadAll(c.Request().Body) - } - c.Request().Body = ioutil.NopCloser(bytes.NewBuffer(reqBody)) // Reset + reqBuf := bodyDumpBufferPool.Get().(*bytes.Buffer) + reqBuf.Reset() + defer bodyDumpBufferPool.Put(reqBuf) - // Response - resBody := new(bytes.Buffer) - mw := io.MultiWriter(c.Response().Writer, resBody) - writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer} - c.Response().Writer = writer + var bodyReader io.Reader = c.Request().Body + if config.MaxRequestBytes > 0 { + bodyReader = io.LimitReader(c.Request().Body, config.MaxRequestBytes) + } + _, readErr := io.Copy(reqBuf, bodyReader) + if readErr != nil && readErr != io.EOF { + return readErr + } + if config.MaxRequestBytes > 0 { + // Drain any remaining body data to prevent connection issues + _, _ = io.Copy(io.Discard, c.Request().Body) + _ = c.Request().Body.Close() + } - if err = next(c); err != nil { - c.Error(err) + reqBody := make([]byte, reqBuf.Len()) + copy(reqBody, reqBuf.Bytes()) + c.Request().Body = io.NopCloser(bytes.NewReader(reqBody)) + + // response part + resBuf := bodyDumpBufferPool.Get().(*bytes.Buffer) + resBuf.Reset() + defer bodyDumpBufferPool.Put(resBuf) + + var respWriter io.Writer + if config.MaxResponseBytes > 0 { + respWriter = &limitedWriter{ + response: c.Response(), + dumpBuf: resBuf, + limit: config.MaxResponseBytes, + } + } else { + respWriter = io.MultiWriter(c.Response(), resBuf) + } + writer := &bodyDumpResponseWriter{ + Writer: respWriter, + ResponseWriter: c.Response(), } + c.SetResponse(writer) + + err := next(c) // Callback - config.Handler(c, reqBody, resBody.Bytes()) + config.Handler(c, reqBody, resBuf.Bytes(), err) - return + return err } - } + }, nil } func (w *bodyDumpResponseWriter) WriteHeader(code int) { @@ -99,9 +152,50 @@ func (w *bodyDumpResponseWriter) Write(b []byte) (int, error) { } func (w *bodyDumpResponseWriter) Flush() { - w.ResponseWriter.(http.Flusher).Flush() + err := http.NewResponseController(w.ResponseWriter).Flush() + if err != nil && errors.Is(err, http.ErrNotSupported) { + panic(errors.New("response writer flushing is not supported")) + } } func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return w.ResponseWriter.(http.Hijacker).Hijack() + return http.NewResponseController(w.ResponseWriter).Hijack() +} + +func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + +var bodyDumpBufferPool = sync.Pool{ + New: func() any { + return new(bytes.Buffer) + }, +} + +type limitedWriter struct { + response http.ResponseWriter + dumpBuf *bytes.Buffer + dumped int64 + limit int64 +} + +func (w *limitedWriter) Write(b []byte) (n int, err error) { + // Always write full data to actual response (don't truncate client response) + n, err = w.response.Write(b) + if err != nil { + return n, err + } + + // Write to dump buffer only up to limit + if w.dumped < w.limit { + remaining := w.limit - w.dumped + toDump := int64(n) + if toDump > remaining { + toDump = remaining + } + w.dumpBuf.Write(b[:toDump]) + w.dumped += toDump + } + + return n, nil } diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go index e6e00f726..f493e75c8 100644 --- a/middleware/body_dump_test.go +++ b/middleware/body_dump_test.go @@ -1,14 +1,17 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( "errors" - "io/ioutil" + "io" "net/http" "net/http/httptest" "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -18,8 +21,8 @@ func TestBodyDump(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := func(c echo.Context) error { - body, err := ioutil.ReadAll(c.Request().Body) + h := func(c *echo.Context) error { + body, err := io.ReadAll(c.Request().Body) if err != nil { return err } @@ -28,64 +31,551 @@ func TestBodyDump(t *testing.T) { requestBody := "" responseBody := "" - mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) { + mw, err := BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { requestBody = string(reqBody) responseBody = string(resBody) - }) - - assert := assert.New(t) + }}.ToMiddleware() + assert.NoError(t, err) - if assert.NoError(mw(h)(c)) { - assert.Equal(requestBody, hw) - assert.Equal(responseBody, hw) - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(hw, rec.Body.String()) + if assert.NoError(t, mw(h)(c)) { + assert.Equal(t, requestBody, hw) + assert.Equal(t, responseBody, hw) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.String()) } - // Must set default skipper - BodyDumpWithConfig(BodyDumpConfig{ - Skipper: nil, - Handler: func(c echo.Context, reqBody, resBody []byte) { - requestBody = string(reqBody) - responseBody = string(resBody) +} + +func TestBodyDump_skipper(t *testing.T) { + e := echo.New() + + isCalled := false + mw, err := BodyDumpConfig{ + Skipper: func(c *echo.Context) bool { + return true }, - }) + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + isCalled = true + }, + }.ToMiddleware() + assert.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c *echo.Context) error { + return errors.New("some error") + } + + err = mw(h)(c) + assert.EqualError(t, err, "some error") + assert.False(t, isCalled) } -func TestBodyDumpFails(t *testing.T) { +func TestBodyDump_fails(t *testing.T) { e := echo.New() hw := "Hello, World!" req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := func(c echo.Context) error { + h := func(c *echo.Context) error { return errors.New("some error") } - mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {}) + mw, err := BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}}.ToMiddleware() + assert.NoError(t, err) - if !assert.Error(t, mw(h)(c)) { - t.FailNow() - } + err = mw(h)(c) + assert.EqualError(t, err, "some error") + assert.Equal(t, http.StatusOK, rec.Code) + +} +func TestBodyDumpWithConfig_panic(t *testing.T) { assert.Panics(t, func() { - mw = BodyDumpWithConfig(BodyDumpConfig{ + mw := BodyDumpWithConfig(BodyDumpConfig{ Skipper: nil, Handler: nil, }) + assert.NotNil(t, mw) }) assert.NotPanics(t, func() { - mw = BodyDumpWithConfig(BodyDumpConfig{ - Skipper: func(c echo.Context) bool { - return true - }, - Handler: func(c echo.Context, reqBody, resBody []byte) { - }, - }) + mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}}) + assert.NotNil(t, mw) + }) +} - if !assert.Error(t, mw(h)(c)) { - t.FailNow() - } +func TestBodyDump_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BodyDump(nil) + assert.NotNil(t, mw) + }) + + assert.NotPanics(t, func() { + BodyDump(func(c *echo.Context, reqBody, resBody []byte, err error) {}) + }) +} + +func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) { + bdrw := bodyDumpResponseWriter{ + ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush + } + assert.PanicsWithError(t, "response writer flushing is not supported", func() { + bdrw.Flush() + }) +} + +func TestBodyDumpResponseWriter_CanFlush(t *testing.T) { + trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} + bdrw := bodyDumpResponseWriter{ + ResponseWriter: &trwu, + } + bdrw.Flush() + assert.Equal(t, 1, trwu.unwrapCalled) +} + +func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) { + trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()} + bdrw := bodyDumpResponseWriter{ + ResponseWriter: trwu, + } + result := bdrw.Unwrap() + assert.Equal(t, trwu, result) +} + +func TestBodyDumpResponseWriter_CanHijack(t *testing.T) { + trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} + bdrw := bodyDumpResponseWriter{ + ResponseWriter: &trwu, // this RW supports hijacking through unwrapping + } + _, _, err := bdrw.Hijack() + assert.EqualError(t, err, "can hijack") +} + +func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) { + trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()} + bdrw := bodyDumpResponseWriter{ + ResponseWriter: &trwu, // this RW supports hijacking through unwrapping + } + _, _, err := bdrw.Hijack() + assert.EqualError(t, err, "feature not supported") +} + +func TestBodyDump_ReadError(t *testing.T) { + e := echo.New() + + // Create a reader that fails during read + failingReader := &failingReadCloser{ + data: []byte("partial data"), + failAt: 7, // Fail after 7 bytes + failWith: errors.New("connection reset"), + } + + req := httptest.NewRequest(http.MethodPost, "/", failingReader) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + // This handler should not be reached if body read fails + body, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(body)) + } + + requestBodyReceived := "" + mw := BodyDump(func(c *echo.Context, reqBody, resBody []byte, err error) { + requestBodyReceived = string(reqBody) }) + + err := mw(h)(c) + + // Verify error is propagated + assert.Error(t, err) + assert.Contains(t, err.Error(), "connection reset") + + // Verify handler was not executed (callback wouldn't have received data) + assert.Empty(t, requestBodyReceived) +} + +// failingReadCloser is a helper type for testing read errors +type failingReadCloser struct { + data []byte + pos int + failAt int + failWith error +} + +func (f *failingReadCloser) Read(p []byte) (n int, err error) { + if f.pos >= f.failAt { + return 0, f.failWith + } + + n = copy(p, f.data[f.pos:]) + f.pos += n + + if f.pos >= f.failAt { + return n, f.failWith + } + + return n, nil +} + +func (f *failingReadCloser) Close() error { + return nil +} + +func TestBodyDump_RequestWithinLimit(t *testing.T) { + e := echo.New() + requestData := "Hello, World!" + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + body, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(body)) + } + + requestBodyDumped := "" + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + requestBodyDumped = string(reqBody) + }, + MaxRequestBytes: 1 * MB, // 1MB limit + MaxResponseBytes: 1 * MB, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, requestData, requestBodyDumped, "Small request should be fully dumped") + assert.Equal(t, requestData, rec.Body.String(), "Handler should receive full request") +} + +func TestBodyDump_RequestExceedsLimit(t *testing.T) { + e := echo.New() + // Create 2KB of data but limit to 1KB + largeData := strings.Repeat("A", 2*1024) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeData)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + body, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(body)) + } + + requestBodyDumped := "" + limit := int64(1024) // 1KB limit + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + requestBodyDumped = string(reqBody) + }, + MaxRequestBytes: limit, + MaxResponseBytes: 1 * MB, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, int(limit), len(requestBodyDumped), "Dumped request should be truncated to limit") + assert.Equal(t, strings.Repeat("A", 1024), requestBodyDumped, "Dumped data should match first N bytes") + // Handler should receive truncated data (what was dumped) + assert.Equal(t, strings.Repeat("A", 1024), rec.Body.String()) +} + +func TestBodyDump_RequestAtExactLimit(t *testing.T) { + e := echo.New() + exactData := strings.Repeat("B", 1024) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(exactData)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + body, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(body)) + } + + requestBodyDumped := "" + limit := int64(1024) + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + requestBodyDumped = string(reqBody) + }, + MaxRequestBytes: limit, + MaxResponseBytes: 1 * MB, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, int(limit), len(requestBodyDumped), "Exact limit should dump full data") + assert.Equal(t, exactData, requestBodyDumped) +} + +func TestBodyDump_ResponseWithinLimit(t *testing.T) { + e := echo.New() + responseData := "Response data" + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + return c.String(http.StatusOK, responseData) + } + + responseBodyDumped := "" + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + responseBodyDumped = string(resBody) + }, + MaxRequestBytes: 1 * MB, + MaxResponseBytes: 1 * MB, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, responseData, responseBodyDumped, "Small response should be fully dumped") + assert.Equal(t, responseData, rec.Body.String(), "Client should receive full response") +} + +func TestBodyDump_ResponseExceedsLimit(t *testing.T) { + e := echo.New() + largeResponse := strings.Repeat("X", 2*1024) // 2KB + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + return c.String(http.StatusOK, largeResponse) + } + + responseBodyDumped := "" + limit := int64(1024) // 1KB limit + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + responseBodyDumped = string(resBody) + }, + MaxRequestBytes: 1 * MB, + MaxResponseBytes: limit, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + // Dump should be truncated + assert.Equal(t, int(limit), len(responseBodyDumped), "Dumped response should be truncated to limit") + assert.Equal(t, strings.Repeat("X", 1024), responseBodyDumped) + // Client should still receive full response! + assert.Equal(t, largeResponse, rec.Body.String(), "Client must receive full response despite dump truncation") +} + +func TestBodyDump_ClientGetsFullResponse(t *testing.T) { + e := echo.New() + // This is critical - even when dump is limited, client gets everything + largeResponse := strings.Repeat("DATA", 500) // 2KB + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + // Write response in chunks to test incremental writes + for i := 0; i < 4; i++ { + c.Response().Write([]byte(strings.Repeat("DATA", 125))) + } + return nil + } + + responseBodyDumped := "" + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + responseBodyDumped = string(resBody) + }, + MaxRequestBytes: 1 * MB, + MaxResponseBytes: 512, // Very small limit + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, 512, len(responseBodyDumped), "Dump should be limited") + assert.Equal(t, largeResponse, rec.Body.String(), "Client must get full response") +} + +func TestBodyDump_BothLimitsSimultaneous(t *testing.T) { + e := echo.New() + largeRequest := strings.Repeat("Q", 2*1024) + largeResponse := strings.Repeat("R", 2*1024) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeRequest)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + io.ReadAll(c.Request().Body) // Consume request + return c.String(http.StatusOK, largeResponse) + } + + requestBodyDumped := "" + responseBodyDumped := "" + limit := int64(1024) + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + requestBodyDumped = string(reqBody) + responseBodyDumped = string(resBody) + }, + MaxRequestBytes: limit, + MaxResponseBytes: limit, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, int(limit), len(requestBodyDumped), "Request dump should be limited") + assert.Equal(t, int(limit), len(responseBodyDumped), "Response dump should be limited") +} + +func TestBodyDump_DefaultConfig(t *testing.T) { + e := echo.New() + smallData := "test" + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(smallData)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + body, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(body)) + } + + requestBodyDumped := "" + // Use default config which should have 1MB limits + config := BodyDumpConfig{} + config.Handler = func(c *echo.Context, reqBody, resBody []byte, err error) { + requestBodyDumped = string(reqBody) + } + mw, err := config.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, smallData, requestBodyDumped) +} + +func TestBodyDump_LargeRequestDosPrevention(t *testing.T) { + e := echo.New() + // Simulate a very large request (10MB) that could cause OOM + largeSize := 10 * 1024 * 1024 // 10MB + largeData := strings.Repeat("M", largeSize) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeData)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + body, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(body)) + } + + requestBodyDumped := "" + limit := int64(1 * MB) // Only dump 1MB max + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + requestBodyDumped = string(reqBody) + }, + MaxRequestBytes: limit, + MaxResponseBytes: 1 * MB, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + // Verify only 1MB was dumped, not 10MB + assert.Equal(t, int(limit), len(requestBodyDumped), "Should only dump up to limit, preventing DoS") + assert.Less(t, len(requestBodyDumped), largeSize, "Dump should be much smaller than full request") +} + +func TestBodyDump_LargeResponseDosPrevention(t *testing.T) { + e := echo.New() + // Simulate a very large response (10MB) + largeSize := 10 * 1024 * 1024 // 10MB + largeResponse := strings.Repeat("R", largeSize) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + return c.String(http.StatusOK, largeResponse) + } + + responseBodyDumped := "" + limit := int64(1 * MB) // Only dump 1MB max + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + responseBodyDumped = string(resBody) + }, + MaxRequestBytes: 1 * MB, + MaxResponseBytes: limit, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + // Verify only 1MB was dumped, not 10MB + assert.Equal(t, int(limit), len(responseBodyDumped), "Should only dump up to limit, preventing DoS") + assert.Less(t, len(responseBodyDumped), largeSize, "Dump should be much smaller than full response") + // Client still gets full response + assert.Equal(t, largeSize, rec.Body.Len(), "Client must receive full response") +} + +func BenchmarkBodyDump_WithLimit(b *testing.B) { + e := echo.New() + requestData := strings.Repeat("data", 256) // 1KB + responseData := strings.Repeat("resp", 256) // 1KB + + h := func(c *echo.Context) error { + io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, responseData) + } + + mw, _ := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + // Simulate logging + _ = len(reqBody) + len(resBody) + }, + MaxRequestBytes: 1 * MB, + MaxResponseBytes: 1 * MB, + }.ToMiddleware() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + mw(h)(c) + } +} + +func BenchmarkBodyDump_BufferPooling(b *testing.B) { + e := echo.New() + requestData := strings.Repeat("x", 1024) + responseData := "response" + + h := func(c *echo.Context) error { + io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, responseData) + } + + mw, _ := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}, + MaxRequestBytes: 1 * MB, + MaxResponseBytes: 1 * MB, + }.ToMiddleware() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + mw(h)(c) + } } diff --git a/middleware/body_limit.go b/middleware/body_limit.go index b436bd595..4f1963e18 100644 --- a/middleware/body_limit.go +++ b/middleware/body_limit.go @@ -1,98 +1,89 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( - "fmt" "io" + "net/http" "sync" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/bytes" + "github.com/labstack/echo/v5" ) -type ( - // BodyLimitConfig defines the config for BodyLimit middleware. - BodyLimitConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Maximum allowed size for a request body, it can be specified - // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P. - Limit string `yaml:"limit"` - limit int64 - } +// BodyLimitConfig defines the config for BodyLimitWithConfig middleware. +type BodyLimitConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - limitedReader struct { - BodyLimitConfig - reader io.ReadCloser - read int64 - context echo.Context - } -) + // LimitBytes is maximum allowed size in bytes for a request body + LimitBytes int64 +} -var ( - // DefaultBodyLimitConfig is the default BodyLimit middleware config. - DefaultBodyLimitConfig = BodyLimitConfig{ - Skipper: DefaultSkipper, - } -) +type limitedReader struct { + BodyLimitConfig + reader io.ReadCloser + read int64 +} // BodyLimit returns a BodyLimit middleware. // -// BodyLimit middleware sets the maximum allowed size for a request body, if the -// size exceeds the configured limit, it sends "413 - Request Entity Too Large" -// response. The BodyLimit is determined based on both `Content-Length` request +// BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it +// sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request // header and actual content read, which makes it super secure. -// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M, -// G, T or P. -func BodyLimit(limit string) echo.MiddlewareFunc { - c := DefaultBodyLimitConfig - c.Limit = limit - return BodyLimitWithConfig(c) +func BodyLimit(limitBytes int64) echo.MiddlewareFunc { + return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes}) } -// BodyLimitWithConfig returns a BodyLimit middleware with config. -// See: `BodyLimit()`. +// BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for +// a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response. +// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which +// makes it super secure. func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration +func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultBodyLimitConfig.Skipper + config.Skipper = DefaultSkipper } - - limit, err := bytes.Parse(config.Limit) - if err != nil { - panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit)) + pool := sync.Pool{ + New: func() any { + return &limitedReader{BodyLimitConfig: config} + }, } - config.limit = limit - pool := limitedReaderPool(config) return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } - req := c.Request() // Based on content length - if req.ContentLength > config.limit { + if req.ContentLength > config.LimitBytes { return echo.ErrStatusRequestEntityTooLarge } // Based on content read - r := pool.Get().(*limitedReader) - r.Reset(req.Body, c) + r, ok := pool.Get().(*limitedReader) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, "invalid pool object") + } + r.Reset(req.Body) defer pool.Put(r) req.Body = r return next(c) } - } + }, nil } func (r *limitedReader) Read(b []byte) (n int, err error) { n, err = r.reader.Read(b) r.read += int64(n) - if r.read > r.limit { + if r.read > r.LimitBytes { return n, echo.ErrStatusRequestEntityTooLarge } return @@ -102,16 +93,7 @@ func (r *limitedReader) Close() error { return r.reader.Close() } -func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) { +func (r *limitedReader) Reset(reader io.ReadCloser) { r.reader = reader - r.context = context r.read = 0 } - -func limitedReaderPool(c BodyLimitConfig) sync.Pool { - return sync.Pool{ - New: func() interface{} { - return &limitedReader{BodyLimitConfig: c} - }, - } -} diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index 0e8642a06..5529f5d84 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -1,85 +1,166 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( "bytes" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) -func TestBodyLimit(t *testing.T) { +func TestBodyLimitConfig_ToMiddleware(t *testing.T) { e := echo.New() hw := []byte("Hello, World!") req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := func(c echo.Context) error { - body, err := ioutil.ReadAll(c.Request().Body) + h := func(c *echo.Context) error { + body, err := io.ReadAll(c.Request().Body) if err != nil { return err } return c.String(http.StatusOK, string(body)) } - assert := assert.New(t) - // Based on content length (within limit) - if assert.NoError(BodyLimit("2M")(h)(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal(hw, rec.Body.Bytes()) + mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) } // Based on content read (overlimit) - he := BodyLimit("2B")(h)(c).(*echo.HTTPError) - assert.Equal(http.StatusRequestEntityTooLarge, he.Code) + mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() + assert.NoError(t, err) + he := mw(h)(c).(echo.HTTPStatusCoder) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) // Based on content read (within limit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) - if assert.NoError(BodyLimit("2M")(h)(c)) { - assert.Equal(http.StatusOK, rec.Code) - assert.Equal("Hello, World!", rec.Body.String()) - } + + mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware() + assert.NoError(t, err) + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "Hello, World!", rec.Body.String()) // Based on content read (overlimit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) - he = BodyLimit("2B")(h)(c).(*echo.HTTPError) - assert.Equal(http.StatusRequestEntityTooLarge, he.Code) + mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() + assert.NoError(t, err) + he = mw(h)(c).(echo.HTTPStatusCoder) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) } func TestBodyLimitReader(t *testing.T) { hw := []byte("Hello, World!") - e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) - rec := httptest.NewRecorder() config := BodyLimitConfig{ - Skipper: DefaultSkipper, - Limit: "2B", - limit: 2, + Skipper: DefaultSkipper, + LimitBytes: 2, } reader := &limitedReader{ BodyLimitConfig: config, - reader: ioutil.NopCloser(bytes.NewReader(hw)), - context: e.NewContext(req, rec), + reader: io.NopCloser(bytes.NewReader(hw)), } // read all should return ErrStatusRequestEntityTooLarge - _, err := ioutil.ReadAll(reader) - he := err.(*echo.HTTPError) - assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) + _, err := io.ReadAll(reader) + he := err.(echo.HTTPStatusCoder) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) // reset reader and read two bytes must succeed bt := make([]byte, 2) - reader.Reset(ioutil.NopCloser(bytes.NewReader(hw)), e.NewContext(req, rec)) + reader.Reset(io.NopCloser(bytes.NewReader(hw))) n, err := reader.Read(bt) assert.Equal(t, 2, n) assert.Equal(t, nil, err) } + +func TestBodyLimit_skipper(t *testing.T) { + e := echo.New() + h := func(c *echo.Context) error { + body, err := io.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + mw, err := BodyLimitConfig{ + Skipper: func(c *echo.Context) bool { + return true + }, + LimitBytes: 2, + }.ToMiddleware() + assert.NoError(t, err) + + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) +} + +func TestBodyLimitWithConfig(t *testing.T) { + e := echo.New() + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c *echo.Context) error { + body, err := io.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + + mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB}) + + err := mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) +} + +func TestBodyLimit(t *testing.T) { + e := echo.New() + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c *echo.Context) error { + body, err := io.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + + mw := BodyLimit(2 * MB) + + err := mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) +} diff --git a/middleware/compress.go b/middleware/compress.go index 89da16efe..7754d5db8 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -1,65 +1,90 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( "bufio" + "bytes" "compress/gzip" + "errors" "io" - "io/ioutil" "net" "net/http" "strings" + "sync" - "github.com/labstack/echo/v4" -) - -type ( - // GzipConfig defines the config for Gzip middleware. - GzipConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Gzip compression level. - // Optional. Default value -1. - Level int `yaml:"level"` - } - - gzipResponseWriter struct { - io.Writer - http.ResponseWriter - } + "github.com/labstack/echo/v5" ) const ( gzipScheme = "gzip" ) -var ( - // DefaultGzipConfig is the default Gzip middleware config. - DefaultGzipConfig = GzipConfig{ - Skipper: DefaultSkipper, - Level: -1, - } -) +// GzipConfig defines the config for Gzip middleware. +type GzipConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Gzip compression level. + // Optional. Default value -1. + Level int + + // Length threshold before gzip compression is applied. + // Optional. Default value 0. + // + // Most of the time you will not need to change the default. Compressing + // a short response might increase the transmitted data because of the + // gzip format overhead. Compressing the response will also consume CPU + // and time on the server and the client (for decompressing). Depending on + // your use case such a threshold might be useful. + // + // See also: + // https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits + MinLength int +} + +type gzipResponseWriter struct { + io.Writer + http.ResponseWriter + wroteHeader bool + wroteBody bool + minLength int + minLengthExceeded bool + buffer *bytes.Buffer + code int +} -// Gzip returns a middleware which compresses HTTP response using gzip compression -// scheme. +// Gzip returns a middleware which compresses HTTP response using gzip compression scheme. func Gzip() echo.MiddlewareFunc { - return GzipWithConfig(DefaultGzipConfig) + return GzipWithConfig(GzipConfig{}) } -// GzipWithConfig return Gzip middleware with config. -// See: `Gzip()`. +// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme. func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration +func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultGzipConfig.Skipper + config.Skipper = DefaultSkipper + } + if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression + return nil, errors.New("invalid gzip level") } if config.Level == 0 { - config.Level = DefaultGzipConfig.Level + config.Level = -1 } + if config.MinLength < 0 { + config.MinLength = 0 + } + + pool := gzipCompressPool(config) + bpool := bufferPool() return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } @@ -67,55 +92,144 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { res := c.Response() res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) { - res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 - rw := res.Writer - w, err := gzip.NewWriterLevel(rw, config.Level) - if err != nil { - return err + i := pool.Get() + w, ok := i.(*gzip.Writer) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, "invalid pool object") } + rw := res + w.Reset(rw) + buf := bpool.Get().(*bytes.Buffer) + buf.Reset() + + grw := &gzipResponseWriter{ + Writer: w, + ResponseWriter: rw, + minLength: config.MinLength, + buffer: buf, + } + c.SetResponse(grw) defer func() { - if res.Size == 0 { + // There are different reasons for cases when we have not yet written response to the client and now need to do so. + // a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now. + // b) body is shorter than our minimum length threshold and being buffered currently and needs to be written + if !grw.wroteBody { if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { res.Header().Del(echo.HeaderContentEncoding) } + if grw.wroteHeader { + rw.WriteHeader(grw.code) + } // We have to reset response to it's pristine state when // nothing is written to body or error is returned. // See issue #424, #407. - res.Writer = rw - w.Reset(ioutil.Discard) + c.SetResponse(rw) + w.Reset(io.Discard) + } else if !grw.minLengthExceeded { + // Write uncompressed response + c.SetResponse(rw) + if grw.wroteHeader { + grw.ResponseWriter.WriteHeader(grw.code) + } + _, _ = grw.buffer.WriteTo(rw) + w.Reset(io.Discard) } - w.Close() + _ = w.Close() + bpool.Put(buf) + pool.Put(w) }() - grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw} - res.Writer = grw } return next(c) } - } + }, nil } func (w *gzipResponseWriter) WriteHeader(code int) { - if code == http.StatusNoContent { // Issue #489 - w.ResponseWriter.Header().Del(echo.HeaderContentEncoding) - } w.Header().Del(echo.HeaderContentLength) // Issue #444 - w.ResponseWriter.WriteHeader(code) + + w.wroteHeader = true + + // Delay writing of the header until we know if we'll actually compress the response + w.code = code } func (w *gzipResponseWriter) Write(b []byte) (int, error) { if w.Header().Get(echo.HeaderContentType) == "" { w.Header().Set(echo.HeaderContentType, http.DetectContentType(b)) } + w.wroteBody = true + + if !w.minLengthExceeded { + n, err := w.buffer.Write(b) + + if w.buffer.Len() >= w.minLength { + w.minLengthExceeded = true + + // The minimum length is exceeded, add Content-Encoding header and write the header + w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 + if w.wroteHeader { + w.ResponseWriter.WriteHeader(w.code) + } + + return w.Writer.Write(w.buffer.Bytes()) + } + + return n, err + } + return w.Writer.Write(b) } func (w *gzipResponseWriter) Flush() { - w.Writer.(*gzip.Writer).Flush() - if flusher, ok := w.ResponseWriter.(http.Flusher); ok { - flusher.Flush() + if !w.minLengthExceeded { + // Enforce compression because we will not know how much more data will come + w.minLengthExceeded = true + w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 + if w.wroteHeader { + w.ResponseWriter.WriteHeader(w.code) + } + + _, _ = w.Writer.Write(w.buffer.Bytes()) + } + + if gw, ok := w.Writer.(*gzip.Writer); ok { + gw.Flush() } + _ = http.NewResponseController(w.ResponseWriter).Flush() } func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return w.ResponseWriter.(http.Hijacker).Hijack() + return http.NewResponseController(w.ResponseWriter).Hijack() +} + +func (w *gzipResponseWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + +func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { + if p, ok := w.ResponseWriter.(http.Pusher); ok { + return p.Push(target, opts) + } + return http.ErrNotSupported +} + +func gzipCompressPool(config GzipConfig) sync.Pool { + return sync.Pool{ + New: func() any { + w, err := gzip.NewWriterLevel(io.Discard, config.Level) + if err != nil { + return err + } + return w + }, + } +} + +func bufferPool() sync.Pool { + return sync.Pool{ + New: func() any { + b := &bytes.Buffer{} + return b + }, + } } diff --git a/middleware/compress_test.go b/middleware/compress_test.go index ac5b6c3bb..084ffc9c7 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -1,102 +1,140 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( "bytes" "compress/gzip" "io" - "io/ioutil" "net/http" "net/http/httptest" + "os" "testing" + "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) -func TestGzip(t *testing.T) { +func TestGzip_NoAcceptEncodingHeader(t *testing.T) { + // Skip if no Accept-Encoding header + h := Gzip()(func(c *echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - // Skip if no Accept-Encoding header - h := Gzip()(func(c echo.Context) error { + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, "test", rec.Body.String()) +} + +func TestMustGzipWithConfig_panics(t *testing.T) { + assert.Panics(t, func() { + GzipWithConfig(GzipConfig{Level: 999}) + }) +} + +func TestGzip_AcceptEncodingHeader(t *testing.T) { + h := Gzip()(func(c *echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) - h(c) - assert := assert.New(t) + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - assert.Equal("test", rec.Body.String()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - // Gzip - req = httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h(c) - assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) - assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) - r, err := gzip.NewReader(rec.Body) - if assert.NoError(err) { - buf := new(bytes.Buffer) - defer r.Close() - buf.ReadFrom(r) - assert.Equal("test", buf.String()) - } + err := h(c) + assert.NoError(t, err) - chunkBuf := make([]byte, 5) + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) - // Gzip chunked - req = httptest.NewRequest(http.MethodGet, "/", nil) + r, err := gzip.NewReader(rec.Body) + assert.NoError(t, err) + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal(t, "test", buf.String()) +} + +func TestGzip_chunked(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec = httptest.NewRecorder() + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - c = e.NewContext(req, rec) - Gzip()(func(c echo.Context) error { + chunkChan := make(chan struct{}) + waitChan := make(chan struct{}) + h := Gzip()(func(c *echo.Context) error { + rc := http.NewResponseController(c.Response()) c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Transfer-Encoding", "chunked") // Write and flush the first part of the data - c.Response().Write([]byte("test\n")) - c.Response().Flush() - - // Read the first part of the data - assert.True(rec.Flushed) - assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) - r.Reset(rec.Body) + c.Response().Write([]byte("first\n")) + rc.Flush() - _, err = io.ReadFull(r, chunkBuf) - assert.NoError(err) - assert.Equal("test\n", string(chunkBuf)) + chunkChan <- struct{}{} + <-waitChan // Write and flush the second part of the data - c.Response().Write([]byte("test\n")) - c.Response().Flush() + c.Response().Write([]byte("second\n")) + rc.Flush() - _, err = io.ReadFull(r, chunkBuf) - assert.NoError(err) - assert.Equal("test\n", string(chunkBuf)) + chunkChan <- struct{}{} + <-waitChan // Write the final part of the data and return - c.Response().Write([]byte("test")) + c.Response().Write([]byte("third")) + + chunkChan <- struct{}{} return nil - })(c) + }) + + go func() { + err := h(c) + chunkChan <- struct{}{} + assert.NoError(t, err) + }() + <-chunkChan // wait for first write + waitChan <- struct{}{} + + <-chunkChan // wait for second write + waitChan <- struct{}{} + + <-chunkChan // wait for final write in handler + <-chunkChan // wait for return from handler + time.Sleep(5 * time.Millisecond) // to have time for flushing + + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + + r, err := gzip.NewReader(rec.Body) + assert.NoError(t, err) buf := new(bytes.Buffer) - defer r.Close() buf.ReadFrom(r) - assert.Equal("test", buf.String()) + assert.Equal(t, "first\nsecond\nthird", buf.String()) } -func TestGzipNoContent(t *testing.T) { +func TestGzip_NoContent(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := Gzip()(func(c echo.Context) error { + h := Gzip()(func(c *echo.Context) error { return c.NoContent(http.StatusNoContent) }) if assert.NoError(t, h(c)) { @@ -106,10 +144,31 @@ func TestGzipNoContent(t *testing.T) { } } -func TestGzipErrorReturned(t *testing.T) { +func TestGzip_Empty(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := Gzip()(func(c *echo.Context) error { + return c.String(http.StatusOK, "") + }) + if assert.NoError(t, h(c)) { + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get(echo.HeaderContentType)) + r, err := gzip.NewReader(rec.Body) + if assert.NoError(t, err) { + var buf bytes.Buffer + buf.ReadFrom(r) + assert.Equal(t, "", buf.String()) + } + } +} + +func TestGzip_ErrorReturned(t *testing.T) { e := echo.New() e.Use(Gzip()) - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { return echo.ErrNotFound }) req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -120,15 +179,25 @@ func TestGzipErrorReturned(t *testing.T) { assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) } +func TestGzipWithConfig_invalidLevel(t *testing.T) { + mw, err := GzipConfig{Level: 12}.ToMiddleware() + assert.EqualError(t, err, "invalid gzip level") + assert.Nil(t, mw) +} + // Issue #806 func TestGzipWithStatic(t *testing.T) { e := echo.New() + e.Filesystem = os.DirFS("../") + e.Use(Gzip()) - e.Static("/test", "../_fixture/images") + e.Static("/test", "_fixture/images") req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) // Data is written out in chunks when Content-Length == "", so only // validate the content length if it's not set. @@ -138,7 +207,7 @@ func TestGzipWithStatic(t *testing.T) { r, err := gzip.NewReader(rec.Body) if assert.NoError(t, err) { defer r.Close() - want, err := ioutil.ReadFile("../_fixture/images/walle.png") + want, err := os.ReadFile("../_fixture/images/walle.png") if assert.NoError(t, err) { buf := new(bytes.Buffer) buf.ReadFrom(r) @@ -146,3 +215,184 @@ func TestGzipWithStatic(t *testing.T) { } } } + +func TestGzipWithMinLength(t *testing.T) { + e := echo.New() + // Minimal response length + e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) + e.GET("/", func(c *echo.Context) error { + c.Response().Write([]byte("foobarfoobar")) + return nil + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + r, err := gzip.NewReader(rec.Body) + if assert.NoError(t, err) { + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal(t, "foobarfoobar", buf.String()) + } +} + +func TestGzipWithMinLengthTooShort(t *testing.T) { + e := echo.New() + // Minimal response length + e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) + e.GET("/", func(c *echo.Context) error { + c.Response().Write([]byte("test")) + return nil + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding)) + assert.Contains(t, rec.Body.String(), "test") +} + +func TestGzipWithResponseWithoutBody(t *testing.T) { + e := echo.New() + + e.Use(Gzip()) + e.GET("/", func(c *echo.Context) error { + return c.Redirect(http.StatusMovedPermanently, "http://localhost") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusMovedPermanently, rec.Code) + assert.Equal(t, "", rec.Header().Get(echo.HeaderContentEncoding)) +} + +func TestGzipWithMinLengthChunked(t *testing.T) { + e := echo.New() + + // Gzip chunked + chunkBuf := make([]byte, 5) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + + var r *gzip.Reader = nil + + c := e.NewContext(req, rec) + next := func(c *echo.Context) error { + rc := http.NewResponseController(c.Response()) + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Transfer-Encoding", "chunked") + + // Write and flush the first part of the data + c.Response().Write([]byte("test\n")) + rc.Flush() + + // Read the first part of the data + assert.True(t, rec.Flushed) + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + + var err error + r, err = gzip.NewReader(rec.Body) + assert.NoError(t, err) + + _, err = io.ReadFull(r, chunkBuf) + assert.NoError(t, err) + assert.Equal(t, "test\n", string(chunkBuf)) + + // Write and flush the second part of the data + c.Response().Write([]byte("test\n")) + rc.Flush() + + _, err = io.ReadFull(r, chunkBuf) + assert.NoError(t, err) + assert.Equal(t, "test\n", string(chunkBuf)) + + // Write the final part of the data and return + c.Response().Write([]byte("test")) + return nil + } + err := GzipWithConfig(GzipConfig{MinLength: 10})(next)(c) + + assert.NoError(t, err) + assert.NotNil(t, r) + + buf := new(bytes.Buffer) + + buf.ReadFrom(r) + assert.Equal(t, "test", buf.String()) + + r.Close() +} + +func TestGzipWithMinLengthNoContent(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c *echo.Context) error { + return c.NoContent(http.StatusNoContent) + }) + if assert.NoError(t, h(c)) { + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) + assert.Equal(t, 0, len(rec.Body.Bytes())) + } +} + +func TestGzipResponseWriter_CanUnwrap(t *testing.T) { + trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()} + bdrw := gzipResponseWriter{ + ResponseWriter: trwu, + } + result := bdrw.Unwrap() + assert.Equal(t, trwu, result) +} + +func TestGzipResponseWriter_CanHijack(t *testing.T) { + trwu := testResponseWriterUnwrapperHijack{testResponseWriterUnwrapper: testResponseWriterUnwrapper{rw: httptest.NewRecorder()}} + bdrw := gzipResponseWriter{ + ResponseWriter: &trwu, // this RW supports hijacking through unwrapping + } + _, _, err := bdrw.Hijack() + assert.EqualError(t, err, "can hijack") +} + +func TestGzipResponseWriter_CanNotHijack(t *testing.T) { + trwu := testResponseWriterUnwrapper{rw: httptest.NewRecorder()} + bdrw := gzipResponseWriter{ + ResponseWriter: &trwu, // this RW supports hijacking through unwrapping + } + _, _, err := bdrw.Hijack() + assert.EqualError(t, err, "feature not supported") +} + +func BenchmarkGzip(b *testing.B) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + + h := Gzip()(func(c *echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // Gzip + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h(c) + } +} diff --git a/middleware/context_timeout.go b/middleware/context_timeout.go new file mode 100644 index 000000000..68465199a --- /dev/null +++ b/middleware/context_timeout.go @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package middleware + +import ( + "context" + "errors" + "time" + + "github.com/labstack/echo/v5" +) + +// ContextTimeoutConfig defines the config for ContextTimeout middleware. +type ContextTimeoutConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // ErrorHandler is a function when error arises in middeware execution. + ErrorHandler func(c *echo.Context, err error) error + + // Timeout configures a timeout for the middleware + Timeout time.Duration +} + +// ContextTimeout returns a middleware which returns error (503 Service Unavailable error) to client +// when underlying method returns context.DeadlineExceeded error. +func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc { + return ContextTimeoutWithConfig(ContextTimeoutConfig{Timeout: timeout}) +} + +// ContextTimeoutWithConfig returns a Timeout middleware with config. +func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts Config to middleware. +func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) { + if config.Timeout == 0 { + return nil, errors.New("timeout must be set") + } + if config.Skipper == nil { + config.Skipper = DefaultSkipper + } + if config.ErrorHandler == nil { + config.ErrorHandler = func(c *echo.Context, err error) error { + if err != nil && errors.Is(err, context.DeadlineExceeded) { + return echo.ErrServiceUnavailable.Wrap(err) + } + return err + } + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c *echo.Context) error { + if config.Skipper(c) { + return next(c) + } + + timeoutContext, cancel := context.WithTimeout(c.Request().Context(), config.Timeout) + defer cancel() + + c.SetRequest(c.Request().WithContext(timeoutContext)) + + if err := next(c); err != nil { + return config.ErrorHandler(c, err) + } + return nil + } + }, nil +} diff --git a/middleware/context_timeout_test.go b/middleware/context_timeout_test.go new file mode 100644 index 000000000..c7ba76beb --- /dev/null +++ b/middleware/context_timeout_test.go @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package middleware + +import ( + "context" + "errors" + "github.com/labstack/echo/v5" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestContextTimeoutSkipper(t *testing.T) { + t.Parallel() + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + Skipper: func(context *echo.Context) bool { + return true + }, + Timeout: 10 * time.Millisecond, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c *echo.Context) error { + if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil { + return err + } + + return errors.New("response from handler") + })(c) + + // if not skipped we would have not returned error due context timeout logic + assert.EqualError(t, err, "response from handler") +} + +func TestContextTimeoutWithTimeout0(t *testing.T) { + t.Parallel() + assert.Panics(t, func() { + ContextTimeout(time.Duration(0)) + }) +} + +func TestContextTimeoutErrorOutInHandler(t *testing.T) { + t.Parallel() + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + // Timeout has to be defined or the whole flow for timeout middleware will be skipped + Timeout: 10 * time.Millisecond, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + rec.Code = 1 // we want to be sure that even 200 will not be sent + err := m(func(c *echo.Context) error { + // this error must not be written to the client response. Middlewares upstream of timeout middleware must be able + // to handle returned error and this can be done only then handler has not yet committed (written status code) + // the response. + return echo.NewHTTPError(http.StatusTeapot, "err") + })(c) + + assert.Error(t, err) + assert.EqualError(t, err, "code=418, message=err") + assert.Equal(t, 1, rec.Code) + assert.Equal(t, "", rec.Body.String()) +} + +func TestContextTimeoutSuccessfulRequest(t *testing.T) { + t.Parallel() + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + // Timeout has to be defined or the whole flow for timeout middleware will be skipped + Timeout: 10 * time.Millisecond, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c *echo.Context) error { + return c.JSON(http.StatusCreated, map[string]string{"data": "ok"}) + })(c) + + assert.NoError(t, err) + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Equal(t, "{\"data\":\"ok\"}\n", rec.Body.String()) +} + +func TestContextTimeoutTestRequestClone(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(http.MethodPost, "/uri?query=value", strings.NewReader(url.Values{"form": {"value"}}.Encode())) + req.AddCookie(&http.Cookie{Name: "cookie", Value: "value"}) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + // Timeout has to be defined or the whole flow for timeout middleware will be skipped + Timeout: 1 * time.Second, + }) + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c *echo.Context) error { + // Cookie test + cookie, err := c.Request().Cookie("cookie") + if assert.NoError(t, err) { + assert.EqualValues(t, "cookie", cookie.Name) + assert.EqualValues(t, "value", cookie.Value) + } + + // Form values + if assert.NoError(t, c.Request().ParseForm()) { + assert.EqualValues(t, "value", c.Request().FormValue("form")) + } + + // Query string + assert.EqualValues(t, "value", c.Request().URL.Query()["query"][0]) + return nil + })(c) + + assert.NoError(t, err) +} + +func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) { + t.Parallel() + + timeout := 10 * time.Millisecond + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + Timeout: timeout, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c *echo.Context) error { + if err := sleepWithContext(c.Request().Context(), time.Duration(80*time.Millisecond)); err != nil { + return err + } + return c.String(http.StatusOK, "Hello, World!") + })(c) + + assert.Error(t, err) + if assert.IsType(t, &echo.HTTPError{}, err) { + assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code) + assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message) + } +} + +func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) { + t.Parallel() + + timeoutErrorHandler := func(c *echo.Context, err error) error { + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return &echo.HTTPError{ + Code: http.StatusServiceUnavailable, + Message: "Timeout! change me", + } + } + return err + } + return nil + } + + timeout := 50 * time.Millisecond + m := ContextTimeoutWithConfig(ContextTimeoutConfig{ + Timeout: timeout, + ErrorHandler: timeoutErrorHandler, + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e := echo.New() + c := e.NewContext(req, rec) + + err := m(func(c *echo.Context) error { + // NOTE: Very short periods are not reliable for tests due to Go routine scheduling and the unpredictable order + // for 1) request and 2) time goroutine. For most OS this works as expected, but MacOS seems most flaky. + + if err := sleepWithContext(c.Request().Context(), 100*time.Millisecond); err != nil { + return err + } + + // The Request Context should have a Deadline set by http.ContextTimeoutHandler + if _, ok := c.Request().Context().Deadline(); !ok { + assert.Fail(t, "No timeout set on Request Context") + } + return c.String(http.StatusOK, "Hello, World!") + })(c) + + assert.IsType(t, &echo.HTTPError{}, err) + assert.Error(t, err) + assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code) + assert.Equal(t, "Timeout! change me", err.(*echo.HTTPError).Message) +} + +func sleepWithContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + + defer func() { + _ = timer.Stop() + }() + + select { + case <-ctx.Done(): + return context.DeadlineExceeded + case <-timer.C: + return nil + } +} diff --git a/middleware/cors.go b/middleware/cors.go index 5dfe31f95..96ed16985 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -1,88 +1,190 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( + "errors" + "fmt" "net/http" "strconv" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // CORSConfig defines the config for CORS middleware. - CORSConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // AllowOrigin defines a list of origins that may access the resource. - // Optional. Default value []string{"*"}. - AllowOrigins []string `yaml:"allow_origins"` - - // AllowMethods defines a list methods allowed when accessing the resource. - // This is used in response to a preflight request. - // Optional. Default value DefaultCORSConfig.AllowMethods. - AllowMethods []string `yaml:"allow_methods"` - - // AllowHeaders defines a list of request headers that can be used when - // making the actual request. This is in response to a preflight request. - // Optional. Default value []string{}. - AllowHeaders []string `yaml:"allow_headers"` - - // AllowCredentials indicates whether or not the response to the request - // can be exposed when the credentials flag is true. When used as part of - // a response to a preflight request, this indicates whether or not the - // actual request can be made using credentials. - // Optional. Default value false. - AllowCredentials bool `yaml:"allow_credentials"` - - // ExposeHeaders defines a whitelist headers that clients are allowed to - // access. - // Optional. Default value []string{}. - ExposeHeaders []string `yaml:"expose_headers"` - - // MaxAge indicates how long (in seconds) the results of a preflight request - // can be cached. - // Optional. Default value 0. - MaxAge int `yaml:"max_age"` - } -) +// CORSConfig defines the config for CORS middleware. +type CORSConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper -var ( - // DefaultCORSConfig is the default CORS middleware config. - DefaultCORSConfig = CORSConfig{ - Skipper: DefaultSkipper, - AllowOrigins: []string{"*"}, - AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, - } -) + // AllowOrigins determines the value of the Access-Control-Allow-Origin + // response header. This header defines a list of origins that may access the + // resource. + // + // Origin consist of following parts: `scheme + "://" + host + optional ":" + port` + // Wildcard can be used, but has to be set explicitly []string{"*"} + // Example: `https://example.com`, `http://example.com:8080`, `*` + // + // Security: use extreme caution when handling the origin, and carefully + // validate any logic. Remember that attackers may register hostile domain names. + // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin + // + // Mandatory. + AllowOrigins []string + + // UnsafeAllowOriginFunc is an optional custom function to validate the origin. It takes the + // origin as an argument and returns + // - string, allowed origin + // - bool, true if allowed or false otherwise. + // - error, if an error is returned, it is returned immediately by the handler. + // If this option is set, AllowOrigins is ignored. + // + // Security: use extreme caution when handling the origin, and carefully + // validate any logic. Remember that attackers may register hostile (sub)domain names. + // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // + // Sub-domain checks example: + // UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) { + // if strings.HasSuffix(origin, ".example.com") { + // return origin, true, nil + // } + // return "", false, nil + // }, + // + // Optional. + UnsafeAllowOriginFunc func(c *echo.Context, origin string) (allowedOrigin string, allowed bool, err error) + + // AllowMethods determines the value of the Access-Control-Allow-Methods + // response header. This header specified the list of methods allowed when + // accessing the resource. This is used in response to a preflight request. + // + // Optional. Default value DefaultCORSConfig.AllowMethods. + // If `allowMethods` is left empty, this middleware will fill for preflight + // request `Access-Control-Allow-Methods` header value + // from `Allow` header that echo.Router set into context. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + AllowMethods []string + + // AllowHeaders determines the value of the Access-Control-Allow-Headers + // response header. This header is used in response to a preflight request to + // indicate which HTTP headers can be used when making the actual request. + // + // Optional. Defaults to empty list. No domains allowed for CORS. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + AllowHeaders []string + + // AllowCredentials determines the value of the + // Access-Control-Allow-Credentials response header. This header indicates + // whether or not the response to the request can be exposed when the + // credentials mode (Request.credentials) is true. When used as part of a + // response to a preflight request, this indicates whether or not the actual + // request can be made using credentials. See also + // [MDN: Access-Control-Allow-Credentials]. + // + // Optional. Default value false, in which case the header is not set. + // + // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. + // See "Exploiting CORS misconfigurations for Bitcoins and bounties", + // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials + AllowCredentials bool + + // ExposeHeaders determines the value of Access-Control-Expose-Headers, which + // defines a list of headers that clients are allowed to access. + // + // Optional. Default value []string{}, in which case the header is not set. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header + ExposeHeaders []string + + // MaxAge determines the value of the Access-Control-Max-Age response header. + // This header indicates how long (in seconds) the results of a preflight + // request can be cached. + // The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response. + // + // Optional. Default value 0 - meaning header is not sent. + // + // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age + MaxAge int +} // CORS returns a Cross-Origin Resource Sharing (CORS) middleware. -// See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS -func CORS() echo.MiddlewareFunc { - return CORSWithConfig(DefaultCORSConfig) +// See also [MDN: Cross-Origin Resource Sharing (CORS)]. +// +// Origin consist of following parts: `scheme + "://" + host + optional ":" + port` +// Wildcard `*` can be used, but has to be set explicitly. +// Example: `https://example.com`, `http://example.com:8080`, `*` +// +// Security: Poorly configured CORS can compromise security because it allows +// relaxation of the browser's Same-Origin policy. See [Exploiting CORS +// misconfigurations for Bitcoins and bounties] and [Portswigger: Cross-origin +// resource sharing (CORS)] for more details. +// +// [MDN: Cross-Origin Resource Sharing (CORS)]: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS +// [Exploiting CORS misconfigurations for Bitcoins and bounties]: https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html +// [Portswigger: Cross-origin resource sharing (CORS)]: https://portswigger.net/web-security/cors +func CORS(allowOrigins ...string) echo.MiddlewareFunc { + c := CORSConfig{ + AllowOrigins: allowOrigins, + } + return CORSWithConfig(c) } -// CORSWithConfig returns a CORS middleware with config. -// See: `CORS()`. +// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration. +// See: [CORS]. func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration +func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { - config.Skipper = DefaultCORSConfig.Skipper - } - if len(config.AllowOrigins) == 0 { - config.AllowOrigins = DefaultCORSConfig.AllowOrigins + config.Skipper = DefaultSkipper } + hasCustomAllowMethods := true if len(config.AllowMethods) == 0 { - config.AllowMethods = DefaultCORSConfig.AllowMethods + hasCustomAllowMethods = false + config.AllowMethods = []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete} } allowMethods := strings.Join(config.AllowMethods, ",") allowHeaders := strings.Join(config.AllowHeaders, ",") exposeHeaders := strings.Join(config.ExposeHeaders, ",") - maxAge := strconv.Itoa(config.MaxAge) + + maxAge := "0" + if config.MaxAge > 0 { + maxAge = strconv.Itoa(config.MaxAge) + } + + allowOriginFunc := config.UnsafeAllowOriginFunc + if config.UnsafeAllowOriginFunc == nil { + if len(config.AllowOrigins) == 0 { + return nil, errors.New("at least one AllowOrigins is required or UnsafeAllowOriginFunc must be provided") + } + allowOriginFunc = config.defaultAllowOriginFunc + for _, origin := range config.AllowOrigins { + if origin == "*" { + if config.AllowCredentials { + return nil, fmt.Errorf("* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc") + } + allowOriginFunc = config.starAllowOriginFunc + break + } + if err := validateOrigin(origin, "allow origin"); err != nil { + return nil, err + } + } + config.AllowOrigins = append([]string(nil), config.AllowOrigins...) + } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } @@ -90,46 +192,84 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { req := c.Request() res := c.Response() origin := req.Header.Get(echo.HeaderOrigin) - allowOrigin := "" - // Check allowed origins - for _, o := range config.AllowOrigins { - if o == "*" && config.AllowCredentials { - allowOrigin = origin - break - } - if o == "*" || o == origin { - allowOrigin = o - break + res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) + + // Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method, + // Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request + // For simplicity we just consider method type and later `Origin` header. + preflight := req.Method == http.MethodOptions + + // Although router adds special handler in case of OPTIONS method we avoid calling next for OPTIONS in this middleware + // as CORS requests do not have cookies / authentication headers by default, so we could get stuck in auth + // middlewares by calling next(c). + // But we still want to send `Allow` header as response in case of Non-CORS OPTIONS request as router default + // handler does. + routerAllowMethods := "" + if preflight { + tmpAllowMethods, ok := c.Get(echo.ContextKeyHeaderAllow).(string) + if ok && tmpAllowMethods != "" { + routerAllowMethods = tmpAllowMethods + c.Response().Header().Set(echo.HeaderAllow, routerAllowMethods) } - if matchSubdomain(origin, o) { - allowOrigin = origin - break + } + + // No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain + if origin == "" { + if preflight { // req.Method=OPTIONS + return c.NoContent(http.StatusNoContent) } + return next(c) // let non-browser calls through } - // Simple request - if req.Method != http.MethodOptions { - res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) - res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) - if config.AllowCredentials { - res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") + allowedOrigin, allowed, err := allowOriginFunc(c, origin) + if err != nil { + return err + } + if !allowed { + // Origin existed and was NOT allowed + if preflight { + // From: https://github.com/labstack/echo/issues/2767 + // If the request's origin isn't allowed by the CORS configuration, + // the middleware should simply omit the relevant CORS headers from the response + // and let the browser fail the CORS check (if any). + return c.NoContent(http.StatusNoContent) } + // From: https://github.com/labstack/echo/issues/2767 + // no CORS middleware should block non-preflight requests; + // such requests should be let through. One reason is that not all requests that + // carry an Origin header participate in the CORS protocol. + return next(c) + } + + // Origin existed and was allowed + + res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) + if config.AllowCredentials { + res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") + } + + // Simple request will be let though + if !preflight { if exposeHeaders != "" { res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders) } return next(c) } - - // Preflight request - res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) + // Below code is for Preflight (OPTIONS) request + // + // Preflight will end with c.NoContent(http.StatusNoContent) as we do not know if + // at the end of handler chain is actual OPTIONS route or 404/405 route which + // response code will confuse browsers res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) - res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) - res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods) - if config.AllowCredentials { - res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") + + if !hasCustomAllowMethods && routerAllowMethods != "" { + res.Header().Set(echo.HeaderAccessControlAllowMethods, routerAllowMethods) + } else { + res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods) } + if allowHeaders != "" { res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders) } else { @@ -138,10 +278,23 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { res.Header().Set(echo.HeaderAccessControlAllowHeaders, h) } } - if config.MaxAge > 0 { + if config.MaxAge != 0 { res.Header().Set(echo.HeaderAccessControlMaxAge, maxAge) } return c.NoContent(http.StatusNoContent) } + }, nil +} + +func (config CORSConfig) starAllowOriginFunc(c *echo.Context, origin string) (string, bool, error) { + return "*", true, nil +} + +func (config CORSConfig) defaultAllowOriginFunc(c *echo.Context, origin string) (string, bool, error) { + for _, allowedOrigin := range config.AllowOrigins { + if strings.EqualFold(allowedOrigin, origin) { + return allowedOrigin, true, nil + } } + return "", false, nil } diff --git a/middleware/cors_test.go b/middleware/cors_test.go index acfdf47bc..5de4ca063 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -1,85 +1,628 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( + "cmp" + "errors" "net/http" "net/http/httptest" + "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestCORS(t *testing.T) { e := echo.New() - - // Wildcard origin - req := httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequest(http.MethodOptions, "/", nil) // Preflight request + req.Header.Set(echo.HeaderOrigin, "http://example.com") rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := CORS()(echo.NotFoundHandler) - h(c) - assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - // Allow origins - req = httptest.NewRequest(http.MethodGet, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"localhost"}, - })(echo.NotFoundHandler) - req.Header.Set(echo.HeaderOrigin, "localhost") - h(c) - assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - - // Preflight request - req = httptest.NewRequest(http.MethodOptions, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - req.Header.Set(echo.HeaderOrigin, "localhost") - req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) - cors := CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"localhost"}, - AllowCredentials: true, - MaxAge: 3600, + mw := CORS("*") + handler := mw(func(c *echo.Context) error { + return nil }) - h = cors(echo.NotFoundHandler) - h(c) - assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) - assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) - assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge)) - - // Preflight request with `AllowOrigins` * - req = httptest.NewRequest(http.MethodOptions, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - req.Header.Set(echo.HeaderOrigin, "localhost") - req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) - cors = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"*"}, - AllowCredentials: true, - MaxAge: 3600, - }) - h = cors(echo.NotFoundHandler) - h(c) - assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) - assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials)) - assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge)) - - // Preflight request with `AllowOrigins` which allow all subdomains with * - req = httptest.NewRequest(http.MethodOptions, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - req.Header.Set(echo.HeaderOrigin, "http://aaa.example.com") - cors = CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"http://*.example.com"}, - }) - h = cors(echo.NotFoundHandler) - h(c) - assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - req.Header.Set(echo.HeaderOrigin, "http://bbb.example.com") - h(c) - assert.Equal(t, "http://bbb.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + err := handler(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusNoContent, rec.Code) + assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) +} + +func TestCORSConfig(t *testing.T) { + var testCases = []struct { + name string + givenConfig *CORSConfig + whenMethod string + whenHeaders map[string]string + expectHeaders map[string]string + notExpectHeaders map[string]string + expectErr string + }{ + { + name: "ok, wildcard origin", + givenConfig: &CORSConfig{ + AllowOrigins: []string{"*"}, + }, + whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"}, + expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "*"}, + }, + { + name: "ok, wildcard AllowedOrigin with no Origin header in request", + givenConfig: &CORSConfig{ + AllowOrigins: []string{"*"}, + }, + notExpectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: ""}, + }, + { + name: "ok, specific AllowOrigins and AllowCredentials", + givenConfig: &CORSConfig{ + AllowOrigins: []string{"http://localhost", "http://localhost:8080"}, + AllowCredentials: true, + MaxAge: 3600, + }, + whenHeaders: map[string]string{echo.HeaderOrigin: "http://localhost"}, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "http://localhost", + echo.HeaderAccessControlAllowCredentials: "true", + }, + }, + { + name: "ok, preflight request with matching origin for `AllowOrigins`", + givenConfig: &CORSConfig{ + AllowOrigins: []string{"http://localhost"}, + AllowCredentials: true, + MaxAge: 3600, + }, + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "http://localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "http://localhost", + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + echo.HeaderAccessControlAllowCredentials: "true", + echo.HeaderAccessControlMaxAge: "3600", + }, + }, + { + name: "ok, preflight request when `Access-Control-Max-Age` is set", + givenConfig: &CORSConfig{ + AllowOrigins: []string{"http://localhost"}, + AllowCredentials: true, + MaxAge: 1, + }, + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "http://localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlMaxAge: "1", + }, + }, + { + name: "ok, preflight request when `Access-Control-Max-Age` is set to 0 - not to cache response", + givenConfig: &CORSConfig{ + AllowOrigins: []string{"http://localhost"}, + AllowCredentials: true, + MaxAge: -1, // forces `Access-Control-Max-Age: 0` + }, + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "http://localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlMaxAge: "0", + }, + }, + { + name: "ok, CORS check are skipped", + givenConfig: &CORSConfig{ + AllowOrigins: []string{"http://localhost"}, + AllowCredentials: true, + Skipper: func(c *echo.Context) bool { + return true + }, + }, + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "http://localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + notExpectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "localhost", + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + echo.HeaderAccessControlAllowCredentials: "true", + echo.HeaderAccessControlMaxAge: "3600", + }, + }, + { + name: "nok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true", + givenConfig: &CORSConfig{ + AllowOrigins: []string{"*"}, + AllowCredentials: true, + MaxAge: 3600, + }, + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectErr: `* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc`, + }, + { + name: "nok, preflight request with invalid `AllowOrigins` value", + givenConfig: &CORSConfig{ + AllowOrigins: []string{"http://server", "missing-scheme"}, + }, + expectErr: `allow origin is missing scheme or host: missing-scheme`, + }, + { + name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` false", + givenConfig: &CORSConfig{ + AllowOrigins: []string{"*"}, + AllowCredentials: false, // important for this testcase + MaxAge: 3600, + }, + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "*", + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + echo.HeaderAccessControlMaxAge: "3600", + }, + notExpectHeaders: map[string]string{ + echo.HeaderAccessControlAllowCredentials: "", + }, + }, + { + name: "ok, INSECURE preflight request with wildcard `AllowOrigins` and `AllowCredentials` true", + givenConfig: &CORSConfig{ + AllowOrigins: []string{"*"}, + AllowCredentials: true, + MaxAge: 3600, + }, + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + }, + expectErr: `* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc`, + }, + { + name: "ok, preflight request with Access-Control-Request-Headers", + givenConfig: &CORSConfig{ + AllowOrigins: []string{"*"}, + }, + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{ + echo.HeaderOrigin: "localhost", + echo.HeaderContentType: echo.MIMEApplicationJSON, + echo.HeaderAccessControlRequestHeaders: "Special-Request-Header", + }, + expectHeaders: map[string]string{ + echo.HeaderAccessControlAllowOrigin: "*", + echo.HeaderAccessControlAllowHeaders: "Special-Request-Header", + echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + }, + }, + { + name: "ok, preflight request with `AllowOrigins` which allow all subdomains aaa with *", + givenConfig: &CORSConfig{ + UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (allowedOrigin string, allowed bool, err error) { + if strings.HasSuffix(origin, ".example.com") { + allowed = true + } + return origin, allowed, nil + }, + }, + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{echo.HeaderOrigin: "http://aaa.example.com"}, + expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://aaa.example.com"}, + }, + { + name: "ok, preflight request with `AllowOrigins` which allow all subdomains bbb with *", + givenConfig: &CORSConfig{ + UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) { + if strings.HasSuffix(origin, ".example.com") { + return origin, true, nil + } + return "", false, nil + }, + }, + whenMethod: http.MethodOptions, + whenHeaders: map[string]string{echo.HeaderOrigin: "http://bbb.example.com"}, + expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://bbb.example.com"}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + var mw echo.MiddlewareFunc + var err error + if tc.givenConfig != nil { + mw, err = tc.givenConfig.ToMiddleware() + } else { + mw, err = CORSConfig{}.ToMiddleware() + } + if err != nil { + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + return + } + t.Fatal(err) + } + + h := mw(func(c *echo.Context) error { + return nil + }) + + method := cmp.Or(tc.whenMethod, http.MethodGet) + req := httptest.NewRequest(method, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + for k, v := range tc.whenHeaders { + req.Header.Set(k, v) + } + + err = h(c) + + assert.NoError(t, err) + header := rec.Header() + for k, v := range tc.expectHeaders { + assert.Equal(t, v, header.Get(k), "header: `%v` should be `%v`", k, v) + } + for k, v := range tc.notExpectHeaders { + if v == "" { + assert.Len(t, header.Values(k), 0, "header: `%v` should not be set", k) + } else { + assert.NotEqual(t, v, header.Get(k), "header: `%v` should not be `%v`", k, v) + } + } + }) + } +} + +func Test_allowOriginScheme(t *testing.T) { + tests := []struct { + domain, pattern string + expected bool + }{ + { + domain: "http://example.com", + pattern: "http://example.com", + expected: true, + }, + { + domain: "https://example.com", + pattern: "https://example.com", + expected: true, + }, + { + domain: "http://example.com", + pattern: "https://example.com", + expected: false, + }, + { + domain: "https://example.com", + pattern: "http://example.com", + expected: false, + }, + } + + e := echo.New() + for _, tt := range tests { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + req.Header.Set(echo.HeaderOrigin, tt.domain) + cors := CORSWithConfig(CORSConfig{ + AllowOrigins: []string{tt.pattern}, + }) + h := cors(func(c *echo.Context) error { return echo.ErrNotFound }) + h(c) + + if tt.expected { + assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } else { + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + } + } +} + +func TestCORSWithConfig_AllowMethods(t *testing.T) { + var testCases = []struct { + name string + givenAllowOrigins []string + givenAllowMethods []string + whenAllowContextKey string + whenOrigin string + expectAllow string + expectAccessControlAllowMethods string + }{ + { + name: "custom AllowMethods, preflight, no origin, sets only allow header from context key", + givenAllowOrigins: []string{"*"}, + givenAllowMethods: []string{http.MethodGet, http.MethodHead}, + whenAllowContextKey: "OPTIONS, GET", + whenOrigin: "", + expectAllow: "OPTIONS, GET", + }, + { + name: "default AllowMethods, preflight, no origin, no allow header in context key and in response", + givenAllowOrigins: []string{"*"}, + givenAllowMethods: nil, + whenAllowContextKey: "", + whenOrigin: "", + expectAllow: "", + }, + { + name: "custom AllowMethods, preflight, existing origin, sets both headers different values", + givenAllowOrigins: []string{"*"}, + givenAllowMethods: []string{http.MethodGet, http.MethodHead}, + whenAllowContextKey: "OPTIONS, GET", + whenOrigin: "http://google.com", + expectAllow: "OPTIONS, GET", + expectAccessControlAllowMethods: "GET,HEAD", + }, + { + name: "default AllowMethods, preflight, existing origin, sets both headers", + givenAllowOrigins: []string{"*"}, + givenAllowMethods: nil, + whenAllowContextKey: "OPTIONS, GET", + whenOrigin: "http://google.com", + expectAllow: "OPTIONS, GET", + expectAccessControlAllowMethods: "OPTIONS, GET", + }, + { + name: "default AllowMethods, preflight, existing origin, no allows, sets only CORS allow methods", + givenAllowOrigins: []string{"*"}, + givenAllowMethods: nil, + whenAllowContextKey: "", + whenOrigin: "http://google.com", + expectAllow: "", + expectAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + e.GET("/test", func(c *echo.Context) error { + return c.String(http.StatusOK, "OK") + }) + + cors := CORSWithConfig(CORSConfig{ + AllowOrigins: tc.givenAllowOrigins, + AllowMethods: tc.givenAllowMethods, + }) + + req := httptest.NewRequest(http.MethodOptions, "/test", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + req.Header.Set(echo.HeaderOrigin, tc.whenOrigin) + if tc.whenAllowContextKey != "" { + c.Set(echo.ContextKeyHeaderAllow, tc.whenAllowContextKey) + } + + h := cors(func(c *echo.Context) error { + return c.String(http.StatusOK, "OK") + }) + h(c) + + assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow)) + assert.Equal(t, tc.expectAccessControlAllowMethods, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) + }) + } +} + +func TestCorsHeaders(t *testing.T) { + tests := []struct { + name string + originDomain string + method string + allowedOrigin string + expected bool + expectStatus int + expectAllowHeader string + }{ + { + name: "non-preflight request, allow any origin, missing origin header = no CORS logic done", + originDomain: "", + allowedOrigin: "*", + method: http.MethodGet, + expected: false, + expectStatus: http.StatusOK, + }, + { + name: "non-preflight request, allow any origin, specific origin domain", + originDomain: "http://example.com", + allowedOrigin: "*", + method: http.MethodGet, + expected: true, + expectStatus: http.StatusOK, + }, + { + name: "non-preflight request, allow specific origin, missing origin header = no CORS logic done", + originDomain: "", // Request does not have Origin header + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: false, + expectStatus: http.StatusOK, + }, + { + name: "non-preflight request, allow specific origin, different origin header = CORS logic failure", + originDomain: "http://bar.com", + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: false, + expectStatus: http.StatusOK, + }, + { + name: "non-preflight request, allow specific origin, matching origin header = CORS logic done", + originDomain: "http://example.com", + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: true, + expectStatus: http.StatusOK, + }, + { + name: "preflight, allow any origin, missing origin header = no CORS logic done", + originDomain: "", // Request does not have Origin header + allowedOrigin: "*", + method: http.MethodOptions, + expected: false, + expectStatus: http.StatusNoContent, + expectAllowHeader: "OPTIONS, GET, POST", + }, + { + name: "preflight, allow any origin, existing origin header = CORS logic done", + originDomain: "http://example.com", + allowedOrigin: "*", + method: http.MethodOptions, + expected: true, + expectStatus: http.StatusNoContent, + expectAllowHeader: "OPTIONS, GET, POST", + }, + { + name: "preflight, allow any origin, missing origin header = no CORS logic done", + originDomain: "", // Request does not have Origin header + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: false, + expectStatus: http.StatusNoContent, + expectAllowHeader: "OPTIONS, GET, POST", + }, + { + name: "preflight, allow specific origin, different origin header = no CORS logic done", + originDomain: "http://bar.com", + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: false, + expectStatus: http.StatusNoContent, + expectAllowHeader: "OPTIONS, GET, POST", + }, + { + name: "preflight, allow specific origin, matching origin header = CORS logic done", + originDomain: "http://example.com", + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: true, + expectStatus: http.StatusNoContent, + expectAllowHeader: "OPTIONS, GET, POST", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + e.Use(CORSWithConfig(CORSConfig{ + AllowOrigins: []string{tc.allowedOrigin}, + //AllowCredentials: true, + //MaxAge: 3600, + })) + + e.GET("/", func(c *echo.Context) error { + return c.String(http.StatusOK, "OK") + }) + e.POST("/", func(c *echo.Context) error { + return c.String(http.StatusCreated, "OK") + }) + + req := httptest.NewRequest(tc.method, "/", nil) + rec := httptest.NewRecorder() + + if tc.originDomain != "" { + req.Header.Set(echo.HeaderOrigin, tc.originDomain) + } + + // we run through whole Echo handler chain to see how CORS works with Router OPTIONS handler + e.ServeHTTP(rec, req) + + assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary)) + assert.Equal(t, tc.expectAllowHeader, rec.Header().Get(echo.HeaderAllow)) + assert.Equal(t, tc.expectStatus, rec.Code) + + expectedAllowOrigin := "" + if tc.allowedOrigin == "*" { + expectedAllowOrigin = "*" + } else { + expectedAllowOrigin = tc.originDomain + } + switch { + case tc.expected && tc.method == http.MethodOptions: + assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods) + assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + + assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary])) + + case tc.expected && tc.method == http.MethodGet: + assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin + default: + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin + } + }) + + } +} + +func Test_allowOriginFunc(t *testing.T) { + returnTrue := func(c *echo.Context, origin string) (string, bool, error) { + return origin, true, nil + } + returnFalse := func(c *echo.Context, origin string) (string, bool, error) { + return origin, false, nil + } + returnError := func(c *echo.Context, origin string) (string, bool, error) { + return origin, true, errors.New("this is a test error") + } + + allowOriginFuncs := []func(c *echo.Context, origin string) (string, bool, error){ + returnTrue, + returnFalse, + returnError, + } + + const origin = "http://example.com" + + e := echo.New() + for _, allowOriginFunc := range allowOriginFuncs { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + req.Header.Set(echo.HeaderOrigin, origin) + cors, err := CORSConfig{UnsafeAllowOriginFunc: allowOriginFunc}.ToMiddleware() + assert.NoError(t, err) + + h := cors(func(c *echo.Context) error { return echo.ErrNotFound }) + err = h(c) + + allowedOrigin, allowed, expectedErr := allowOriginFunc(c, origin) + if expectedErr != nil { + assert.Equal(t, expectedErr, err) + assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + continue + } + + if allowed { + assert.Equal(t, allowedOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } else { + assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } + } } diff --git a/middleware/csrf.go b/middleware/csrf.go index 09a66bb64..33757b760 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -1,91 +1,120 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( "crypto/subtle" - "errors" "net/http" + "slices" "strings" "time" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" + "github.com/labstack/echo/v5" ) -type ( - // CSRFConfig defines the config for CSRF middleware. - CSRFConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // TokenLength is the length of the generated token. - TokenLength uint8 `yaml:"token_length"` - // Optional. Default value 32. - - // TokenLookup is a string in the form of ":" that is used - // to extract token from the request. - // Optional. Default value "header:X-CSRF-Token". - // Possible values: - // - "header:" - // - "form:" - // - "query:" - TokenLookup string `yaml:"token_lookup"` - - // Context key to store generated CSRF token into context. - // Optional. Default value "csrf". - ContextKey string `yaml:"context_key"` - - // Name of the CSRF cookie. This cookie will store CSRF token. - // Optional. Default value "csrf". - CookieName string `yaml:"cookie_name"` - - // Domain of the CSRF cookie. - // Optional. Default value none. - CookieDomain string `yaml:"cookie_domain"` - - // Path of the CSRF cookie. - // Optional. Default value none. - CookiePath string `yaml:"cookie_path"` - - // Max age (in seconds) of the CSRF cookie. - // Optional. Default value 86400 (24hr). - CookieMaxAge int `yaml:"cookie_max_age"` - - // Indicates if CSRF cookie is secure. - // Optional. Default value false. - CookieSecure bool `yaml:"cookie_secure"` - - // Indicates if CSRF cookie is HTTP only. - // Optional. Default value false. - CookieHTTPOnly bool `yaml:"cookie_http_only"` - } - - // csrfTokenExtractor defines a function that takes `echo.Context` and returns - // either a token or an error. - csrfTokenExtractor func(echo.Context) (string, error) -) +// CSRFConfig defines the config for CSRF middleware. +type CSRFConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + // TrustedOrigin permits any request with `Sec-Fetch-Site` header whose `Origin` header + // exactly matches the specified value. + // Values should be formated as Origin header "scheme://host[:port]". + // + // See [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin + // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers + TrustedOrigins []string -var ( - // DefaultCSRFConfig is the default CSRF middleware config. - DefaultCSRFConfig = CSRFConfig{ - Skipper: DefaultSkipper, - TokenLength: 32, - TokenLookup: "header:" + echo.HeaderXCSRFToken, - ContextKey: "csrf", - CookieName: "_csrf", - CookieMaxAge: 86400, - } -) + // AllowSecFetchSameSite allows custom behaviour for `Sec-Fetch-Site` requests that are about to + // fail with CRSF error, to be allowed or replaced with custom error. + // This function applies to `Sec-Fetch-Site` values: + // - `same-site` same registrable domain (subdomain and/or different port) + // - `cross-site` request originates from different site + // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers + AllowSecFetchSiteFunc func(c *echo.Context) (bool, error) + + // TokenLength is the length of the generated token. + TokenLength uint8 + // Optional. Default value 32. + + // TokenLookup is a string in the form of ":" or ":,:" that is used + // to extract token from the request. + // Optional. Default value "header:X-CSRF-Token". + // Possible values: + // - "header:" or "header::" + // - "query:" + // - "form:" + // Multiple sources example: + // - "header:X-CSRF-Token,query:csrf" + TokenLookup string `yaml:"token_lookup"` + + // Generator defines a function to generate token. + // Optional. Defaults tp randomString(TokenLength). + Generator func() string + + // Context key to store generated CSRF token into context. + // Optional. Default value "csrf". + ContextKey string + + // Name of the CSRF cookie. This cookie will store CSRF token. + // Optional. Default value "csrf". + CookieName string + + // Domain of the CSRF cookie. + // Optional. Default value none. + CookieDomain string + + // Path of the CSRF cookie. + // Optional. Default value none. + CookiePath string + + // Max age (in seconds) of the CSRF cookie. + // Optional. Default value 86400 (24hr). + CookieMaxAge int + + // Indicates if CSRF cookie is secure. + // Optional. Default value false. + CookieSecure bool + + // Indicates if CSRF cookie is HTTP only. + // Optional. Default value false. + CookieHTTPOnly bool + + // Indicates SameSite mode of the CSRF cookie. + // Optional. Default value SameSiteDefaultMode. + CookieSameSite http.SameSite + + // ErrorHandler defines a function which is executed for returning custom errors. + ErrorHandler func(c *echo.Context, err error) error +} + +// ErrCSRFInvalid is returned when CSRF check fails +var ErrCSRFInvalid = &echo.HTTPError{Code: http.StatusForbidden, Message: "invalid csrf token"} + +// DefaultCSRFConfig is the default CSRF middleware config. +var DefaultCSRFConfig = CSRFConfig{ + Skipper: DefaultSkipper, + TokenLength: 32, + TokenLookup: "header:" + echo.HeaderXCSRFToken, + ContextKey: "csrf", + CookieName: "_csrf", + CookieMaxAge: 86400, + CookieSameSite: http.SameSiteDefaultMode, +} // CSRF returns a Cross-Site Request Forgery (CSRF) middleware. // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery func CSRF() echo.MiddlewareFunc { - c := DefaultCSRFConfig - return CSRFWithConfig(c) + return CSRFWithConfig(DefaultCSRFConfig) } -// CSRFWithConfig returns a CSRF middleware with config. -// See `CSRF()`. +// CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration. func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration +func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultCSRFConfig.Skipper @@ -93,6 +122,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.TokenLength == 0 { config.TokenLength = DefaultCSRFConfig.TokenLength } + if config.Generator == nil { + config.Generator = createRandomStringGenerator(config.TokenLength) + } if config.TokenLookup == "" { config.TokenLookup = DefaultCSRFConfig.TokenLookup } @@ -105,45 +137,79 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.CookieMaxAge == 0 { config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge } + if config.CookieSameSite == http.SameSiteNoneMode { + config.CookieSecure = true + } + if len(config.TrustedOrigins) > 0 { + if err := validateOrigins(config.TrustedOrigins, "trusted origin"); err != nil { + return nil, err + } + config.TrustedOrigins = append([]string(nil), config.TrustedOrigins...) + } - // Initialize - parts := strings.Split(config.TokenLookup, ":") - extractor := csrfTokenFromHeader(parts[1]) - switch parts[0] { - case "form": - extractor = csrfTokenFromForm(parts[1]) - case "query": - extractor = csrfTokenFromQuery(parts[1]) + extractors, cErr := createExtractors(config.TokenLookup, 1) + if cErr != nil { + return nil, cErr } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } - req := c.Request() - k, err := c.Cookie(config.CookieName) - token := "" - - // Generate token + // use the `Sec-Fetch-Site` header as part of a modern approach to CSRF protection + allow, err := config.checkSecFetchSiteRequest(c) if err != nil { - token = random.String(config.TokenLength) + return err + } + if allow { + return next(c) + } + + // Fallback to legacy token based CSRF protection + + token := "" + if k, err := c.Cookie(config.CookieName); err != nil { + token = config.Generator() // Generate token } else { - // Reuse token - token = k.Value + token = k.Value // Reuse token } - switch req.Method { + switch c.Request().Method { case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: default: // Validate token only for requests which are not defined as 'safe' by RFC7231 - clientToken, err := extractor(c) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + var lastExtractorErr error + var lastTokenErr error + outer: + for _, extractor := range extractors { + clientTokens, _, err := extractor(c) + if err != nil { + lastExtractorErr = err + continue + } + + for _, clientToken := range clientTokens { + if validateCSRFToken(token, clientToken) { + lastTokenErr = nil + lastExtractorErr = nil + break outer + } + lastTokenErr = ErrCSRFInvalid + } } - if !validateCSRFToken(token, clientToken) { - return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") + var finalErr error + if lastTokenErr != nil { + finalErr = lastTokenErr + } else if lastExtractorErr != nil { + finalErr = echo.ErrBadRequest.Wrap(lastExtractorErr) + } + if finalErr != nil { + if config.ErrorHandler != nil { + return config.ErrorHandler(c, finalErr) + } + return finalErr } } @@ -157,6 +223,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.CookieDomain != "" { cookie.Domain = config.CookieDomain } + if config.CookieSameSite != http.SameSiteDefaultMode { + cookie.SameSite = config.CookieSameSite + } cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second) cookie.Secure = config.CookieSecure cookie.HttpOnly = config.CookieHTTPOnly @@ -170,41 +239,55 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return next(c) } - } + }, nil } -// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the -// provided request header. -func csrfTokenFromHeader(header string) csrfTokenExtractor { - return func(c echo.Context) (string, error) { - return c.Request().Header.Get(header), nil - } +func validateCSRFToken(token, clientToken string) bool { + return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1 } -// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the -// provided form parameter. -func csrfTokenFromForm(param string) csrfTokenExtractor { - return func(c echo.Context) (string, error) { - token := c.FormValue(param) - if token == "" { - return "", errors.New("missing csrf token in the form parameter") - } - return token, nil +var safeMethods = []string{http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace} + +func (config CSRFConfig) checkSecFetchSiteRequest(c *echo.Context) (bool, error) { + // https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers + // Sec-Fetch-Site values are: + // - `same-origin` exact origin match - allow always + // - `same-site` same registrable domain (subdomain and/or different port) - block, unless explicitly trusted + // - `cross-site` request originates from different site - block, unless explicitly trusted + // - `none` direct navigation (URL bar, bookmark) - allow always + secFetchSite := c.Request().Header.Get(echo.HeaderSecFetchSite) + if secFetchSite == "" { + return false, nil } -} -// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the -// provided query parameter. -func csrfTokenFromQuery(param string) csrfTokenExtractor { - return func(c echo.Context) (string, error) { - token := c.QueryParam(param) - if token == "" { - return "", errors.New("missing csrf token in the query string") + if len(config.TrustedOrigins) > 0 { + // trusted sites ala OAuth callbacks etc. should be let through + origin := c.Request().Header.Get(echo.HeaderOrigin) + if origin != "" { + for _, trustedOrigin := range config.TrustedOrigins { + if strings.EqualFold(origin, trustedOrigin) { + return true, nil + } + } } - return token, nil } -} + isSafe := slices.Contains(safeMethods, c.Request().Method) + if !isSafe { // for state-changing request check SecFetchSite value + isSafe = secFetchSite == "same-origin" || secFetchSite == "none" + } -func validateCSRFToken(token, clientToken string) bool { - return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1 + if isSafe { + return true, nil + } + // we are here when request is state-changing and `cross-site` or `same-site` + + // Note: if you want to allow `same-site` use config.TrustedOrigins or `config.AllowSecFetchSiteFunc` + if config.AllowSecFetchSiteFunc != nil { + return config.AllowSecFetchSiteFunc(c) + } + + if secFetchSite == "same-site" { + return false, echo.NewHTTPError(http.StatusForbidden, "same-site request blocked by CSRF") + } + return false, echo.NewHTTPError(http.StatusForbidden, "cross-site request blocked by CSRF") } diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index efb4dd1d2..ddecc10e3 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -1,26 +1,358 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( + "cmp" "net/http" "net/http/httptest" "net/url" "strings" "testing" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) +func TestCSRF_tokenExtractors(t *testing.T) { + var testCases = []struct { + name string + whenTokenLookup string + whenCookieName string + givenCSRFCookie string + givenMethod string + givenQueryTokens map[string][]string + givenFormTokens map[string][]string + givenHeaderTokens map[string][]string + expectError string + expectToMiddlewareError string + }{ + { + name: "ok, multiple token lookups sources, succeeds on last one", + whenTokenLookup: "header:X-CSRF-Token,form:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenHeaderTokens: map[string][]string{ + echo.HeaderXCSRFToken: {"invalid_token"}, + }, + givenFormTokens: map[string][]string{ + "csrf": {"token"}, + }, + }, + { + name: "ok, token from POST form", + whenTokenLookup: "form:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenFormTokens: map[string][]string{ + "csrf": {"token"}, + }, + }, + { + name: "ok, token from POST form, second token passes", + whenTokenLookup: "form:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenFormTokens: map[string][]string{ + "csrf": {"invalid", "token"}, + }, + expectError: "code=403, message=invalid csrf token", + }, + { + name: "nok, invalid token from POST form", + whenTokenLookup: "form:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenFormTokens: map[string][]string{ + "csrf": {"invalid_token"}, + }, + expectError: "code=403, message=invalid csrf token", + }, + { + name: "nok, missing token from POST form", + whenTokenLookup: "form:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenFormTokens: map[string][]string{}, + expectError: "code=400, message=Bad Request, err=missing value in the form", + }, + { + name: "ok, token from POST header", + whenTokenLookup: "", // will use defaults + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenHeaderTokens: map[string][]string{ + echo.HeaderXCSRFToken: {"token"}, + }, + }, + { + name: "nok, token from POST header, tokens limited to 1, second token would pass", + whenTokenLookup: "header:" + echo.HeaderXCSRFToken, + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenHeaderTokens: map[string][]string{ + echo.HeaderXCSRFToken: {"invalid", "token"}, + }, + expectError: "code=403, message=invalid csrf token", + }, + { + name: "nok, invalid token from POST header", + whenTokenLookup: "header:" + echo.HeaderXCSRFToken, + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenHeaderTokens: map[string][]string{ + echo.HeaderXCSRFToken: {"invalid_token"}, + }, + expectError: "code=403, message=invalid csrf token", + }, + { + name: "nok, missing token from POST header", + whenTokenLookup: "header:" + echo.HeaderXCSRFToken, + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenHeaderTokens: map[string][]string{}, + expectError: "code=400, message=Bad Request, err=missing value in request header", + }, + { + name: "ok, token from PUT query param", + whenTokenLookup: "query:csrf-param", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{ + "csrf-param": {"token"}, + }, + }, + { + name: "nok, token from PUT query form, second token would pass", + whenTokenLookup: "query:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{ + "csrf": {"invalid", "token"}, + }, + expectError: "code=403, message=invalid csrf token", + }, + { + name: "nok, invalid token from PUT query form", + whenTokenLookup: "query:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{ + "csrf": {"invalid_token"}, + }, + expectError: "code=403, message=invalid csrf token", + }, + { + name: "nok, missing token from PUT query form", + whenTokenLookup: "query:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{}, + expectError: "code=400, message=Bad Request, err=missing value in the query string", + }, + { + name: "nok, invalid TokenLookup", + whenTokenLookup: "q", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{}, + expectToMiddlewareError: "extractor source for lookup could not be split into needed parts: q", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + q := make(url.Values) + for queryParam, values := range tc.givenQueryTokens { + for _, v := range values { + q.Add(queryParam, v) + } + } + + f := make(url.Values) + for formKey, values := range tc.givenFormTokens { + for _, v := range values { + f.Add(formKey, v) + } + } + + var req *http.Request + switch tc.givenMethod { + case http.MethodGet: + req = httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil) + case http.MethodPost, http.MethodPut: + req = httptest.NewRequest(http.MethodPost, "/?"+q.Encode(), strings.NewReader(f.Encode())) + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + } + + for header, values := range tc.givenHeaderTokens { + for _, v := range values { + req.Header.Add(header, v) + } + } + + if tc.givenCSRFCookie != "" { + req.Header.Set(echo.HeaderCookie, "_csrf="+tc.givenCSRFCookie) + } + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + config := CSRFConfig{ + TokenLookup: tc.whenTokenLookup, + CookieName: tc.whenCookieName, + } + csrf, err := config.ToMiddleware() + if tc.expectToMiddlewareError != "" { + assert.EqualError(t, err, tc.expectToMiddlewareError) + return + } else if err != nil { + assert.NoError(t, err) + } + + h := csrf(func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + err = h(c) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestCSRFWithConfig(t *testing.T) { + token := randomString(16) + + var testCases = []struct { + name string + givenConfig *CSRFConfig + whenMethod string + whenHeaders map[string]string + expectEmptyBody bool + expectMWError string + expectCookieContains string + expectErr string + }{ + { + name: "ok, GET", + whenMethod: http.MethodGet, + expectCookieContains: "_csrf", + }, + { + name: "ok, POST valid token", + whenHeaders: map[string]string{ + echo.HeaderCookie: "_csrf=" + token, + echo.HeaderXCSRFToken: token, + }, + whenMethod: http.MethodPost, + expectCookieContains: "_csrf", + }, + { + name: "nok, POST without token", + whenMethod: http.MethodPost, + expectEmptyBody: true, + expectErr: `code=400, message=Bad Request, err=missing value in request header`, + }, + { + name: "nok, POST empty token", + whenHeaders: map[string]string{echo.HeaderXCSRFToken: ""}, + whenMethod: http.MethodPost, + expectEmptyBody: true, + expectErr: `code=403, message=invalid csrf token`, + }, + { + name: "nok, invalid trusted origin in Config", + givenConfig: &CSRFConfig{ + TrustedOrigins: []string{"http://example.com", "invalid"}, + }, + expectMWError: `trusted origin is missing scheme or host: invalid`, + }, + { + name: "ok, TokenLength", + givenConfig: &CSRFConfig{ + TokenLength: 16, + }, + whenMethod: http.MethodGet, + expectCookieContains: "_csrf", + }, + { + name: "ok, unsafe method + SecFetchSite=same-origin passes", + whenHeaders: map[string]string{ + echo.HeaderSecFetchSite: "same-origin", + }, + whenMethod: http.MethodPost, + }, + { + name: "nok, unsafe method + SecFetchSite=same-cross blocked", + whenHeaders: map[string]string{ + echo.HeaderSecFetchSite: "same-cross", + }, + whenMethod: http.MethodPost, + expectEmptyBody: true, + expectErr: `code=403, message=cross-site request blocked by CSRF`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(cmp.Or(tc.whenMethod, http.MethodPost), "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + for key, value := range tc.whenHeaders { + req.Header.Set(key, value) + } + + config := CSRFConfig{} + if tc.givenConfig != nil { + config = *tc.givenConfig + } + mw, err := config.ToMiddleware() + if tc.expectMWError != "" { + assert.EqualError(t, err, tc.expectMWError) + return + } + assert.NoError(t, err) + + h := mw(func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + err = h(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + + expect := "test" + if tc.expectEmptyBody { + expect = "" + } + assert.Equal(t, expect, rec.Body.String()) + + if tc.expectCookieContains != "" { + assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), tc.expectCookieContains) + } + }) + } +} + func TestCSRF(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - csrf := CSRFWithConfig(CSRFConfig{ - TokenLength: 16, - }) - h := csrf(func(c echo.Context) error { + csrf := CSRF() + h := csrf(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -28,56 +360,495 @@ func TestCSRF(t *testing.T) { h(c) assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf") - // Without CSRF cookie - req = httptest.NewRequest(http.MethodPost, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - assert.Error(t, h(c)) - - // Empty/invalid CSRF token - req = httptest.NewRequest(http.MethodPost, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - req.Header.Set(echo.HeaderXCSRFToken, "") - assert.Error(t, h(c)) - - // Valid CSRF token - token := random.String(16) - req.Header.Set(echo.HeaderCookie, "_csrf="+token) - req.Header.Set(echo.HeaderXCSRFToken, token) - if assert.NoError(t, h(c)) { - assert.Equal(t, http.StatusOK, rec.Code) - } } -func TestCSRFTokenFromForm(t *testing.T) { - f := make(url.Values) - f.Set("csrf", "token") +func TestCSRFSetSameSiteMode(t *testing.T) { e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) - req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) - c := e.NewContext(req, nil) - token, err := csrfTokenFromForm("csrf")(c) - if assert.NoError(t, err) { - assert.Equal(t, "token", token) - } - _, err = csrfTokenFromForm("invalid")(c) - assert.Error(t, err) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + csrf := CSRFWithConfig(CSRFConfig{ + CookieSameSite: http.SameSiteStrictMode, + }) + + h := csrf(func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + r := h(c) + assert.NoError(t, r) + assert.Regexp(t, "SameSite=Strict", rec.Header()["Set-Cookie"]) } -func TestCSRFTokenFromQuery(t *testing.T) { - q := make(url.Values) - q.Set("csrf", "token") +func TestCSRFWithoutSameSiteMode(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) - req.URL.RawQuery = q.Encode() - c := e.NewContext(req, nil) - token, err := csrfTokenFromQuery("csrf")(c) - if assert.NoError(t, err) { - assert.Equal(t, "token", token) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + csrf := CSRFWithConfig(CSRFConfig{}) + + h := csrf(func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + r := h(c) + assert.NoError(t, r) + assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"]) +} + +func TestCSRFWithSameSiteDefaultMode(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + csrf := CSRFWithConfig(CSRFConfig{ + CookieSameSite: http.SameSiteDefaultMode, + }) + + h := csrf(func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + r := h(c) + assert.NoError(t, r) + assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"]) +} + +func TestCSRFWithSameSiteModeNone(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + csrf, err := CSRFConfig{ + CookieSameSite: http.SameSiteNoneMode, + }.ToMiddleware() + assert.NoError(t, err) + + h := csrf(func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + r := h(c) + assert.NoError(t, r) + assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"]) + assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"]) +} + +func TestCSRFConfig_skipper(t *testing.T) { + var testCases = []struct { + name string + whenSkip bool + expectCookies int + }{ + { + name: "do skip", + whenSkip: true, + expectCookies: 0, + }, + { + name: "do not skip", + whenSkip: false, + expectCookies: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + csrf := CSRFWithConfig(CSRFConfig{ + Skipper: func(c *echo.Context) bool { + return tc.whenSkip + }, + }) + + h := csrf(func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + r := h(c) + assert.NoError(t, r) + cookie := rec.Header()["Set-Cookie"] + assert.Len(t, cookie, tc.expectCookies) + }) + } +} + +func TestCSRFErrorHandling(t *testing.T) { + cfg := CSRFConfig{ + ErrorHandler: func(c *echo.Context, err error) error { + return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed") + }, + } + + e := echo.New() + e.POST("/", func(c *echo.Context) error { + return c.String(http.StatusNotImplemented, "should not end up here") + }) + + e.Use(CSRFWithConfig(cfg)) + + req := httptest.NewRequest(http.MethodPost, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.Equal(t, "{\"message\":\"error_handler_executed\"}\n", res.Body.String()) +} + +func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { + var testCases = []struct { + name string + givenConfig CSRFConfig + whenMethod string + whenSecFetchSite string + whenOrigin string + expectAllow bool + expectErr string + }{ + { + name: "ok, unsafe POST, no SecFetchSite is not blocked", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPost, + whenSecFetchSite: "", + expectAllow: false, // should fall back to token CSRF + }, + { + name: "ok, safe GET + same-origin passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodGet, + whenSecFetchSite: "same-origin", + expectAllow: true, + }, + { + name: "ok, safe GET + none passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodGet, + whenSecFetchSite: "none", + expectAllow: true, + }, + { + name: "ok, safe GET + same-site passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodGet, + whenSecFetchSite: "same-site", + expectAllow: true, + }, + { + name: "ok, safe GET + cross-site passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodGet, + whenSecFetchSite: "cross-site", + expectAllow: true, + }, + { + name: "nok, unsafe POST + cross-site is blocked", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + expectAllow: false, + expectErr: `code=403, message=cross-site request blocked by CSRF`, + }, + { + name: "nok, unsafe POST + same-site is blocked", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPost, + whenSecFetchSite: "same-site", + expectAllow: false, + expectErr: `code=403, message=same-site request blocked by CSRF`, + }, + { + name: "ok, unsafe POST + same-origin passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPost, + whenSecFetchSite: "same-origin", + expectAllow: true, + }, + { + name: "ok, unsafe POST + none passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPost, + whenSecFetchSite: "none", + expectAllow: true, + }, + { + name: "ok, unsafe PUT + same-origin passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPut, + whenSecFetchSite: "same-origin", + expectAllow: true, + }, + { + name: "ok, unsafe PUT + none passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPut, + whenSecFetchSite: "none", + expectAllow: true, + }, + { + name: "ok, unsafe DELETE + same-origin passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodDelete, + whenSecFetchSite: "same-origin", + expectAllow: true, + }, + { + name: "ok, unsafe PATCH + same-origin passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPatch, + whenSecFetchSite: "same-origin", + expectAllow: true, + }, + { + name: "nok, unsafe PUT + cross-site is blocked", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPut, + whenSecFetchSite: "cross-site", + expectAllow: false, + expectErr: `code=403, message=cross-site request blocked by CSRF`, + }, + { + name: "nok, unsafe PUT + same-site is blocked", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPut, + whenSecFetchSite: "same-site", + expectAllow: false, + expectErr: `code=403, message=same-site request blocked by CSRF`, + }, + { + name: "nok, unsafe DELETE + cross-site is blocked", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodDelete, + whenSecFetchSite: "cross-site", + expectAllow: false, + expectErr: `code=403, message=cross-site request blocked by CSRF`, + }, + { + name: "nok, unsafe DELETE + same-site is blocked", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodDelete, + whenSecFetchSite: "same-site", + expectAllow: false, + expectErr: `code=403, message=same-site request blocked by CSRF`, + }, + { + name: "nok, unsafe PATCH + cross-site is blocked", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPatch, + whenSecFetchSite: "cross-site", + expectAllow: false, + expectErr: `code=403, message=cross-site request blocked by CSRF`, + }, + { + name: "ok, safe HEAD + same-origin passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodHead, + whenSecFetchSite: "same-origin", + expectAllow: true, + }, + { + name: "ok, safe HEAD + cross-site passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodHead, + whenSecFetchSite: "cross-site", + expectAllow: true, + }, + { + name: "ok, safe OPTIONS + cross-site passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodOptions, + whenSecFetchSite: "cross-site", + expectAllow: true, + }, + { + name: "ok, safe TRACE + cross-site passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodTrace, + whenSecFetchSite: "cross-site", + expectAllow: true, + }, + { + name: "ok, unsafe POST + cross-site + matching trusted origin passes", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://trusted.example.com"}, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + whenOrigin: "https://trusted.example.com", + expectAllow: true, + }, + { + name: "ok, unsafe POST + same-site + matching trusted origin passes", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://trusted.example.com"}, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "same-site", + whenOrigin: "https://trusted.example.com", + expectAllow: true, + }, + { + name: "nok, unsafe POST + cross-site + non-matching origin is blocked", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://trusted.example.com"}, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + whenOrigin: "https://evil.example.com", + expectAllow: false, + expectErr: `code=403, message=cross-site request blocked by CSRF`, + }, + { + name: "ok, unsafe POST + cross-site + case-insensitive trusted origin match passes", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://trusted.example.com"}, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + whenOrigin: "https://TRUSTED.example.com", + expectAllow: true, + }, + { + name: "ok, unsafe POST + same-origin + trusted origins configured but not matched passes", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://trusted.example.com"}, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "same-origin", + whenOrigin: "https://different.example.com", + expectAllow: true, + }, + { + name: "nok, unsafe POST + cross-site + empty origin + trusted origins configured is blocked", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://trusted.example.com"}, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + whenOrigin: "", + expectAllow: false, + expectErr: `code=403, message=cross-site request blocked by CSRF`, + }, + { + name: "ok, unsafe POST + cross-site + multiple trusted origins, second one matches", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://first.example.com", "https://second.example.com"}, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + whenOrigin: "https://second.example.com", + expectAllow: true, + }, + { + name: "ok, unsafe POST + same-site + custom func allows", + givenConfig: CSRFConfig{ + AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { + return true, nil + }, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "same-site", + expectAllow: true, + }, + { + name: "ok, unsafe POST + cross-site + custom func allows", + givenConfig: CSRFConfig{ + AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { + return true, nil + }, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + expectAllow: true, + }, + { + name: "nok, unsafe POST + same-site + custom func returns custom error", + givenConfig: CSRFConfig{ + AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { + return false, echo.NewHTTPError(http.StatusTeapot, "custom error from func") + }, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "same-site", + expectAllow: false, + expectErr: `code=418, message=custom error from func`, + }, + { + name: "nok, unsafe POST + cross-site + custom func returns false with nil error", + givenConfig: CSRFConfig{ + AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { + return false, nil + }, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + expectAllow: false, + expectErr: "", // custom func returns nil error, so no error expected + }, + { + name: "nok, unsafe POST + invalid Sec-Fetch-Site value treated as cross-site", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPost, + whenSecFetchSite: "invalid-value", + expectAllow: false, + expectErr: `code=403, message=cross-site request blocked by CSRF`, + }, + { + name: "ok, unsafe POST + cross-site + trusted origin takes precedence over custom func", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://trusted.example.com"}, + AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { + return false, echo.NewHTTPError(http.StatusTeapot, "should not be called") + }, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + whenOrigin: "https://trusted.example.com", + expectAllow: true, + }, + { + name: "nok, unsafe POST + cross-site + trusted origin not matched, custom func blocks", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://trusted.example.com"}, + AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { + return false, echo.NewHTTPError(http.StatusTeapot, "custom block") + }, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + whenOrigin: "https://evil.example.com", + expectAllow: false, + expectErr: `code=418, message=custom block`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(tc.whenMethod, "/", nil) + if tc.whenSecFetchSite != "" { + req.Header.Set(echo.HeaderSecFetchSite, tc.whenSecFetchSite) + } + if tc.whenOrigin != "" { + req.Header.Set(echo.HeaderOrigin, tc.whenOrigin) + } + + res := httptest.NewRecorder() + c := echo.NewContext(req, res) + + allow, err := tc.givenConfig.checkSecFetchSiteRequest(c) + + assert.Equal(t, tc.expectAllow, allow) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + }) } - _, err = csrfTokenFromQuery("invalid")(c) - assert.Error(t, err) - csrfTokenFromQuery("csrf") } diff --git a/middleware/decompress.go b/middleware/decompress.go new file mode 100644 index 000000000..a384af2ea --- /dev/null +++ b/middleware/decompress.go @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package middleware + +import ( + "compress/gzip" + "io" + "net/http" + "sync" + + "github.com/labstack/echo/v5" +) + +// DecompressConfig defines the config for Decompress middleware. +type DecompressConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers + GzipDecompressPool Decompressor + + // MaxDecompressedSize limits the maximum size of decompressed request body in bytes. + // If the decompressed body exceeds this limit, the middleware returns HTTP 413 error. + // This prevents zip bomb attacks where small compressed payloads decompress to huge sizes. + // Default: 100 * MB (104,857,600 bytes) + // Set to -1 to disable limits (not recommended in production). + MaxDecompressedSize int64 +} + +// GZIPEncoding content-encoding header if set to "gzip", decompress body contents. +const GZIPEncoding string = "gzip" + +// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers +type Decompressor interface { + gzipDecompressPool() sync.Pool +} + +// DefaultGzipDecompressPool is the default implementation of Decompressor interface +type DefaultGzipDecompressPool struct { +} + +func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { + return sync.Pool{New: func() any { return new(gzip.Reader) }} +} + +// Decompress decompresses request body based if content encoding type is set to "gzip" with default config +// +// SECURITY: By default, this limits decompressed data to 100MB to prevent zip bomb attacks. +// To customize the limit, use DecompressWithConfig. To disable limits (not recommended in production), +// set MaxDecompressedSize to -1. +func Decompress() echo.MiddlewareFunc { + return DecompressWithConfig(DecompressConfig{}) +} + +// DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration. +// +// SECURITY: If MaxDecompressedSize is not set (zero value), it defaults to 100MB to prevent +// DoS attacks via zip bombs. Set to -1 to explicitly disable limits if needed for your use case. +func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration +func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) { + if config.Skipper == nil { + config.Skipper = DefaultSkipper + } + if config.GzipDecompressPool == nil { + config.GzipDecompressPool = &DefaultGzipDecompressPool{} + } + // Apply secure default for decompression limit + if config.MaxDecompressedSize == 0 { + config.MaxDecompressedSize = 100 * MB + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + pool := config.GzipDecompressPool.gzipDecompressPool() + + return func(c *echo.Context) error { + if config.Skipper(c) { + return next(c) + } + + if c.Request().Header.Get(echo.HeaderContentEncoding) != GZIPEncoding { + return next(c) + } + + i := pool.Get() + gr, ok := i.(*gzip.Reader) + if !ok || gr == nil { + if err, isErr := i.(error); isErr { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return echo.NewHTTPError(http.StatusInternalServerError, "unexpected type from gzip decompression pool") + } + defer pool.Put(gr) + + b := c.Request().Body + defer b.Close() + + if err := gr.Reset(b); err != nil { + if err == io.EOF { //ignore if body is empty + return next(c) + } + return err + } + + // only Close gzip reader if it was set to a proper gzip source otherwise it will panic on close. + defer gr.Close() + + // Apply decompression size limit to prevent zip bombs + if config.MaxDecompressedSize > 0 { + c.Request().Body = &limitedGzipReader{ + Reader: gr, + remaining: config.MaxDecompressedSize, + limit: config.MaxDecompressedSize, + } + } else { + // -1 means explicitly unlimited (not recommended) + c.Request().Body = gr + } + + return next(c) + } + }, nil +} + +// limitedGzipReader wraps a gzip reader with size limiting to prevent zip bombs +type limitedGzipReader struct { + *gzip.Reader + remaining int64 + limit int64 +} + +func (r *limitedGzipReader) Read(p []byte) (n int, err error) { + if r.remaining <= 0 { + // Limit exceeded - return 413 error + return 0, echo.ErrStatusRequestEntityTooLarge + } + + // Limit the read to remaining bytes + if int64(len(p)) > r.remaining { + p = p[:r.remaining] + } + + n, err = r.Reader.Read(p) + r.remaining -= int64(n) + + return n, err +} + +func (r *limitedGzipReader) Close() error { + return r.Reader.Close() +} diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go new file mode 100644 index 000000000..1823e94bb --- /dev/null +++ b/middleware/decompress_test.go @@ -0,0 +1,508 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package middleware + +import ( + "bytes" + "compress/gzip" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + + "github.com/labstack/echo/v5" + "github.com/stretchr/testify/assert" +) + +func TestDecompress(t *testing.T) { + e := echo.New() + + h := Decompress()(func(c *echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + + // Decompress request body + body := `{"name": "echo"}` + gz, _ := gzipString(body) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := io.ReadAll(req.Body) + assert.NoError(t, err) + assert.Equal(t, body, string(b)) +} + +func TestDecompress_skippedIfNoHeader(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Skip if no Content-Encoding header + h := Decompress()(func(c *echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + + err := h(c) + assert.NoError(t, err) + assert.Equal(t, "test", rec.Body.String()) + +} + +func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{}.ToMiddleware() + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + })(c) + assert.NoError(t, err) + + assert.Equal(t, "test", rec.Body.String()) + +} + +func TestDecompressWithConfig_DefaultConfig(t *testing.T) { + e := echo.New() + + h := Decompress()(func(c *echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + + // Decompress + body := `{"name": "echo"}` + gz, _ := gzipString(body) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := io.ReadAll(req.Body) + assert.NoError(t, err) + assert.Equal(t, body, string(b)) +} + +func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { + e := echo.New() + body := `{"name":"echo"}` + gz, _ := gzipString(body) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + e.NewContext(req, rec) + + e.ServeHTTP(rec, req) + + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + b, err := io.ReadAll(req.Body) + assert.NoError(t, err) + assert.NotEqual(t, b, body) + assert.Equal(t, b, gz) +} + +func TestDecompressNoContent(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := Decompress()(func(c *echo.Context) error { + return c.NoContent(http.StatusNoContent) + }) + + err := h(c) + + if assert.NoError(t, err) { + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) + assert.Equal(t, 0, len(rec.Body.Bytes())) + } +} + +func TestDecompressErrorReturned(t *testing.T) { + e := echo.New() + e.Use(Decompress()) + e.GET("/", func(c *echo.Context) error { + return echo.ErrNotFound + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusNotFound, rec.Code) + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) +} + +func TestDecompressSkipper(t *testing.T) { + e := echo.New() + e.Use(DecompressWithConfig(DecompressConfig{ + Skipper: func(c *echo.Context) bool { + return c.Request().URL.Path == "/skip" + }, + })) + body := `{"name": "echo"}` + req := httptest.NewRequest(http.MethodPost, "/skip", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + e.ServeHTTP(rec, req) + + assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSON) + reqBody, err := io.ReadAll(c.Request().Body) + assert.NoError(t, err) + assert.Equal(t, body, string(reqBody)) +} + +type TestDecompressPoolWithError struct { +} + +func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool { + return sync.Pool{ + New: func() any { + return errors.New("pool error") + }, + } +} + +func TestDecompressPoolError(t *testing.T) { + e := echo.New() + e.Use(DecompressWithConfig(DecompressConfig{ + Skipper: DefaultSkipper, + GzipDecompressPool: &TestDecompressPoolWithError{}, + })) + body := `{"name": "echo"}` + req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + e.ServeHTTP(rec, req) + + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) + reqBody, err := io.ReadAll(c.Request().Body) + assert.NoError(t, err) + assert.Equal(t, body, string(reqBody)) + assert.Equal(t, rec.Code, http.StatusInternalServerError) +} + +func BenchmarkDecompress(b *testing.B) { + e := echo.New() + body := `{"name": "echo"}` + gz, _ := gzipString(body) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + + h := Decompress()(func(c *echo.Context) error { + c.Response().Write([]byte(body)) // For Content-Type sniffing + return nil + }) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // Decompress + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h(c) + } +} + +func gzipString(body string) ([]byte, error) { + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + + _, err := gz.Write([]byte(body)) + if err != nil { + return nil, err + } + + if err := gz.Close(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +func TestDecompress_WithinLimit(t *testing.T) { + e := echo.New() + body := strings.Repeat("test data ", 100) // Small payload ~1KB + gz, _ := gzipString(body) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{MaxDecompressedSize: 100 * MB}.ToMiddleware() + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + b, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(b)) + })(c) + + assert.NoError(t, err) + assert.Equal(t, body, rec.Body.String()) +} + +func TestDecompress_ExceedsLimit(t *testing.T) { + e := echo.New() + // Create 2KB of data but limit to 1KB + largeBody := strings.Repeat("A", 2*1024) + gz, _ := gzipString(largeBody) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + _, readErr := io.ReadAll(c.Request().Body) + return readErr + })(c) + + // Should return 413 error + assert.Error(t, err) + he, ok := err.(echo.HTTPStatusCoder) + assert.True(t, ok) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) +} + +func TestDecompress_AtExactLimit(t *testing.T) { + e := echo.New() + exactBody := strings.Repeat("B", 1024) // Exactly 1KB + gz, _ := gzipString(exactBody) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + b, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(b)) + })(c) + + assert.NoError(t, err) + assert.Equal(t, exactBody, rec.Body.String()) +} + +func TestDecompress_ZipBomb(t *testing.T) { + e := echo.New() + // Create highly compressed data that expands to 2MB + // but limit is 1MB + largeBody := bytes.Repeat([]byte("A"), 2*1024*1024) // 2MB + var buf bytes.Buffer + gzWriter := gzip.NewWriter(&buf) + gzWriter.Write(largeBody) + gzWriter.Close() + + req := httptest.NewRequest(http.MethodPost, "/", &buf) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{MaxDecompressedSize: 1 * MB}.ToMiddleware() + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + _, readErr := io.ReadAll(c.Request().Body) + return readErr + })(c) + + // Should return 413 error + assert.Error(t, err) + he, ok := err.(echo.HTTPStatusCoder) + assert.True(t, ok) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) +} + +func TestDecompress_UnlimitedExplicit(t *testing.T) { + e := echo.New() + largeBody := strings.Repeat("X", 10*1024) // 10KB + gz, _ := gzipString(largeBody) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{MaxDecompressedSize: -1}.ToMiddleware() // Unlimited + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + b, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(b)) + })(c) + + assert.NoError(t, err) + assert.Equal(t, largeBody, rec.Body.String()) +} + +func TestDecompress_DefaultLimit(t *testing.T) { + e := echo.New() + smallBody := "test" + gz, _ := gzipString(smallBody) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Use zero value which should apply 100MB default + h, err := DecompressConfig{}.ToMiddleware() + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + b, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(b)) + })(c) + + assert.NoError(t, err) + assert.Equal(t, smallBody, rec.Body.String()) +} + +func TestDecompress_SmallCustomLimit(t *testing.T) { + e := echo.New() + body := strings.Repeat("D", 512) // 512 bytes + gz, _ := gzipString(body) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + b, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(b)) + })(c) + + assert.NoError(t, err) + assert.Equal(t, body, rec.Body.String()) +} + +func TestDecompress_MultipleReads(t *testing.T) { + e := echo.New() + // Test that limit is enforced across multiple Read() calls + largeBody := strings.Repeat("M", 2*1024) // 2KB + gz, _ := gzipString(largeBody) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + // Read in small chunks + buf := make([]byte, 256) + total := 0 + for { + n, readErr := c.Request().Body.Read(buf) + total += n + if readErr != nil { + if readErr == io.EOF { + return nil + } + return readErr + } + } + })(c) + + // Should return 413 error from cumulative reads + assert.Error(t, err) + he, ok := err.(echo.HTTPStatusCoder) + assert.True(t, ok) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) +} + +func TestDecompress_LargePayloadDosPrevention(t *testing.T) { + e := echo.New() + // Simulate a DoS attack with highly compressed large payload + largeSize := 10 * 1024 * 1024 // 10MB decompressed + largeBody := bytes.Repeat([]byte("Z"), largeSize) + var buf bytes.Buffer + gzWriter := gzip.NewWriter(&buf) + gzWriter.Write(largeBody) + gzWriter.Close() + + req := httptest.NewRequest(http.MethodPost, "/", &buf) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{MaxDecompressedSize: 1 * MB}.ToMiddleware() + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + _, readErr := io.ReadAll(c.Request().Body) + return readErr + })(c) + + // Should prevent DoS by returning 413 + assert.Error(t, err) + he, ok := err.(echo.HTTPStatusCoder) + assert.True(t, ok) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) +} + +func BenchmarkDecompress_WithLimit(b *testing.B) { + e := echo.New() + body := strings.Repeat("benchmark data ", 1000) // ~15KB + gz, _ := gzipString(body) + + h, _ := DecompressConfig{MaxDecompressedSize: 100 * MB}.ToMiddleware() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h(func(c *echo.Context) error { + io.ReadAll(c.Request().Body) + return nil + })(c) + } +} diff --git a/middleware/extractor.go b/middleware/extractor.go new file mode 100644 index 000000000..abb603186 --- /dev/null +++ b/middleware/extractor.go @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package middleware + +import ( + "fmt" + "net/textproto" + "strings" + + "github.com/labstack/echo/v5" +) + +const ( + // extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion + // attack vector + extractorLimit = 20 +) + +// ExtractorSource is type to indicate source for extracted value +type ExtractorSource string + +const ( + // ExtractorSourceHeader means value was extracted from request header + ExtractorSourceHeader ExtractorSource = "header" + // ExtractorSourceQuery means value was extracted from request query parameters + ExtractorSourceQuery ExtractorSource = "query" + // ExtractorSourcePathParam means value was extracted from route path parameters + ExtractorSourcePathParam ExtractorSource = "param" + // ExtractorSourceCookie means value was extracted from request cookies + ExtractorSourceCookie ExtractorSource = "cookie" + // ExtractorSourceForm means value was extracted from request form values + ExtractorSourceForm ExtractorSource = "form" +) + +// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups +type ValueExtractorError struct { + message string +} + +// Error returns errors text +func (e *ValueExtractorError) Error() string { + return e.message +} + +var errHeaderExtractorValueMissing = &ValueExtractorError{message: "missing value in request header"} +var errHeaderExtractorValueInvalid = &ValueExtractorError{message: "invalid value in request header"} +var errQueryExtractorValueMissing = &ValueExtractorError{message: "missing value in the query string"} +var errParamExtractorValueMissing = &ValueExtractorError{message: "missing value in path params"} +var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing value in cookies"} +var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"} + +// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context. +type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error) + +// CreateExtractors creates ValuesExtractors from given lookups. +// lookups is a string in the form of ":" or ":,:" that is used +// to extract key from the request. +// Possible values: +// - "header:" or "header::" +// `` is argument value to cut/trim prefix of the extracted value. This is useful if header +// value has static prefix like `Authorization: ` where part that we +// want to cut is ` ` note the space at the end. +// In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. +// - "query:" +// - "param:" +// - "form:" +// - "cookie:" +// +// Multiple sources example: +// - "header:Authorization,header:X-Api-Key" +// +// limit sets the maximum amount how many lookups can be returned. +func CreateExtractors(lookups string, limit uint) ([]ValuesExtractor, error) { + return createExtractors(lookups, limit) +} + +func createExtractors(lookups string, limit uint) ([]ValuesExtractor, error) { + if lookups == "" { + return nil, nil + } + if limit == 0 { + limit = 1 + } else if limit > extractorLimit { + limit = extractorLimit + } + + sources := strings.Split(lookups, ",") + var extractors = make([]ValuesExtractor, 0) + for _, source := range sources { + parts := strings.Split(source, ":") + if len(parts) < 2 { + return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source) + } + + switch parts[0] { + case "query": + extractors = append(extractors, valuesFromQuery(parts[1], limit)) + case "param": + extractors = append(extractors, valuesFromParam(parts[1], limit)) + case "cookie": + extractors = append(extractors, valuesFromCookie(parts[1], limit)) + case "form": + extractors = append(extractors, valuesFromForm(parts[1], limit)) + case "header": + prefix := "" + if len(parts) > 2 { + prefix = parts[2] + } + extractors = append(extractors, valuesFromHeader(parts[1], prefix, limit)) + } + } + return extractors, nil +} + +// valuesFromHeader returns a functions that extracts values from the request header. +// valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static +// prefix like `Authorization: ` where part that we want to remove is ` ` +// note the space at the end. In case of basic authentication `Authorization: Basic ` prefix we want to remove +// is `Basic `. In case of JWT tokens `Authorization: Bearer ` prefix is `Bearer `. +// If prefix is left empty the whole value is returned. +func valuesFromHeader(header string, valuePrefix string, limit uint) ValuesExtractor { + prefixLen := len(valuePrefix) + // standard library parses http.Request header keys in canonical form but we may provide something else so fix this + header = textproto.CanonicalMIMEHeaderKey(header) + if limit == 0 { + limit = 1 + } + return func(c *echo.Context) ([]string, ExtractorSource, error) { + values := c.Request().Header.Values(header) + if len(values) == 0 { + return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing + } + + i := uint(0) + result := make([]string, 0) + for _, value := range values { + if prefixLen == 0 { + result = append(result, value) + i++ + if i >= limit { + break + } + } else if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) { + result = append(result, value[prefixLen:]) + i++ + if i >= limit { + break + } + } + } + + if len(result) == 0 { + if prefixLen > 0 { + return nil, ExtractorSourceHeader, errHeaderExtractorValueInvalid + } + return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing + } + return result, ExtractorSourceHeader, nil + } +} + +// valuesFromQuery returns a function that extracts values from the query string. +func valuesFromQuery(param string, limit uint) ValuesExtractor { + if limit == 0 { + limit = 1 + } + return func(c *echo.Context) ([]string, ExtractorSource, error) { + result := c.QueryParams()[param] + if len(result) == 0 { + return nil, ExtractorSourceQuery, errQueryExtractorValueMissing + } else if len(result) > int(limit)-1 { + result = result[:limit] + } + return result, ExtractorSourceQuery, nil + } +} + +// valuesFromParam returns a function that extracts values from the url param string. +func valuesFromParam(param string, limit uint) ValuesExtractor { + if limit == 0 { + limit = 1 + } + return func(c *echo.Context) ([]string, ExtractorSource, error) { + result := make([]string, 0) + i := uint(0) + for _, p := range c.PathValues() { + if param != p.Name { + continue + } + result = append(result, p.Value) + i++ + if i >= limit { + break + } + } + if len(result) == 0 { + return nil, ExtractorSourcePathParam, errParamExtractorValueMissing + } + return result, ExtractorSourcePathParam, nil + } +} + +// valuesFromCookie returns a function that extracts values from the named cookie. +func valuesFromCookie(name string, limit uint) ValuesExtractor { + if limit == 0 { + limit = 1 + } + return func(c *echo.Context) ([]string, ExtractorSource, error) { + cookies := c.Cookies() + if len(cookies) == 0 { + return nil, ExtractorSourceCookie, errCookieExtractorValueMissing + } + + i := uint(0) + result := make([]string, 0) + for _, cookie := range cookies { + if name != cookie.Name { + continue + } + result = append(result, cookie.Value) + i++ + if i >= limit { + break + } + } + if len(result) == 0 { + return nil, ExtractorSourceCookie, errCookieExtractorValueMissing + } + return result, ExtractorSourceCookie, nil + } +} + +// valuesFromForm returns a function that extracts values from the form field. +func valuesFromForm(name string, limit uint) ValuesExtractor { + if limit == 0 { + limit = 1 + } + return func(c *echo.Context) ([]string, ExtractorSource, error) { + if c.Request().Form == nil { + _, _ = c.MultipartForm() // we want to trigger c.request.ParseMultipartForm(c.formParseMaxMemory) + } + values := c.Request().Form[name] + if len(values) == 0 { + return nil, ExtractorSourceForm, errFormExtractorValueMissing + } + if len(values) > int(limit)-1 { + values = values[:limit] + } + result := append([]string{}, values...) + return result, ExtractorSourceForm, nil + } +} diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go new file mode 100644 index 000000000..04cc7b829 --- /dev/null +++ b/middleware/extractor_test.go @@ -0,0 +1,625 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package middleware + +import ( + "bytes" + "fmt" + "mime/multipart" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/labstack/echo/v5" + "github.com/stretchr/testify/assert" +) + +func TestCreateExtractors(t *testing.T) { + var testCases = []struct { + name string + givenRequest func() *http.Request + givenPathValues echo.PathValues + whenLookups string + whenLimit uint + expectValues []string + expectSource ExtractorSource + expectCreateError string + expectError string + }{ + { + name: "ok, header", + givenRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, "Bearer token") + return req + }, + whenLookups: "header:Authorization:Bearer ", + expectValues: []string{"token"}, + expectSource: ExtractorSourceHeader, + }, + { + name: "ok, form", + givenRequest: func() *http.Request { + f := make(url.Values) + f.Set("name", "Jon Snow") + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + return req + }, + whenLookups: "form:name", + expectValues: []string{"Jon Snow"}, + expectSource: ExtractorSourceForm, + }, + { + name: "ok, cookie", + givenRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderCookie, "_csrf=token") + return req + }, + whenLookups: "cookie:_csrf", + expectValues: []string{"token"}, + expectSource: ExtractorSourceCookie, + }, + { + name: "ok, param", + givenPathValues: echo.PathValues{ + {Name: "id", Value: "123"}, + }, + whenLookups: "param:id", + expectValues: []string{"123"}, + expectSource: ExtractorSourcePathParam, + }, + { + name: "ok, query", + givenRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/?id=999", nil) + return req + }, + whenLookups: "query:id", + expectValues: []string{"999"}, + expectSource: ExtractorSourceQuery, + }, + { + name: "nok, invalid lookup", + whenLookups: "query", + expectCreateError: "extractor source for lookup could not be split into needed parts: query", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenRequest != nil { + req = tc.givenRequest() + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if tc.givenPathValues != nil { + c.SetPathValues(tc.givenPathValues) + } + + extractors, err := CreateExtractors(tc.whenLookups, tc.whenLimit) + if tc.expectCreateError != "" { + assert.EqualError(t, err, tc.expectCreateError) + return + } + assert.NoError(t, err) + + for _, e := range extractors { + values, source, eErr := e(c) + assert.Equal(t, tc.expectValues, values) + assert.Equal(t, tc.expectSource, source) + if tc.expectError != "" { + assert.EqualError(t, eErr, tc.expectError) + return + } + assert.NoError(t, eErr) + } + }) + } +} + +func TestValuesFromHeader(t *testing.T) { + exampleRequest := func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") + } + + var testCases = []struct { + name string + givenRequest func(req *http.Request) + whenName string + whenValuePrefix string + whenLimit uint + expectValues []string + expectError string + }{ + { + name: "ok, single value", + givenRequest: exampleRequest, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "basic ", + expectValues: []string{"dXNlcjpwYXNzd29yZA=="}, + }, + { + name: "ok, single value, case insensitive", + givenRequest: exampleRequest, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "Basic ", + expectValues: []string{"dXNlcjpwYXNzd29yZA=="}, + }, + { + name: "ok, multiple value", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") + req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") + }, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "basic ", + whenLimit: 2, + expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"}, + }, + { + name: "ok, empty prefix", + givenRequest: exampleRequest, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "", + expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="}, + }, + { + name: "nok, no matching due different prefix", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") + req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") + }, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "Bearer ", + expectError: errHeaderExtractorValueInvalid.Error(), + }, + { + name: "nok, no matching due different prefix", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") + req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") + }, + whenName: echo.HeaderWWWAuthenticate, + whenValuePrefix: "", + expectError: errHeaderExtractorValueMissing.Error(), + }, + { + name: "nok, no headers", + givenRequest: nil, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "basic ", + expectError: errHeaderExtractorValueMissing.Error(), + }, + { + name: "ok, prefix, cut values over extractorLimit", + givenRequest: func(req *http.Request) { + for i := 1; i <= 25; i++ { + req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("basic %v", i)) + } + }, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "basic ", + whenLimit: extractorLimit, + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + { + name: "ok, cut values over extractorLimit", + givenRequest: func(req *http.Request) { + for i := 1; i <= 25; i++ { + req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("%v", i)) + } + }, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "", + whenLimit: extractorLimit, + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenRequest != nil { + tc.givenRequest(req) + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix, tc.whenLimit) + + values, source, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceHeader, source) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValuesFromQuery(t *testing.T) { + var testCases = []struct { + name string + givenQueryPart string + whenName string + whenLimit uint + expectValues []string + expectError string + }{ + { + name: "ok, single value", + givenQueryPart: "?id=123&name=test", + whenName: "id", + expectValues: []string{"123"}, + }, + { + name: "ok, multiple value", + givenQueryPart: "?id=123&id=456&name=test", + whenName: "id", + whenLimit: 2, + expectValues: []string{"123", "456"}, + }, + { + name: "nok, missing value", + givenQueryPart: "?id=123&name=test", + whenName: "nope", + expectError: errQueryExtractorValueMissing.Error(), + }, + { + name: "ok, cut values over extractorLimit", + givenQueryPart: "?name=test" + + "&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" + + "&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" + + "&id=21&id=22&id=23&id=24&id=25", + whenName: "id", + whenLimit: extractorLimit, + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + extractor := valuesFromQuery(tc.whenName, tc.whenLimit) + + values, source, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceQuery, source) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValuesFromParam(t *testing.T) { + examplePathValues := echo.PathValues{ + {Name: "id", Value: "123"}, + {Name: "gid", Value: "456"}, + {Name: "gid", Value: "789"}, + } + examplePathValues20 := make(echo.PathValues, 0) + for i := 1; i < 25; i++ { + examplePathValues20 = append(examplePathValues20, echo.PathValue{Name: "id", Value: fmt.Sprintf("%v", i)}) + } + + var testCases = []struct { + name string + givenPathValues echo.PathValues + whenName string + whenLimit uint + expectValues []string + expectError string + }{ + { + name: "ok, single value", + givenPathValues: examplePathValues, + whenName: "id", + expectValues: []string{"123"}, + }, + { + name: "ok, multiple value", + givenPathValues: examplePathValues, + whenName: "gid", + whenLimit: 2, + expectValues: []string{"456", "789"}, + }, + { + name: "nok, no values", + givenPathValues: nil, + whenName: "nope", + expectValues: nil, + expectError: errParamExtractorValueMissing.Error(), + }, + { + name: "nok, no matching value", + givenPathValues: examplePathValues, + whenName: "nope", + expectValues: nil, + expectError: errParamExtractorValueMissing.Error(), + }, + { + name: "ok, cut values over extractorLimit", + givenPathValues: examplePathValues20, + whenName: "id", + whenLimit: extractorLimit, + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if tc.givenPathValues != nil { + c.SetPathValues(tc.givenPathValues) + } + + extractor := valuesFromParam(tc.whenName, tc.whenLimit) + + values, source, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourcePathParam, source) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValuesFromCookie(t *testing.T) { + exampleRequest := func(req *http.Request) { + req.Header.Set(echo.HeaderCookie, "_csrf=token") + } + + var testCases = []struct { + name string + givenRequest func(req *http.Request) + whenName string + whenLimit uint + expectValues []string + expectError string + }{ + { + name: "ok, single value", + givenRequest: exampleRequest, + whenName: "_csrf", + expectValues: []string{"token"}, + }, + { + name: "ok, multiple value", + givenRequest: func(req *http.Request) { + req.Header.Add(echo.HeaderCookie, "_csrf=token") + req.Header.Add(echo.HeaderCookie, "_csrf=token2") + }, + whenName: "_csrf", + whenLimit: 2, + expectValues: []string{"token", "token2"}, + }, + { + name: "nok, no matching cookie", + givenRequest: exampleRequest, + whenName: "xxx", + expectValues: nil, + expectError: errCookieExtractorValueMissing.Error(), + }, + { + name: "nok, no cookies at all", + givenRequest: nil, + whenName: "xxx", + expectValues: nil, + expectError: errCookieExtractorValueMissing.Error(), + }, + { + name: "ok, cut values over extractorLimit", + givenRequest: func(req *http.Request) { + for i := 1; i < 25; i++ { + req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i)) + } + }, + whenName: "_csrf", + whenLimit: extractorLimit, + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenRequest != nil { + tc.givenRequest(req) + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + extractor := valuesFromCookie(tc.whenName, tc.whenLimit) + + values, source, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceCookie, source) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValuesFromForm(t *testing.T) { + examplePostFormRequest := func(mod func(v *url.Values)) *http.Request { + f := make(url.Values) + f.Set("name", "Jon Snow") + f.Set("emails[]", "jon@labstack.com") + if mod != nil { + mod(&f) + } + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + + return req + } + exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request { + f := make(url.Values) + f.Set("name", "Jon Snow") + f.Set("emails[]", "jon@labstack.com") + if mod != nil { + mod(&f) + } + + req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil) + return req + } + + exampleMultiPartFormRequest := func(mod func(w *multipart.Writer)) *http.Request { + var b bytes.Buffer + w := multipart.NewWriter(&b) + w.WriteField("name", "Jon Snow") + w.WriteField("emails[]", "jon@labstack.com") + if mod != nil { + mod(w) + } + + fw, _ := w.CreateFormFile("upload", "my.file") + fw.Write([]byte(`
hi
`)) + w.Close() + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(b.String())) + req.Header.Add(echo.HeaderContentType, w.FormDataContentType()) + + return req + } + + var testCases = []struct { + name string + givenRequest *http.Request + whenName string + whenLimit uint + expectValues []string + expectError string + }{ + { + name: "ok, POST form, single value", + givenRequest: examplePostFormRequest(nil), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com"}, + }, + { + name: "ok, POST form, multiple value", + givenRequest: examplePostFormRequest(func(v *url.Values) { + v.Add("emails[]", "snow@labstack.com") + }), + whenName: "emails[]", + whenLimit: 2, + expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, + }, + { + name: "ok, POST multipart/form, multiple value", + givenRequest: exampleMultiPartFormRequest(func(w *multipart.Writer) { + w.WriteField("emails[]", "snow@labstack.com") + }), + whenName: "emails[]", + whenLimit: 2, + expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, + }, + { + name: "ok, GET form, single value", + givenRequest: exampleGetFormRequest(nil), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com"}, + }, + { + name: "ok, GET form, multiple value", + givenRequest: examplePostFormRequest(func(v *url.Values) { + v.Add("emails[]", "snow@labstack.com") + }), + whenName: "emails[]", + whenLimit: 2, + expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, + }, + { + name: "nok, POST form, value missing", + givenRequest: examplePostFormRequest(nil), + whenName: "nope", + expectError: errFormExtractorValueMissing.Error(), + }, + { + name: "ok, cut values over extractorLimit", + givenRequest: examplePostFormRequest(func(v *url.Values) { + for i := 1; i < 25; i++ { + v.Add("id[]", fmt.Sprintf("%v", i)) + } + }), + whenName: "id[]", + whenLimit: extractorLimit, + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := tc.givenRequest + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + extractor := valuesFromForm(tc.whenName, tc.whenLimit) + + values, source, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceForm, source) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/middleware/jwt.go b/middleware/jwt.go deleted file mode 100644 index 55a986327..000000000 --- a/middleware/jwt.go +++ /dev/null @@ -1,267 +0,0 @@ -package middleware - -import ( - "fmt" - "net/http" - "reflect" - "strings" - - "github.com/dgrijalva/jwt-go" - "github.com/labstack/echo/v4" -) - -type ( - // JWTConfig defines the config for JWT middleware. - JWTConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // BeforeFunc defines a function which is executed just before the middleware. - BeforeFunc BeforeFunc - - // SuccessHandler defines a function which is executed for a valid token. - SuccessHandler JWTSuccessHandler - - // ErrorHandler defines a function which is executed for an invalid token. - // It may be used to define a custom JWT error. - ErrorHandler JWTErrorHandler - - // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. - ErrorHandlerWithContext JWTErrorHandlerWithContext - - // Signing key to validate token. Used as fallback if SigningKeys has length 0. - // Required. This or SigningKeys. - SigningKey interface{} - - // Map of signing keys to validate token with kid field usage. - // Required. This or SigningKey. - SigningKeys map[string]interface{} - - // Signing method, used to check token signing method. - // Optional. Default value HS256. - SigningMethod string - - // Context key to store user information from the token into context. - // Optional. Default value "user". - ContextKey string - - // Claims are extendable claims data defining token content. - // Optional. Default value jwt.MapClaims - Claims jwt.Claims - - // TokenLookup is a string in the form of ":" that is used - // to extract token from the request. - // Optional. Default value "header:Authorization". - // Possible values: - // - "header:" - // - "query:" - // - "param:" - // - "cookie:" - TokenLookup string - - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string - - keyFunc jwt.Keyfunc - } - - // JWTSuccessHandler defines a function which is executed for a valid token. - JWTSuccessHandler func(echo.Context) - - // JWTErrorHandler defines a function which is executed for an invalid token. - JWTErrorHandler func(error) error - - // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. - JWTErrorHandlerWithContext func(error, echo.Context) error - - jwtExtractor func(echo.Context) (string, error) -) - -// Algorithms -const ( - AlgorithmHS256 = "HS256" -) - -// Errors -var ( - ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt") -) - -var ( - // DefaultJWTConfig is the default JWT auth middleware config. - DefaultJWTConfig = JWTConfig{ - Skipper: DefaultSkipper, - SigningMethod: AlgorithmHS256, - ContextKey: "user", - TokenLookup: "header:" + echo.HeaderAuthorization, - AuthScheme: "Bearer", - Claims: jwt.MapClaims{}, - } -) - -// JWT returns a JSON Web Token (JWT) auth middleware. -// -// For valid token, it sets the user in context and calls next handler. -// For invalid token, it returns "401 - Unauthorized" error. -// For missing token, it returns "400 - Bad Request" error. -// -// See: https://jwt.io/introduction -// See `JWTConfig.TokenLookup` -func JWT(key interface{}) echo.MiddlewareFunc { - c := DefaultJWTConfig - c.SigningKey = key - return JWTWithConfig(c) -} - -// JWTWithConfig returns a JWT auth middleware with config. -// See: `JWT()`. -func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { - // Defaults - if config.Skipper == nil { - config.Skipper = DefaultJWTConfig.Skipper - } - if config.SigningKey == nil && len(config.SigningKeys) == 0 { - panic("echo: jwt middleware requires signing key") - } - if config.SigningMethod == "" { - config.SigningMethod = DefaultJWTConfig.SigningMethod - } - if config.ContextKey == "" { - config.ContextKey = DefaultJWTConfig.ContextKey - } - if config.Claims == nil { - config.Claims = DefaultJWTConfig.Claims - } - if config.TokenLookup == "" { - config.TokenLookup = DefaultJWTConfig.TokenLookup - } - if config.AuthScheme == "" { - config.AuthScheme = DefaultJWTConfig.AuthScheme - } - config.keyFunc = func(t *jwt.Token) (interface{}, error) { - // Check the signing method - if t.Method.Alg() != config.SigningMethod { - return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) - } - if len(config.SigningKeys) > 0 { - if kid, ok := t.Header["kid"].(string); ok { - if key, ok := config.SigningKeys[kid]; ok { - return key, nil - } - } - return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"]) - } - - return config.SigningKey, nil - } - - // Initialize - parts := strings.Split(config.TokenLookup, ":") - extractor := jwtFromHeader(parts[1], config.AuthScheme) - switch parts[0] { - case "query": - extractor = jwtFromQuery(parts[1]) - case "param": - extractor = jwtFromParam(parts[1]) - case "cookie": - extractor = jwtFromCookie(parts[1]) - } - - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - if config.Skipper(c) { - return next(c) - } - - if config.BeforeFunc != nil { - config.BeforeFunc(c) - } - - auth, err := extractor(c) - if err != nil { - if config.ErrorHandler != nil { - return config.ErrorHandler(err) - } - - if config.ErrorHandlerWithContext != nil { - return config.ErrorHandlerWithContext(err, c) - } - return err - } - token := new(jwt.Token) - // Issue #647, #656 - if _, ok := config.Claims.(jwt.MapClaims); ok { - token, err = jwt.Parse(auth, config.keyFunc) - } else { - t := reflect.ValueOf(config.Claims).Type().Elem() - claims := reflect.New(t).Interface().(jwt.Claims) - token, err = jwt.ParseWithClaims(auth, claims, config.keyFunc) - } - if err == nil && token.Valid { - // Store user information from token into context. - c.Set(config.ContextKey, token) - if config.SuccessHandler != nil { - config.SuccessHandler(c) - } - return next(c) - } - if config.ErrorHandler != nil { - return config.ErrorHandler(err) - } - if config.ErrorHandlerWithContext != nil { - return config.ErrorHandlerWithContext(err, c) - } - return &echo.HTTPError{ - Code: http.StatusUnauthorized, - Message: "invalid or expired jwt", - Internal: err, - } - } - } -} - -// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header. -func jwtFromHeader(header string, authScheme string) jwtExtractor { - return func(c echo.Context) (string, error) { - auth := c.Request().Header.Get(header) - l := len(authScheme) - if len(auth) > l+1 && auth[:l] == authScheme { - return auth[l+1:], nil - } - return "", ErrJWTMissing - } -} - -// jwtFromQuery returns a `jwtExtractor` that extracts token from the query string. -func jwtFromQuery(param string) jwtExtractor { - return func(c echo.Context) (string, error) { - token := c.QueryParam(param) - if token == "" { - return "", ErrJWTMissing - } - return token, nil - } -} - -// jwtFromParam returns a `jwtExtractor` that extracts token from the url param string. -func jwtFromParam(param string) jwtExtractor { - return func(c echo.Context) (string, error) { - token := c.Param(param) - if token == "" { - return "", ErrJWTMissing - } - return token, nil - } -} - -// jwtFromCookie returns a `jwtExtractor` that extracts token from the named cookie. -func jwtFromCookie(name string) jwtExtractor { - return func(c echo.Context) (string, error) { - cookie, err := c.Cookie(name) - if err != nil { - return "", ErrJWTMissing - } - return cookie.Value, nil - } -} diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go deleted file mode 100644 index 7f15bd467..000000000 --- a/middleware/jwt_test.go +++ /dev/null @@ -1,329 +0,0 @@ -package middleware - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/dgrijalva/jwt-go" - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" -) - -// jwtCustomInfo defines some custom types we're going to use within our tokens. -type jwtCustomInfo struct { - Name string `json:"name"` - Admin bool `json:"admin"` -} - -// jwtCustomClaims are custom claims expanding default ones. -type jwtCustomClaims struct { - *jwt.StandardClaims - jwtCustomInfo -} - -func TestJWTRace(t *testing.T) { - e := echo.New() - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - initialToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" - raceToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IlJhY2UgQ29uZGl0aW9uIiwiYWRtaW4iOmZhbHNlfQ.Xzkx9mcgGqYMTkuxSCbJ67lsDyk5J2aB7hu65cEE-Ss" - validKey := []byte("secret") - - h := JWTWithConfig(JWTConfig{ - Claims: &jwtCustomClaims{}, - SigningKey: validKey, - })(handler) - - makeReq := func(token string) echo.Context { - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" "+token) - c := e.NewContext(req, res) - assert.NoError(t, h(c)) - return c - } - - c := makeReq(initialToken) - user := c.Get("user").(*jwt.Token) - claims := user.Claims.(*jwtCustomClaims) - assert.Equal(t, claims.Name, "John Doe") - - makeReq(raceToken) - user = c.Get("user").(*jwt.Token) - claims = user.Claims.(*jwtCustomClaims) - // Initial context should still be "John Doe", not "Race Condition" - assert.Equal(t, claims.Name, "John Doe") - assert.Equal(t, claims.Admin, true) -} - -func TestJWT(t *testing.T) { - e := echo.New() - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" - validKey := []byte("secret") - invalidKey := []byte("invalid-key") - validAuth := DefaultJWTConfig.AuthScheme + " " + token - - for _, tc := range []struct { - expPanic bool - expErrCode int // 0 for Success - config JWTConfig - reqURL string // "/" if empty - hdrAuth string - hdrCookie string // test.Request doesn't provide SetCookie(); use name=val - info string - }{ - { - expPanic: true, - info: "No signing key provided", - }, - { - expErrCode: http.StatusBadRequest, - config: JWTConfig{ - SigningKey: validKey, - SigningMethod: "RS256", - }, - info: "Unexpected signing method", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: validAuth, - config: JWTConfig{SigningKey: invalidKey}, - info: "Invalid key", - }, - { - hdrAuth: validAuth, - config: JWTConfig{SigningKey: validKey}, - info: "Valid JWT", - }, - { - hdrAuth: "Token" + " " + token, - config: JWTConfig{AuthScheme: "Token", SigningKey: validKey}, - info: "Valid JWT with custom AuthScheme", - }, - { - hdrAuth: validAuth, - config: JWTConfig{ - Claims: &jwtCustomClaims{}, - SigningKey: []byte("secret"), - }, - info: "Valid JWT with custom claims", - }, - { - hdrAuth: "invalid-auth", - expErrCode: http.StatusBadRequest, - config: JWTConfig{SigningKey: validKey}, - info: "Invalid Authorization header", - }, - { - config: JWTConfig{SigningKey: validKey}, - expErrCode: http.StatusBadRequest, - info: "Empty header auth field", - }, - { - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", - }, - reqURL: "/?a=b&jwt=" + token, - info: "Valid query method", - }, - { - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", - }, - reqURL: "/?a=b&jwtxyz=" + token, - expErrCode: http.StatusBadRequest, - info: "Invalid query param name", - }, - { - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", - }, - reqURL: "/?a=b&jwt=invalid-token", - expErrCode: http.StatusUnauthorized, - info: "Invalid query param value", - }, - { - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "query:jwt", - }, - reqURL: "/?a=b", - expErrCode: http.StatusBadRequest, - info: "Empty query", - }, - { - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "param:jwt", - }, - reqURL: "/" + token, - info: "Valid param method", - }, - { - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", - }, - hdrCookie: "jwt=" + token, - info: "Valid cookie method", - }, - { - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", - }, - expErrCode: http.StatusUnauthorized, - hdrCookie: "jwt=invalid", - info: "Invalid token with cookie method", - }, - { - config: JWTConfig{ - SigningKey: validKey, - TokenLookup: "cookie:jwt", - }, - expErrCode: http.StatusBadRequest, - info: "Empty cookie", - }, - } { - if tc.reqURL == "" { - tc.reqURL = "/" - } - - req := httptest.NewRequest(http.MethodGet, tc.reqURL, nil) - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) - req.Header.Set(echo.HeaderCookie, tc.hdrCookie) - c := e.NewContext(req, res) - - if tc.reqURL == "/" + token { - c.SetParamNames("jwt") - c.SetParamValues(token) - } - - if tc.expPanic { - assert.Panics(t, func() { - JWTWithConfig(tc.config) - }, tc.info) - continue - } - - if tc.expErrCode != 0 { - h := JWTWithConfig(tc.config)(handler) - he := h(c).(*echo.HTTPError) - assert.Equal(t, tc.expErrCode, he.Code, tc.info) - continue - } - - h := JWTWithConfig(tc.config)(handler) - if assert.NoError(t, h(c), tc.info) { - user := c.Get("user").(*jwt.Token) - switch claims := user.Claims.(type) { - case jwt.MapClaims: - assert.Equal(t, claims["name"], "John Doe", tc.info) - case *jwtCustomClaims: - assert.Equal(t, claims.Name, "John Doe", tc.info) - assert.Equal(t, claims.Admin, true, tc.info) - default: - panic("unexpected type of claims") - } - } - } -} - -func TestJWTwithKID(t *testing.T) { - test := assert.New(t) - - e := echo.New() - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - firstToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6ImZpcnN0T25lIn0.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.w5VGpHOe0jlNgf7jMVLHzIYH_XULmpUlreJnilwSkWk" - secondToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.sdghDYQ85jdh0hgQ6bKbMguLI_NSPYWjkhVJkee-yZM" - wrongToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiIsImtpZCI6InNlY29uZE9uZSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.RyhLybtVLpoewF6nz9YN79oXo32kAtgUxp8FNwTkb90" - staticToken := "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.1_-XFYUPpJfgsaGwYhgZEt7hfySMg-a3GN-nfZmbW7o" - validKeys := map[string]interface{}{"firstOne": []byte("first_secret"), "secondOne": []byte("second_secret")} - invalidKeys := map[string]interface{}{"thirdOne": []byte("third_secret")} - staticSecret := []byte("static_secret") - invalidStaticSecret := []byte("invalid_secret") - - for _, tc := range []struct { - expErrCode int // 0 for Success - config JWTConfig - hdrAuth string - info string - }{ - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "First token valid", - }, - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "Second token valid", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + wrongToken, - config: JWTConfig{SigningKeys: validKeys}, - info: "Wrong key id token", - }, - { - hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken, - config: JWTConfig{SigningKey: staticSecret}, - info: "Valid static secret token", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + staticToken, - config: JWTConfig{SigningKey: invalidStaticSecret}, - info: "Invalid static secret", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + firstToken, - config: JWTConfig{SigningKeys: invalidKeys}, - info: "Invalid keys first token", - }, - { - expErrCode: http.StatusUnauthorized, - hdrAuth: DefaultJWTConfig.AuthScheme + " " + secondToken, - config: JWTConfig{SigningKeys: invalidKeys}, - info: "Invalid keys second token", - }, - } { - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) - c := e.NewContext(req, res) - - if tc.expErrCode != 0 { - h := JWTWithConfig(tc.config)(handler) - he := h(c).(*echo.HTTPError) - test.Equal(tc.expErrCode, he.Code, tc.info) - continue - } - - h := JWTWithConfig(tc.config)(handler) - if test.NoError(h(c), tc.info) { - user := c.Get("user").(*jwt.Token) - switch claims := user.Claims.(type) { - case jwt.MapClaims: - test.Equal(claims["name"], "John Doe", tc.info) - case *jwtCustomClaims: - test.Equal(claims.Name, "John Doe", tc.info) - test.Equal(claims.Admin, true, tc.info) - default: - panic("unexpected type of claims") - } - } - } -} diff --git a/middleware/key_auth.go b/middleware/key_auth.go index 94cfd1429..e14bd9e2e 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -1,51 +1,118 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( + "cmp" "errors" + "fmt" "net/http" - "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // KeyAuthConfig defines the config for KeyAuth middleware. - KeyAuthConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // KeyLookup is a string in the form of ":" that is used - // to extract key from the request. - // Optional. Default value "header:Authorization". - // Possible values: - // - "header:" - // - "query:" - // - "form:" - KeyLookup string `yaml:"key_lookup"` - - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string - - // Validator is a function to validate key. - // Required. - Validator KeyAuthValidator - } - - // KeyAuthValidator defines a function to validate KeyAuth credentials. - KeyAuthValidator func(string, echo.Context) (bool, error) - - keyExtractor func(echo.Context) (string, error) -) +// KeyAuthConfig defines the config for KeyAuth middleware. +// +// SECURITY: The Validator function is responsible for securely comparing API keys. +// See KeyAuthValidator documentation for guidance on preventing timing attacks. +type KeyAuthConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // KeyLookup is a string in the form of ":" or ":,:" that is used + // to extract key from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" or "header::" + // `` is argument value to cut/trim prefix of the extracted value. This is useful if header + // value has static prefix like `Authorization: ` where part that we + // want to cut is ` ` note the space at the end. + // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. + // - "query:" + // - "form:" + // - "cookie:" + // Multiple sources example: + // - "header:Authorization,header:X-Api-Key" + KeyLookup string + + // AllowedCheckLimit set how many KeyLookup values are allowed to be checked. This is + // useful environments like corporate test environments with application proxies restricting + // access to environment with their own auth scheme. + AllowedCheckLimit uint + + // Validator is a function to validate key. + // Required. + Validator KeyAuthValidator + + // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator + // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key. + // It may be used to define a custom error. + // + // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler. + // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users + // In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain. + ErrorHandler KeyAuthErrorHandler + + // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to + // ignore the error (by returning `nil`). + // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. + // In that case you can use ErrorHandler to set a default public key auth value in the request context + // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. + ContinueOnIgnoredError bool +} -var ( - // DefaultKeyAuthConfig is the default KeyAuth middleware config. - DefaultKeyAuthConfig = KeyAuthConfig{ - Skipper: DefaultSkipper, - KeyLookup: "header:" + echo.HeaderAuthorization, - AuthScheme: "Bearer", - } -) +// KeyAuthValidator defines a function to validate KeyAuth credentials. +// +// SECURITY WARNING: To prevent timing attacks that could allow attackers to enumerate +// valid API keys, validator implementations MUST use constant-time comparison. +// Use crypto/subtle.ConstantTimeCompare instead of standard string equality (==) +// or switch statements. +// +// Example of SECURE implementation: +// +// import "crypto/subtle" +// +// validator := func(c *echo.Context, key string, source ExtractorSource) (bool, error) { +// // Fetch valid keys from database/config +// validKeys := []string{"key1", "key2", "key3"} +// +// for _, validKey := range validKeys { +// // Use constant-time comparison to prevent timing attacks +// if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 { +// return true, nil +// } +// } +// return false, nil +// } +// +// Example of INSECURE implementation (DO NOT USE): +// +// // VULNERABLE TO TIMING ATTACKS - DO NOT USE +// validator := func(c *echo.Context, key string, source ExtractorSource) (bool, error) { +// switch key { // Timing leak! +// case "valid-key": +// return true, nil +// default: +// return false, nil +// } +// } +type KeyAuthValidator func(c *echo.Context, key string, source ExtractorSource) (bool, error) + +// KeyAuthErrorHandler defines a function which is executed for an invalid key. +type KeyAuthErrorHandler func(c *echo.Context, err error) error + +// ErrKeyMissing denotes an error raised when key value could not be extracted from request +var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key") + +// ErrInvalidKey denotes an error raised when key value is invalid by validator +var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key") + +// DefaultKeyAuthConfig is the default KeyAuth middleware config. +var DefaultKeyAuthConfig = KeyAuthConfig{ + Skipper: DefaultSkipper, + KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ", +} // KeyAuth returns an KeyAuth middleware. // @@ -58,96 +125,81 @@ func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc { return KeyAuthWithConfig(c) } -// KeyAuthWithConfig returns an KeyAuth middleware with config. -// See `KeyAuth()`. +// KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid. +// +// For first valid key it calls the next handler. +// For invalid key, it sends "401 - Unauthorized" response. +// For missing key, it sends "400 - Bad Request" response. func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration +func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultKeyAuthConfig.Skipper } - // Defaults - if config.AuthScheme == "" { - config.AuthScheme = DefaultKeyAuthConfig.AuthScheme - } if config.KeyLookup == "" { config.KeyLookup = DefaultKeyAuthConfig.KeyLookup } if config.Validator == nil { - panic("echo: key-auth middleware requires a validator function") + return nil, errors.New("echo key-auth middleware requires a validator function") } - // Initialize - parts := strings.Split(config.KeyLookup, ":") - extractor := keyFromHeader(parts[1], config.AuthScheme) - switch parts[0] { - case "query": - extractor = keyFromQuery(parts[1]) - case "form": - extractor = keyFromForm(parts[1]) + limit := cmp.Or(config.AllowedCheckLimit, 1) + + extractors, cErr := createExtractors(config.KeyLookup, limit) + if cErr != nil { + return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", cErr) + } + if len(extractors) == 0 { + return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string") } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } - // Extract and verify key - key, err := extractor(c) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - valid, err := config.Validator(key, c) - if err != nil { - return &echo.HTTPError{ - Code: http.StatusUnauthorized, - Message: "invalid key", - Internal: err, + var lastExtractorErr error + var lastValidatorErr error + for _, extractor := range extractors { + keys, source, extrErr := extractor(c) + if extrErr != nil { + lastExtractorErr = extrErr + continue + } + for _, key := range keys { + valid, err := config.Validator(c, key, source) + if err != nil { + lastValidatorErr = err + continue + } + if !valid { + lastValidatorErr = ErrInvalidKey + continue + } + return next(c) } - } else if valid { - return next(c) } - return echo.ErrUnauthorized - } - } -} -// keyFromHeader returns a `keyExtractor` that extracts key from the request header. -func keyFromHeader(header string, authScheme string) keyExtractor { - return func(c echo.Context) (string, error) { - auth := c.Request().Header.Get(header) - if auth == "" { - return "", errors.New("missing key in request header") - } - if header == echo.HeaderAuthorization { - l := len(authScheme) - if len(auth) > l+1 && auth[:l] == authScheme { - return auth[l+1:], nil + // prioritize validator errors over extracting errors + err := lastValidatorErr + if err == nil { + err = lastExtractorErr } - return "", errors.New("invalid key in the request header") - } - return auth, nil - } -} - -// keyFromQuery returns a `keyExtractor` that extracts key from the query string. -func keyFromQuery(param string) keyExtractor { - return func(c echo.Context) (string, error) { - key := c.QueryParam(param) - if key == "" { - return "", errors.New("missing key in the query string") - } - return key, nil - } -} - -// keyFromForm returns a `keyExtractor` that extracts key from the form. -func keyFromForm(param string) keyExtractor { - return func(c echo.Context) (string, error) { - key := c.FormValue(param) - if key == "" { - return "", errors.New("missing key in the form") + if config.ErrorHandler != nil { + tmpErr := config.ErrorHandler(c, err) + if config.ContinueOnIgnoredError && tmpErr == nil { + return next(c) + } + return tmpErr + } + if lastValidatorErr == nil { + return ErrKeyMissing.Wrap(err) + } + return echo.ErrUnauthorized.Wrap(err) } - return key, nil - } + }, nil } diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go index b874898c8..49a917ed3 100644 --- a/middleware/key_auth_test.go +++ b/middleware/key_auth_test.go @@ -1,75 +1,375 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( + "crypto/subtle" + "errors" "net/http" "net/http/httptest" - "net/url" "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) +func testKeyValidator(c *echo.Context, key string, source ExtractorSource) (bool, error) { + // Use constant-time comparison to prevent timing attacks + if subtle.ConstantTimeCompare([]byte(key), []byte("valid-key")) == 1 { + return true, nil + } + + // Special case for testing error handling + if key == "error-key" { // Error path doesn't need constant-time + return false, errors.New("some user defined error") + } + + return false, nil +} + func TestKeyAuth(t *testing.T) { + handlerCalled := false + handler := func(c *echo.Context) error { + handlerCalled = true + return c.String(http.StatusOK, "test") + } + middlewareChain := KeyAuth(testKeyValidator)(handler) + e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key") rec := httptest.NewRecorder() c := e.NewContext(req, rec) - config := KeyAuthConfig{ - Validator: func(key string, c echo.Context) (bool, error) { - return key == "valid-key", nil + + err := middlewareChain(c) + + assert.NoError(t, err) + assert.True(t, handlerCalled) +} + +func TestKeyAuthWithConfig(t *testing.T) { + var testCases = []struct { + name string + givenRequestFunc func() *http.Request + givenRequest func(req *http.Request) + whenConfig func(conf *KeyAuthConfig) + expectHandlerCalled bool + expectError string + }{ + { + name: "ok, defaults, key from header", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key") + }, + expectHandlerCalled: true, + }, + { + name: "ok, custom skipper", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.Skipper = func(context *echo.Context) bool { + return true + } + }, + expectHandlerCalled: true, + }, + { + name: "nok, defaults, invalid key from header, Authorization: Bearer", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key") + }, + expectHandlerCalled: false, + expectError: "code=401, message=Unauthorized, err=code=401, message=invalid key", + }, + { + name: "nok, defaults, invalid scheme in header", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "Bear valid-key") + }, + expectHandlerCalled: false, + expectError: "code=401, message=missing key, err=invalid value in request header", + }, + { + name: "nok, defaults, missing header", + givenRequest: func(req *http.Request) {}, + expectHandlerCalled: false, + expectError: "code=401, message=missing key, err=missing value in request header", + }, + { + name: "ok, custom key lookup, header", + givenRequest: func(req *http.Request) { + req.Header.Set("API-Key", "valid-key") + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "header:API-Key" + }, + expectHandlerCalled: true, + }, + { + name: "nok, custom key lookup, missing header", + givenRequest: func(req *http.Request) { + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "header:API-Key" + }, + expectHandlerCalled: false, + expectError: "code=401, message=missing key, err=missing value in request header", + }, + { + name: "ok, custom key lookup, query", + givenRequest: func(req *http.Request) { + q := req.URL.Query() + q.Add("key", "valid-key") + req.URL.RawQuery = q.Encode() + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "query:key" + }, + expectHandlerCalled: true, + }, + { + name: "nok, custom key lookup, missing query param", + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "query:key" + }, + expectHandlerCalled: false, + expectError: "code=401, message=missing key, err=missing value in the query string", + }, + { + name: "ok, custom key lookup, form", + givenRequestFunc: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("key=valid-key")) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) + return req + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "form:key" + }, + expectHandlerCalled: true, + }, + { + name: "nok, custom key lookup, missing key in form", + givenRequestFunc: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("xxx=valid-key")) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) + return req + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "form:key" + }, + expectHandlerCalled: false, + expectError: "code=401, message=missing key, err=missing value in the form", + }, + { + name: "ok, custom key lookup, cookie", + givenRequest: func(req *http.Request) { + req.AddCookie(&http.Cookie{ + Name: "key", + Value: "valid-key", + }) + q := req.URL.Query() + q.Add("key", "valid-key") + req.URL.RawQuery = q.Encode() + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "cookie:key" + }, + expectHandlerCalled: true, + }, + { + name: "nok, custom key lookup, missing cookie param", + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "cookie:key" + }, + expectHandlerCalled: false, + expectError: "code=401, message=missing key, err=missing value in cookies", + }, + { + name: "nok, custom errorHandler, error from extractor", + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "header:token" + conf.ErrorHandler = func(c *echo.Context, err error) error { + return echo.NewHTTPError(http.StatusTeapot, "custom").Wrap(err) + } + }, + expectHandlerCalled: false, + expectError: "code=418, message=custom, err=missing value in request header", + }, + { + name: "nok, custom errorHandler, error from validator", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.ErrorHandler = func(c *echo.Context, err error) error { + return echo.NewHTTPError(http.StatusTeapot, "custom").Wrap(err) + } + }, + expectHandlerCalled: false, + expectError: "code=418, message=custom, err=some user defined error", + }, + { + name: "nok, defaults, error from validator", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") + }, + whenConfig: func(conf *KeyAuthConfig) {}, + expectHandlerCalled: false, + expectError: "code=401, message=Unauthorized, err=some user defined error", + }, + { + name: "ok, custom validator checks source", + givenRequest: func(req *http.Request) { + q := req.URL.Query() + q.Add("key", "valid-key") + req.URL.RawQuery = q.Encode() + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "query:key" + conf.Validator = func(c *echo.Context, key string, source ExtractorSource) (bool, error) { + if source == ExtractorSourceQuery { + return true, nil + } + return false, errors.New("invalid source") + } + + }, + expectHandlerCalled: true, }, } - h := KeyAuthWithConfig(config)(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - assert := assert.New(t) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + handlerCalled := false + handler := func(c *echo.Context) error { + handlerCalled = true + return c.String(http.StatusOK, "test") + } + config := KeyAuthConfig{ + Validator: testKeyValidator, + } + if tc.whenConfig != nil { + tc.whenConfig(&config) + } + middlewareChain := KeyAuthWithConfig(config)(handler) - // Valid key - auth := DefaultKeyAuthConfig.AuthScheme + " " + "valid-key" - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(h(c)) + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenRequestFunc != nil { + req = tc.givenRequestFunc() + } + if tc.givenRequest != nil { + tc.givenRequest(req) + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - // Invalid key - auth = DefaultKeyAuthConfig.AuthScheme + " " + "invalid-key" - req.Header.Set(echo.HeaderAuthorization, auth) - he := h(c).(*echo.HTTPError) - assert.Equal(http.StatusUnauthorized, he.Code) + err := middlewareChain(c) - // Missing Authorization header - req.Header.Del(echo.HeaderAuthorization) - he = h(c).(*echo.HTTPError) - assert.Equal(http.StatusBadRequest, he.Code) + assert.Equal(t, tc.expectHandlerCalled, handlerCalled) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} - // Key from custom header - config.KeyLookup = "header:API-Key" - h = KeyAuthWithConfig(config)(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - req.Header.Set("API-Key", "valid-key") - assert.NoError(h(c)) +func TestKeyAuthWithConfig_errors(t *testing.T) { + var testCases = []struct { + name string + whenConfig KeyAuthConfig + expectError string + }{ + { + name: "ok, no error", + whenConfig: KeyAuthConfig{ + Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) { + return false, nil + }, + }, + }, + { + name: "ok, missing validator func", + whenConfig: KeyAuthConfig{ + Validator: nil, + }, + expectError: "echo key-auth middleware requires a validator function", + }, + { + name: "ok, extractor source can not be split", + whenConfig: KeyAuthConfig{ + KeyLookup: "nope", + Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) { + return false, nil + }, + }, + expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope", + }, + { + name: "ok, no extractors", + whenConfig: KeyAuthConfig{ + KeyLookup: "nope:nope", + Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) { + return false, nil + }, + }, + expectError: "echo key-auth middleware could not create extractors from KeyLookup string", + }, + } - // Key from query string - config.KeyLookup = "query:key" - h = KeyAuthWithConfig(config)(func(c echo.Context) error { - return c.String(http.StatusOK, "test") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mw, err := tc.whenConfig.ToMiddleware() + if tc.expectError != "" { + assert.Nil(t, mw) + assert.EqualError(t, err, tc.expectError) + } else { + assert.NotNil(t, mw) + assert.NoError(t, err) + } + }) + } +} + +func TestMustKeyAuthWithConfig_panic(t *testing.T) { + assert.Panics(t, func() { + KeyAuthWithConfig(KeyAuthConfig{}) }) - q := req.URL.Query() - q.Add("key", "valid-key") - req.URL.RawQuery = q.Encode() - assert.NoError(h(c)) - - // Key from form - config.KeyLookup = "form:key" - h = KeyAuthWithConfig(config)(func(c echo.Context) error { +} + +func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) { + handlerCalled := false + var authValue string + handler := func(c *echo.Context) error { + handlerCalled = true + authValue = c.Get("auth").(string) return c.String(http.StatusOK, "test") - }) - f := make(url.Values) - f.Set("key", "valid-key") - req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) - req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) - c = e.NewContext(req, rec) - assert.NoError(h(c)) + } + middlewareChain := KeyAuthWithConfig(KeyAuthConfig{ + Validator: testKeyValidator, + ErrorHandler: func(c *echo.Context, err error) error { + // could check error to decide if we can swallow the error + c.Set("auth", "public") + return nil + }, + ContinueOnIgnoredError: true, + })(handler) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + // no auth header this time + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := middlewareChain(c) + + assert.NoError(t, err) + assert.True(t, handlerCalled) + assert.Equal(t, "public", authValue) } diff --git a/middleware/logger.go b/middleware/logger.go deleted file mode 100644 index 9baac4769..000000000 --- a/middleware/logger.go +++ /dev/null @@ -1,223 +0,0 @@ -package middleware - -import ( - "bytes" - "encoding/json" - "io" - "strconv" - "strings" - "sync" - "time" - - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/color" - "github.com/valyala/fasttemplate" -) - -type ( - // LoggerConfig defines the config for Logger middleware. - LoggerConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Tags to construct the logger format. - // - // - time_unix - // - time_unix_nano - // - time_rfc3339 - // - time_rfc3339_nano - // - time_custom - // - id (Request ID) - // - remote_ip - // - uri - // - host - // - method - // - path - // - protocol - // - referer - // - user_agent - // - status - // - error - // - latency (In nanoseconds) - // - latency_human (Human readable) - // - bytes_in (Bytes received) - // - bytes_out (Bytes sent) - // - header: - // - query: - // - form: - // - // Example "${remote_ip} ${status}" - // - // Optional. Default value DefaultLoggerConfig.Format. - Format string `yaml:"format"` - - // Optional. Default value DefaultLoggerConfig.CustomTimeFormat. - CustomTimeFormat string `yaml:"custom_time_format"` - - // Output is a writer where logs in JSON format are written. - // Optional. Default value os.Stdout. - Output io.Writer - - template *fasttemplate.Template - colorer *color.Color - pool *sync.Pool - } -) - -var ( - // DefaultLoggerConfig is the default Logger middleware config. - DefaultLoggerConfig = LoggerConfig{ - Skipper: DefaultSkipper, - Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` + - `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + - `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + - `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", - CustomTimeFormat: "2006-01-02 15:04:05.00000", - colorer: color.New(), - } -) - -// Logger returns a middleware that logs HTTP requests. -func Logger() echo.MiddlewareFunc { - return LoggerWithConfig(DefaultLoggerConfig) -} - -// LoggerWithConfig returns a Logger middleware with config. -// See: `Logger()`. -func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { - // Defaults - if config.Skipper == nil { - config.Skipper = DefaultLoggerConfig.Skipper - } - if config.Format == "" { - config.Format = DefaultLoggerConfig.Format - } - if config.Output == nil { - config.Output = DefaultLoggerConfig.Output - } - - config.template = fasttemplate.New(config.Format, "${", "}") - config.colorer = color.New() - config.colorer.SetOutput(config.Output) - config.pool = &sync.Pool{ - New: func() interface{} { - return bytes.NewBuffer(make([]byte, 256)) - }, - } - - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { - if config.Skipper(c) { - return next(c) - } - - req := c.Request() - res := c.Response() - start := time.Now() - if err = next(c); err != nil { - c.Error(err) - } - stop := time.Now() - buf := config.pool.Get().(*bytes.Buffer) - buf.Reset() - defer config.pool.Put(buf) - - if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { - switch tag { - case "time_unix": - return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10)) - case "time_unix_nano": - return buf.WriteString(strconv.FormatInt(time.Now().UnixNano(), 10)) - case "time_rfc3339": - return buf.WriteString(time.Now().Format(time.RFC3339)) - case "time_rfc3339_nano": - return buf.WriteString(time.Now().Format(time.RFC3339Nano)) - case "time_custom": - return buf.WriteString(time.Now().Format(config.CustomTimeFormat)) - case "id": - id := req.Header.Get(echo.HeaderXRequestID) - if id == "" { - id = res.Header().Get(echo.HeaderXRequestID) - } - return buf.WriteString(id) - case "remote_ip": - return buf.WriteString(c.RealIP()) - case "host": - return buf.WriteString(req.Host) - case "uri": - return buf.WriteString(req.RequestURI) - case "method": - return buf.WriteString(req.Method) - case "path": - p := req.URL.Path - if p == "" { - p = "/" - } - return buf.WriteString(p) - case "protocol": - return buf.WriteString(req.Proto) - case "referer": - return buf.WriteString(req.Referer()) - case "user_agent": - return buf.WriteString(req.UserAgent()) - case "status": - n := res.Status - s := config.colorer.Green(n) - switch { - case n >= 500: - s = config.colorer.Red(n) - case n >= 400: - s = config.colorer.Yellow(n) - case n >= 300: - s = config.colorer.Cyan(n) - } - return buf.WriteString(s) - case "error": - if err != nil { - // Error may contain invalid JSON e.g. `"` - b, _ := json.Marshal(err.Error()) - b = b[1 : len(b)-1] - return buf.Write(b) - } - case "latency": - l := stop.Sub(start) - return buf.WriteString(strconv.FormatInt(int64(l), 10)) - case "latency_human": - return buf.WriteString(stop.Sub(start).String()) - case "bytes_in": - cl := req.Header.Get(echo.HeaderContentLength) - if cl == "" { - cl = "0" - } - return buf.WriteString(cl) - case "bytes_out": - return buf.WriteString(strconv.FormatInt(res.Size, 10)) - default: - switch { - case strings.HasPrefix(tag, "header:"): - return buf.Write([]byte(c.Request().Header.Get(tag[7:]))) - case strings.HasPrefix(tag, "query:"): - return buf.Write([]byte(c.QueryParam(tag[6:]))) - case strings.HasPrefix(tag, "form:"): - return buf.Write([]byte(c.FormValue(tag[5:]))) - case strings.HasPrefix(tag, "cookie:"): - cookie, err := c.Cookie(tag[7:]) - if err == nil { - return buf.Write([]byte(cookie.Value)) - } - } - } - return 0, nil - }); err != nil { - return - } - - if config.Output == nil { - _, err = c.Logger().Output().Write(buf.Bytes()) - return - } - _, err = config.Output.Write(buf.Bytes()) - return - } - } -} diff --git a/middleware/logger_test.go b/middleware/logger_test.go deleted file mode 100644 index b196bc6c8..000000000 --- a/middleware/logger_test.go +++ /dev/null @@ -1,173 +0,0 @@ -package middleware - -import ( - "bytes" - "encoding/json" - "errors" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - "time" - "unsafe" - - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" -) - -func TestLogger(t *testing.T) { - // Note: Just for the test coverage, not a real test. - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := Logger()(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - - // Status 2xx - h(c) - - // Status 3xx - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = Logger()(func(c echo.Context) error { - return c.String(http.StatusTemporaryRedirect, "test") - }) - h(c) - - // Status 4xx - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = Logger()(func(c echo.Context) error { - return c.String(http.StatusNotFound, "test") - }) - h(c) - - // Status 5xx with empty path - req = httptest.NewRequest(http.MethodGet, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = Logger()(func(c echo.Context) error { - return errors.New("error") - }) - h(c) -} - -func TestLoggerIPAddress(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) - ip := "127.0.0.1" - h := Logger()(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - - // With X-Real-IP - req.Header.Add(echo.HeaderXRealIP, ip) - h(c) - assert.Contains(t, buf.String(), ip) - - // With X-Forwarded-For - buf.Reset() - req.Header.Del(echo.HeaderXRealIP) - req.Header.Add(echo.HeaderXForwardedFor, ip) - h(c) - assert.Contains(t, buf.String(), ip) - - buf.Reset() - h(c) - assert.Contains(t, buf.String(), ip) -} - -func TestLoggerTemplate(t *testing.T) { - buf := new(bytes.Buffer) - - e := echo.New() - e.Use(LoggerWithConfig(LoggerConfig{ - Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` + - `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + - `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` + - `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}", "protocol":"${protocol}"` + - `"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n", - Output: buf, - })) - - e.GET("/", func(c echo.Context) error { - return c.String(http.StatusOK, "Header Logged") - }) - - req := httptest.NewRequest(http.MethodGet, "/?username=apagano-param&password=secret", nil) - req.RequestURI = "/" - req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") - req.Header.Add("Referer", "google.com") - req.Header.Add("User-Agent", "echo-tests-agent") - req.Header.Add("X-Custom-Header", "AAA-CUSTOM-VALUE") - req.Header.Add("X-Request-ID", "6ba7b810-9dad-11d1-80b4-00c04fd430c8") - req.Header.Add("Cookie", "_ga=GA1.2.000000000.0000000000; session=ac08034cd216a647fc2eb62f2bcf7b810") - req.Form = url.Values{ - "username": []string{"apagano-form"}, - "password": []string{"secret-form"}, - } - - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - - cases := map[string]bool{ - "apagano-param": true, - "apagano-form": true, - "AAA-CUSTOM-VALUE": true, - "BBB-CUSTOM-VALUE": false, - "secret-form": false, - "hexvalue": false, - "GET": true, - "127.0.0.1": true, - "\"path\":\"/\"": true, - "\"uri\":\"/\"": true, - "\"status\":200": true, - "\"bytes_in\":0": true, - "google.com": true, - "echo-tests-agent": true, - "6ba7b810-9dad-11d1-80b4-00c04fd430c8": true, - "ac08034cd216a647fc2eb62f2bcf7b810": true, - } - - for token, present := range cases { - assert.True(t, strings.Contains(buf.String(), token) == present, "Case: "+token) - } -} - -func TestLoggerCustomTimestamp(t *testing.T) { - buf := new(bytes.Buffer) - customTimeFormat := "2006-01-02 15:04:05.00000" - e := echo.New() - e.Use(LoggerWithConfig(LoggerConfig{ - Format: `{"time":"${time_custom}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` + - `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + - `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` + - `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}",` + - `"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n", - CustomTimeFormat: customTimeFormat, - Output: buf, - })) - - e.GET("/", func(c echo.Context) error { - return c.String(http.StatusOK, "custom time stamp test") - }) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - - var objs map[string]*json.RawMessage - if err := json.Unmarshal([]byte(buf.String()), &objs); err != nil { - panic(err) - } - loggedTime := *(*string)(unsafe.Pointer(objs["time"])) - _, err := time.Parse(customTimeFormat, loggedTime) - assert.Error(t, err) -} diff --git a/middleware/method_override.go b/middleware/method_override.go index 92b14d2ed..25ec1f935 100644 --- a/middleware/method_override.go +++ b/middleware/method_override.go @@ -1,33 +1,32 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( "net/http" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // MethodOverrideConfig defines the config for MethodOverride middleware. - MethodOverrideConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// MethodOverrideConfig defines the config for MethodOverride middleware. +type MethodOverrideConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Getter is a function that gets overridden method from the request. - // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). - Getter MethodOverrideGetter - } + // Getter is a function that gets overridden method from the request. + // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). + Getter MethodOverrideGetter +} - // MethodOverrideGetter is a function that gets overridden method from the request - MethodOverrideGetter func(echo.Context) string -) +// MethodOverrideGetter is a function that gets overridden method from the request +type MethodOverrideGetter func(c *echo.Context) string -var ( - // DefaultMethodOverrideConfig is the default MethodOverride middleware config. - DefaultMethodOverrideConfig = MethodOverrideConfig{ - Skipper: DefaultSkipper, - Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), - } -) +// DefaultMethodOverrideConfig is the default MethodOverride middleware config. +var DefaultMethodOverrideConfig = MethodOverrideConfig{ + Skipper: DefaultSkipper, + Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), +} // MethodOverride returns a MethodOverride middleware. // MethodOverride middleware checks for the overridden method from the request and @@ -38,9 +37,13 @@ func MethodOverride() echo.MiddlewareFunc { return MethodOverrideWithConfig(DefaultMethodOverrideConfig) } -// MethodOverrideWithConfig returns a MethodOverride middleware with config. -// See: `MethodOverride()`. +// MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration. func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration +func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultMethodOverrideConfig.Skipper @@ -50,7 +53,7 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } @@ -64,13 +67,13 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { } return next(c) } - } + }, nil } // MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from // the request header. func MethodFromHeader(header string) MethodOverrideGetter { - return func(c echo.Context) string { + return func(c *echo.Context) string { return c.Request().Header.Get(header) } } @@ -78,7 +81,7 @@ func MethodFromHeader(header string) MethodOverrideGetter { // MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the // form parameter. func MethodFromForm(param string) MethodOverrideGetter { - return func(c echo.Context) string { + return func(c *echo.Context) string { return c.FormValue(param) } } @@ -86,7 +89,7 @@ func MethodFromForm(param string) MethodOverrideGetter { // MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from // the query parameter. func MethodFromQuery(param string) MethodOverrideGetter { - return func(c echo.Context) string { + return func(c *echo.Context) string { return c.QueryParam(param) } } diff --git a/middleware/method_override_test.go b/middleware/method_override_test.go index 5760b1581..525ad10ba 100644 --- a/middleware/method_override_test.go +++ b/middleware/method_override_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -6,14 +9,14 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestMethodOverride(t *testing.T) { e := echo.New() m := MethodOverride() - h := func(c echo.Context) error { + h := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } @@ -22,28 +25,68 @@ func TestMethodOverride(t *testing.T) { rec := httptest.NewRecorder() req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) c := e.NewContext(req, rec) - m(h)(c) + + err := m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_formParam(t *testing.T) { + e := echo.New() + h := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + // Override with form parameter - m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")}) - req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) - rec = httptest.NewRecorder() + m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware() + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) + rec := httptest.NewRecorder() req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) - c = e.NewContext(req, rec) - m(h)(c) + c := e.NewContext(req, rec) + + err = m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_queryParam(t *testing.T) { + e := echo.New() + h := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } // Override with query parameter - m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")}) - req = httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - m(h)(c) + m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware() + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err = m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_ignoreGet(t *testing.T) { + e := echo.New() + m := MethodOverride() + h := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } // Ignore `GET` - req = httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodGet, req.Method) } diff --git a/middleware/middleware.go b/middleware/middleware.go index d0b7153cb..4562d03b5 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -1,21 +1,22 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( + "net/http" "regexp" "strconv" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // Skipper defines a function to skip middleware. Returning true skips processing - // the middleware. - Skipper func(echo.Context) bool +// Skipper defines a function to skip middleware. Returning true skips processing the middleware. +type Skipper func(c *echo.Context) bool - // BeforeFunc defines a function which is executed just before the middleware. - BeforeFunc func(echo.Context) -) +// BeforeFunc defines a function which is executed just before the middleware. +type BeforeFunc func(c *echo.Context) func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { groups := pattern.FindAllStringSubmatch(input, -1) @@ -32,7 +33,65 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { return strings.NewReplacer(replace...) } +func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string { + // Initialize + rulesRegex := map[*regexp.Regexp]string{} + for k, v := range rewrite { + k = regexp.QuoteMeta(k) + k = strings.ReplaceAll(k, `\*`, "(.*?)") + if strings.HasPrefix(k, `\^`) { + k = strings.ReplaceAll(k, `\^`, "^") + } + k = k + "$" + rulesRegex[regexp.MustCompile(k)] = v + } + return rulesRegex +} + +func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error { + if len(rewriteRegex) == 0 { + return nil + } + + // Depending how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path. + // We only want to use path part for rewriting and therefore trim prefix if it exists + rawURI := req.RequestURI + if rawURI != "" && rawURI[0] != '/' { + prefix := "" + if req.URL.Scheme != "" { + prefix = req.URL.Scheme + "://" + } + if req.URL.Host != "" { + prefix += req.URL.Host // host or host:port + } + if prefix != "" { + rawURI = strings.TrimPrefix(rawURI, prefix) + } + } + + for k, v := range rewriteRegex { + if replacer := captureTokens(k, rawURI); replacer != nil { + url, err := req.URL.Parse(replacer.Replace(v)) + if err != nil { + return err + } + req.URL = url + + return nil // rewrite only once + } + } + return nil +} + // DefaultSkipper returns false which processes the middleware. -func DefaultSkipper(echo.Context) bool { +func DefaultSkipper(c *echo.Context) bool { return false } + +func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc { + mw, err := config.ToMiddleware() + if err != nil { + panic(err) + } + return mw +} diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go new file mode 100644 index 000000000..28407ed5c --- /dev/null +++ b/middleware/middleware_test.go @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package middleware + +import ( + "bufio" + "errors" + "github.com/stretchr/testify/assert" + "net" + "net/http" + "net/http/httptest" + "regexp" + "testing" +) + +func TestRewriteURL(t *testing.T) { + var testCases = []struct { + whenURL string + expectPath string + expectRawPath string + expectQuery string + expectErr string + }{ + { + whenURL: "http://localhost:8080/old", + expectPath: "/new", + expectRawPath: "", + }, + { // encoded `ol%64` (decoded `old`) should not be rewritten to `/new` + whenURL: "/ol%64", // `%64` is decoded `d` + expectPath: "/old", + expectRawPath: "/ol%64", + }, + { + whenURL: "http://localhost:8080/users/+_+/orders/___++++?test=1", + expectPath: "/user/+_+/order/___++++", + expectRawPath: "", + expectQuery: "test=1", + }, + { + whenURL: "http://localhost:8080/users/%20a/orders/%20aa", + expectPath: "/user/ a/order/ aa", + expectRawPath: "", + }, + { + whenURL: "http://localhost:8080/%47%6f%2f?test=1", + expectPath: "/Go/", + expectRawPath: "/%47%6f%2f", + expectQuery: "test=1", + }, + { + whenURL: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F", + expectPath: "/user/jill/order/T/cO4lW/t/Vp/", + expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + }, + { // do nothing, replace nothing + whenURL: "http://localhost:8080/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + expectPath: "/user/jill/order/T/cO4lW/t/Vp/", + expectRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + }, + { + whenURL: "http://localhost:8080/static", + expectPath: "/static/path", + expectRawPath: "", + expectQuery: "role=AUTHOR&limit=1000", + }, + { + whenURL: "/static", + expectPath: "/static/path", + expectRawPath: "", + expectQuery: "role=AUTHOR&limit=1000", + }, + } + + rules := map[*regexp.Regexp]string{ + regexp.MustCompile("^/old$"): "/new", + regexp.MustCompile("^/users/(.*?)/orders/(.*?)$"): "/user/$1/order/$2", + regexp.MustCompile("^/static$"): "/static/path?role=AUTHOR&limit=1000", + } + + for _, tc := range testCases { + t.Run(tc.whenURL, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + + err := rewriteURL(rules, req) + + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.expectPath, req.URL.Path) // Path field is stored in decoded form: /%47%6f%2f becomes /Go/. + assert.Equal(t, tc.expectRawPath, req.URL.RawPath) // RawPath, an optional field which only gets set if the default encoding is different from Path. + assert.Equal(t, tc.expectQuery, req.URL.RawQuery) + }) + } +} + +type testResponseWriterNoFlushHijack struct { +} + +func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) { +} +func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) { + return 0, nil +} +func (w *testResponseWriterNoFlushHijack) Header() http.Header { + return nil +} + +type testResponseWriterUnwrapper struct { + unwrapCalled int + rw http.ResponseWriter +} + +func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) { +} +func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) { + return 0, nil +} +func (w *testResponseWriterUnwrapper) Header() http.Header { + return nil +} +func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter { + w.unwrapCalled++ + return w.rw +} + +type testResponseWriterUnwrapperHijack struct { + testResponseWriterUnwrapper +} + +func (w *testResponseWriterUnwrapperHijack) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, errors.New("can hijack") +} diff --git a/middleware/proxy.go b/middleware/proxy.go index ef5602bd6..1996032f7 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -1,105 +1,157 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( + "context" + "crypto/tls" + "errors" "fmt" "io" "math/rand" "net" "net/http" + "net/http/httputil" "net/url" "regexp" "strings" "sync" - "sync/atomic" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // TODO: Handle TLS proxy -type ( - // ProxyConfig defines the config for Proxy middleware. - ProxyConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Balancer defines a load balancing technique. - // Required. - Balancer ProxyBalancer - - // Rewrite defines URL path rewrite rules. The values captured in asterisk can be - // retrieved by index e.g. $1, $2 and so on. - // Examples: - // "/old": "/new", - // "/api/*": "/$1", - // "/js/*": "/public/javascripts/$1", - // "/users/*/orders/*": "/user/$1/order/$2", - Rewrite map[string]string - - // Context key to store selected ProxyTarget into context. - // Optional. Default value "target". - ContextKey string - - // To customize the transport to remote. - // Examples: If custom TLS certificates are required. - Transport http.RoundTripper - - rewriteRegex map[*regexp.Regexp]string - } +// ProxyConfig defines the config for Proxy middleware. +type ProxyConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Balancer defines a load balancing technique. + // Required. + Balancer ProxyBalancer + + // RetryCount defines the number of times a failed proxied request should be retried + // using the next available ProxyTarget. Defaults to 0, meaning requests are never retried. + RetryCount int + + // RetryFilter defines a function used to determine if a failed request to a + // ProxyTarget should be retried. The RetryFilter will only be called when the number + // of previous retries is less than RetryCount. If the function returns true, the + // request will be retried. The provided error indicates the reason for the request + // failure. When the ProxyTarget is unavailable, the error will be an instance of + // echo.HTTPError with a code of http.StatusBadGateway. In all other cases, the error + // will indicate an internal error in the Proxy middleware. When a RetryFilter is not + // specified, all requests that fail with http.StatusBadGateway will be retried. A custom + // RetryFilter can be provided to only retry specific requests. Note that RetryFilter is + // only called when the request to the target fails, or an internal error in the Proxy + // middleware has occurred. Successful requests that return a non-200 response code cannot + // be retried. + RetryFilter func(c *echo.Context, e error) bool + + // ErrorHandler defines a function which can be used to return custom errors from + // the Proxy middleware. ErrorHandler is only invoked when there has been + // either an internal error in the Proxy middleware or the ProxyTarget is + // unavailable. Due to the way requests are proxied, ErrorHandler is not invoked + // when a ProxyTarget returns a non-200 response. In these cases, the response + // is already written so errors cannot be modified. ErrorHandler is only + // invoked after all retry attempts have been exhausted. + ErrorHandler func(c *echo.Context, err error) error + + // Rewrite defines URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Examples: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + Rewrite map[string]string + + // RegexRewrite defines rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRewrite map[*regexp.Regexp]string + + // Context key to store selected ProxyTarget into context. + // Optional. Default value "target". + ContextKey string + + // To customize the transport to remote. + // Examples: If custom TLS certificates are required. + Transport http.RoundTripper + + // ModifyResponse defines function to modify response from ProxyTarget. + ModifyResponse func(*http.Response) error +} - // ProxyTarget defines the upstream target. - ProxyTarget struct { - Name string - URL *url.URL - Meta echo.Map - } +// ProxyTarget defines the upstream target. +type ProxyTarget struct { + Name string + URL *url.URL + Meta map[string]any +} - // ProxyBalancer defines an interface to implement a load balancing technique. - ProxyBalancer interface { - AddTarget(*ProxyTarget) bool - RemoveTarget(string) bool - Next(echo.Context) *ProxyTarget - } +// ProxyBalancer defines an interface to implement a load balancing technique. +type ProxyBalancer interface { + AddTarget(target *ProxyTarget) bool + RemoveTarget(targetName string) bool + Next(c *echo.Context) (*ProxyTarget, error) +} - commonBalancer struct { - targets []*ProxyTarget - mutex sync.RWMutex - } +type commonBalancer struct { + targets []*ProxyTarget + mutex sync.Mutex +} - // RandomBalancer implements a random load balancing technique. - randomBalancer struct { - *commonBalancer - random *rand.Rand - } +// RandomBalancer implements a random load balancing technique. +type randomBalancer struct { + commonBalancer + random *rand.Rand +} - // RoundRobinBalancer implements a round-robin load balancing technique. - roundRobinBalancer struct { - *commonBalancer - i uint32 - } -) +// RoundRobinBalancer implements a round-robin load balancing technique. +type roundRobinBalancer struct { + commonBalancer + // tracking the index on `targets` slice for the next `*ProxyTarget` to be used + i int +} -var ( - // DefaultProxyConfig is the default Proxy middleware config. - DefaultProxyConfig = ProxyConfig{ - Skipper: DefaultSkipper, - ContextKey: "target", +// DefaultProxyConfig is the default Proxy middleware config. +var DefaultProxyConfig = ProxyConfig{ + Skipper: DefaultSkipper, + ContextKey: "target", +} + +func proxyRaw(c *echo.Context, t *ProxyTarget, config ProxyConfig) http.Handler { + var dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) + if transport, ok := config.Transport.(*http.Transport); ok { + if transport.TLSClientConfig != nil { + d := tls.Dialer{ + Config: transport.TLSClientConfig, + } + dialFunc = d.DialContext + } + } + if dialFunc == nil { + var d net.Dialer + dialFunc = d.DialContext } -) -func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - in, _, err := c.Response().Hijack() + in, _, err := http.NewResponseController(w).Hijack() if err != nil { - c.Set("_error", fmt.Sprintf("proxy raw, hijack error=%v, url=%s", t.URL, err)) + c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL)) return } defer in.Close() - out, err := net.Dial("tcp", t.URL.Host) + out, err := dialFunc(c.Request().Context(), "tcp", t.URL.Host) if err != nil { - c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err))) + c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL))) return } defer out.Close() @@ -107,53 +159,66 @@ func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { // Write header err = r.Write(out) if err != nil { - c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err))) + c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", err, t.URL))) return } errCh := make(chan error, 2) cp := func(dst io.Writer, src io.Reader) { - _, err = io.Copy(dst, src) - errCh <- err + _, copyErr := io.Copy(dst, src) + errCh <- copyErr } go cp(out, in) go cp(in, out) - err = <-errCh - if err != nil && err != io.EOF { - c.Set("_error", fmt.Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err)) + + // Wait for BOTH goroutines to complete + err1 := <-errCh + err2 := <-errCh + + if err1 != nil && err1 != io.EOF { + c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err1, t.URL)) + } else if err2 != nil && err2 != io.EOF { + c.Set("_error", fmt.Errorf("proxy raw, copy body error=%w, url=%s", err2, t.URL)) } }) } // NewRandomBalancer returns a random proxy balancer. func NewRandomBalancer(targets []*ProxyTarget) ProxyBalancer { - b := &randomBalancer{commonBalancer: new(commonBalancer)} + b := randomBalancer{} b.targets = targets - return b + // G404 (CWE-338): Use of weak random number generator (math/rand or math/rand/v2 instead of crypto/rand) + // this random is used to select next target. I can not think of reason this must be cryptographically safe. If you can - please open PR. + b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) // #nosec G404 + return &b } // NewRoundRobinBalancer returns a round-robin proxy balancer. func NewRoundRobinBalancer(targets []*ProxyTarget) ProxyBalancer { - b := &roundRobinBalancer{commonBalancer: new(commonBalancer)} + b := roundRobinBalancer{} b.targets = targets - return b + return &b } -// AddTarget adds an upstream target to the list. +// AddTarget adds an upstream target to the list and returns `true`. +// +// However, if a target with the same name already exists then the operation is aborted returning `false`. func (b *commonBalancer) AddTarget(target *ProxyTarget) bool { + b.mutex.Lock() + defer b.mutex.Unlock() for _, t := range b.targets { if t.Name == target.Name { return false } } - b.mutex.Lock() - defer b.mutex.Unlock() b.targets = append(b.targets, target) return true } -// RemoveTarget removes an upstream target from the list. +// RemoveTarget removes an upstream target from the list by name. +// +// Returns `true` on success, `false` if no target with the name is found. func (b *commonBalancer) RemoveTarget(name string) bool { b.mutex.Lock() defer b.mutex.Unlock() @@ -167,21 +232,57 @@ func (b *commonBalancer) RemoveTarget(name string) bool { } // Next randomly returns an upstream target. -func (b *randomBalancer) Next(c echo.Context) *ProxyTarget { - if b.random == nil { - b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) +// +// Note: `nil` is returned in case upstream target list is empty. +func (b *randomBalancer) Next(c *echo.Context) (*ProxyTarget, error) { + b.mutex.Lock() + defer b.mutex.Unlock() + if len(b.targets) == 0 { + return nil, nil + } else if len(b.targets) == 1 { + return b.targets[0], nil } - b.mutex.RLock() - defer b.mutex.RUnlock() - return b.targets[b.random.Intn(len(b.targets))] + return b.targets[b.random.Intn(len(b.targets))], nil } -// Next returns an upstream target using round-robin technique. -func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget { - b.i = b.i % uint32(len(b.targets)) - t := b.targets[b.i] - atomic.AddUint32(&b.i, 1) - return t +// Next returns an upstream target using round-robin technique. In the case +// where a previously failed request is being retried, the round-robin +// balancer will attempt to use the next target relative to the original +// request. If the list of targets held by the balancer is modified while a +// failed request is being retried, it is possible that the balancer will +// return the original failed target. +// +// Note: `nil` is returned in case upstream target list is empty. +func (b *roundRobinBalancer) Next(c *echo.Context) (*ProxyTarget, error) { + b.mutex.Lock() + defer b.mutex.Unlock() + if len(b.targets) == 0 { + return nil, nil + } else if len(b.targets) == 1 { + return b.targets[0], nil + } + + var i int + const lastIdxKey = "_round_robin_last_index" + // This request is a retry, start from the index of the previous + // target to ensure we don't attempt to retry the request with + // the same failed target + if c.Get(lastIdxKey) != nil { + i = c.Get(lastIdxKey).(int) + i++ + if i >= len(b.targets) { + i = 0 + } + } else { + // This is a first time request, use the global index + if b.i >= len(b.targets) { + b.i = 0 + } + i = b.i + b.i++ + } + c.Set(lastIdxKey, i) + return b.targets[i], nil } // Proxy returns a Proxy middleware. @@ -193,45 +294,63 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc { return ProxyWithConfig(c) } -// ProxyWithConfig returns a Proxy middleware with config. -// See: `Proxy()` +// ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid. +// +// Proxy middleware forwards the request to upstream server using a configured load balancing technique. func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration +func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultProxyConfig.Skipper } + if config.ContextKey == "" { + config.ContextKey = DefaultProxyConfig.ContextKey + } if config.Balancer == nil { - panic("echo: proxy middleware requires balancer") + return nil, errors.New("echo proxy middleware requires balancer") + } + if config.RetryFilter == nil { + config.RetryFilter = func(c *echo.Context, e error) bool { + if httpErr, ok := e.(*echo.HTTPError); ok { + return httpErr.Code == http.StatusBadGateway + } + return false + } + } + if config.ErrorHandler == nil { + config.ErrorHandler = func(c *echo.Context, err error) error { + return err + } } - config.rewriteRegex = map[*regexp.Regexp]string{} - // Initialize - for k, v := range config.Rewrite { - k = strings.Replace(k, "*", "(\\S*)", -1) - config.rewriteRegex[regexp.MustCompile(k)] = v + if config.Rewrite != nil { + if config.RegexRewrite == nil { + config.RegexRewrite = make(map[*regexp.Regexp]string) + } + for k, v := range rewriteRulesRegex(config.Rewrite) { + config.RegexRewrite[k] = v + } } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c *echo.Context) (err error) { if config.Skipper(c) { return next(c) } req := c.Request() res := c.Response() - tgt := config.Balancer.Next(c) - c.Set(config.ContextKey, tgt) - - // Rewrite - for k, v := range config.rewriteRegex { - replacer := captureTokens(k, req.URL.Path) - if replacer != nil { - req.URL.Path = replacer.Replace(v) - } + if err := rewriteURL(config.RegexRewrite, req); err != nil { + return config.ErrorHandler(c, err) } // Fix header - if req.Header.Get(echo.HeaderXRealIP) == "" { + // Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream. + // However, for backward compatibility, legacy behavior is preserved unless you configure Echo#IPExtractor. + if req.Header.Get(echo.HeaderXRealIP) == "" || c.Echo().IPExtractor != nil { req.Header.Set(echo.HeaderXRealIP, c.RealIP()) } if req.Header.Get(echo.HeaderXForwardedProto) == "" { @@ -241,19 +360,82 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { req.Header.Set(echo.HeaderXForwardedFor, c.RealIP()) } - // Proxy - switch { - case c.IsWebSocket(): - proxyRaw(tgt, c).ServeHTTP(res, req) - case req.Header.Get(echo.HeaderAccept) == "text/event-stream": - default: - proxyHTTP(tgt, c, config).ServeHTTP(res, req) - } - if e, ok := c.Get("_error").(error); ok { - err = e + retries := config.RetryCount + for { + tgt, err := config.Balancer.Next(c) + if err != nil { + return config.ErrorHandler(c, err) + } + + c.Set(config.ContextKey, tgt) + + //If retrying a failed request, clear any previous errors from + //context here so that balancers have the option to check for + //errors that occurred using previous target + if retries < config.RetryCount { + c.Set("_error", nil) + } + + // This is needed for ProxyConfig.ModifyResponse and/or ProxyConfig.Transport to be able to process the Request + // that Balancer may have replaced with c.SetRequest. + req = c.Request() + + // Proxy + switch { + case c.IsWebSocket(): + proxyRaw(c, tgt, config).ServeHTTP(res, req) + default: // even SSE requests + proxyHTTP(c, tgt, config).ServeHTTP(res, req) + } + + err, hasError := c.Get("_error").(error) + if !hasError { + return nil + } + + retry := retries > 0 && config.RetryFilter(c, err) + if !retry { + return config.ErrorHandler(c, err) + } + + retries-- } + } + }, nil +} - return +// StatusCodeContextCanceled is a custom HTTP status code for situations +// where a client unexpectedly closed the connection to the server. +// As there is no standard error code for "client closed connection", but +// various well-known HTTP clients and server implement this HTTP code we use +// 499 too instead of the more problematic 5xx, which does not allow to detect this situation +const StatusCodeContextCanceled = 499 + +func proxyHTTP(c *echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler { + proxy := httputil.NewSingleHostReverseProxy(tgt.URL) + proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { + desc := tgt.URL.String() + if tgt.Name != "" { + desc = fmt.Sprintf("%s(%s)", tgt.Name, tgt.URL.String()) + } + // If the client canceled the request (usually by closing the connection), we can report a + // client error (4xx) instead of a server error (5xx) to correctly identify the situation. + // The Go standard library (at of late 2020) wraps the exported, standard + // context. Canceled error with unexported garbage value requiring a substring check, see + // https://github.com/golang/go/blob/6965b01ea248cabb70c3749fd218b36089a21efb/src/net/net.go#L416-L430 + // From Caddy https://github.com/caddyserver/caddy/blob/afa778ae05503f563af0d1015cdf7e5e78b1eeec/modules/caddyhttp/reverseproxy/reverseproxy.go#L1352 + if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "operation was canceled") { + httpError := echo.NewHTTPError(StatusCodeContextCanceled, "client closed connection").Wrap(err) + c.Set("_error", httpError) + } else { + httpError := echo.NewHTTPError( + http.StatusBadGateway, + "remote server unreachable, could not proxy request", + ).Wrap(fmt.Errorf("server: %s, err: %w", desc, err)) + c.Set("_error", httpError) } } + proxy.Transport = config.Transport + proxy.ModifyResponse = config.ModifyResponse + return proxy } diff --git a/middleware/proxy_1_11.go b/middleware/proxy_1_11.go deleted file mode 100644 index 12b7568bf..000000000 --- a/middleware/proxy_1_11.go +++ /dev/null @@ -1,24 +0,0 @@ -// +build go1.11 - -package middleware - -import ( - "fmt" - "net/http" - "net/http/httputil" - - "github.com/labstack/echo/v4" -) - -func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { - proxy := httputil.NewSingleHostReverseProxy(tgt.URL) - proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { - desc := tgt.URL.String() - if tgt.Name != "" { - desc = fmt.Sprintf("%s(%s)", tgt.Name, tgt.URL.String()) - } - c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err))) - } - proxy.Transport = config.Transport - return proxy -} diff --git a/middleware/proxy_1_11_n.go b/middleware/proxy_1_11_n.go deleted file mode 100644 index 9a78929fe..000000000 --- a/middleware/proxy_1_11_n.go +++ /dev/null @@ -1,14 +0,0 @@ -// +build !go1.11 - -package middleware - -import ( - "net/http" - "net/http/httputil" - - "github.com/labstack/echo/v4" -) - -func proxyHTTP(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { - return httputil.NewSingleHostReverseProxy(t.URL) -} diff --git a/middleware/proxy_1_11_test.go b/middleware/proxy_1_11_test.go deleted file mode 100644 index 26feaabaa..000000000 --- a/middleware/proxy_1_11_test.go +++ /dev/null @@ -1,53 +0,0 @@ -// +build go1.11 - -package middleware - -import ( - "net/http" - "net/http/httptest" - "net/url" - "testing" - - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" -) - -func TestProxy_1_11(t *testing.T) { - // Setup - url1, _ := url.Parse("http://127.0.0.1:27121") - url2, _ := url.Parse("http://127.0.0.1:27122") - - targets := []*ProxyTarget{ - { - Name: "target 1", - URL: url1, - }, - { - Name: "target 2", - URL: url2, - }, - } - rb := NewRandomBalancer(nil) - // must add targets: - for _, target := range targets { - assert.True(t, rb.AddTarget(target)) - } - - // must ignore duplicates: - for _, target := range targets { - assert.False(t, rb.AddTarget(target)) - } - - // Random - e := echo.New() - e.Use(Proxy(rb)) - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - - // Remote unreachable - rec = httptest.NewRecorder() - req.URL.Path = "/api/users" - e.ServeHTTP(rec, req) - assert.Equal(t, "/api/users", req.URL.Path) - assert.Equal(t, http.StatusBadGateway, rec.Code) -} diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index 1a375db86..420be3240 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -1,16 +1,30 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( + "bytes" + "context" + "crypto/tls" + "errors" "fmt" + "io" + "net" "net/http" "net/http/httptest" "net/url" + "regexp" + "sync" "testing" + "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" + "golang.org/x/net/websocket" ) +// Assert expected with url.EscapedPath method to obtain the path. func TestProxy(t *testing.T) { // Setup t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -47,7 +61,7 @@ func TestProxy(t *testing.T) { // Random e := echo.New() - e.Use(Proxy(rb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -69,53 +83,973 @@ func TestProxy(t *testing.T) { // Round-robin rrb := NewRoundRobinBalancer(targets) e = echo.New() - e.Use(Proxy(rrb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb})) + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) body = rec.Body.String() assert.Equal(t, "target 1", body) + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) body = rec.Body.String() assert.Equal(t, "target 2", body) - // Rewrite + // ModifyResponse e = echo.New() e.Use(ProxyWithConfig(ProxyConfig{ Balancer: rrb, - Rewrite: map[string]string{ - "/old": "/new", - "/api/*": "/$1", - "/js/*": "/public/javascripts/$1", - "/users/*/orders/*": "/user/$1/order/$2", + ModifyResponse: func(res *http.Response) error { + res.Body = io.NopCloser(bytes.NewBuffer([]byte("modified"))) + res.Header.Set("X-Modified", "1") + return nil }, })) - req.URL.Path = "/api/users" - e.ServeHTTP(rec, req) - assert.Equal(t, "/users", req.URL.Path) - req.URL.Path = "/js/main.js" - e.ServeHTTP(rec, req) - assert.Equal(t, "/public/javascripts/main.js", req.URL.Path) - req.URL.Path = "/old" - e.ServeHTTP(rec, req) - assert.Equal(t, "/new", req.URL.Path) - req.URL.Path = "/users/jack/orders/1" + + rec = httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/user/jack/order/1", req.URL.Path) - assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "modified", rec.Body.String()) + assert.Equal(t, "1", rec.Header().Get("X-Modified")) // ProxyTarget is set in context contextObserver := func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c *echo.Context) (err error) { next(c) assert.Contains(t, targets, c.Get("target"), "target is not set in context") return nil } } - rrb1 := NewRoundRobinBalancer(targets) + e = echo.New() e.Use(contextObserver) - e.Use(Proxy(rrb1)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)})) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) } + +func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) { + assert.Panics(t, func() { + ProxyWithConfig(ProxyConfig{Balancer: nil}) + }) +} + +func TestProxyRealIPHeader(t *testing.T) { + // Setup + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + defer upstream.Close() + url, _ := url.Parse(upstream.URL) + rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) + e := echo.New() + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb})) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + remoteAddrIP, _, _ := net.SplitHostPort(req.RemoteAddr) + realIPHeaderIP := "203.0.113.1" + extractedRealIP := "203.0.113.10" + tests := []*struct { + hasRealIPheader bool + hasIPExtractor bool + expectedXRealIP string + }{ + {false, false, remoteAddrIP}, + {false, true, extractedRealIP}, + {true, false, realIPHeaderIP}, + {true, true, extractedRealIP}, + } + + for _, tt := range tests { + if tt.hasRealIPheader { + req.Header.Set(echo.HeaderXRealIP, realIPHeaderIP) + } else { + req.Header.Del(echo.HeaderXRealIP) + } + if tt.hasIPExtractor { + e.IPExtractor = func(*http.Request) string { + return extractedRealIP + } + } else { + e.IPExtractor = nil + } + e.ServeHTTP(rec, req) + assert.Equal(t, tt.expectedXRealIP, req.Header.Get(echo.HeaderXRealIP), "hasRealIPheader: %t / hasIPExtractor: %t", tt.hasRealIPheader, tt.hasIPExtractor) + } +} + +func TestProxyRewrite(t *testing.T) { + var testCases = []struct { + whenPath string + expectProxiedURI string + expectStatus int + }{ + { + whenPath: "/api/users", + expectProxiedURI: "/users", + expectStatus: http.StatusOK, + }, + { + whenPath: "/js/main.js", + expectProxiedURI: "/public/javascripts/main.js", + expectStatus: http.StatusOK, + }, + { + whenPath: "/old", + expectProxiedURI: "/new", + expectStatus: http.StatusOK, + }, + { + whenPath: "/users/jack/orders/1", + expectProxiedURI: "/user/jack/order/1", + expectStatus: http.StatusOK, + }, + { + whenPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + expectProxiedURI: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + expectStatus: http.StatusOK, + }, + { // ` ` (space) is encoded by httpClient to `%20` when doing request to Echo. `%20` should not be double escaped when proxying request + whenPath: "/api/new users", + expectProxiedURI: "/new%20users", + expectStatus: http.StatusOK, + }, + { // query params should be proxied and not be modified + whenPath: "/api/users?limit=10", + expectProxiedURI: "/users?limit=10", + expectStatus: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.whenPath, func(t *testing.T) { + receivedRequestURI := make(chan string, 1) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server + // we need unmodified target to see if we are encoding/decoding the url in addition to rewrite/replace logic + // if original request had `%2F` we should not magically decode it to `/` as it would change what was requested + receivedRequestURI <- r.RequestURI + })) + defer upstream.Close() + serverURL, _ := url.Parse(upstream.URL) + rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: serverURL}}) + + // Rewrite + e := echo.New() + e.Use(ProxyWithConfig(ProxyConfig{ + Balancer: rrb, + Rewrite: map[string]string{ + "/old": "/new", + "/api/*": "/$1", + "/js/*": "/public/javascripts/$1", + "/users/*/orders/*": "/user/$1/order/$2", + }, + })) + + targetURL, _ := serverURL.Parse(tc.whenPath) + req := httptest.NewRequest(http.MethodGet, targetURL.String(), nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectStatus, rec.Code) + actualRequestURI := <-receivedRequestURI + assert.Equal(t, tc.expectProxiedURI, actualRequestURI) + }) + } +} + +func TestProxyRewriteRegex(t *testing.T) { + // Setup + receivedRequestURI := make(chan string, 1) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server + // we need unmodified target to see if we are encoding/decoding the url in addition to rewrite/replace logic + // if original request had `%2F` we should not magically decode it to `/` as it would change what was requested + receivedRequestURI <- r.RequestURI + })) + defer upstream.Close() + tmpUrL, _ := url.Parse(upstream.URL) + rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: tmpUrL}}) + + // Rewrite + e := echo.New() + e.Use(ProxyWithConfig(ProxyConfig{ + Balancer: rrb, + Rewrite: map[string]string{ + "^/a/*": "/v1/$1", + "^/b/*/c/*": "/v2/$2/$1", + "^/c/*/*": "/v3/$2", + }, + RegexRewrite: map[*regexp.Regexp]string{ + regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1", + regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1", + }, + })) + + testCases := []struct { + requestPath string + statusCode int + expectPath string + }{ + {"/unmatched", http.StatusOK, "/unmatched"}, + {"/a/test", http.StatusOK, "/v1/test"}, + {"/b/foo/c/bar/baz", http.StatusOK, "/v2/bar/baz/foo"}, + {"/c/ignore/test", http.StatusOK, "/v3/test"}, + {"/c/ignore1/test/this", http.StatusOK, "/v3/test/this"}, + {"/x/ignore/test", http.StatusOK, "/v4/test"}, + {"/y/foo/bar", http.StatusOK, "/v5/bar/foo"}, + // NB: fragment is not added by golang httputil.NewSingleHostReverseProxy implementation + // $2 = `bar?q=1#frag`, $1 = `foo`. replaced uri = `/v5/bar?q=1#frag/foo` but httputil.NewSingleHostReverseProxy does not send `#frag/foo` (currently) + {"/y/foo/bar?q=1#frag", http.StatusOK, "/v5/bar?q=1"}, + } + + for _, tc := range testCases { + t.Run(tc.requestPath, func(t *testing.T) { + targetURL, _ := url.Parse(tc.requestPath) + req := httptest.NewRequest(http.MethodGet, targetURL.String(), nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + actualRequestURI := <-receivedRequestURI + assert.Equal(t, tc.expectPath, actualRequestURI) + assert.Equal(t, tc.statusCode, rec.Code) + }) + } +} + +func TestProxyError(t *testing.T) { + // Setup + url1, _ := url.Parse("http://127.0.0.1:27121") + url2, _ := url.Parse("http://127.0.0.1:27122") + + targets := []*ProxyTarget{ + { + Name: "target 1", + URL: url1, + }, + { + Name: "target 2", + URL: url2, + }, + } + rb := NewRandomBalancer(nil) + // must add targets: + for _, target := range targets { + assert.True(t, rb.AddTarget(target)) + } + + // must ignore duplicates: + for _, target := range targets { + assert.False(t, rb.AddTarget(target)) + } + + // Random + e := echo.New() + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) + req := httptest.NewRequest(http.MethodGet, "/", nil) + + // Remote unreachable + rec := httptest.NewRecorder() + req.URL.Path = "/api/users" + e.ServeHTTP(rec, req) + assert.Equal(t, "/api/users", req.URL.Path) + assert.Equal(t, http.StatusBadGateway, rec.Code) +} + +func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { + var timeoutStop sync.WaitGroup + timeoutStop.Add(1) + HTTPTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + timeoutStop.Wait() // wait until we have canceled the request + w.WriteHeader(http.StatusOK) + })) + defer HTTPTarget.Close() + targetURL, _ := url.Parse(HTTPTarget.URL) + target := &ProxyTarget{ + Name: "target", + URL: targetURL, + } + rb := NewRandomBalancer(nil) + assert.True(t, rb.AddTarget(target)) + e := echo.New() + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + ctx, cancel := context.WithCancel(req.Context()) + req = req.WithContext(ctx) + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + e.ServeHTTP(rec, req) + timeoutStop.Done() + assert.Equal(t, 499, rec.Code) +} + +type testProvider struct { + commonBalancer + target *ProxyTarget + err error +} + +func (p *testProvider) Next(c *echo.Context) (*ProxyTarget, error) { + return p.target, p.err +} + +func TestTargetProvider(t *testing.T) { + t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "target 1") + })) + defer t1.Close() + url1, _ := url.Parse(t1.URL) + + e := echo.New() + tp := &testProvider{} + tp.target = &ProxyTarget{Name: "target 1", URL: url1} + e.Use(Proxy(tp)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + e.ServeHTTP(rec, req) + body := rec.Body.String() + assert.Equal(t, "target 1", body) +} + +func TestFailNextTarget(t *testing.T) { + url1, err := url.Parse("http://dummy:8080") + assert.Nil(t, err) + + e := echo.New() + tp := &testProvider{} + tp.target = &ProxyTarget{Name: "target 1", URL: url1} + tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target") + + e.Use(Proxy(tp)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + e.ServeHTTP(rec, req) + body := rec.Body.String() + assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body) +} + +func TestRandomBalancerWithNoTargets(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Assert balancer with empty targets does return `nil` on `Next()` + rb := NewRandomBalancer(nil) + target, err := rb.Next(c) + assert.Nil(t, target) + assert.NoError(t, err) +} + +func TestRoundRobinBalancerWithNoTargets(t *testing.T) { + // Assert balancer with empty targets does return `nil` on `Next()` + rrb := NewRoundRobinBalancer([]*ProxyTarget{}) + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + target, err := rrb.Next(c) + assert.Nil(t, target) + assert.NoError(t, err) +} + +func TestProxyRetries(t *testing.T) { + newServer := func(res int) (*url.URL, *httptest.Server) { + server := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(res) + }), + ) + targetURL, _ := url.Parse(server.URL) + return targetURL, server + } + + targetURL, server := newServer(http.StatusOK) + defer server.Close() + goodTarget := &ProxyTarget{ + Name: "Good", + URL: targetURL, + } + + targetURL, server = newServer(http.StatusBadRequest) + defer server.Close() + goodTargetWith40X := &ProxyTarget{ + Name: "Good with 40X", + URL: targetURL, + } + + targetURL, _ = url.Parse("http://127.0.0.1:27121") + badTarget := &ProxyTarget{ + Name: "Bad", + URL: targetURL, + } + + alwaysRetryFilter := func(c *echo.Context, e error) bool { return true } + neverRetryFilter := func(c *echo.Context, e error) bool { return false } + + testCases := []struct { + name string + retryCount int + retryFilters []func(c *echo.Context, e error) bool + targets []*ProxyTarget + expectedResponse int + }{ + { + name: "retry count 0 does not attempt retry on fail", + targets: []*ProxyTarget{ + badTarget, + goodTarget, + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 1 does not attempt retry on success", + retryCount: 1, + targets: []*ProxyTarget{ + goodTarget, + }, + expectedResponse: http.StatusOK, + }, + { + name: "retry count 1 does retry on handler return true", + retryCount: 1, + retryFilters: []func(c *echo.Context, e error) bool{ + alwaysRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + goodTarget, + }, + expectedResponse: http.StatusOK, + }, + { + name: "retry count 1 does not retry on handler return false", + retryCount: 1, + retryFilters: []func(c *echo.Context, e error) bool{ + neverRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + goodTarget, + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 2 returns error when no more retries left", + retryCount: 2, + retryFilters: []func(c *echo.Context, e error) bool{ + alwaysRetryFilter, + alwaysRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + badTarget, + badTarget, + goodTarget, //Should never be reached as only 2 retries + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 2 returns error when retries left but handler returns false", + retryCount: 3, + retryFilters: []func(c *echo.Context, e error) bool{ + alwaysRetryFilter, + alwaysRetryFilter, + neverRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + badTarget, + badTarget, + goodTarget, //Should never be reached as retry handler returns false on 2nd check + }, + expectedResponse: http.StatusBadGateway, + }, + { + name: "retry count 3 succeeds", + retryCount: 3, + retryFilters: []func(c *echo.Context, e error) bool{ + alwaysRetryFilter, + alwaysRetryFilter, + alwaysRetryFilter, + }, + targets: []*ProxyTarget{ + badTarget, + badTarget, + badTarget, + goodTarget, + }, + expectedResponse: http.StatusOK, + }, + { + name: "40x responses are not retried", + retryCount: 1, + targets: []*ProxyTarget{ + goodTargetWith40X, + goodTarget, + }, + expectedResponse: http.StatusBadRequest, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + retryFilterCall := 0 + retryFilter := func(c *echo.Context, e error) bool { + if len(tc.retryFilters) == 0 { + assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall)) + } + + retryFilterCall++ + + nextRetryFilter := tc.retryFilters[0] + tc.retryFilters = tc.retryFilters[1:] + + return nextRetryFilter(c, e) + } + + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Balancer: NewRoundRobinBalancer(tc.targets), + RetryCount: tc.retryCount, + RetryFilter: retryFilter, + }, + )) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedResponse, rec.Code) + if len(tc.retryFilters) > 0 { + assert.FailNow(t, fmt.Sprintf("expected %d more retry handler calls", len(tc.retryFilters))) + } + }) + } +} + +func TestProxyRetryWithBackendTimeout(t *testing.T) { + + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.ResponseHeaderTimeout = time.Millisecond * 500 + + timeoutBackend := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(1 * time.Second) + w.WriteHeader(404) + }), + ) + defer timeoutBackend.Close() + + timeoutTargetURL, _ := url.Parse(timeoutBackend.URL) + goodBackend := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + }), + ) + defer goodBackend.Close() + + goodTargetURL, _ := url.Parse(goodBackend.URL) + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Transport: transport, + Balancer: NewRoundRobinBalancer([]*ProxyTarget{ + { + Name: "Timeout", + URL: timeoutTargetURL, + }, + { + Name: "Good", + URL: goodTargetURL, + }, + }), + RetryCount: 1, + }, + )) + + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, 200, rec.Code) + }() + } + + wg.Wait() + +} + +func TestProxyErrorHandler(t *testing.T) { + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + goodURL, _ := url.Parse(server.URL) + defer server.Close() + goodTarget := &ProxyTarget{ + Name: "Good", + URL: goodURL, + } + + badURL, _ := url.Parse("http://127.0.0.1:27121") + badTarget := &ProxyTarget{ + Name: "Bad", + URL: badURL, + } + + transformedError := errors.New("a new error") + + testCases := []struct { + name string + target *ProxyTarget + errorHandler func(c *echo.Context, e error) error + expectFinalError func(t *testing.T, err error) + }{ + { + name: "Error handler not invoked when request success", + target: goodTarget, + errorHandler: func(c *echo.Context, e error) error { + assert.FailNow(t, "error handler should not be invoked") + return e + }, + }, + { + name: "Error handler invoked when request fails", + target: badTarget, + errorHandler: func(c *echo.Context, e error) error { + httpErr, ok := e.(*echo.HTTPError) + assert.True(t, ok, "expected http error to be passed to handler") + assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler") + return transformedError + }, + expectFinalError: func(t *testing.T, err error) { + assert.Equal(t, transformedError, err, "transformed error not returned from proxy") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Balancer: NewRoundRobinBalancer([]*ProxyTarget{tc.target}), + ErrorHandler: tc.errorHandler, + }, + )) + + errorHandlerCalled := false + dheh := echo.DefaultHTTPErrorHandler(false) + e.HTTPErrorHandler = func(c *echo.Context, err error) { + errorHandlerCalled = true + tc.expectFinalError(t, err) + dheh(c, err) + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + if !errorHandlerCalled && tc.expectFinalError != nil { + t.Fatalf("error handler was not called") + } + + }) + } +} + +type testContextKey string +type customBalancer struct { + target *ProxyTarget +} + +func (b *customBalancer) AddTarget(target *ProxyTarget) bool { + return false +} +func (b *customBalancer) RemoveTarget(name string) bool { + return false +} + +func (b *customBalancer) Next(c *echo.Context) (*ProxyTarget, error) { + ctx := context.WithValue(c.Request().Context(), testContextKey("FROM_BALANCER"), "CUSTOM_BALANCER") + c.SetRequest(c.Request().WithContext(ctx)) + return b.target, nil +} + +func TestModifyResponseUseContext(t *testing.T) { + server := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }), + ) + defer server.Close() + targetURL, _ := url.Parse(server.URL) + e := echo.New() + e.Use(ProxyWithConfig( + ProxyConfig{ + Balancer: &customBalancer{ + target: &ProxyTarget{ + Name: "tst", + URL: targetURL, + }, + }, + RetryCount: 1, + ModifyResponse: func(res *http.Response) error { + val := res.Request.Context().Value(testContextKey("FROM_BALANCER")) + if valStr, ok := val.(string); ok { + res.Header.Set("FROM_BALANCER", valStr) + } + return nil + }, + }, + )) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "OK", rec.Body.String()) + assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER")) +} + +func createSimpleWebSocketServer(serveTLS bool) *httptest.Server { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsHandler := func(conn *websocket.Conn) { + defer conn.Close() + for { + var msg string + err := websocket.Message.Receive(conn, &msg) + if err != nil { + return + } + // message back to the client + websocket.Message.Send(conn, msg) + } + } + websocket.Server{Handler: wsHandler}.ServeHTTP(w, r) + }) + if serveTLS { + return httptest.NewTLSServer(handler) + } + return httptest.NewServer(handler) +} + +func createSimpleProxyServer(t *testing.T, srv *httptest.Server, serveTLS bool, toTLS bool) *httptest.Server { + e := echo.New() + + if toTLS { + // proxy to tls target + tgtURL, _ := url.Parse(srv.URL) + tgtURL.Scheme = "wss" + balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}}) + + defaultTransport, ok := http.DefaultTransport.(*http.Transport) + if !ok { + t.Fatal("Default transport is not of type *http.Transport") + } + transport := defaultTransport.Clone() + transport.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer, Transport: transport})) + } else { + // proxy to non-TLS target + tgtURL, _ := url.Parse(srv.URL) + balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}}) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer})) + } + + if serveTLS { + // serve proxy server with TLS + ts := httptest.NewTLSServer(e) + return ts + } + // serve proxy server without TLS + ts := httptest.NewServer(e) + return ts +} + +// TestProxyWithConfigWebSocketNonTLS2NonTLS tests the proxy with non-TLS to non-TLS WebSocket connection. +func TestProxyWithConfigWebSocketNonTLS2NonTLS(t *testing.T) { + /* + Arrange + */ + // Create a WebSocket test server (non-TLS) + srv := createSimpleWebSocketServer(false) + defer srv.Close() + + // create proxy server (non-TLS to non-TLS) + ts := createSimpleProxyServer(t, srv, false, false) + defer ts.Close() + + tsURL, _ := url.Parse(ts.URL) + tsURL.Scheme = "ws" + tsURL.Path = "/" + + /* + Act + */ + + // Connect to the proxy WebSocket + wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/") + assert.NoError(t, err) + defer wsConn.Close() + + // Send message + sendMsg := "Hello, Non TLS WebSocket!" + err = websocket.Message.Send(wsConn, sendMsg) + assert.NoError(t, err) + + /* + Assert + */ + // Read response + var recvMsg string + err = websocket.Message.Receive(wsConn, &recvMsg) + assert.NoError(t, err) + assert.Equal(t, sendMsg, recvMsg) +} + +// TestProxyWithConfigWebSocketTLS2TLS tests the proxy with TLS to TLS WebSocket connection. +func TestProxyWithConfigWebSocketTLS2TLS(t *testing.T) { + /* + Arrange + */ + // Create a WebSocket test server (TLS) + srv := createSimpleWebSocketServer(true) + defer srv.Close() + + // create proxy server (TLS to TLS) + ts := createSimpleProxyServer(t, srv, true, true) + defer ts.Close() + + tsURL, _ := url.Parse(ts.URL) + tsURL.Scheme = "wss" + tsURL.Path = "/" + + /* + Act + */ + origin, err := url.Parse(ts.URL) + assert.NoError(t, err) + config := &websocket.Config{ + Location: tsURL, + Origin: origin, + TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing + Version: websocket.ProtocolVersionHybi13, + } + wsConn, err := websocket.DialConfig(config) + assert.NoError(t, err) + defer wsConn.Close() + + // Send message + sendMsg := "Hello, TLS to TLS WebSocket!" + err = websocket.Message.Send(wsConn, sendMsg) + assert.NoError(t, err) + + // Read response + var recvMsg string + err = websocket.Message.Receive(wsConn, &recvMsg) + assert.NoError(t, err) + assert.Equal(t, sendMsg, recvMsg) +} + +// TestProxyWithConfigWebSocketNonTLS2TLS tests the proxy with non-TLS to TLS WebSocket connection. +func TestProxyWithConfigWebSocketNonTLS2TLS(t *testing.T) { + /* + Arrange + */ + + // Create a WebSocket test server (TLS) + srv := createSimpleWebSocketServer(true) + defer srv.Close() + + // create proxy server (Non-TLS to TLS) + ts := createSimpleProxyServer(t, srv, false, true) + defer ts.Close() + + tsURL, _ := url.Parse(ts.URL) + tsURL.Scheme = "ws" + tsURL.Path = "/" + + /* + Act + */ + // Connect to the proxy WebSocket + wsConn, err := websocket.Dial(tsURL.String(), "", "http://localhost/") + assert.NoError(t, err) + defer wsConn.Close() + + // Send message + sendMsg := "Hello, Non TLS to TLS WebSocket!" + err = websocket.Message.Send(wsConn, sendMsg) + assert.NoError(t, err) + + /* + Assert + */ + // Read response + var recvMsg string + err = websocket.Message.Receive(wsConn, &recvMsg) + assert.NoError(t, err) + assert.Equal(t, sendMsg, recvMsg) +} + +// TestProxyWithConfigWebSocketTLSToNoneTLS tests the proxy with TLS to non-TLS WebSocket connection. (TLS termination) +func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) { + /* + Arrange + */ + + // Create a WebSocket test server (non-TLS) + srv := createSimpleWebSocketServer(false) + defer srv.Close() + + // create proxy server (TLS to non-TLS) + ts := createSimpleProxyServer(t, srv, true, false) + defer ts.Close() + + tsURL, _ := url.Parse(ts.URL) + tsURL.Scheme = "wss" + tsURL.Path = "/" + + /* + Act + */ + origin, err := url.Parse(ts.URL) + assert.NoError(t, err) + config := &websocket.Config{ + Location: tsURL, + Origin: origin, + TlsConfig: &tls.Config{InsecureSkipVerify: true}, // skip verify for testing + Version: websocket.ProtocolVersionHybi13, + } + wsConn, err := websocket.DialConfig(config) + assert.NoError(t, err) + defer wsConn.Close() + + // Send message + sendMsg := "Hello, TLS to NoneTLS WebSocket!" + err = websocket.Message.Send(wsConn, sendMsg) + assert.NoError(t, err) + + // Read response + var recvMsg string + err = websocket.Message.Receive(wsConn, &recvMsg) + assert.NoError(t, err) + assert.Equal(t, sendMsg, recvMsg) +} diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go new file mode 100644 index 000000000..c04ae157d --- /dev/null +++ b/middleware/rate_limiter.go @@ -0,0 +1,263 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package middleware + +import ( + "errors" + "math" + "net/http" + "sync" + "time" + + "github.com/labstack/echo/v5" + "golang.org/x/time/rate" +) + +// RateLimiterStore is the interface to be implemented by custom stores. +type RateLimiterStore interface { + Allow(identifier string) (bool, error) +} + +// RateLimiterConfig defines the configuration for the rate limiter +type RateLimiterConfig struct { + Skipper Skipper + BeforeFunc BeforeFunc + // IdentifierExtractor uses *echo.Context to extract the identifier for a visitor + IdentifierExtractor Extractor + // Store defines a store for the rate limiter + Store RateLimiterStore + // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error + ErrorHandler func(c *echo.Context, err error) error + // DenyHandler provides a handler to be called when RateLimiter denies access + DenyHandler func(c *echo.Context, identifier string, err error) error +} + +// Extractor is used to extract data from *echo.Context +type Extractor func(c *echo.Context) (string, error) + +// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded +var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") + +// ErrExtractorError denotes an error raised when extractor function is unsuccessful +var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") + +// DefaultRateLimiterConfig defines default values for RateLimiterConfig +var DefaultRateLimiterConfig = RateLimiterConfig{ + Skipper: DefaultSkipper, + IdentifierExtractor: func(ctx *echo.Context) (string, error) { + id := ctx.RealIP() + return id, nil + }, + ErrorHandler: func(c *echo.Context, err error) error { + return ErrExtractorError.Wrap(err) + }, + DenyHandler: func(c *echo.Context, identifier string, err error) error { + return ErrRateLimitExceeded.Wrap(err) + }, +} + +/* +RateLimiter returns a rate limiting middleware + + e := echo.New() + + limiterStore := middleware.NewRateLimiterMemoryStore(20) + + e.GET("/rate-limited", func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + }, RateLimiter(limiterStore)) +*/ +func RateLimiter(store RateLimiterStore) echo.MiddlewareFunc { + config := DefaultRateLimiterConfig + config.Store = store + + return RateLimiterWithConfig(config) +} + +/* +RateLimiterWithConfig returns a rate limiting middleware + + e := echo.New() + + config := middleware.RateLimiterConfig{ + Skipper: DefaultSkipper, + Store: middleware.NewRateLimiterMemoryStore( + middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute} + ) + IdentifierExtractor: func(ctx *echo.Context) (string, error) { + id := ctx.RealIP() + return id, nil + }, + ErrorHandler: func(ctx *echo.Context, err error) error { + return context.JSON(http.StatusTooManyRequests, nil) + }, + DenyHandler: func(ctx *echo.Context, identifier string, err error) error { + return context.JSON(http.StatusForbidden, nil) + }, + } + + e.GET("/rate-limited", func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + }, middleware.RateLimiterWithConfig(config)) +*/ +func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration +func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) { + if config.Skipper == nil { + config.Skipper = DefaultRateLimiterConfig.Skipper + } + if config.IdentifierExtractor == nil { + config.IdentifierExtractor = DefaultRateLimiterConfig.IdentifierExtractor + } + if config.ErrorHandler == nil { + config.ErrorHandler = DefaultRateLimiterConfig.ErrorHandler + } + if config.DenyHandler == nil { + config.DenyHandler = DefaultRateLimiterConfig.DenyHandler + } + if config.Store == nil { + return nil, errors.New("echo rate limiter store configuration must be provided") + } + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c *echo.Context) error { + if config.Skipper(c) { + return next(c) + } + if config.BeforeFunc != nil { + config.BeforeFunc(c) + } + + identifier, err := config.IdentifierExtractor(c) + if err != nil { + return config.ErrorHandler(c, err) + } + + if allow, allowErr := config.Store.Allow(identifier); !allow { + return config.DenyHandler(c, identifier, allowErr) + } + return next(c) + } + }, nil +} + +// RateLimiterMemoryStore is the built-in store implementation for RateLimiter +type RateLimiterMemoryStore struct { + visitors map[string]*Visitor + mutex sync.Mutex + rate float64 // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit + burst int + expiresIn time.Duration + lastCleanup time.Time + + timeNow func() time.Time +} + +// Visitor signifies a unique user's limiter details +type Visitor struct { + *rate.Limiter + lastSeen time.Time +} + +/* +NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with +the provided rate (as req/s). +for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. + +Burst and ExpiresIn will be set to default values. + +Note that if the provided rate is a float number and Burst is zero, Burst will be treated as the rounded down value of the rate. + +Example (with 20 requests/sec): + + limiterStore := middleware.NewRateLimiterMemoryStore(20) +*/ +func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore) { + return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ + Rate: rateLimit, + }) +} + +/* +NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore +with the provided configuration. Rate must be provided. Burst will be set to the rounded down value of +the configured rate if not provided or set to 0. + +The built-in memory store is usually capable for modest loads. For higher loads other +store implementations should be considered. + +Characteristics: +* Concurrency above 100 parallel requests may causes measurable lock contention +* A high number of different IP addresses (above 16000) may be impacted by the internally used Go map +* A high number of requests from a single IP address may cause lock contention + +Example: + + limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig( + middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minute}, + ) +*/ +func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) { + store = &RateLimiterMemoryStore{} + + store.rate = config.Rate + store.burst = config.Burst + store.expiresIn = config.ExpiresIn + if config.ExpiresIn == 0 { + store.expiresIn = DefaultRateLimiterMemoryStoreConfig.ExpiresIn + } + if config.Burst == 0 { + store.burst = int(math.Max(1, math.Ceil(float64(config.Rate)))) + } + store.visitors = make(map[string]*Visitor) + store.timeNow = time.Now + store.lastCleanup = store.timeNow() + return +} + +// RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore +type RateLimiterMemoryStoreConfig struct { + Rate float64 // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. + Burst int // Burst is maximum number of requests to pass at the same moment. It additionally allows a number of requests to pass when rate limit is reached. + ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up +} + +// DefaultRateLimiterMemoryStoreConfig provides default configuration values for RateLimiterMemoryStore +var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{ + ExpiresIn: 3 * time.Minute, +} + +// Allow implements RateLimiterStore.Allow +func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { + store.mutex.Lock() + limiter, exists := store.visitors[identifier] + if !exists { + limiter = new(Visitor) + limiter.Limiter = rate.NewLimiter(rate.Limit(store.rate), store.burst) + store.visitors[identifier] = limiter + } + now := store.timeNow() + limiter.lastSeen = now + if now.Sub(store.lastCleanup) > store.expiresIn { + store.cleanupStaleVisitors(now) + } + allowed := limiter.AllowN(now, 1) + store.mutex.Unlock() + return allowed, nil +} + +/* +cleanupStaleVisitors helps manage the size of the visitors map by removing stale records +of users who haven't visited again after the configured expiry time has elapsed +*/ +func (store *RateLimiterMemoryStore) cleanupStaleVisitors(now time.Time) { + for id, visitor := range store.visitors { + if now.Sub(visitor.lastSeen) > store.expiresIn { + delete(store.visitors, id) + } + } + store.lastCleanup = now +} diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go new file mode 100644 index 000000000..c591d2b19 --- /dev/null +++ b/middleware/rate_limiter_test.go @@ -0,0 +1,648 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package middleware + +import ( + "errors" + "math/rand" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/labstack/echo/v5" + "github.com/stretchr/testify/assert" + "golang.org/x/time/rate" +) + +func TestRateLimiter(t *testing.T) { + e := echo.New() + + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore}) + + testCases := []struct { + id string + expectErr string + }{ + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + } + + for _, tc := range testCases { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, tc.id) + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) + } +} + +func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + assert.Panics(t, func() { + RateLimiterWithConfig(RateLimiterConfig{}) + }) + + assert.NotPanics(t, func() { + RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore}) + }) +} + +func TestRateLimiterWithConfig(t *testing.T) { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + e := echo.New() + + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + + mw, err := RateLimiterConfig{ + IdentifierExtractor: func(c *echo.Context) (string, error) { + id := c.Request().Header.Get(echo.HeaderXRealIP) + if id == "" { + return "", errors.New("invalid identifier") + } + return id, nil + }, + DenyHandler: func(ctx *echo.Context, identifier string, err error) error { + return ctx.JSON(http.StatusForbidden, nil) + }, + ErrorHandler: func(ctx *echo.Context, err error) error { + return ctx.JSON(http.StatusBadRequest, nil) + }, + Store: inMemoryStore, + }.ToMiddleware() + assert.NoError(t, err) + + testCases := []struct { + id string + code int + }{ + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusOK}, + {"127.0.0.1", http.StatusForbidden}, + {"", http.StatusBadRequest}, + {"127.0.0.1", http.StatusForbidden}, + {"127.0.0.1", http.StatusForbidden}, + } + + for _, tc := range testCases { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, tc.id) + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + err := mw(handler)(c) + + assert.NoError(t, err) + assert.Equal(t, tc.code, rec.Code) + } +} + +func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + e := echo.New() + + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + + mw, err := RateLimiterConfig{ + IdentifierExtractor: func(c *echo.Context) (string, error) { + id := c.Request().Header.Get(echo.HeaderXRealIP) + if id == "" { + return "", errors.New("invalid identifier") + } + return id, nil + }, + Store: inMemoryStore, + }.ToMiddleware() + assert.NoError(t, err) + + testCases := []struct { + id string + expectErr string + }{ + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {expectErr: "code=403, message=error while extracting identifier, err=invalid identifier"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + } + + for _, tc := range testCases { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, tc.id) + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) + } +} + +func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { + { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + e := echo.New() + + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + + mw, err := RateLimiterConfig{ + Store: inMemoryStore, + }.ToMiddleware() + assert.NoError(t, err) + + testCases := []struct { + id string + expectErr string + }{ + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + } + + for _, tc := range testCases { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, tc.id) + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) + } + } +} + +func TestRateLimiterWithConfig_skipper(t *testing.T) { + e := echo.New() + + var beforeFuncRan bool + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + var inMemoryStore = NewRateLimiterMemoryStore(5) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + mw, err := RateLimiterConfig{ + Skipper: func(c *echo.Context) bool { + return true + }, + BeforeFunc: func(c *echo.Context) { + beforeFuncRan = true + }, + Store: inMemoryStore, + IdentifierExtractor: func(ctx *echo.Context) (string, error) { + return "127.0.0.1", nil + }, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(handler)(c) + + assert.NoError(t, err) + assert.Equal(t, false, beforeFuncRan) +} + +func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { + e := echo.New() + + var beforeFuncRan bool + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + var inMemoryStore = NewRateLimiterMemoryStore(5) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + mw, err := RateLimiterConfig{ + Skipper: func(c *echo.Context) bool { + return false + }, + BeforeFunc: func(c *echo.Context) { + beforeFuncRan = true + }, + Store: inMemoryStore, + IdentifierExtractor: func(ctx *echo.Context) (string, error) { + return "127.0.0.1", nil + }, + }.ToMiddleware() + assert.NoError(t, err) + + _ = mw(handler)(c) + + assert.Equal(t, true, beforeFuncRan) +} + +func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { + e := echo.New() + + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + + var beforeRan bool + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") + + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + + mw, err := RateLimiterConfig{ + BeforeFunc: func(c *echo.Context) { + beforeRan = true + }, + Store: inMemoryStore, + IdentifierExtractor: func(ctx *echo.Context) (string, error) { + return "127.0.0.1", nil + }, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(handler)(c) + + assert.NoError(t, err) + assert.Equal(t, true, beforeRan) +} + +func TestRateLimiterMemoryStore_Allow(t *testing.T) { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3, ExpiresIn: 2 * time.Second}) + testCases := []struct { + id string + allowed bool + }{ + {"127.0.0.1", true}, // 0 ms + {"127.0.0.1", true}, // 220 ms burst #2 + {"127.0.0.1", true}, // 440 ms burst #3 + {"127.0.0.1", false}, // 660 ms block + {"127.0.0.1", false}, // 880 ms block + {"127.0.0.1", true}, // 1100 ms next second #1 + {"127.0.0.2", true}, // 1320 ms allow other ip + {"127.0.0.1", false}, // 1540 ms no burst + {"127.0.0.1", false}, // 1760 ms no burst + {"127.0.0.1", false}, // 1980 ms no burst + {"127.0.0.1", true}, // 2200 ms no burst + {"127.0.0.1", false}, // 2420 ms no burst + {"127.0.0.1", false}, // 2640 ms no burst + {"127.0.0.1", false}, // 2860 ms no burst + {"127.0.0.1", true}, // 3080 ms no burst + {"127.0.0.1", false}, // 3300 ms no burst + {"127.0.0.1", false}, // 3520 ms no burst + {"127.0.0.1", false}, // 3740 ms no burst + {"127.0.0.1", false}, // 3960 ms no burst + {"127.0.0.1", true}, // 4180 ms no burst + {"127.0.0.1", false}, // 4400 ms no burst + {"127.0.0.1", false}, // 4620 ms no burst + {"127.0.0.1", false}, // 4840 ms no burst + {"127.0.0.1", true}, // 5060 ms no burst + } + + for i, tc := range testCases { + t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond) + inMemoryStore.timeNow = func() time.Time { + return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond) + } + allowed, _ := inMemoryStore.Allow(tc.id) + assert.Equal(t, tc.allowed, allowed) + } +} + +func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) { + var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) + inMemoryStore.visitors = map[string]*Visitor{ + "A": { + Limiter: rate.NewLimiter(1, 3), + lastSeen: time.Now(), + }, + "B": { + Limiter: rate.NewLimiter(1, 3), + lastSeen: time.Now().Add(-1 * time.Minute), + }, + "C": { + Limiter: rate.NewLimiter(1, 3), + lastSeen: time.Now().Add(-5 * time.Minute), + }, + "D": { + Limiter: rate.NewLimiter(1, 3), + lastSeen: time.Now().Add(-10 * time.Minute), + }, + } + + inMemoryStore.Allow("D") + inMemoryStore.cleanupStaleVisitors(time.Now()) + + var exists bool + + _, exists = inMemoryStore.visitors["A"] + assert.Equal(t, true, exists) + + _, exists = inMemoryStore.visitors["B"] + assert.Equal(t, true, exists) + + _, exists = inMemoryStore.visitors["C"] + assert.Equal(t, false, exists) + + _, exists = inMemoryStore.visitors["D"] + assert.Equal(t, true, exists) +} + +func TestNewRateLimiterMemoryStore(t *testing.T) { + testCases := []struct { + rate float64 + burst int + expiresIn time.Duration + expectedExpiresIn time.Duration + }{ + {1, 3, 5 * time.Second, 5 * time.Second}, + {2, 4, 0, 3 * time.Minute}, + {1, 5, 10 * time.Minute, 10 * time.Minute}, + {3, 7, 0, 3 * time.Minute}, + } + + for _, tc := range testCases { + store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: tc.rate, Burst: tc.burst, ExpiresIn: tc.expiresIn}) + assert.Equal(t, tc.rate, store.rate) + assert.Equal(t, tc.burst, store.burst) + assert.Equal(t, tc.expectedExpiresIn, store.expiresIn) + } +} + +func TestRateLimiterMemoryStore_FractionalRateDefaultBurst(t *testing.T) { + store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ + Rate: 0.5, // fractional rate should get a burst of at least 1 + }) + + base := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) + store.timeNow = func() time.Time { + return base + } + + allowed, err := store.Allow("user") + assert.NoError(t, err) + assert.True(t, allowed, "first request should not be blocked") + + allowed, err = store.Allow("user") + assert.NoError(t, err) + assert.False(t, allowed, "burst token should be consumed immediately") + + store.timeNow = func() time.Time { + return base.Add(2 * time.Second) + } + + allowed, err = store.Allow("user") + assert.NoError(t, err) + assert.True(t, allowed, "token should refill for fractional rate after time passes") +} + +func generateAddressList(count int) []string { + addrs := make([]string, count) + for i := 0; i < count; i++ { + addrs[i] = randomString(15) + } + return addrs +} + +func run(wg *sync.WaitGroup, store RateLimiterStore, addrs []string, max int, b *testing.B) { + for i := 0; i < b.N; i++ { + store.Allow(addrs[rand.Intn(max)]) + } + wg.Done() +} + +func benchmarkStore(store RateLimiterStore, parallel int, max int, b *testing.B) { + addrs := generateAddressList(max) + wg := &sync.WaitGroup{} + for i := 0; i < parallel; i++ { + wg.Add(1) + go run(wg, store, addrs, max, b) + } + wg.Wait() +} + +const ( + testExpiresIn = 1000 * time.Millisecond +) + +func BenchmarkRateLimiterMemoryStore_1000(b *testing.B) { + var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) + benchmarkStore(store, 10, 1000, b) +} + +func BenchmarkRateLimiterMemoryStore_10000(b *testing.B) { + var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) + benchmarkStore(store, 10, 10000, b) +} + +func BenchmarkRateLimiterMemoryStore_100000(b *testing.B) { + var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) + benchmarkStore(store, 10, 100000, b) +} + +func BenchmarkRateLimiterMemoryStore_conc100_10000(b *testing.B) { + var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) + benchmarkStore(store, 100, 10000, b) +} + +// TestRateLimiterMemoryStore_TOCTOUFix verifies that the TOCTOU race condition is fixed +// by ensuring timeNow() is only called once per Allow() call +func TestRateLimiterMemoryStore_TOCTOUFix(t *testing.T) { + t.Parallel() + + store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ + Rate: 1, + Burst: 1, + ExpiresIn: 2 * time.Second, + }) + + // Track time calls to verify we use the same time value + timeCallCount := 0 + baseTime := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) + + store.timeNow = func() time.Time { + timeCallCount++ + return baseTime + } + + // First request - should succeed + allowed, err := store.Allow("127.0.0.1") + assert.NoError(t, err) + assert.True(t, allowed, "First request should be allowed") + + // Verify timeNow() was only called once + assert.Equal(t, 1, timeCallCount, "timeNow() should only be called once per Allow()") +} + +// TestRateLimiterMemoryStore_ConcurrentAccess verifies rate limiting correctness under concurrent load +func TestRateLimiterMemoryStore_ConcurrentAccess(t *testing.T) { + t.Parallel() + + store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ + Rate: 10, + Burst: 5, + ExpiresIn: 5 * time.Second, + }) + + const goroutines = 50 + const requestsPerGoroutine = 20 + + var wg sync.WaitGroup + var allowedCount, deniedCount int32 + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < requestsPerGoroutine; j++ { + allowed, err := store.Allow("test-user") + assert.NoError(t, err) + if allowed { + atomic.AddInt32(&allowedCount, 1) + } else { + atomic.AddInt32(&deniedCount, 1) + } + time.Sleep(time.Millisecond) + } + }() + } + + wg.Wait() + + totalRequests := goroutines * requestsPerGoroutine + allowed := int(allowedCount) + denied := int(deniedCount) + + assert.Equal(t, totalRequests, allowed+denied, "All requests should be processed") + assert.Greater(t, denied, 0, "Some requests should be denied due to rate limiting") + assert.Greater(t, allowed, 0, "Some requests should be allowed") +} + +// TestRateLimiterMemoryStore_RaceDetection verifies no data races with high concurrency +// Run with: go test -race ./middleware -run TestRateLimiterMemoryStore_RaceDetection +func TestRateLimiterMemoryStore_RaceDetection(t *testing.T) { + t.Parallel() + + store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ + Rate: 100, + Burst: 200, + ExpiresIn: 1 * time.Second, + }) + + const goroutines = 100 + const requestsPerGoroutine = 100 + + var wg sync.WaitGroup + identifiers := []string{"user1", "user2", "user3", "user4", "user5"} + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(routineID int) { + defer wg.Done() + for j := 0; j < requestsPerGoroutine; j++ { + identifier := identifiers[routineID%len(identifiers)] + _, err := store.Allow(identifier) + assert.NoError(t, err) + } + }(i) + } + + wg.Wait() +} + +// TestRateLimiterMemoryStore_TimeOrdering verifies time ordering consistency in rate limiting decisions +func TestRateLimiterMemoryStore_TimeOrdering(t *testing.T) { + t.Parallel() + + store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ + Rate: 1, + Burst: 2, + ExpiresIn: 5 * time.Second, + }) + + currentTime := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) + store.timeNow = func() time.Time { + return currentTime + } + + // First two requests should succeed (burst=2) + allowed1, _ := store.Allow("user1") + assert.True(t, allowed1, "Request 1 should be allowed (burst)") + + allowed2, _ := store.Allow("user1") + assert.True(t, allowed2, "Request 2 should be allowed (burst)") + + // Third request should be denied + allowed3, _ := store.Allow("user1") + assert.False(t, allowed3, "Request 3 should be denied (burst exhausted)") + + // Advance time by 1 second + currentTime = currentTime.Add(1 * time.Second) + + // Fourth request should succeed + allowed4, _ := store.Allow("user1") + assert.True(t, allowed4, "Request 4 should be allowed (1 token available)") +} diff --git a/middleware/recover.go b/middleware/recover.go index e87aaf321..01fde5152 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -1,42 +1,42 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( "fmt" + "net/http" "runtime" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // RecoverConfig defines the config for Recover middleware. - RecoverConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// RecoverConfig defines the config for Recover middleware. +type RecoverConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Size of the stack to be printed. - // Optional. Default value 4KB. - StackSize int `yaml:"stack_size"` + // Size of the stack to be printed. + // Optional. Default value 4KB. + StackSize int - // DisableStackAll disables formatting stack traces of all other goroutines - // into buffer after the trace for the current goroutine. - // Optional. Default value false. - DisableStackAll bool `yaml:"disable_stack_all"` + // DisableStackAll disables formatting stack traces of all other goroutines + // into buffer after the trace for the current goroutine. + // Optional. Default value false. + DisableStackAll bool - // DisablePrintStack disables printing stack trace. - // Optional. Default value as false. - DisablePrintStack bool `yaml:"disable_print_stack"` - } -) + // DisablePrintStack disables printing stack trace. + // Optional. Default value as false. + DisablePrintStack bool +} -var ( - // DefaultRecoverConfig is the default Recover middleware config. - DefaultRecoverConfig = RecoverConfig{ - Skipper: DefaultSkipper, - StackSize: 4 << 10, // 4 KB - DisableStackAll: false, - DisablePrintStack: false, - } -) +// DefaultRecoverConfig is the default Recover middleware config. +var DefaultRecoverConfig = RecoverConfig{ + Skipper: DefaultSkipper, + StackSize: 4 << 10, // 4 KB + DisableStackAll: false, + DisablePrintStack: false, +} // Recover returns a middleware which recovers from panics anywhere in the chain // and handles the control to the centralized HTTPErrorHandler. @@ -44,9 +44,13 @@ func Recover() echo.MiddlewareFunc { return RecoverWithConfig(DefaultRecoverConfig) } -// RecoverWithConfig returns a Recover middleware with config. -// See: `Recover()`. +// RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration. func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration +func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultRecoverConfig.Skipper @@ -56,26 +60,44 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) (err error) { if config.Skipper(c) { return next(c) } defer func() { if r := recover(); r != nil { - err, ok := r.(error) + if r == http.ErrAbortHandler { + panic(r) + } + tmpErr, ok := r.(error) if !ok { - err = fmt.Errorf("%v", r) + tmpErr = fmt.Errorf("%v", r) } - stack := make([]byte, config.StackSize) - length := runtime.Stack(stack, !config.DisableStackAll) if !config.DisablePrintStack { - c.Logger().Printf("[PANIC RECOVER] %v %s\n", err, stack[:length]) + stack := make([]byte, config.StackSize) + length := runtime.Stack(stack, !config.DisableStackAll) + tmpErr = &PanicStackError{Stack: stack[:length], Err: tmpErr} } - c.Error(err) + err = tmpErr } }() return next(c) } - } + }, nil +} + +// PanicStackError is an error type that wraps an error along with its stack trace. +// It is returned when config.DisablePrintStack is set to false. +type PanicStackError struct { + Stack []byte + Err error +} + +func (e *PanicStackError) Error() string { + return fmt.Sprintf("[PANIC RECOVER] %s %s", e.Err.Error(), e.Stack) +} + +func (e *PanicStackError) Unwrap() error { + return e.Err } diff --git a/middleware/recover_test.go b/middleware/recover_test.go index 37707c5c1..719e0cc3d 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -1,26 +1,150 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( "bytes" + "errors" + "log/slog" "net/http" "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestRecover(t *testing.T) { e := echo.New() buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) + e.Logger = slog.New(&discardHandler{}) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := Recover()(echo.HandlerFunc(func(c echo.Context) error { + h := Recover()(func(c *echo.Context) error { panic("test") - })) - h(c) + }) + err := h(c) + assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine") + + var pse *PanicStackError + if errors.As(err, &pse) { + assert.Contains(t, string(pse.Stack), "middleware/recover.go") + } else { + assert.Fail(t, "not of type PanicStackError") + } + + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain + assert.Contains(t, buf.String(), "") // nothing is logged +} + +func TestRecover_skipper(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + config := RecoverConfig{ + Skipper: func(c *echo.Context) bool { + return true + }, + } + h := RecoverWithConfig(config)(func(c *echo.Context) error { + panic("testPANIC") + }) + + var err error + assert.Panics(t, func() { + err = h(c) + }) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain +} + +func TestRecoverErrAbortHandler(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := Recover()(func(c *echo.Context) error { + panic(http.ErrAbortHandler) + }) + defer func() { + r := recover() + if r == nil { + assert.Fail(t, "expecting `http.ErrAbortHandler`, got `nil`") + } else { + if err, ok := r.(error); ok { + assert.ErrorIs(t, err, http.ErrAbortHandler) + } else { + assert.Fail(t, "not of error type") + } + } + }() + + hErr := h(c) + assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.Contains(t, buf.String(), "PANIC RECOVER") + assert.NotContains(t, hErr.Error(), "PANIC RECOVER") +} + +func TestRecoverWithConfig(t *testing.T) { + var testCases = []struct { + name string + givenNoPanic bool + whenConfig RecoverConfig + expectErrContain string + expectErr string + }{ + { + name: "ok, default config", + whenConfig: DefaultRecoverConfig, + expectErrContain: "[PANIC RECOVER] testPANIC goroutine", + }, + { + name: "ok, no panic", + givenNoPanic: true, + whenConfig: DefaultRecoverConfig, + expectErrContain: "", + }, + { + name: "ok, DisablePrintStack", + whenConfig: RecoverConfig{ + DisablePrintStack: true, + }, + expectErr: "testPANIC", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + config := tc.whenConfig + h := RecoverWithConfig(config)(func(c *echo.Context) error { + if tc.givenNoPanic { + return nil + } + panic("testPANIC") + }) + + err := h(c) + + if tc.expectErrContain != "" { + assert.Contains(t, err.Error(), tc.expectErrContain) + } else if tc.expectErr != "" { + assert.Contains(t, err.Error(), tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain + }) + } } diff --git a/middleware/redirect.go b/middleware/redirect.go index 813e5b856..bb7045cfe 100644 --- a/middleware/redirect.go +++ b/middleware/redirect.go @@ -1,9 +1,14 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( + "errors" "net/http" + "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // RedirectConfig defines the config for Redirect middleware. @@ -13,7 +18,9 @@ type RedirectConfig struct { // Status code to be used when redirecting the request. // Optional. Default value http.StatusMovedPermanently. - Code int `yaml:"code"` + Code int + + redirect redirectLogic } // redirectLogic represents a function that given a scheme, host and uri @@ -23,29 +30,33 @@ type redirectLogic func(scheme, host, uri string) (ok bool, url string) const www = "www." -// DefaultRedirectConfig is the default Redirect middleware config. -var DefaultRedirectConfig = RedirectConfig{ - Skipper: DefaultSkipper, - Code: http.StatusMovedPermanently, -} +// RedirectHTTPSConfig is the HTTPS Redirect middleware config. +var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS} + +// RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config. +var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW} + +// RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config. +var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW} + +// RedirectWWWConfig is the WWW Redirect middleware config. +var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW} + +// RedirectNonWWWConfig is the non WWW Redirect middleware config. +var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW} // HTTPSRedirect redirects http requests to https. // For example, http://labstack.com will be redirect to https://labstack.com. // // Usage `Echo#Pre(HTTPSRedirect())` func HTTPSRedirect() echo.MiddlewareFunc { - return HTTPSRedirectWithConfig(DefaultRedirectConfig) + return HTTPSRedirectWithConfig(RedirectHTTPSConfig) } -// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSRedirect()`. +// HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration. func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if ok = scheme != "https"; ok { - url = "https://" + host + uri - } - return - }) + config.redirect = redirectHTTPS + return toMiddlewareOrPanic(config) } // HTTPSWWWRedirect redirects http requests to https www. @@ -53,18 +64,13 @@ func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(HTTPSWWWRedirect())` func HTTPSWWWRedirect() echo.MiddlewareFunc { - return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig) + return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig) } -// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSWWWRedirect()`. +// HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration. func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if ok = scheme != "https" && host[:4] != www; ok { - url = "https://www." + host + uri - } - return - }) + config.redirect = redirectHTTPSWWW + return toMiddlewareOrPanic(config) } // HTTPSNonWWWRedirect redirects http requests to https non www. @@ -72,21 +78,13 @@ func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(HTTPSNonWWWRedirect())` func HTTPSNonWWWRedirect() echo.MiddlewareFunc { - return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig) + return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig) } -// HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSNonWWWRedirect()`. +// HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration. func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if ok = scheme != "https"; ok { - if host[:4] == www { - host = host[4:] - } - url = "https://" + host + uri - } - return - }) + config.redirect = redirectNonHTTPSWWW + return toMiddlewareOrPanic(config) } // WWWRedirect redirects non www requests to www. @@ -94,18 +92,13 @@ func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(WWWRedirect())` func WWWRedirect() echo.MiddlewareFunc { - return WWWRedirectWithConfig(DefaultRedirectConfig) + return WWWRedirectWithConfig(RedirectWWWConfig) } -// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `WWWRedirect()`. +// WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration. func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if ok = host[:4] != www; ok { - url = scheme + "://www." + host + uri - } - return - }) + config.redirect = redirectWWW + return toMiddlewareOrPanic(config) } // NonWWWRedirect redirects www requests to non www. @@ -113,41 +106,79 @@ func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(NonWWWRedirect())` func NonWWWRedirect() echo.MiddlewareFunc { - return NonWWWRedirectWithConfig(DefaultRedirectConfig) + return NonWWWRedirectWithConfig(RedirectNonWWWConfig) } -// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `NonWWWRedirect()`. +// NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration. func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if ok = host[:4] == www; ok { - url = scheme + "://" + host[4:] + uri - } - return - }) + config.redirect = redirectNonWWW + return toMiddlewareOrPanic(config) } -func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { +// ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration +func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultTrailingSlashConfig.Skipper + config.Skipper = DefaultSkipper } if config.Code == 0 { - config.Code = DefaultRedirectConfig.Code + config.Code = http.StatusMovedPermanently + } + if config.redirect == nil { + return nil, errors.New("redirectConfig is missing redirect function") } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } req, scheme := c.Request(), c.Scheme() host := req.Host - if ok, url := cb(scheme, host, req.RequestURI); ok { + if ok, url := config.redirect(scheme, host, req.RequestURI); ok { return c.Redirect(config.Code, url) } return next(c) } + }, nil +} + +var redirectHTTPS = func(scheme, host, uri string) (bool, string) { + if scheme != "https" { + return true, "https://" + host + uri + } + return false, "" +} + +var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) { + // Redirect if not HTTPS OR missing www prefix (needs either fix) + if scheme != "https" || !strings.HasPrefix(host, www) { + host = strings.TrimPrefix(host, www) // Remove www if present to avoid duplication + return true, "https://www." + host + uri + } + return false, "" +} + +var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) { + // Redirect if not HTTPS OR has www prefix (needs either fix) + if scheme != "https" || strings.HasPrefix(host, www) { + host = strings.TrimPrefix(host, www) + return true, "https://" + host + uri + } + return false, "" +} + +var redirectWWW = func(scheme, host, uri string) (bool, string) { + if !strings.HasPrefix(host, www) { + return true, scheme + "://www." + host + uri + } + return false, "" +} + +var redirectNonWWW = func(scheme, host, uri string) (bool, string) { + if strings.HasPrefix(host, www) { + return true, scheme + "://" + host[4:] + uri } + return false, "" } diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go index 082609574..a127ca40c 100644 --- a/middleware/redirect_test.go +++ b/middleware/redirect_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -5,74 +8,277 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) type middlewareGenerator func() echo.MiddlewareFunc func TestRedirectHTTPSRedirect(t *testing.T) { - res := redirectTest(HTTPSRedirect, "labstack.com", nil) - - assert.Equal(t, http.StatusMovedPermanently, res.Code) - assert.Equal(t, "https://labstack.com/", res.Header().Get(echo.HeaderLocation)) -} + var testCases = []struct { + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + whenHost: "labstack.com", + expectLocation: "https://labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "labstack.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + } -func TestHTTPSRedirectBehindTLSTerminationProxy(t *testing.T) { - header := http.Header{} - header.Set(echo.HeaderXForwardedProto, "https") - res := redirectTest(HTTPSRedirect, "labstack.com", header) + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + res := redirectTest(HTTPSRedirect, tc.whenHost, tc.whenHeader) - assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } } func TestRedirectHTTPSWWWRedirect(t *testing.T) { - res := redirectTest(HTTPSWWWRedirect, "labstack.com", nil) - - assert.Equal(t, http.StatusMovedPermanently, res.Code) - assert.Equal(t, "https://www.labstack.com/", res.Header().Get(echo.HeaderLocation)) -} + var testCases = []struct { + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + whenHost: "labstack.com", + expectLocation: "https://www.labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.labstack.com", + expectLocation: "https://www.labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "a.com", + expectLocation: "https://www.a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "ip", + expectLocation: "https://www.ip/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "labstack.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "https://www.labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.labstack.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + } -func TestRedirectHTTPSWWWRedirectBehindTLSTerminationProxy(t *testing.T) { - header := http.Header{} - header.Set(echo.HeaderXForwardedProto, "https") - res := redirectTest(HTTPSWWWRedirect, "labstack.com", header) + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + res := redirectTest(HTTPSWWWRedirect, tc.whenHost, tc.whenHeader) - assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } } func TestRedirectHTTPSNonWWWRedirect(t *testing.T) { - res := redirectTest(HTTPSNonWWWRedirect, "www.labstack.com", nil) - - assert.Equal(t, http.StatusMovedPermanently, res.Code) - assert.Equal(t, "https://labstack.com/", res.Header().Get(echo.HeaderLocation)) -} + var testCases = []struct { + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + whenHost: "www.labstack.com", + expectLocation: "https://labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "a.com", + expectLocation: "https://a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "ip", + expectLocation: "https://ip/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.labstack.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "https://labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "labstack.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + } -func TestRedirectHTTPSNonWWWRedirectBehindTLSTerminationProxy(t *testing.T) { - header := http.Header{} - header.Set(echo.HeaderXForwardedProto, "https") - res := redirectTest(HTTPSNonWWWRedirect, "www.labstack.com", header) + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + res := redirectTest(HTTPSNonWWWRedirect, tc.whenHost, tc.whenHeader) - assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } } func TestRedirectWWWRedirect(t *testing.T) { - res := redirectTest(WWWRedirect, "labstack.com", nil) + var testCases = []struct { + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + whenHost: "labstack.com", + expectLocation: "http://www.labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "a.com", + expectLocation: "http://www.a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "ip", + expectLocation: "http://www.ip/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "a.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "https://www.a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.ip", + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + res := redirectTest(WWWRedirect, tc.whenHost, tc.whenHeader) - assert.Equal(t, http.StatusMovedPermanently, res.Code) - assert.Equal(t, "http://www.labstack.com/", res.Header().Get(echo.HeaderLocation)) + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } } func TestRedirectNonWWWRedirect(t *testing.T) { - res := redirectTest(NonWWWRedirect, "www.labstack.com", nil) + var testCases = []struct { + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + whenHost: "www.labstack.com", + expectLocation: "http://labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.a.com", + expectLocation: "http://a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.a.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "https://a.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "ip", + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + } + + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + res := redirectTest(NonWWWRedirect, tc.whenHost, tc.whenHeader) + + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } +} + +func TestNonWWWRedirectWithConfig(t *testing.T) { + var testCases = []struct { + name string + givenCode int + givenSkipFunc func(c *echo.Context) bool + whenHost string + whenHeader http.Header + expectLocation string + expectStatusCode int + }{ + { + name: "usual redirect", + whenHost: "www.labstack.com", + expectLocation: "http://labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + name: "redirect is skipped", + givenSkipFunc: func(c *echo.Context) bool { + return true // skip always + }, + whenHost: "www.labstack.com", + expectLocation: "", + expectStatusCode: http.StatusOK, + }, + { + name: "redirect with custom status code", + givenCode: http.StatusSeeOther, + whenHost: "www.labstack.com", + expectLocation: "http://labstack.com/", + expectStatusCode: http.StatusSeeOther, + }, + } + + for _, tc := range testCases { + t.Run(tc.whenHost, func(t *testing.T) { + middleware := func() echo.MiddlewareFunc { + return NonWWWRedirectWithConfig(RedirectConfig{ + Skipper: tc.givenSkipFunc, + Code: tc.givenCode, + }) + } + res := redirectTest(middleware, tc.whenHost, tc.whenHeader) - assert.Equal(t, http.StatusMovedPermanently, res.Code) - assert.Equal(t, "http://labstack.com/", res.Header().Get(echo.HeaderLocation)) + assert.Equal(t, tc.expectStatusCode, res.Code) + assert.Equal(t, tc.expectLocation, res.Header().Get(echo.HeaderLocation)) + }) + } } func redirectTest(fn middlewareGenerator, host string, header http.Header) *httptest.ResponseRecorder { e := echo.New() - next := func(c echo.Context) (err error) { + next := func(c *echo.Context) (err error) { return c.NoContent(http.StatusOK) } req := httptest.NewRequest(http.MethodGet, "/", nil) diff --git a/middleware/request_id.go b/middleware/request_id.go index 21f801f3b..b3de40d19 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -1,64 +1,73 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" + "github.com/labstack/echo/v5" ) -type ( - // RequestIDConfig defines the config for RequestID middleware. - RequestIDConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper +// RequestIDConfig defines the config for RequestID middleware. +type RequestIDConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - // Generator defines a function to generate an ID. - // Optional. Default value random.String(32). - Generator func() string - } -) + // Generator defines a function to generate an ID. + // Optional. Default value random.String(32). + Generator func() string -var ( - // DefaultRequestIDConfig is the default RequestID middleware config. - DefaultRequestIDConfig = RequestIDConfig{ - Skipper: DefaultSkipper, - Generator: generator, - } -) + // RequestIDHandler defines a function which is executed for a request id. + RequestIDHandler func(c *echo.Context, requestID string) + + // TargetHeader defines what header to look for to populate the id. + // Optional. Default value is `X-Request-Id` + TargetHeader string +} -// RequestID returns a X-Request-ID middleware. +// RequestID returns a middleware that reads RequestIDConfig.TargetHeader (`X-Request-ID`) header value or when +// the header value is empty, generates that value and sets request ID to response +// as RequestIDConfig.TargetHeader (`X-Request-Id`) value. func RequestID() echo.MiddlewareFunc { - return RequestIDWithConfig(DefaultRequestIDConfig) + return RequestIDWithConfig(RequestIDConfig{}) } -// RequestIDWithConfig returns a X-Request-ID middleware with config. +// RequestIDWithConfig returns a middleware with given valid config or panics on invalid configuration. +// The middleware reads RequestIDConfig.TargetHeader (`X-Request-ID`) header value or when the header value is empty, +// generates that value and sets request ID to response as RequestIDConfig.TargetHeader (`X-Request-Id`) value. func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration +func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultRequestIDConfig.Skipper + config.Skipper = DefaultSkipper } if config.Generator == nil { - config.Generator = generator + config.Generator = createRandomStringGenerator(32) + } + if config.TargetHeader == "" { + config.TargetHeader = echo.HeaderXRequestID } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() res := c.Response() - rid := req.Header.Get(echo.HeaderXRequestID) + rid := req.Header.Get(config.TargetHeader) if rid == "" { rid = config.Generator() } - res.Header().Set(echo.HeaderXRequestID, rid) + res.Header().Set(config.TargetHeader, rid) + if config.RequestIDHandler != nil { + config.RequestIDHandler(c, rid) + } return next(c) } - } -} - -func generator() string { - return random.String(32) + }, nil } diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go index 30eecdef9..465e6fc42 100644 --- a/middleware/request_id_test.go +++ b/middleware/request_id_test.go @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( @@ -5,7 +8,7 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -14,11 +17,97 @@ func TestRequestID(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - handler := func(c echo.Context) error { + handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } - rid := RequestIDWithConfig(RequestIDConfig{}) + rid := RequestID() + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) +} + +func TestMustRequestIDWithConfig_skipper(t *testing.T) { + e := echo.New() + e.GET("/", func(c *echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) + + generatorCalled := false + e.Use(RequestIDWithConfig(RequestIDConfig{ + Skipper: func(c *echo.Context) bool { + return true + }, + Generator: func() string { + generatorCalled = true + return "customGenerator" + }, + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.Equal(t, "test", res.Body.String()) + + assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "") + assert.False(t, generatorCalled) +} + +func TestMustRequestIDWithConfig_customGenerator(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid := RequestIDWithConfig(RequestIDConfig{ + Generator: func() string { return "customGenerator" }, + }) + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") +} + +func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + + called := false + rid := RequestIDWithConfig(RequestIDConfig{ + Generator: func() string { return "customGenerator" }, + RequestIDHandler: func(c *echo.Context, s string) { + called = true + }, + }) + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") + assert.True(t, called) +} + +func TestRequestIDWithConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid, err := RequestIDConfig{}.ToMiddleware() + assert.NoError(t, err) h := rid(handler) h(c) assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) @@ -31,3 +120,51 @@ func TestRequestID(t *testing.T) { h(c) assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") } + +func TestRequestID_IDNotAltered(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Add(echo.HeaderXRequestID, "") + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid := RequestIDWithConfig(RequestIDConfig{}) + h := rid(handler) + _ = h(c) + assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "") +} + +func TestRequestIDConfigDifferentHeader(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid := RequestIDWithConfig(RequestIDConfig{TargetHeader: echo.HeaderXCorrelationID}) + h := rid(handler) + h(c) + assert.Len(t, rec.Header().Get(echo.HeaderXCorrelationID), 32) + + // Custom generator and handler + customID := "customGenerator" + calledHandler := false + rid = RequestIDWithConfig(RequestIDConfig{ + Generator: func() string { return customID }, + TargetHeader: echo.HeaderXCorrelationID, + RequestIDHandler: func(_ *echo.Context, id string) { + calledHandler = true + assert.Equal(t, customID, id) + }, + }) + h = rid(handler) + h(c) + assert.Equal(t, rec.Header().Get(echo.HeaderXCorrelationID), "customGenerator") + assert.True(t, calledHandler) +} diff --git a/middleware/request_logger.go b/middleware/request_logger.go new file mode 100644 index 000000000..76903c62a --- /dev/null +++ b/middleware/request_logger.go @@ -0,0 +1,462 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package middleware + +import ( + "context" + "errors" + "log/slog" + "net/http" + "time" + + "github.com/labstack/echo/v5" +) + +// Example for `slog` https://pkg.go.dev/log/slog +// logger := slog.New(slog.NewJSONHandler(os.Stdout, nil)) +// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ +// LogStatus: true, +// LogURI: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code +// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { +// if v.Error == nil { +// logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST", +// slog.String("uri", v.URI), +// slog.Int("status", v.Status), +// ) +// } else { +// logger.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR", +// slog.String("uri", v.URI), +// slog.Int("status", v.Status), +// slog.String("err", v.Error.Error()), +// ) +// } +// return nil +// }, +// })) +// +// Example for `fmt.Printf` +// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ +// LogStatus: true, +// LogURI: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code +// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { +// if v.Error == nil { +// fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status) +// } else { +// fmt.Printf("REQUEST_ERROR: uri: %v, status: %v, err: %v\n", v.URI, v.Status, v.Error) +// } +// return nil +// }, +// })) +// +// Example for Zerolog (https://github.com/rs/zerolog) +// logger := zerolog.New(os.Stdout) +// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ +// LogURI: true, +// LogStatus: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code +// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { +// if v.Error == nil { +// logger.Info(). +// Str("URI", v.URI). +// Int("status", v.Status). +// Msg("request") +// } else { +// logger.Error(). +// Err(v.Error). +// Str("URI", v.URI). +// Int("status", v.Status). +// Msg("request error") +// } +// return nil +// }, +// })) +// +// Example for Zap (https://github.com/uber-go/zap) +// logger, _ := zap.NewProduction() +// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ +// LogURI: true, +// LogStatus: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code +// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { +// if v.Error == nil { +// logger.Info("request", +// zap.String("URI", v.URI), +// zap.Int("status", v.Status), +// ) +// } else { +// logger.Error("request error", +// zap.String("URI", v.URI), +// zap.Int("status", v.Status), +// zap.Error(v.Error), +// ) +// } +// return nil +// }, +// })) +// +// Example for Logrus (https://github.com/sirupsen/logrus) +// log := logrus.New() +// e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ +// LogURI: true, +// LogStatus: true, +// HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code +// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { +// if v.Error == nil { +// log.WithFields(logrus.Fields{ +// "URI": v.URI, +// "status": v.Status, +// }).Info("request") +// } else { +// log.WithFields(logrus.Fields{ +// "URI": v.URI, +// "status": v.Status, +// "error": v.Error, +// }).Error("request error") +// } +// return nil +// }, +// })) + +// RequestLoggerConfig is configuration for Request Logger middleware. +type RequestLoggerConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // BeforeNextFunc defines a function that is called before next middleware or handler is called in chain. + BeforeNextFunc func(c *echo.Context) + // LogValuesFunc defines a function that is called with values extracted by logger from request/response. + // Mandatory. + LogValuesFunc func(c *echo.Context, v RequestLoggerValues) error + + // HandleError instructs logger to call global error handler when next middleware/handler returns an error. + // This is useful when you have custom error handler that can decide to use different status codes. + // + // A side-effect of calling global error handler is that now Response has been committed and sent to the client + // and middlewares up in chain can not change Response status code or response body. + HandleError bool + + // LogLatency instructs logger to record duration it took to execute rest of the handler chain (next(c) call). + LogLatency bool + // LogProtocol instructs logger to extract request protocol (i.e. `HTTP/1.1` or `HTTP/2`) + LogProtocol bool + // LogRemoteIP instructs logger to extract request remote IP. See `echo.Context.RealIP()` for implementation details. + LogRemoteIP bool + // LogHost instructs logger to extract request host value (i.e. `example.com`) + LogHost bool + // LogMethod instructs logger to extract request method value (i.e. `GET` etc) + LogMethod bool + // LogURI instructs logger to extract request URI (i.e. `/list?lang=en&page=1`) + LogURI bool + // LogURIPath instructs logger to extract request URI path part (i.e. `/list`) + LogURIPath bool + // LogRoutePath instructs logger to extract route path part to which request was matched to (i.e. `/user/:id`) + LogRoutePath bool + // LogRequestID instructs logger to extract request ID from request `X-Request-ID` header or response if request did not have value. + LogRequestID bool + // LogReferer instructs logger to extract request referer values. + LogReferer bool + // LogUserAgent instructs logger to extract request user agent values. + LogUserAgent bool + // LogStatus instructs logger to extract response status code. If handler chain returns an echo.HTTPError, + // the status code is extracted from the echo.HTTPError returned + LogStatus bool + // LogContentLength instructs logger to extract content length header value. Note: this value could be different from + // actual request body size as it could be spoofed etc. + LogContentLength bool + // LogResponseSize instructs logger to extract response content length value. Note: when used with Gzip middleware + // this value may not be always correct. + LogResponseSize bool + // LogHeaders instructs logger to extract given list of headers from request. Note: request can contain more than + // one header with same value so slice of values is been logger for each given header. + // + // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header + // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding". + LogHeaders []string + // LogQueryParams instructs logger to extract given list of query parameters from request URI. Note: request can + // contain more than one query parameter with same name so slice of values is been logger for each given query param name. + LogQueryParams []string + // LogFormValues instructs logger to extract given list of form values from request body+URI. Note: request can + // contain more than one form value with same name so slice of values is been logger for each given form value name. + LogFormValues []string + + timeNow func() time.Time +} + +// RequestLoggerValues contains extracted values from logger. +type RequestLoggerValues struct { + // StartTime is time recorded before next middleware/handler is executed. + StartTime time.Time + // Latency is duration it took to execute rest of the handler chain (next(c) call). + Latency time.Duration + // Protocol is request protocol (i.e. `HTTP/1.1` or `HTTP/2`) + Protocol string + // RemoteIP is request remote IP. See `echo.Context.RealIP()` for implementation details. + RemoteIP string + // Host is request host value (i.e. `example.com`) + Host string + // Method is request method value (i.e. `GET` etc) + Method string + // URI is request URI (i.e. `/list?lang=en&page=1`) + URI string + // URIPath is request URI path part (i.e. `/list`) + URIPath string + // RoutePath is route path part to which request was matched to (i.e. `/user/:id`) + RoutePath string + // RequestID is request ID from request `X-Request-ID` header or response if request did not have value. + RequestID string + // Referer is request referer values. + Referer string + // UserAgent is request user agent values. + UserAgent string + // Status is response status code. Then handler returns an echo.HTTPError then code from there. + Status int + // Error is error returned from executed handler chain. + Error error + // ContentLength is content length header value. Note: this value could be different from actual request body size + // as it could be spoofed etc. + ContentLength string + // ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct. + ResponseSize int64 + // Headers are list of headers from request. Note: request can contain more than one header with same value so slice + // of values is what will be returned/logged for each given header. + // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header + // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding". + Headers map[string][]string + // QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter + // with same name so slice of values is what will be returned/logged for each given query param name. + QueryParams map[string][]string + // FormValues are list of form values from request body+URI. Note: request can contain more than one form value with + // same name so slice of values is what will be returned/logged for each given form value name. + FormValues map[string][]string +} + +// RequestLoggerWithConfig returns a RequestLogger middleware with config. +func RequestLoggerWithConfig(config RequestLoggerConfig) echo.MiddlewareFunc { + mw, err := config.ToMiddleware() + if err != nil { + panic(err) + } + return mw +} + +// ToMiddleware converts RequestLoggerConfig into middleware or returns an error for invalid configuration. +func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { + if config.Skipper == nil { + config.Skipper = DefaultSkipper + } + now := time.Now + if config.timeNow != nil { + now = config.timeNow + } + + if config.LogValuesFunc == nil { + return nil, errors.New("missing LogValuesFunc callback function for request logger middleware") + } + + logHeaders := len(config.LogHeaders) > 0 + headers := append([]string(nil), config.LogHeaders...) + for i, v := range headers { + headers[i] = http.CanonicalHeaderKey(v) + } + + logQueryParams := len(config.LogQueryParams) > 0 + logFormValues := len(config.LogFormValues) > 0 + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c *echo.Context) error { + if config.Skipper(c) { + return next(c) + } + + req := c.Request() + res := c.Response() + start := now() + + if config.BeforeNextFunc != nil { + config.BeforeNextFunc(c) + } + err := next(c) + if err != nil && config.HandleError { + // When global error handler writes the error to the client the Response gets "committed". This state can be + // checked with `c.Response().Committed` field. + c.Echo().HTTPErrorHandler(c, err) + } + + v := RequestLoggerValues{ + StartTime: start, + } + if config.LogLatency { + v.Latency = now().Sub(start) + } + if config.LogProtocol { + v.Protocol = req.Proto + } + if config.LogRemoteIP { + v.RemoteIP = c.RealIP() + } + if config.LogHost { + v.Host = req.Host + } + if config.LogMethod { + v.Method = req.Method + } + if config.LogURI { + v.URI = req.RequestURI + } + if config.LogURIPath { + p := req.URL.Path + if p == "" { + p = "/" + } + v.URIPath = p + } + if config.LogRoutePath { + v.RoutePath = c.Path() + } + if config.LogRequestID { + id := req.Header.Get(echo.HeaderXRequestID) + if id == "" { + id = res.Header().Get(echo.HeaderXRequestID) + } + v.RequestID = id + } + if config.LogReferer { + v.Referer = req.Referer() + } + if config.LogUserAgent { + v.UserAgent = req.UserAgent() + } + + var resp *echo.Response + if config.LogStatus || config.LogResponseSize { + if r, err := echo.UnwrapResponse(res); err != nil { + c.Logger().Error("can not determine response status and/or size. ResponseWriter in context does not implement unwrapper interface") + } else { + resp = r + } + } + + if config.LogStatus { + v.Status = -1 + if resp != nil { + v.Status = resp.Status + } + if err != nil && !config.HandleError { + // this block should not be executed in case of HandleError=true as the global error handler will decide + // the status code. In that case status code could be different from what err contains. + var hsc echo.HTTPStatusCoder + if errors.As(err, &hsc) { + v.Status = hsc.StatusCode() + } + } + } + if err != nil { + v.Error = err + } + if config.LogContentLength { + v.ContentLength = req.Header.Get(echo.HeaderContentLength) + } + if config.LogResponseSize { + v.ResponseSize = -1 + if resp != nil { + v.ResponseSize = resp.Size + } + } + if logHeaders { + v.Headers = map[string][]string{} + for _, header := range headers { + if values, ok := req.Header[header]; ok { + v.Headers[header] = values + } + } + } + if logQueryParams { + queryParams := c.QueryParams() + v.QueryParams = map[string][]string{} + for _, param := range config.LogQueryParams { + if values, ok := queryParams[param]; ok { + v.QueryParams[param] = values + } + } + } + if logFormValues { + v.FormValues = map[string][]string{} + for _, formValue := range config.LogFormValues { + if values, ok := req.Form[formValue]; ok { + v.FormValues[formValue] = values + } + } + } + + if errOnLog := config.LogValuesFunc(c, v); errOnLog != nil { + return errOnLog + } + // in case of HandleError=true we are returning the error that we already have handled with global error handler + // this is deliberate as this error could be useful for upstream middlewares and default global error handler + // will ignore that error when it bubbles up in middleware chain. + // Committed response can be checked in custom error handler with following logic + // + // if r, _ := echo.UnwrapResponse(c.Response()); r != nil && r.Committed { + // return + // } + return err + } + }, nil +} + +// RequestLogger creates Request Logger middleware with Echo default settings that uses Context.Logger() as logger. +func RequestLogger() echo.MiddlewareFunc { + return RequestLoggerWithConfig(RequestLoggerConfig{ + LogLatency: true, + LogRemoteIP: true, + LogHost: true, + LogMethod: true, + LogURI: true, + LogRequestID: true, + LogUserAgent: true, + LogStatus: true, + LogContentLength: true, + LogResponseSize: true, + // forwards error to the global error handler, so it can decide appropriate status code. + // NB: side-effect of that is - request is now "commited" written to the client. Middlewares up in chain can not + // change Response status code or response body. + HandleError: true, + LogValuesFunc: func(c *echo.Context, v RequestLoggerValues) error { + logger := c.Logger() + if v.Error == nil { + logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST", + slog.String("method", v.Method), + slog.String("uri", v.URI), + slog.Int("status", v.Status), + slog.Duration("latency", v.Latency), + slog.String("host", v.Host), + slog.String("bytes_in", v.ContentLength), + slog.Int64("bytes_out", v.ResponseSize), + slog.String("user_agent", v.UserAgent), + slog.String("remote_ip", v.RemoteIP), + slog.String("request_id", v.RequestID), + ) + return nil + } + + logger.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR", + slog.String("method", v.Method), + slog.String("uri", v.URI), + slog.Int("status", v.Status), + slog.Duration("latency", v.Latency), + slog.String("host", v.Host), + slog.String("bytes_in", v.ContentLength), + slog.Int64("bytes_out", v.ResponseSize), + slog.String("user_agent", v.UserAgent), + slog.String("remote_ip", v.RemoteIP), + slog.String("request_id", v.RequestID), + + slog.String("error", v.Error.Error()), + ) + return nil + }, + }) +} diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go new file mode 100644 index 000000000..af39eb32a --- /dev/null +++ b/middleware/request_logger_test.go @@ -0,0 +1,630 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package middleware + +import ( + "bytes" + "encoding/json" + "errors" + "log/slog" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + "testing" + "time" + + "github.com/labstack/echo/v5" + "github.com/stretchr/testify/assert" +) + +func TestRequestLoggerOK(t *testing.T) { + old := slog.Default() + t.Cleanup(func() { + slog.SetDefault(old) + }) + + e := echo.New() + buf := new(bytes.Buffer) + e.Logger = slog.New(slog.NewJSONHandler(buf, nil)) + e.Use(RequestLogger()) + + e.POST("/test", func(c *echo.Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + reader := strings.NewReader(`{"foo":"bar"}`) + req := httptest.NewRequest(http.MethodPost, "/test", reader) + req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size()))) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Set(echo.HeaderXRealIP, "8.8.8.8") + req.Header.Set("User-Agent", "curl/7.68.0") + + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + logAttrs := map[string]interface{}{} + assert.NoError(t, json.Unmarshal(buf.Bytes(), &logAttrs)) + logAttrs["latency"] = 123 + logAttrs["time"] = "x" + + expect := map[string]interface{}{ + "level": "INFO", + "msg": "REQUEST", + "method": "POST", + "uri": "/test", + "status": float64(418), + "bytes_in": "13", + "host": "example.com", + "bytes_out": float64(2), + "user_agent": "curl/7.68.0", + "remote_ip": "8.8.8.8", + "request_id": "", + + "time": "x", + "latency": 123, + } + assert.Equal(t, expect, logAttrs) +} + +func TestRequestLoggerError(t *testing.T) { + old := slog.Default() + t.Cleanup(func() { + slog.SetDefault(old) + }) + + e := echo.New() + buf := new(bytes.Buffer) + e.Logger = slog.New(slog.NewJSONHandler(buf, nil)) + e.Use(RequestLogger()) + + e.GET("/test", func(c *echo.Context) error { + return errors.New("nope") + }) + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + logAttrs := map[string]interface{}{} + assert.NoError(t, json.Unmarshal(buf.Bytes(), &logAttrs)) + logAttrs["latency"] = 123 + logAttrs["time"] = "x" + + expect := map[string]interface{}{ + "level": "ERROR", + "msg": "REQUEST_ERROR", + "method": "GET", + "uri": "/test", + "status": float64(500), + "bytes_in": "", + "host": "example.com", + "bytes_out": float64(36.0), + "user_agent": "", + "remote_ip": "192.0.2.1", + "request_id": "", + "error": "nope", + + "latency": 123, + "time": "x", + } + assert.Equal(t, expect, logAttrs) +} + +func TestRequestLoggerWithConfig(t *testing.T) { + e := echo.New() + + var expect RequestLoggerValues + e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + LogRoutePath: true, + LogURI: true, + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + expect = values + return nil + }, + })) + + e.GET("/test", func(c *echo.Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.Equal(t, "/test", expect.RoutePath) +} + +func TestRequestLoggerWithConfig_missingOnLogValuesPanics(t *testing.T) { + assert.Panics(t, func() { + RequestLoggerWithConfig(RequestLoggerConfig{ + LogValuesFunc: nil, + }) + }) +} + +func TestRequestLogger_skipper(t *testing.T) { + e := echo.New() + + loggerCalled := false + e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + Skipper: func(c *echo.Context) bool { + return true + }, + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + loggerCalled = true + return nil + }, + })) + + e.GET("/test", func(c *echo.Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.False(t, loggerCalled) +} + +func TestRequestLogger_beforeNextFunc(t *testing.T) { + e := echo.New() + + var myLoggerInstance int + e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + BeforeNextFunc: func(c *echo.Context) { + c.Set("myLoggerInstance", 42) + }, + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + myLoggerInstance = c.Get("myLoggerInstance").(int) + return nil + }, + })) + + e.GET("/test", func(c *echo.Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.Equal(t, 42, myLoggerInstance) +} + +func TestRequestLogger_logError(t *testing.T) { + e := echo.New() + + var actual RequestLoggerValues + e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + LogStatus: true, + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + actual = values + return nil + }, + })) + + e.GET("/test", func(c *echo.Context) error { + return echo.NewHTTPError(http.StatusNotAcceptable, "nope") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusNotAcceptable, rec.Code) + assert.Equal(t, http.StatusNotAcceptable, actual.Status) + assert.EqualError(t, actual.Error, "code=406, message=nope") +} + +func TestRequestLogger_HandleError(t *testing.T) { + e := echo.New() + + var actual RequestLoggerValues + e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + timeNow: func() time.Time { + return time.Unix(1631045377, 0).UTC() + }, + HandleError: true, + LogStatus: true, + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + actual = values + return nil + }, + })) + + // to see if "HandleError" works we create custom error handler that uses its own status codes + e.HTTPErrorHandler = func(c *echo.Context, err error) { + if r, _ := echo.UnwrapResponse(c.Response()); r != nil && r.Committed { + return + } + c.JSON(http.StatusTeapot, "custom error handler") + } + + e.GET("/test", func(c *echo.Context) error { + return echo.NewHTTPError(http.StatusForbidden, "nope") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + + expect := RequestLoggerValues{ + StartTime: time.Unix(1631045377, 0).UTC(), + Status: http.StatusTeapot, + Error: echo.NewHTTPError(http.StatusForbidden, "nope"), + } + assert.Equal(t, expect, actual) +} + +func TestRequestLogger_LogValuesFuncError(t *testing.T) { + e := echo.New() + + var expect RequestLoggerValues + e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + LogStatus: true, + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + expect = values + return echo.NewHTTPError(http.StatusNotAcceptable, "LogValuesFuncError") + }, + })) + + e.GET("/test", func(c *echo.Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + // NOTE: when global error handler received error returned from middleware the status has already + // been written to the client and response has been "committed" therefore global error handler does not do anything + // and error that bubbled up in middleware chain will not be reflected in response code. + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.Equal(t, http.StatusTeapot, expect.Status) +} + +func TestRequestLogger_ID(t *testing.T) { + var testCases = []struct { + name string + whenFromRequest bool + expect string + }{ + { + name: "ok, ID is provided from request headers", + whenFromRequest: true, + expect: "123", + }, + { + name: "ok, ID is from response headers", + whenFromRequest: false, + expect: "321", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + var expect RequestLoggerValues + e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ + LogRequestID: true, + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + expect = values + return nil + }, + })) + + e.GET("/test", func(c *echo.Context) error { + c.Response().Header().Set(echo.HeaderXRequestID, "321") + return c.String(http.StatusTeapot, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + if tc.whenFromRequest { + req.Header.Set(echo.HeaderXRequestID, "123") + } + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.Equal(t, tc.expect, expect.RequestID) + }) + } +} + +func TestRequestLogger_headerIsCaseInsensitive(t *testing.T) { + e := echo.New() + + var expect RequestLoggerValues + mw := RequestLoggerWithConfig(RequestLoggerConfig{ + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + expect = values + return nil + }, + LogHeaders: []string{"referer", "User-Agent"}, + })(func(c *echo.Context) error { + c.Request().Header.Set(echo.HeaderXRequestID, "123") + c.FormValue("to force parse form") + return c.String(http.StatusTeapot, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/test?lang=en&checked=1&checked=2", nil) + req.Header.Set("referer", "https://echo.labstack.com/") + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := mw(c) + + assert.NoError(t, err) + assert.Len(t, expect.Headers, 1) + assert.Equal(t, []string{"https://echo.labstack.com/"}, expect.Headers["Referer"]) +} + +func TestRequestLogger_allFields(t *testing.T) { + e := echo.New() + + isFirstNowCall := true + var expect RequestLoggerValues + mw := RequestLoggerWithConfig(RequestLoggerConfig{ + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + expect = values + return nil + }, + LogLatency: true, + LogProtocol: true, + LogRemoteIP: true, + LogHost: true, + LogMethod: true, + LogURI: true, + LogURIPath: true, + LogRoutePath: true, + LogRequestID: true, + LogReferer: true, + LogUserAgent: true, + LogStatus: true, + LogContentLength: true, + LogResponseSize: true, + LogHeaders: []string{"accept-encoding", "User-Agent"}, + LogQueryParams: []string{"lang", "checked"}, + LogFormValues: []string{"csrf", "multiple"}, + timeNow: func() time.Time { + if isFirstNowCall { + isFirstNowCall = false + return time.Unix(1631045377, 0) + } + return time.Unix(1631045377+10, 0) + }, + })(func(c *echo.Context) error { + c.Request().Header.Set(echo.HeaderXRequestID, "123") + c.FormValue("to force parse form") + return c.String(http.StatusTeapot, "OK") + }) + + f := make(url.Values) + f.Set("csrf", "token") + f.Set("multiple", "1") + f.Add("multiple", "2") + reader := strings.NewReader(f.Encode()) + req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", reader) + req.Header.Set("Referer", "https://echo.labstack.com/") + req.Header.Set("User-Agent", "curl/7.68.0") + req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size()))) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) + req.Header.Set(echo.HeaderXRealIP, "8.8.8.8") + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + c.SetPath("/test*") + + err := mw(c) + + assert.NoError(t, err) + assert.Equal(t, time.Unix(1631045377, 0), expect.StartTime) + assert.Equal(t, 10*time.Second, expect.Latency) + assert.Equal(t, "HTTP/1.1", expect.Protocol) + assert.Equal(t, "8.8.8.8", expect.RemoteIP) + assert.Equal(t, "example.com", expect.Host) + assert.Equal(t, http.MethodPost, expect.Method) + assert.Equal(t, "/test?lang=en&checked=1&checked=2", expect.URI) + assert.Equal(t, "/test", expect.URIPath) + assert.Equal(t, "/test*", expect.RoutePath) + assert.Equal(t, "123", expect.RequestID) + assert.Equal(t, "https://echo.labstack.com/", expect.Referer) + assert.Equal(t, "curl/7.68.0", expect.UserAgent) + assert.Equal(t, 418, expect.Status) + assert.Equal(t, nil, expect.Error) + assert.Equal(t, "32", expect.ContentLength) + assert.Equal(t, int64(2), expect.ResponseSize) + + assert.Len(t, expect.Headers, 1) + assert.Equal(t, []string{"curl/7.68.0"}, expect.Headers["User-Agent"]) + + assert.Len(t, expect.QueryParams, 2) + assert.Equal(t, []string{"en"}, expect.QueryParams["lang"]) + assert.Equal(t, []string{"1", "2"}, expect.QueryParams["checked"]) + + assert.Len(t, expect.FormValues, 2) + assert.Equal(t, []string{"token"}, expect.FormValues["csrf"]) + assert.Equal(t, []string{"1", "2"}, expect.FormValues["multiple"]) +} + +func TestTestRequestLogger(t *testing.T) { + var testCases = []struct { + name string + whenStatus int + whenError error + expectStatus string + expectError string + }{ + { + name: "ok", + whenStatus: http.StatusTeapot, + expectStatus: "418", + }, + { + name: "error", + whenError: echo.NewHTTPError(http.StatusBadGateway, "bad gw"), + expectStatus: "502", + expectError: `"error":"code=502, message=bad gw"`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + buf := new(bytes.Buffer) + e.Logger = slog.New(slog.NewJSONHandler(buf, nil)) + + e.Use(RequestLogger()) + e.POST("/test", func(c *echo.Context) error { + if tc.whenError != nil { + return tc.whenError + } + return c.String(tc.whenStatus, "OK") + }) + + f := make(url.Values) + f.Set("csrf", "token") + f.Set("multiple", "1") + f.Add("multiple", "2") + reader := strings.NewReader(f.Encode()) + req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", reader) + req.Header.Set("Referer", "https://echo.labstack.com/") + req.Header.Set("User-Agent", "curl/7.68.0") + req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size()))) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) + req.Header.Set(echo.HeaderXRealIP, "8.8.8.8") + req.Header.Set(echo.HeaderXRequestID, "MY_ID") + + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + rawlog := buf.Bytes() + if tc.expectError != "" { + assert.Contains(t, string(rawlog), `"level":"ERROR"`) + assert.Contains(t, string(rawlog), `"msg":"REQUEST_ERROR"`) + assert.Contains(t, string(rawlog), tc.expectError) + } else { + assert.Contains(t, string(rawlog), `"level":"INFO"`) + assert.Contains(t, string(rawlog), `"msg":"REQUEST"`) + } + assert.Contains(t, string(rawlog), `"status":`+tc.expectStatus) + assert.Contains(t, string(rawlog), `"method":"POST"`) + assert.Contains(t, string(rawlog), `"uri":"/test?lang=en&checked=1&checked=2"`) + assert.Contains(t, string(rawlog), `"latency":`) // this value varies + assert.Contains(t, string(rawlog), `"request_id":"MY_ID"`) + assert.Contains(t, string(rawlog), `"remote_ip":"8.8.8.8"`) + assert.Contains(t, string(rawlog), `"host":"example.com"`) + assert.Contains(t, string(rawlog), `"user_agent":"curl/7.68.0"`) + assert.Contains(t, string(rawlog), `"bytes_in":"32"`) + assert.Contains(t, string(rawlog), `"bytes_out":2`) + }) + } +} + +func BenchmarkRequestLogger_withoutMapFields(b *testing.B) { + e := echo.New() + + mw := RequestLoggerWithConfig(RequestLoggerConfig{ + Skipper: nil, + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + return nil + }, + LogLatency: true, + LogProtocol: true, + LogRemoteIP: true, + LogHost: true, + LogMethod: true, + LogURI: true, + LogURIPath: true, + LogRoutePath: true, + LogRequestID: true, + LogReferer: true, + LogUserAgent: true, + LogStatus: true, + LogContentLength: true, + LogResponseSize: true, + })(func(c *echo.Context) error { + c.Request().Header.Set(echo.HeaderXRequestID, "123") + return c.String(http.StatusTeapot, "OK") + }) + + req := httptest.NewRequest(http.MethodGet, "/test?lang=en", nil) + req.Header.Set("Referer", "https://echo.labstack.com/") + req.Header.Set("User-Agent", "curl/7.68.0") + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + mw(c) + } +} + +func BenchmarkRequestLogger_withMapFields(b *testing.B) { + e := echo.New() + + mw := RequestLoggerWithConfig(RequestLoggerConfig{ + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { + return nil + }, + LogLatency: true, + LogProtocol: true, + LogRemoteIP: true, + LogHost: true, + LogMethod: true, + LogURI: true, + LogURIPath: true, + LogRoutePath: true, + LogRequestID: true, + LogReferer: true, + LogUserAgent: true, + LogStatus: true, + LogContentLength: true, + LogResponseSize: true, + LogHeaders: []string{"accept-encoding", "User-Agent"}, + LogQueryParams: []string{"lang", "checked"}, + LogFormValues: []string{"csrf", "multiple"}, + })(func(c *echo.Context) error { + c.Request().Header.Set(echo.HeaderXRequestID, "123") + c.FormValue("to force parse form") + return c.String(http.StatusTeapot, "OK") + }) + + f := make(url.Values) + f.Set("csrf", "token") + f.Add("multiple", "1") + f.Add("multiple", "2") + req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", strings.NewReader(f.Encode())) + req.Header.Set("Referer", "https://echo.labstack.com/") + req.Header.Set("User-Agent", "curl/7.68.0") + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + mw(c) + } +} diff --git a/middleware/rewrite.go b/middleware/rewrite.go index a64e10bb3..ea58091b0 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -1,84 +1,80 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( + "errors" "regexp" - "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // RewriteConfig defines the config for Rewrite middleware. - RewriteConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Rules defines the URL path rewrite rules. The values captured in asterisk can be - // retrieved by index e.g. $1, $2 and so on. - // Example: - // "/old": "/new", - // "/api/*": "/$1", - // "/js/*": "/public/javascripts/$1", - // "/users/*/orders/*": "/user/$1/order/$2", - // Required. - Rules map[string]string `yaml:"rules"` +// RewriteConfig defines the config for Rewrite middleware. +type RewriteConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper - rulesRegex map[*regexp.Regexp]string - } -) + // Rules defines the URL path rewrite rules. The values captured in asterisk can be + // retrieved by index e.g. $1, $2 and so on. + // Example: + // "/old": "/new", + // "/api/*": "/$1", + // "/js/*": "/public/javascripts/$1", + // "/users/*/orders/*": "/user/$1/order/$2", + // Required. + Rules map[string]string -var ( - // DefaultRewriteConfig is the default Rewrite middleware config. - DefaultRewriteConfig = RewriteConfig{ - Skipper: DefaultSkipper, - } -) + // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRules map[*regexp.Regexp]string +} // Rewrite returns a Rewrite middleware. // // Rewrite middleware rewrites the URL path based on the provided rules. func Rewrite(rules map[string]string) echo.MiddlewareFunc { - c := DefaultRewriteConfig + c := RewriteConfig{} c.Rules = rules return RewriteWithConfig(c) } -// RewriteWithConfig returns a Rewrite middleware with config. -// See: `Rewrite()`. +// RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration. +// +// Rewrite middleware rewrites the URL path based on the provided rules. func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { - // Defaults - if config.Rules == nil { - panic("echo: rewrite middleware requires url path rewrite rules") - } + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration +func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultBodyDumpConfig.Skipper + config.Skipper = DefaultSkipper + } + if config.Rules == nil && config.RegexRules == nil { + return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules") } - config.rulesRegex = map[*regexp.Regexp]string{} - // Initialize - for k, v := range config.Rules { - k = strings.Replace(k, "*", "(.*)", -1) - k = k + "$" - config.rulesRegex[regexp.MustCompile(k)] = v + if config.RegexRules == nil { + config.RegexRules = make(map[*regexp.Regexp]string) + } + for k, v := range rewriteRulesRegex(config.Rules) { + config.RegexRules[k] = v } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c *echo.Context) (err error) { if config.Skipper(c) { return next(c) } - req := c.Request() - - // Rewrite - for k, v := range config.rulesRegex { - replacer := captureTokens(k, req.URL.Path) - if replacer != nil { - req.URL.Path = replacer.Replace(v) - break - } + if err := rewriteURL(config.RegexRules, c.Request()); err != nil { + return err } return next(c) } - } + }, nil } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index eb5a46d89..f45b8d98a 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -1,17 +1,23 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( - "io/ioutil" + "io" "net/http" "net/http/httptest" + "net/url" + "regexp" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) -func TestRewrite(t *testing.T) { +func TestRewriteAfterRouting(t *testing.T) { e := echo.New() + // middlewares added with `Use()` are executed after routing is done and do not affect which route handler is matched e.Use(RewriteWithConfig(RewriteConfig{ Rules: map[string]string{ "/old": "/new", @@ -20,54 +26,156 @@ func TestRewrite(t *testing.T) { "/users/*/orders/*": "/user/$1/order/$2", }, })) - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - req.URL.Path = "/api/users" - e.ServeHTTP(rec, req) - assert.Equal(t, "/users", req.URL.Path) - req.URL.Path = "/js/main.js" - e.ServeHTTP(rec, req) - assert.Equal(t, "/public/javascripts/main.js", req.URL.Path) - req.URL.Path = "/old" - e.ServeHTTP(rec, req) - assert.Equal(t, "/new", req.URL.Path) - req.URL.Path = "/users/jack/orders/1" - e.ServeHTTP(rec, req) - assert.Equal(t, "/user/jack/order/1", req.URL.Path) - req.URL.Path = "/api/new users" - e.ServeHTTP(rec, req) - assert.Equal(t, "/new users", req.URL.Path) + e.GET("/public/*", func(c *echo.Context) error { + return c.String(http.StatusOK, c.Param("*")) + }) + e.GET("/*", func(c *echo.Context) error { + return c.String(http.StatusOK, c.Param("*")) + }) + + var testCases = []struct { + whenPath string + expectRoutePath string + expectRequestPath string + expectRequestRawPath string + }{ + { + whenPath: "/api/users", + expectRoutePath: "api/users", + expectRequestPath: "/users", + expectRequestRawPath: "", + }, + { + whenPath: "/js/main.js", + expectRoutePath: "js/main.js", + expectRequestPath: "/public/javascripts/main.js", + expectRequestRawPath: "", + }, + { + whenPath: "/users/jack/orders/1", + expectRoutePath: "users/jack/orders/1", + expectRequestPath: "/user/jack/order/1", + expectRequestRawPath: "", + }, + { // no rewrite rule matched. already encoded URL should not be double encoded or changed in any way + whenPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + expectRoutePath: "user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + expectRequestPath: "/user/jill/order/T/cO4lW/t/Vp/", // this is equal to `url.Parse(tc.whenPath)` result + expectRequestRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + }, + { // just rewrite but do not touch encoding. already encoded URL should not be double encoded + whenPath: "/users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F", + expectRoutePath: "users/jill/orders/T%2FcO4lW%2Ft%2FVp%2F", + expectRequestPath: "/user/jill/order/T/cO4lW/t/Vp/", // this is equal to `url.Parse(tc.whenPath)` result + expectRequestRawPath: "/user/jill/order/T%2FcO4lW%2Ft%2FVp%2F", + }, + { // ` ` (space) is encoded by httpClient to `%20` when doing request to Echo. `%20` should not be double escaped or changed in any way when rewriting request + whenPath: "/api/new users", + expectRoutePath: "api/new users", + expectRequestPath: "/new users", + expectRequestRawPath: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.whenPath, func(t *testing.T) { + target, _ := url.Parse(tc.whenPath) + req := httptest.NewRequest(http.MethodGet, target.String(), nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, tc.expectRoutePath, rec.Body.String()) + assert.Equal(t, tc.expectRequestPath, req.URL.Path) + assert.Equal(t, tc.expectRequestRawPath, req.URL.RawPath) + }) + } +} + +func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) { + assert.Panics(t, func() { + RewriteWithConfig(RewriteConfig{}) + }) +} + +func TestMustRewriteWithConfig_skipper(t *testing.T) { + var testCases = []struct { + name string + givenSkipper func(c *echo.Context) bool + whenURL string + expectURL string + expectStatus int + }{ + { + name: "not skipped", + whenURL: "/old", + expectURL: "/new", + expectStatus: http.StatusOK, + }, + { + name: "skipped", + givenSkipper: func(c *echo.Context) bool { + return true + }, + whenURL: "/old", + expectURL: "/old", + expectStatus: http.StatusNotFound, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + e.Pre(RewriteWithConfig( + RewriteConfig{ + Skipper: tc.givenSkipper, + Rules: map[string]string{"/old": "/new"}}, + )) + + e.GET("/new", func(c *echo.Context) error { + return c.NoContent(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectURL, req.URL.EscapedPath()) + assert.Equal(t, tc.expectStatus, rec.Code) + }) + } } // Issue #1086 func TestEchoRewritePreMiddleware(t *testing.T) { e := echo.New() - r := e.Router() // Rewrite old url to new one + // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches e.Pre(RewriteWithConfig(RewriteConfig{ - Rules: map[string]string{ - "/old": "/new", - }, - })) + Rules: map[string]string{"/old": "/new"}}), + ) // Route - r.Add(http.MethodGet, "/new", func(c echo.Context) error { - return c.NoContent(200) + e.Add(http.MethodGet, "/new", func(c *echo.Context) error { + return c.NoContent(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/old", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/new", req.URL.Path) - assert.Equal(t, 200, rec.Code) + assert.Equal(t, "/new", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) } // Issue #1143 func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { e := echo.New() - r := e.Router() + // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches e.Pre(RewriteWithConfig(RewriteConfig{ Rules: map[string]string{ "/api/*/mgmt/proj/*/agt": "/api/$1/hosts/$2", @@ -75,22 +183,135 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { }, })) - r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { - return c.String(200, "hosts") + e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c *echo.Context) error { + return c.String(http.StatusOK, "hosts") }) - r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { - return c.String(200, "eng") + e.Add(http.MethodGet, "/api/:version/eng", func(c *echo.Context) error { + return c.String(http.StatusOK, "eng") }) for i := 0; i < 100; i++ { req := httptest.NewRequest(http.MethodGet, "/api/v1/mgmt/proj/test/agt", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) - assert.Equal(t, "/api/v1/hosts/test", req.URL.Path) - assert.Equal(t, 200, rec.Code) + assert.Equal(t, "/api/v1/hosts/test", req.URL.EscapedPath()) + assert.Equal(t, http.StatusOK, rec.Code) defer rec.Result().Body.Close() - bodyBytes, _ := ioutil.ReadAll(rec.Result().Body) + bodyBytes, _ := io.ReadAll(rec.Result().Body) assert.Equal(t, "hosts", string(bodyBytes)) } } + +// Issue #1573 +func TestEchoRewriteWithCaret(t *testing.T) { + e := echo.New() + + e.Pre(RewriteWithConfig(RewriteConfig{ + Rules: map[string]string{ + "^/abc/*": "/v1/abc/$1", + }, + })) + + rec := httptest.NewRecorder() + + var req *http.Request + + req = httptest.NewRequest(http.MethodGet, "/abc/test", nil) + e.ServeHTTP(rec, req) + assert.Equal(t, "/v1/abc/test", req.URL.Path) + + req = httptest.NewRequest(http.MethodGet, "/v1/abc/test", nil) + e.ServeHTTP(rec, req) + assert.Equal(t, "/v1/abc/test", req.URL.Path) + + req = httptest.NewRequest(http.MethodGet, "/v2/abc/test", nil) + e.ServeHTTP(rec, req) + assert.Equal(t, "/v2/abc/test", req.URL.Path) +} + +// Verify regex used with rewrite +func TestEchoRewriteWithRegexRules(t *testing.T) { + e := echo.New() + + e.Pre(RewriteWithConfig(RewriteConfig{ + Rules: map[string]string{ + "^/a/*": "/v1/$1", + "^/b/*/c/*": "/v2/$2/$1", + "^/c/*/*": "/v3/$2", + }, + RegexRules: map[*regexp.Regexp]string{ + regexp.MustCompile("^/x/.+?/(.*)"): "/v4/$1", + regexp.MustCompile("^/y/(.+?)/(.*)"): "/v5/$2/$1", + }, + })) + + var rec *httptest.ResponseRecorder + var req *http.Request + + testCases := []struct { + requestPath string + expectPath string + }{ + {"/unmatched", "/unmatched"}, + {"/a/test", "/v1/test"}, + {"/b/foo/c/bar/baz", "/v2/bar/baz/foo"}, + {"/c/ignore/test", "/v3/test"}, + {"/c/ignore1/test/this", "/v3/test/this"}, + {"/x/ignore/test", "/v4/test"}, + {"/y/foo/bar", "/v5/bar/foo"}, + } + + for _, tc := range testCases { + t.Run(tc.requestPath, func(t *testing.T) { + req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil) + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expectPath, req.URL.EscapedPath()) + }) + } +} + +// Ensure correct escaping as defined in replacement (issue #1798) +func TestEchoRewriteReplacementEscaping(t *testing.T) { + e := echo.New() + + // NOTE: these are incorrect regexps as they do not factor in that URI we are replacing could contain ? (query) and # (fragment) parts + // so in reality they append query and fragment part as `$1` matches everything after that prefix + e.Pre(RewriteWithConfig(RewriteConfig{ + Rules: map[string]string{ + "^/a/*": "/$1?query=param", + "^/b/*": "/$1;part#one", + }, + RegexRules: map[*regexp.Regexp]string{ + regexp.MustCompile("^/x/(.*)"): "/$1?query=param", + regexp.MustCompile("^/y/(.*)"): "/$1;part#one", + regexp.MustCompile("^/z/(.*)"): "/$1?test=1#escaped%20test", + }, + })) + + var rec *httptest.ResponseRecorder + var req *http.Request + + testCases := []struct { + requestPath string + expect string + }{ + {"/unmatched", "/unmatched"}, + {"/a/test", "/test?query=param"}, + {"/b/foo/bar", "/foo/bar;part#one"}, + {"/x/test", "/test?query=param"}, + {"/y/foo/bar", "/foo/bar;part#one"}, + {"/z/foo/b%20ar", "/foo/b%20ar?test=1#escaped%20test"}, + {"/z/foo/b%20ar?nope=1#yes", "/foo/b%20ar?nope=1#yes?test=1%23escaped%20test"}, // example of appending + } + + for _, tc := range testCases { + t.Run(tc.requestPath, func(t *testing.T) { + req = httptest.NewRequest(http.MethodGet, tc.requestPath, nil) + rec = httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expect, req.URL.String()) + }) + } +} diff --git a/middleware/secure.go b/middleware/secure.go index 6c4051723..bd389f7ae 100644 --- a/middleware/secure.go +++ b/middleware/secure.go @@ -1,89 +1,88 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + package middleware import ( "fmt" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -type ( - // SecureConfig defines the config for Secure middleware. - SecureConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // XSSProtection provides protection against cross-site scripting attack (XSS) - // by setting the `X-XSS-Protection` header. - // Optional. Default value "1; mode=block". - XSSProtection string `yaml:"xss_protection"` - - // ContentTypeNosniff provides protection against overriding Content-Type - // header by setting the `X-Content-Type-Options` header. - // Optional. Default value "nosniff". - ContentTypeNosniff string `yaml:"content_type_nosniff"` - - // XFrameOptions can be used to indicate whether or not a browser should - // be allowed to render a page in a ,