Skip to Content
Steel is in alpha 🎉
MiddlewarePure Middleware

Pure Middleware

Middleware in SteelRouter provides a powerful way to add cross-cutting functionality to your application. From authentication and logging to rate limiting and CORS, middleware helps keep your handlers clean and focused.

Middleware Basics

Middleware functions wrap HTTP handlers to add functionality before or after request processing:

type MiddlewareFunc func(http.Handler) http.Handler

Simple Middleware Example

func loggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() // Call the next handler next.ServeHTTP(w, r) // Log after processing duration := time.Since(start) log.Printf("%s %s %v", r.Method, r.URL.Path, duration) }) } // Use the middleware r := router.NewRouter() r.Use(loggingMiddleware)

Built-in Middleware

SteelRouter includes several built-in middleware functions:

Logger

Logs requests with method, path, and duration:

r.Use(router.Logger)

Output:

GET /users/123 1.2ms POST /users 5.4ms

Recoverer

Recovers from panics and returns a 500 Internal Server Error:

r.Use(router.Recoverer) // Without recoverer, this would crash the server r.GET("/panic", func(w http.ResponseWriter, r *http.Request) { panic("Something went wrong!") })

Timeout

Sets a timeout for request processing:

// 30-second timeout r.Use(router.Timeout(30 * time.Second)) // Slow handler that might timeout r.GET("/slow", func(w http.ResponseWriter, r *http.Request) { select { case <-time.After(45 * time.Second): w.Write([]byte("Done")) case <-r.Context().Done(): // Request timed out or was cancelled return } })

Custom Middleware

Create your own middleware for specific needs:

Authentication Middleware

func authMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get token from header token := r.Header.Get("Authorization") if token == "" { http.Error(w, "Authorization header required", http.StatusUnauthorized) return } // Remove "Bearer " prefix if strings.HasPrefix(token, "Bearer ") { token = token[7:] } // Validate token userID, err := validateJWTToken(token) if err != nil { http.Error(w, "Invalid token", http.StatusUnauthorized) return } // Add user ID to context ctx := context.WithValue(r.Context(), "user_id", userID) next.ServeHTTP(w, r.WithContext(ctx)) }) } // Usage r.Route("/protected", func(protected router.Router) { protected.Use(authMiddleware) protected.GET("/profile", getProfileHandler) protected.POST("/data", createDataHandler) })

CORS Middleware

func corsMiddleware(allowedOrigins ...string) router.MiddlewareFunc { if len(allowedOrigins) == 0 { allowedOrigins = []string{"*"} } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") // Check if origin is allowed allowed := false for _, allowedOrigin := range allowedOrigins { if allowedOrigin == "*" || allowedOrigin == origin { allowed = true break } } if allowed { w.Header().Set("Access-Control-Allow-Origin", origin) } w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") w.Header().Set("Access-Control-Allow-Credentials", "true") w.Header().Set("Access-Control-Max-Age", "86400") // Handle preflight requests if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } next.ServeHTTP(w, r) }) } } // Usage r.Use(corsMiddleware("https://myapp.com", "https://app.myapp.com"))

Rate Limiting Middleware

import "golang.org/x/time/rate" type RateLimiter struct { limiters map[string]*rate.Limiter mu sync.RWMutex rate rate.Limit burst int } func NewRateLimiter(r rate.Limit, burst int) *RateLimiter { return &RateLimiter{ limiters: make(map[string]*rate.Limiter), rate: r, burst: burst, } } func (rl *RateLimiter) getLimiter(ip string) *rate.Limiter { rl.mu.RLock() limiter, exists := rl.limiters[ip] rl.mu.RUnlock() if !exists { rl.mu.Lock() limiter = rate.NewLimiter(rl.rate, rl.burst) rl.limiters[ip] = limiter rl.mu.Unlock() } return limiter } func (rl *RateLimiter) Middleware() router.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get client IP ip := r.Header.Get("X-Forwarded-For") if ip == "" { ip = r.Header.Get("X-Real-IP") } if ip == "" { ip = r.RemoteAddr } limiter := rl.getLimiter(ip) if !limiter.Allow() { http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } } // Usage limiter := NewRateLimiter(rate.Every(time.Second), 10) // 10 requests per second r.Use(limiter.Middleware())

Request ID Middleware

func requestIDMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Check if request ID already exists requestID := r.Header.Get("X-Request-ID") if requestID == "" { requestID = generateRequestID() } // Add to context ctx := context.WithValue(r.Context(), "request_id", requestID) // Add to response header w.Header().Set("X-Request-ID", requestID) next.ServeHTTP(w, r.WithContext(ctx)) }) } func generateRequestID() string { return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8)) }

Middleware Chains

Apply multiple middleware in order:

r := router.NewRouter() // Global middleware (applied to all routes) r.Use(requestIDMiddleware) r.Use(router.Logger) r.Use(router.Recoverer) r.Use(corsMiddleware()) // Route-specific middleware r.Route("/api", func(api router.Router) { // API-specific middleware api.Use(rateLimitMiddleware) api.Use(jsonContentTypeMiddleware) // Public endpoints api.GET("/health", healthHandler) // Protected endpoints api.Route("/protected", func(protected router.Router) { protected.Use(authMiddleware) protected.GET("/profile", profileHandler) }) })

Middleware Execution Order

Middleware executes in LIFO (Last In, First Out) order for the request path and FIFO (First In, First Out) for the response path:

func middleware1(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Println("Middleware 1: Before") next.ServeHTTP(w, r) log.Println("Middleware 1: After") }) } func middleware2(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Println("Middleware 2: Before") next.ServeHTTP(w, r) log.Println("Middleware 2: After") }) } r.Use(middleware1) r.Use(middleware2) // Output: // Middleware 1: Before // Middleware 2: Before // Handler executes // Middleware 2: After // Middleware 1: After

Advanced Middleware Patterns

Conditional Middleware

Apply middleware based on conditions:

func conditionalAuth(condition func(*http.Request) bool) router.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if condition(r) { authMiddleware(next).ServeHTTP(w, r) } else { next.ServeHTTP(w, r) } }) } } // Usage: Only require auth for non-GET requests r.Use(conditionalAuth(func(r *http.Request) bool { return r.Method != "GET" }))

Middleware with Configuration

type SecurityConfig struct { CSPPolicy string HSTSMaxAge int FrameOptions string ContentTypeNoSniff bool } func securityHeaders(config SecurityConfig) router.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if config.CSPPolicy != "" { w.Header().Set("Content-Security-Policy", config.CSPPolicy) } if config.HSTSMaxAge > 0 { w.Header().Set("Strict-Transport-Security", fmt.Sprintf("max-age=%d; includeSubDomains", config.HSTSMaxAge)) } if config.FrameOptions != "" { w.Header().Set("X-Frame-Options", config.FrameOptions) } if config.ContentTypeNoSniff { w.Header().Set("X-Content-Type-Options", "nosniff") } next.ServeHTTP(w, r) }) } } // Usage r.Use(securityHeaders(SecurityConfig{ CSPPolicy: "default-src 'self'", HSTSMaxAge: 31536000, // 1 year FrameOptions: "DENY", ContentTypeNoSniff: true, }))

Response Writer Middleware

Middleware that modifies responses:

type responseWriter struct { http.ResponseWriter statusCode int body *bytes.Buffer } func (rw *responseWriter) WriteHeader(code int) { rw.statusCode = code rw.ResponseWriter.WriteHeader(code) } func (rw *responseWriter) Write(b []byte) (int, error) { rw.body.Write(b) return rw.ResponseWriter.Write(b) } func responseLoggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rw := &responseWriter{ ResponseWriter: w, statusCode: http.StatusOK, body: &bytes.Buffer{}, } start := time.Now() next.ServeHTTP(rw, r) duration := time.Since(start) log.Printf("%s %s %d %d bytes %v", r.Method, r.URL.Path, rw.statusCode, rw.body.Len(), duration, ) }) }

Error Handling in Middleware

Handle errors gracefully in middleware:

func errorHandlingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if err := recover(); err != nil { log.Printf("Panic in request %s %s: %v", r.Method, r.URL.Path, err) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusInternalServerError) response := map[string]interface{}{ "error": "Internal server error", "code": "INTERNAL_ERROR", } json.NewEncoder(w).Encode(response) } }() next.ServeHTTP(w, r) }) }

Database Transaction Middleware

