Skip to content

Commit

Permalink
Merge pull request #242 from actiontech/fix-issue228-1
Browse files Browse the repository at this point in the history
add get token from context func
  • Loading branch information
LordofAvernus authored Apr 25, 2024
2 parents 0f4273b + 832ec6f commit b8b791f
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 55 deletions.
24 changes: 13 additions & 11 deletions internal/dms/biz/access_token.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package biz

import (
"fmt"
"net/http"

"github.com/actiontech/dms/pkg/dms-common/api/accesstoken"
jwtPkg "github.com/actiontech/dms/pkg/dms-common/api/jwt"
utilLog "github.com/actiontech/dms/pkg/dms-common/pkg/log"

"github.com/labstack/echo/v4"
)

Expand All @@ -26,28 +28,28 @@ func NewAuthAccessTokenUsecase(log utilLog.Logger, usecase *UserUsecase) *AuthAc
func (au *AuthAccessTokenUsecase) CheckLatestAccessToken() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
token, exist, err := accesstoken.GetTokenFromContext(c)
tokenDetail, err := jwtPkg.GetTokenDetailFromContext(c)
if err != nil {
echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("get token detail failed, err:%v", err))
return err
}
if !exist {

// LoginType为空,不需要校验access token
if tokenDetail.LoginType == "" {
return next(c)
}
uid, exist, err := accesstoken.GetUidFromAccessToken(token)
if err != nil {
return err
}
if !exist {
return next(c)

if tokenDetail.LoginType != AccessTokenLogin {
return echo.NewHTTPError(http.StatusUnauthorized, "access token login type is error")
}

accessTokenInfo, err := au.userUsecase.repo.GetAccessTokenByUser(c.Request().Context(), uid)
accessTokenInfo, err := au.userUsecase.repo.GetAccessTokenByUser(c.Request().Context(), tokenDetail.UID)

if err != nil {
return err
}

if accessTokenInfo.Token != token.Raw {
if accessTokenInfo.Token != tokenDetail.TokenStr {
return echo.NewHTTPError(http.StatusUnauthorized, "access token is not latest")
}

Expand Down
59 changes: 15 additions & 44 deletions pkg/dms-common/api/accesstoken/access_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,75 +7,46 @@ import (

jwtPkg "github.com/actiontech/dms/pkg/dms-common/api/jwt"
"github.com/actiontech/dms/pkg/dms-common/dmsobject"
"github.com/golang-jwt/jwt/v4"
"github.com/labstack/echo/v4"
)

const AccessTokenLogin = "access_token_login"

func CheckLatestAccessToken(dmsAddress string) echo.MiddlewareFunc {
func CheckLatestAccessToken(dmsAddress string, getTokenDetail func(c jwtPkg.EchoContextGetter) (*jwtPkg.TokenDetail, error)) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
token, exist, err := GetTokenFromContext(c)
tokenDetail, err := getTokenDetail(c)

if err != nil {
echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("get token detail failed, err:%v", err))
return err
}
if !exist {

if tokenDetail.TokenStr == "" {
return next(c)
}
uid, exist, err := GetUidFromAccessToken(token)
if err != nil {
return err
}
if !exist {

// LoginType为空,不需要校验access token
if tokenDetail.LoginType == "" {
return next(c)
}

userInfo, err := dmsobject.GetUser(context.TODO(), uid, dmsAddress)
if tokenDetail.LoginType != AccessTokenLogin {
return echo.NewHTTPError(http.StatusUnauthorized, "access token login type is error")
}

userInfo, err := dmsobject.GetUser(context.TODO(), tokenDetail.UID, dmsAddress)
if err != nil {
return err
}
if userInfo == nil {
return echo.NewHTTPError(http.StatusNotFound, "access token: cannot get user info")
}

if userInfo.AccessTokenInfo.AccessToken != token.Raw {
if userInfo.AccessTokenInfo.AccessToken != tokenDetail.TokenStr {
return echo.NewHTTPError(http.StatusUnauthorized, "access token is not latest")
}

return next(c)
}
}
}

func GetTokenFromContext(c echo.Context) (token *jwt.Token, exist bool, err error) {
user := c.Get("user")
// 获取token为空,代表该请求不需要校验token或者是sqle和provision的请求
if user == nil {
return nil, false, nil
}
token, ok := user.(*jwt.Token)
if !ok {
return nil, true, echo.NewHTTPError(http.StatusBadRequest, "failed to convert user from jwt token")
}

return token, true, nil
}

func GetUidFromAccessToken(token *jwt.Token) (uid string, exist bool, err error) {
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return "", true, echo.NewHTTPError(http.StatusBadRequest, "failed to convert token claims to jwt")
}

// 如果不存在JWTLoginType字段,代表是账号密码登录获取的token或者是扫描任务的凭证,不进行校验
loginType, ok := claims[jwtPkg.JWTLoginType]
if !ok {
return "", false, nil
}
if loginType != AccessTokenLogin {
return "", true, echo.NewHTTPError(http.StatusUnauthorized, "access token login type is error")
}
uid = fmt.Sprintf("%v", claims[jwtPkg.JWTUserId])
return uid, true, nil
}
76 changes: 76 additions & 0 deletions pkg/dms-common/api/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,79 @@ func ParseUserUidStrFromTokenWithOldJwt(token *jwtOld.Token) (uid string, err er
}
return uidStr, nil
}

type TokenDetail struct {
TokenStr string
UID string
LoginType string
}

// 由于sqle使用的github.com/golang-jwt/jwt,本方法为sqle兼容
func GetTokenDetailFromContextWithOldJwt(c EchoContextGetter) (tokenDetail *TokenDetail, err error) {
tokenDetail = &TokenDetail{}

if c.Get("user") == nil {
return tokenDetail, nil
}

// Gets user token from the context.
u, ok := c.Get("user").(*jwtOld.Token)
if !ok {
return nil, fmt.Errorf("failed to convert user from jwt token")
}
tokenDetail.TokenStr = u.Raw

// get uid from token
uid, err := ParseUserUidStrFromTokenWithOldJwt(u)
if err != nil {
return nil, err
}
tokenDetail.UID = uid

// get login type from token
claims, ok := u.Claims.(jwtOld.MapClaims)
if !ok {
return nil, fmt.Errorf("failed to convert token claims to jwt")
}
loginType, ok := claims[JWTLoginType]
if !ok {
return tokenDetail, nil
}

tokenDetail.LoginType = fmt.Sprint(loginType)
return tokenDetail, nil
}

func GetTokenDetailFromContext(c EchoContextGetter) (tokenDetail *TokenDetail, err error) {
tokenDetail = &TokenDetail{}
if c.Get("user") == nil {
return tokenDetail, nil
}

// Gets user token from the context.
u, ok := c.Get("user").(*jwt.Token)
if !ok {
return nil, fmt.Errorf("failed to convert user from jwt token")
}
tokenDetail.TokenStr = u.Raw

// get uid from token
uid, err := ParseUserUidStrFromToken(u)
if err != nil {
return nil, err
}
tokenDetail.UID = uid

// get login type from token
claims, ok := u.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("failed to convert token claims to jwt")
}
loginType, ok := claims[JWTLoginType]
if !ok {
return tokenDetail, nil
}

tokenDetail.LoginType = fmt.Sprint(loginType)
return tokenDetail, nil
}

0 comments on commit b8b791f

Please sign in to comment.