golang double token實現

使用golang邊寫雙token驗證


實現目標

  1. 基本Token驗證
  2. Refresh Token進行雙token驗證
  3. 使用Redis保存
  4. 新增Black list機制
  5. 對高併發下使用message queue來批量處理black list
  6. 使用Prometheus進行流量統計已變在不同情況下的不同方式

定義Token Struct

  • configs.go
package configs
 
type Token struct {
	AccessToken string `json:"access_token"`
	AccessUuid  string `json:"access_uuid"`
	AtExpires   int64  `json:"at_expires"`
 
	RefreshToken     string `json:"refresh_token"`
	RefreshUuid      string `json:"refresh_uuid"`
	RefreshAtExpires int64  `json:"rat_expires"`
}
  • token.go
var (
	err        error
	TokenKey   atomic.Value
	RefreshKey atomic.Value
)
 
type Token struct {
	UserId       int64         `json:"user_id"`
	AccessId     string        `json:"access_id"`
	AccessToken  string        `json:"access_token"`
	RefreshId    string        `json:"refresh_id"`
	RefreshToken string        `json:"refresh_token"`
	Token        configs.Token `json:"token"`
}

NewToken邏輯

func NewToken(id int64) *Token {
	TokenKey.Store(configs.Config.GetString("token.token_key"))
	RefreshKey.Store(configs.Config.GetString("token.refresh_key"))
	token := &configs.Token{}
	token.AccessUuid = uuid.NewString()
	token.RefreshUuid = uuid.NewString()
	token.AtExpires = time.Now().Add(time.Hour * 2).Unix()
	token.RefreshAtExpires = time.Now().Add(time.Hour * 24 * 7 * 2).Unix()
	return &Token{
		UserId: id,
		Token:  *token,
	}
}

建立基本Token

func (t *Token) CreateToken() {
	claims := jwt.MapClaims{
		"access_id": t.Token.AccessUuid,
		"exp":       t.Token.AtExpires,
		"type":      "access",
		"userId":    t.UserId,
		"jti":       t.UserId,
		"iat":       time.Now().Unix(),
	}
 
	tk := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
	t.AccessToken, err = tk.SignedString([]byte(TokenKey.Load().(string)))
	tools.HandelError("create token error", err)
	t.AccessId = claims["access_id"].(string)
}

測試Token

package token
 
import (
	"github.com/golang-jwt/jwt/v5"
	"github.com/peterouob/golang_template/configs"
	"github.com/peterouob/golang_template/pkg/verify"
	"github.com/peterouob/golang_template/tools"
	"github.com/stretchr/testify/assert"
	"testing"
	"time"
)
 
func TestNewToken(t *testing.T) {
	tools.InitLogger()
	configs.InitViper()
	userId := int64(123)
	token := verify.NewToken(userId)
	assert.NotNil(t, token)
	assert.Equal(t, userId, token.UserId)
	assert.NotEmpty(t, token.Token.AccessUuid)
	assert.NotEmpty(t, token.Token.RefreshUuid)
	assert.Greater(t, token.Token.AtExpires, time.Now().Unix())
	assert.Greater(t, token.Token.RefreshAtExpires, time.Now().Unix())
	t.Logf("token: %v", token)
	t.Logf("token.Token: %v", token.Token)
}
 
func TestCreateToken(t *testing.T) {
	tools.InitLogger()
	configs.InitViper()
	userId := int64(123)
	token := verify.NewToken(userId)
	assert.Equal(t, token.AccessToken, "")
	token.CreateToken()
	assert.NotEqual(t, token.AccessToken, "")
	t.Log(verify.TokenKey.Load().(string))
 
	parse, err := jwt.Parse(token.AccessToken, func(token *jwt.Token) (interface{}, error) {
		return []byte(verify.TokenKey.Load().(string)), nil
	})
	assert.NoError(t, err)
	assert.NotNil(t, parse)
}

將Token運用在Grpc服務上

新增需要Token的service

message TokenTestRequest {}
 
message TokenTestResponse {
  string msg = 1;
}
 
service Token {
    rpc TokenTest(TokenTestRequest) returns (TokenTestResponse);
}

編寫Server

package grpcserver
 
import (
	"context"
	"fmt"
	"github.com/peterouob/golang_template/api/protobuf"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)
 
type TokenTestServer struct {
	protobuf.UnimplementedUserServer
}
 
func NewTokenTestServer() *TokenTestServer {
	return &TokenTestServer{}
}
 
func (t TokenTestServer) TokenTest(ctx context.Context, in *protobuf.TokenTestRequest) (*protobuf.TokenTestResponse, error) {
	userId := 123
  return &protobuf.TokenTestResponse{
		Msg: fmt.Sprintf("This is Token Test! your id is :%d", userId),
	}, nil
}

將服務開在8086

package server
 
