package middleware import ( "context" "net/http" "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" kratoshttp "github.com/go-kratos/kratos/v2/transport/http" ) // CORSWhitelist CORS白名单配置 type CORSWhitelist struct { AllowOrigin string AllowHeaders string AllowMethods string ExposeHeaders string AllowCredentials bool } // CORSConfig CORS配置 type CORSConfig struct { // Mode 模式: allow-all, whitelist, strict-whitelist Mode string Whitelist []CORSWhitelist } // CORS 跨域中间件(放行所有) func CORS() middleware.Middleware { return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (interface{}, error) { if tr, ok := transport.FromServerContext(ctx); ok { if ht, ok := tr.(kratoshttp.Transporter); ok { origin := ht.Request().Header.Get("Origin") header := ht.ReplyHeader() header.Set("Access-Control-Allow-Origin", origin) header.Set("Access-Control-Allow-Headers", "Content-Type,AccessToken,X-CSRF-Token,Authorization,Token,X-Token,X-User-Id") header.Set("Access-Control-Allow-Methods", "POST,GET,OPTIONS,DELETE,PUT") header.Set("Access-Control-Expose-Headers", "Content-Length,Access-Control-Allow-Origin,Access-Control-Allow-Headers,Content-Type,New-Token,New-Expires-At") header.Set("Access-Control-Allow-Credentials", "true") // OPTIONS请求直接返回 if ht.Request().Method == http.MethodOptions { return nil, nil } } } return handler(ctx, req) } } } // CORSByRules 按配置规则处理跨域 func CORSByRules(cfg CORSConfig) middleware.Middleware { // 放行全部 if cfg.Mode == "allow-all" { return CORS() } return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (interface{}, error) { if tr, ok := transport.FromServerContext(ctx); ok { if ht, ok := tr.(kratoshttp.Transporter); ok { r := ht.Request() origin := r.Header.Get("Origin") whitelist := checkCors(origin, cfg.Whitelist) header := ht.ReplyHeader() // 通过检查,添加请求头 if whitelist != nil { header.Set("Access-Control-Allow-Origin", whitelist.AllowOrigin) header.Set("Access-Control-Allow-Headers", whitelist.AllowHeaders) header.Set("Access-Control-Allow-Methods", whitelist.AllowMethods) header.Set("Access-Control-Expose-Headers", whitelist.ExposeHeaders) if whitelist.AllowCredentials { header.Set("Access-Control-Allow-Credentials", "true") } } // 严格白名单模式且未通过检查,直接拒绝 if whitelist == nil && cfg.Mode == "strict-whitelist" { // 健康检查放行 if !(r.Method == http.MethodGet && r.URL.Path == "/health") { return nil, errors.Forbidden("CORS_FORBIDDEN", "forbidden") } } // OPTIONS请求直接返回 if r.Method == http.MethodOptions { return nil, nil } } } return handler(ctx, req) } } } func checkCors(currentOrigin string, whitelist []CORSWhitelist) *CORSWhitelist { for _, w := range whitelist { if currentOrigin == w.AllowOrigin { return &w } } return nil }