func transactionMiddleware(db *sql.DB) router.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Only use transactions for modifying requests if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" { next.ServeHTTP(w, r) return } tx, err := db.Begin() if err != nil { http.Error(w, "Database error", http.StatusInternalServerError) return } // Add transaction to context ctx := context.WithValue(r.Context(), "tx", tx) // Use custom response writer to capture status rw := &responseWriter{ ResponseWriter: w, statusCode: http.StatusOK, body: &bytes.Buffer{}, } next.ServeHTTP(rw, r.WithContext(ctx)) // Commit or rollback based on response status if rw.statusCode >= 400 { tx.Rollback() } else { if err := tx.Commit(); err != nil { log.Printf("Failed to commit transaction: %v", err) tx.Rollback() } } }) } }

Context Values in Middleware

Pass data between middleware and handlers using context:

type contextKey string const ( userIDKey contextKey = "user_id" requestIDKey contextKey = "request_id" startTimeKey contextKey = "start_time" ) // Set context values func setUserContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userID := extractUserID(r) ctx := context.WithValue(r.Context(), userIDKey, userID) next.ServeHTTP(w, r.WithContext(ctx)) }) } // Get context values in handlers func protectedHandler(w http.ResponseWriter, r *http.Request) { userID, ok := r.Context().Value(userIDKey).(int) if !ok { http.Error(w, "User not found in context", http.StatusInternalServerError) return } fmt.Fprintf(w, "User ID: %d", userID) } // Helper function to get user ID from context func GetUserIDFromContext(ctx context.Context) (int, bool) { userID, ok := ctx.Value(userIDKey).(int) return userID, ok }

Third-Party Middleware Integration

SteelRouter works with any middleware that follows the standard func(http.Handler) http.Handler pattern:

Gorilla Handlers

import "github.com/gorilla/handlers" r.Use(handlers.LoggingHandler(os.Stdout, http.DefaultServeMux)) r.Use(handlers.CompressHandler)

Custom Integration

// Adapt third-party middleware if needed func adaptMiddleware(middleware func(http.Handler) http.Handler) router.MiddlewareFunc { return func(next http.Handler) http.Handler { return middleware(next) } } // Usage r.Use(adaptMiddleware(someThirdPartyMiddleware))

Performance Considerations

Performance Tip: Order middleware by frequency of early termination. Place authentication and rate limiting before expensive operations like logging.

Efficient Middleware Order

// âś… Good: Fast-failing middleware first r.Use(rateLimitMiddleware) // Quick check, may terminate early r.Use(authMiddleware) // Authentication check r.Use(corsMiddleware()) // CORS headers r.Use(requestIDMiddleware) // Add request ID r.Use(router.Logger) // Logging (always executes) r.Use(router.Recoverer) // Panic recovery (always wraps)

Avoid Heavy Operations

// ❌ Avoid: Heavy database operations in middleware func heavyMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Don't do this - heavy operation on every request user := fetchUserFromDatabase(r.Header.Get("User-ID")) next.ServeHTTP(w, r) }) } // ✅ Better: Cache or lazy load func efficientMiddleware(cache *Cache) router.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userID := r.Header.Get("User-ID") // Check cache first if user, found := cache.Get(userID); found { ctx := context.WithValue(r.Context(), "user", user) next.ServeHTTP(w, r.WithContext(ctx)) return } // Only fetch if not in cache user := fetchUserFromDatabase(userID) cache.Set(userID, user) ctx := context.WithValue(r.Context(), "user", user) next.ServeHTTP(w, r.WithContext(ctx)) }) } }

Testing Middleware

Test middleware independently from handlers:

func TestAuthMiddleware(t *testing.T) { testCases := []struct { name string authHeader string expectedStatus int }{ {"Valid token", "Bearer valid-token", http.StatusOK}, {"Invalid token", "Bearer invalid-token", http.StatusUnauthorized}, {"No token", "", http.StatusUnauthorized}, {"Malformed token", "InvalidFormat", http.StatusUnauthorized}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Create test handler testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) }) // Wrap with middleware handler := authMiddleware(testHandler) // Create test request req := httptest.NewRequest("GET", "/test", nil) if tc.authHeader != "" { req.Header.Set("Authorization", tc.authHeader) } // Execute w := httptest.NewRecorder() handler.ServeHTTP(w, req) // Assert if w.Code != tc.expectedStatus { t.Errorf("Expected status %d, got %d", tc.expectedStatus, w.Code) } }) } }

Middleware in SteelRouter provides a clean, composable way to add functionality to your applications while keeping your code organized and testable.

Last updated on