kra/internal/server/middleware/cors.go

108 lines
3.2 KiB
Go

package middleware
import (
"context"
"net/http"
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
)
// CORSWhitelist CORS白名单配置
type CORSWhitelist struct {
AllowOrigin string
AllowHeaders string
AllowMethods string
ExposeHeaders string
AllowCredentials bool
}
// CORSConfig CORS配置
type CORSConfig struct {
// Mode 模式: allow-all, whitelist, strict-whitelist
Mode string
Whitelist []CORSWhitelist
}
// CORS 跨域中间件(放行所有)
func CORS() middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
if tr, ok := transport.FromServerContext(ctx); ok {
if ht, ok := tr.(kratoshttp.Transporter); ok {
origin := ht.Request().Header.Get("Origin")
header := ht.ReplyHeader()
header.Set("Access-Control-Allow-Origin", origin)
header.Set("Access-Control-Allow-Headers", "Content-Type,AccessToken,X-CSRF-Token,Authorization,Token,X-Token,X-User-Id")
header.Set("Access-Control-Allow-Methods", "POST,GET,OPTIONS,DELETE,PUT")
header.Set("Access-Control-Expose-Headers", "Content-Length,Access-Control-Allow-Origin,Access-Control-Allow-Headers,Content-Type,New-Token,New-Expires-At")
header.Set("Access-Control-Allow-Credentials", "true")
// OPTIONS请求直接返回
if ht.Request().Method == http.MethodOptions {
return nil, nil
}
}
}
return handler(ctx, req)
}
}
}
// CORSByRules 按配置规则处理跨域
func CORSByRules(cfg CORSConfig) middleware.Middleware {
// 放行全部
if cfg.Mode == "allow-all" {
return CORS()
}
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
if tr, ok := transport.FromServerContext(ctx); ok {
if ht, ok := tr.(kratoshttp.Transporter); ok {
r := ht.Request()
origin := r.Header.Get("Origin")
whitelist := checkCors(origin, cfg.Whitelist)
header := ht.ReplyHeader()
// 通过检查,添加请求头
if whitelist != nil {
header.Set("Access-Control-Allow-Origin", whitelist.AllowOrigin)
header.Set("Access-Control-Allow-Headers", whitelist.AllowHeaders)
header.Set("Access-Control-Allow-Methods", whitelist.AllowMethods)
header.Set("Access-Control-Expose-Headers", whitelist.ExposeHeaders)
if whitelist.AllowCredentials {
header.Set("Access-Control-Allow-Credentials", "true")
}
}
// 严格白名单模式且未通过检查,直接拒绝
if whitelist == nil && cfg.Mode == "strict-whitelist" {
// 健康检查放行
if !(r.Method == http.MethodGet && r.URL.Path == "/health") {
return nil, errors.Forbidden("CORS_FORBIDDEN", "forbidden")
}
}
// OPTIONS请求直接返回
if r.Method == http.MethodOptions {
return nil, nil
}
}
}
return handler(ctx, req)
}
}
}
func checkCors(currentOrigin string, whitelist []CORSWhitelist) *CORSWhitelist {
for _, w := range whitelist {
if currentOrigin == w.AllowOrigin {
return &w
}
}
return nil
}