Pure Middleware
Middleware in SteelRouter provides a powerful way to add cross-cutting functionality to your application. From authentication and logging to rate limiting and CORS, middleware helps keep your handlers clean and focused.
Middleware Basics
Middleware functions wrap HTTP handlers to add functionality before or after request processing:
type MiddlewareFunc func(http.Handler) http.HandlerSimple Middleware Example
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Call the next handler
next.ServeHTTP(w, r)
// Log after processing
duration := time.Since(start)
log.Printf("%s %s %v", r.Method, r.URL.Path, duration)
})
}
// Use the middleware
r := router.NewRouter()
r.Use(loggingMiddleware)Built-in Middleware
SteelRouter includes several built-in middleware functions:
Logger
Logs requests with method, path, and duration:
r.Use(router.Logger)Output:
GET /users/123 1.2ms
POST /users 5.4msRecoverer
Recovers from panics and returns a 500 Internal Server Error:
r.Use(router.Recoverer)
// Without recoverer, this would crash the server
r.GET("/panic", func(w http.ResponseWriter, r *http.Request) {
panic("Something went wrong!")
})Timeout
Sets a timeout for request processing:
// 30-second timeout
r.Use(router.Timeout(30 * time.Second))
// Slow handler that might timeout
r.GET("/slow", func(w http.ResponseWriter, r *http.Request) {
select {
case <-time.After(45 * time.Second):
w.Write([]byte("Done"))
case <-r.Context().Done():
// Request timed out or was cancelled
return
}
})Custom Middleware
Create your own middleware for specific needs:
Authentication Middleware
func authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get token from header
token := r.Header.Get("Authorization")
if token == "" {
http.Error(w, "Authorization header required", http.StatusUnauthorized)
return
}
// Remove "Bearer " prefix
if strings.HasPrefix(token, "Bearer ") {
token = token[7:]
}
// Validate token
userID, err := validateJWTToken(token)
if err != nil {
http.Error(w, "Invalid token", http.StatusUnauthorized)
return
}
// Add user ID to context
ctx := context.WithValue(r.Context(), "user_id", userID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// Usage
r.Route("/protected", func(protected router.Router) {
protected.Use(authMiddleware)
protected.GET("/profile", getProfileHandler)
protected.POST("/data", createDataHandler)
})CORS Middleware
func corsMiddleware(allowedOrigins ...string) router.MiddlewareFunc {
if len(allowedOrigins) == 0 {
allowedOrigins = []string{"*"}
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
// Check if origin is allowed
allowed := false
for _, allowedOrigin := range allowedOrigins {
if allowedOrigin == "*" || allowedOrigin == origin {
allowed = true
break
}
}
if allowed {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Max-Age", "86400")
// Handle preflight requests
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
}
// Usage
r.Use(corsMiddleware("https://myapp.com", "https://app.myapp.com"))Rate Limiting Middleware
import "golang.org/x/time/rate"
type RateLimiter struct {
limiters map[string]*rate.Limiter
mu sync.RWMutex
rate rate.Limit
burst int
}
func NewRateLimiter(r rate.Limit, burst int) *RateLimiter {
return &RateLimiter{
limiters: make(map[string]*rate.Limiter),
rate: r,
burst: burst,
}
}
func (rl *RateLimiter) getLimiter(ip string) *rate.Limiter {
rl.mu.RLock()
limiter, exists := rl.limiters[ip]
rl.mu.RUnlock()
if !exists {
rl.mu.Lock()
limiter = rate.NewLimiter(rl.rate, rl.burst)
rl.limiters[ip] = limiter
rl.mu.Unlock()
}
return limiter
}
func (rl *RateLimiter) Middleware() router.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get client IP
ip := r.Header.Get("X-Forwarded-For")
if ip == "" {
ip = r.Header.Get("X-Real-IP")
}
if ip == "" {
ip = r.RemoteAddr
}
limiter := rl.getLimiter(ip)
if !limiter.Allow() {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
// Usage
limiter := NewRateLimiter(rate.Every(time.Second), 10) // 10 requests per second
r.Use(limiter.Middleware())Request ID Middleware
func requestIDMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check if request ID already exists
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
// Add to context
ctx := context.WithValue(r.Context(), "request_id", requestID)
// Add to response header
w.Header().Set("X-Request-ID", requestID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func generateRequestID() string {
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), randomString(8))
}Middleware Chains
Apply multiple middleware in order:
r := router.NewRouter()
// Global middleware (applied to all routes)
r.Use(requestIDMiddleware)
r.Use(router.Logger)
r.Use(router.Recoverer)
r.Use(corsMiddleware())
// Route-specific middleware
r.Route("/api", func(api router.Router) {
// API-specific middleware
api.Use(rateLimitMiddleware)
api.Use(jsonContentTypeMiddleware)
// Public endpoints
api.GET("/health", healthHandler)
// Protected endpoints
api.Route("/protected", func(protected router.Router) {
protected.Use(authMiddleware)
protected.GET("/profile", profileHandler)
})
})Middleware Execution Order
Middleware executes in LIFO (Last In, First Out) order for the request path and FIFO (First In, First Out) for the response path:
func middleware1(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("Middleware 1: Before")
next.ServeHTTP(w, r)
log.Println("Middleware 1: After")
})
}
func middleware2(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("Middleware 2: Before")
next.ServeHTTP(w, r)
log.Println("Middleware 2: After")
})
}
r.Use(middleware1)
r.Use(middleware2)
// Output:
// Middleware 1: Before
// Middleware 2: Before
// Handler executes
// Middleware 2: After
// Middleware 1: AfterAdvanced Middleware Patterns
Conditional Middleware
Apply middleware based on conditions:
func conditionalAuth(condition func(*http.Request) bool) router.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if condition(r) {
authMiddleware(next).ServeHTTP(w, r)
} else {
next.ServeHTTP(w, r)
}
})
}
}
// Usage: Only require auth for non-GET requests
r.Use(conditionalAuth(func(r *http.Request) bool {
return r.Method != "GET"
}))Middleware with Configuration
type SecurityConfig struct {
CSPPolicy string
HSTSMaxAge int
FrameOptions string
ContentTypeNoSniff bool
}
func securityHeaders(config SecurityConfig) router.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if config.CSPPolicy != "" {
w.Header().Set("Content-Security-Policy", config.CSPPolicy)
}
if config.HSTSMaxAge > 0 {
w.Header().Set("Strict-Transport-Security",
fmt.Sprintf("max-age=%d; includeSubDomains", config.HSTSMaxAge))
}
if config.FrameOptions != "" {
w.Header().Set("X-Frame-Options", config.FrameOptions)
}
if config.ContentTypeNoSniff {
w.Header().Set("X-Content-Type-Options", "nosniff")
}
next.ServeHTTP(w, r)
})
}
}
// Usage
r.Use(securityHeaders(SecurityConfig{
CSPPolicy: "default-src 'self'",
HSTSMaxAge: 31536000, // 1 year
FrameOptions: "DENY",
ContentTypeNoSniff: true,
}))Response Writer Middleware
Middleware that modifies responses:
type responseWriter struct {
http.ResponseWriter
statusCode int
body *bytes.Buffer
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
func (rw *responseWriter) Write(b []byte) (int, error) {
rw.body.Write(b)
return rw.ResponseWriter.Write(b)
}
func responseLoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rw := &responseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
body: &bytes.Buffer{},
}
start := time.Now()
next.ServeHTTP(rw, r)
duration := time.Since(start)
log.Printf("%s %s %d %d bytes %v",
r.Method,
r.URL.Path,
rw.statusCode,
rw.body.Len(),
duration,
)
})
}Error Handling in Middleware
Handle errors gracefully in middleware:
func errorHandlingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
log.Printf("Panic in request %s %s: %v", r.Method, r.URL.Path, err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
response := map[string]interface{}{
"error": "Internal server error",
"code": "INTERNAL_ERROR",
}
json.NewEncoder(w).Encode(response)
}
}()
next.ServeHTTP(w, r)
})
}Database Transaction Middleware
func transactionMiddleware(db *sql.DB) router.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Only use transactions for modifying requests
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
next.ServeHTTP(w, r)
return
}
tx, err := db.Begin()
if err != nil {
http.Error(w, "Database error", http.StatusInternalServerError)
return
}
// Add transaction to context
ctx := context.WithValue(r.Context(), "tx", tx)
// Use custom response writer to capture status
rw := &responseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
body: &bytes.Buffer{},
}
next.ServeHTTP(rw, r.WithContext(ctx))
// Commit or rollback based on response status
if rw.statusCode >= 400 {
tx.Rollback()
} else {
if err := tx.Commit(); err != nil {
log.Printf("Failed to commit transaction: %v", err)
tx.Rollback()
}
}
})
}
}Context Values in Middleware
Pass data between middleware and handlers using context:
type contextKey string
const (
userIDKey contextKey = "user_id"
requestIDKey contextKey = "request_id"
startTimeKey contextKey = "start_time"
)
// Set context values
func setUserContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userID := extractUserID(r)
ctx := context.WithValue(r.Context(), userIDKey, userID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// Get context values in handlers
func protectedHandler(w http.ResponseWriter, r *http.Request) {
userID, ok := r.Context().Value(userIDKey).(int)
if !ok {
http.Error(w, "User not found in context", http.StatusInternalServerError)
return
}
fmt.Fprintf(w, "User ID: %d", userID)
}
// Helper function to get user ID from context
func GetUserIDFromContext(ctx context.Context) (int, bool) {
userID, ok := ctx.Value(userIDKey).(int)
return userID, ok
}Third-Party Middleware Integration
SteelRouter works with any middleware that follows the standard func(http.Handler) http.Handler pattern:
Gorilla Handlers
import "github.com/gorilla/handlers"
r.Use(handlers.LoggingHandler(os.Stdout, http.DefaultServeMux))
r.Use(handlers.CompressHandler)Custom Integration
// Adapt third-party middleware if needed
func adaptMiddleware(middleware func(http.Handler) http.Handler) router.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return middleware(next)
}
}
// Usage
r.Use(adaptMiddleware(someThirdPartyMiddleware))Performance Considerations
Performance Tip: Order middleware by frequency of early termination. Place authentication and rate limiting before expensive operations like logging.
Efficient Middleware Order
// âś… Good: Fast-failing middleware first
r.Use(rateLimitMiddleware) // Quick check, may terminate early
r.Use(authMiddleware) // Authentication check
r.Use(corsMiddleware()) // CORS headers
r.Use(requestIDMiddleware) // Add request ID
r.Use(router.Logger) // Logging (always executes)
r.Use(router.Recoverer) // Panic recovery (always wraps)Avoid Heavy Operations
// ❌ Avoid: Heavy database operations in middleware
func heavyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Don't do this - heavy operation on every request
user := fetchUserFromDatabase(r.Header.Get("User-ID"))
next.ServeHTTP(w, r)
})
}
// âś… Better: Cache or lazy load
func efficientMiddleware(cache *Cache) router.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userID := r.Header.Get("User-ID")
// Check cache first
if user, found := cache.Get(userID); found {
ctx := context.WithValue(r.Context(), "user", user)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
// Only fetch if not in cache
user := fetchUserFromDatabase(userID)
cache.Set(userID, user)
ctx := context.WithValue(r.Context(), "user", user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}Testing Middleware
Test middleware independently from handlers:
func TestAuthMiddleware(t *testing.T) {
testCases := []struct {
name string
authHeader string
expectedStatus int
}{
{"Valid token", "Bearer valid-token", http.StatusOK},
{"Invalid token", "Bearer invalid-token", http.StatusUnauthorized},
{"No token", "", http.StatusUnauthorized},
{"Malformed token", "InvalidFormat", http.StatusUnauthorized},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create test handler
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
})
// Wrap with middleware
handler := authMiddleware(testHandler)
// Create test request
req := httptest.NewRequest("GET", "/test", nil)
if tc.authHeader != "" {
req.Header.Set("Authorization", tc.authHeader)
}
// Execute
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
// Assert
if w.Code != tc.expectedStatus {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, w.Code)
}
})
}
}Middleware in SteelRouter provides a clean, composable way to add functionality to your applications while keeping your code organized and testable.