76 lines
1.7 KiB
Go
76 lines
1.7 KiB
Go
package cors
|
||
|
||
import (
|
||
"strconv"
|
||
"strings"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
type Config struct {
|
||
AllowOrigins []string
|
||
AllowMethods []string
|
||
AllowHeaders []string
|
||
ExposeHeaders []string
|
||
AllowCredentials bool
|
||
MaxAge int
|
||
}
|
||
|
||
func DefaultConfig() Config {
|
||
return Config{
|
||
AllowOrigins: []string{"*"},
|
||
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"},
|
||
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization", "X-Requested-With"},
|
||
ExposeHeaders: []string{"Content-Length", "Content-Type"},
|
||
AllowCredentials: true,
|
||
MaxAge: 86400, // 24 hours
|
||
}
|
||
}
|
||
|
||
func Cors(config Config) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
origin := c.Request.Header.Get("Origin")
|
||
if origin == "" {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// 检查是否允许该来源
|
||
allowOrigin := ""
|
||
for _, o := range config.AllowOrigins {
|
||
if o == "*" {
|
||
// 通配符时,回显实际 origin(兼容 credentials)
|
||
allowOrigin = origin
|
||
break
|
||
}
|
||
if o == origin {
|
||
allowOrigin = origin
|
||
break
|
||
}
|
||
}
|
||
|
||
if allowOrigin == "" {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
c.Header("Access-Control-Allow-Origin", allowOrigin)
|
||
c.Header("Access-Control-Allow-Methods", strings.Join(config.AllowMethods, ","))
|
||
c.Header("Access-Control-Allow-Headers", strings.Join(config.AllowHeaders, ","))
|
||
c.Header("Access-Control-Expose-Headers", strings.Join(config.ExposeHeaders, ","))
|
||
c.Header("Access-Control-Max-Age", strconv.Itoa(config.MaxAge))
|
||
|
||
if config.AllowCredentials {
|
||
c.Header("Access-Control-Allow-Credentials", "true")
|
||
}
|
||
|
||
if c.Request.Method == "OPTIONS" {
|
||
c.AbortWithStatus(204)
|
||
return
|
||
}
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|