108 lines
3.2 KiB
Go
108 lines
3.2 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
|
|
"github.com/go-kratos/kratos/v2/errors"
|
|
"github.com/go-kratos/kratos/v2/middleware"
|
|
"github.com/go-kratos/kratos/v2/transport"
|
|
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
|
|
)
|
|
|
|
// CORSWhitelist CORS白名单配置
|
|
type CORSWhitelist struct {
|
|
AllowOrigin string
|
|
AllowHeaders string
|
|
AllowMethods string
|
|
ExposeHeaders string
|
|
AllowCredentials bool
|
|
}
|
|
|
|
// CORSConfig CORS配置
|
|
type CORSConfig struct {
|
|
// Mode 模式: allow-all, whitelist, strict-whitelist
|
|
Mode string
|
|
Whitelist []CORSWhitelist
|
|
}
|
|
|
|
// CORS 跨域中间件(放行所有)
|
|
func CORS() middleware.Middleware {
|
|
return func(handler middleware.Handler) middleware.Handler {
|
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
if tr, ok := transport.FromServerContext(ctx); ok {
|
|
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
|
origin := ht.Request().Header.Get("Origin")
|
|
header := ht.ReplyHeader()
|
|
header.Set("Access-Control-Allow-Origin", origin)
|
|
header.Set("Access-Control-Allow-Headers", "Content-Type,AccessToken,X-CSRF-Token,Authorization,Token,X-Token,X-User-Id")
|
|
header.Set("Access-Control-Allow-Methods", "POST,GET,OPTIONS,DELETE,PUT")
|
|
header.Set("Access-Control-Expose-Headers", "Content-Length,Access-Control-Allow-Origin,Access-Control-Allow-Headers,Content-Type,New-Token,New-Expires-At")
|
|
header.Set("Access-Control-Allow-Credentials", "true")
|
|
|
|
// OPTIONS请求直接返回
|
|
if ht.Request().Method == http.MethodOptions {
|
|
return nil, nil
|
|
}
|
|
}
|
|
}
|
|
return handler(ctx, req)
|
|
}
|
|
}
|
|
}
|
|
|
|
// CORSByRules 按配置规则处理跨域
|
|
func CORSByRules(cfg CORSConfig) middleware.Middleware {
|
|
// 放行全部
|
|
if cfg.Mode == "allow-all" {
|
|
return CORS()
|
|
}
|
|
|
|
return func(handler middleware.Handler) middleware.Handler {
|
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
|
if tr, ok := transport.FromServerContext(ctx); ok {
|
|
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
|
r := ht.Request()
|
|
origin := r.Header.Get("Origin")
|
|
whitelist := checkCors(origin, cfg.Whitelist)
|
|
|
|
header := ht.ReplyHeader()
|
|
// 通过检查,添加请求头
|
|
if whitelist != nil {
|
|
header.Set("Access-Control-Allow-Origin", whitelist.AllowOrigin)
|
|
header.Set("Access-Control-Allow-Headers", whitelist.AllowHeaders)
|
|
header.Set("Access-Control-Allow-Methods", whitelist.AllowMethods)
|
|
header.Set("Access-Control-Expose-Headers", whitelist.ExposeHeaders)
|
|
if whitelist.AllowCredentials {
|
|
header.Set("Access-Control-Allow-Credentials", "true")
|
|
}
|
|
}
|
|
|
|
// 严格白名单模式且未通过检查,直接拒绝
|
|
if whitelist == nil && cfg.Mode == "strict-whitelist" {
|
|
// 健康检查放行
|
|
if !(r.Method == http.MethodGet && r.URL.Path == "/health") {
|
|
return nil, errors.Forbidden("CORS_FORBIDDEN", "forbidden")
|
|
}
|
|
}
|
|
|
|
// OPTIONS请求直接返回
|
|
if r.Method == http.MethodOptions {
|
|
return nil, nil
|
|
}
|
|
}
|
|
}
|
|
return handler(ctx, req)
|
|
}
|
|
}
|
|
}
|
|
|
|
func checkCors(currentOrigin string, whitelist []CORSWhitelist) *CORSWhitelist {
|
|
for _, w := range whitelist {
|
|
if currentOrigin == w.AllowOrigin {
|
|
return &w
|
|
}
|
|
}
|
|
return nil
|
|
}
|