import (
	"fmt"
	"github.com/peterouob/golang_template/api/protobuf"
	"github.com/peterouob/golang_template/pkg/grpc_service/interceptors"
	grpcserver "github.com/peterouob/golang_template/pkg/grpc_service/server"
	"github.com/peterouob/golang_template/tools"
	"google.golang.org/grpc"
	"net"
)
 
func InitTokenServer() {
	tools.Log("start grpc token server ...")
	localIp := tools.GetLocalIP()
	lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", localIp, 8086))
	tools.HandelError("listen errors ", err)
 
	s := grpc.NewServer()
 
	tokenServer := grpcserver.NewTokenTestServer()
	protobuf.RegisterUserServer(s, tokenServer)
	err = s.Serve(lis)
	tools.HandelError("login serve errors ", err)
}

編寫Interceptor驗證Token的存在

  • 這邊由於單向驗證的,使用UnaryServerInterceptor而非Stream

Server端接收方法

md, ok := metadata.FromIncomingContext(ctx)
  • 跟蹤源碼如下
func FromIncomingContext(ctx context.Context) (MD, bool) {
  md, ok := ctx.Value(mdIncomingKey{}).(MD)
  if !ok {
  	return nil, false
  }
  out := make(MD, len(md))
  for k, v := range md {
    key := strings.ToLower(k)
  	out[key] = copyOf(v)
  }
    return out, true
}
  • 發現接收Context並且解析獲取值,返回的MD為一個Map且拷貝ToLower後的值

Client端傳送方法

    ctx := metadata.AppendToOutgoingContext(context.Background(), "authorization", fmt.Sprintf("Bearer %s", token))
  • 藉由源碼跟蹤可以發現第一個為context,接下來為傳入kv且回傳context
func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context 

解析Metadata已獲得TokenString

  • 從Client端和Server端觀察後發現只需要簡單的從Map中取值即可
    authHeader, ok := md["authorization"]
  • 此時的authHeader內第一個值會是剛剛藉由Client端傳入的
    fmt.Sprintf("Bearer %s",token)
  • 因此簡單對他做處理後即可
    parts := strings.Split(authHeader[0]," ")
    parts[0] == "Bearer"
    parts[1] == token

驗證傳入Token

func VerifyToken(tokenString string) *jwt.Token {
	token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
			tools.HandelError("parse token error type", err)
		}
		return []byte(fmt.Sprintf("%s", TokenKey.Load().(string))), nil
	})
	tools.HandelError("parse token error", err)
	// TODO:count the fail and report to prometheus count
	switch {
	case token.Valid:
		tools.Log("valid success token")
	case errors.Is(err, jwt.ErrTokenMalformed):
		tools.Log("error in Malformed token type")
	case errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet):
		tools.Log("error in token expired")
	default:
		tools.HandelError("couldn't handle this token", err)
	}
	return token
}

獲取userId並傳入Context

  • 先轉為float64的原因是json解析後的操作
uid := int64(token.Claims.(jwt.MapClaims)["userId"].(float64))
ctx = context.WithValue(ctx, "uid", uid)
return handler(ctx, req)

註冊Interceptors

auth := interceptors.NewTokenInterceptor()
s := grpc.NewServer(grpc.UnaryInterceptor(auth.UnaryServerInterceptor()))

修改TokenTest的Server端已獲得UserId

func (t TokenTestServer) TokenTest(ctx context.Context, in *protobuf.TokenTestRequest) (*protobuf.TokenTestResponse, error) {
	userId, ok := ctx.Value("uid").(int64)
	if !ok {
		return nil, status.Error(codes.InvalidArgument, "userId not found in context")
	}
	return &protobuf.TokenTestResponse{
		Msg: fmt.Sprintf("This is Token Test! your id is :%d", userId),
	}, nil
}

測試Intercaptors

func TestLoginServer(t *testing.T) {
	conn, err := grpc.NewClient(":8085", grpc.WithTransportCredentials(insecure.NewCredentials()))
	assert.NoError(t, err)
	defer func() {
		err = conn.Close()
		assert.NoError(t, err)
	}()
 
	c := protobuf.NewUserClient(conn)
	ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
	defer cancel()
 
	r, err := c.LoginUser(ctx, &protobuf.LoginUserRequest{
		Email:    "admin",
		Password: "admin",
	})
	assert.NoError(t, err)
	t.Logf("Access Token :%s", r.AccessToken)
	t.Logf("Refresh Token :%s", r.RefreshToken)
	testToken(t, r.AccessToken)
}
 
func testToken(t *testing.T, token string) {
	conn, err := grpc.NewClient(":8086", grpc.WithTransportCredentials(insecure.NewCredentials()))
	assert.NoError(t, err)
	defer func() {
		err = conn.Close()
		assert.NoError(t, err)
	}()
	c := protobuf.NewUserClient(conn)
	ctx := metadata.AppendToOutgoingContext(context.Background(), "authorization", fmt.Sprintf("Bearer %s", token))
	_, err = c.TokenTest(ctx, &protobuf.TokenTestRequest{})
	assert.NoError(t, err)
}