Bootstrap

【Go】-基于Gin框架的博客项目

目录

项目分析

项目分层

初始化

用户模块

注册

登录

社区模块

所有社区

指定社区

帖子模块

顺序获取帖子

获取指定帖子

投票模块

发帖

投票

项目开发及部署

开发中使用air

makefile的编写

docker

总结


项目分析

基于Gin框架的IM即时通讯小demo,实现了用户注册,添加好友,创建和加入群聊;好友聊天,群聊天;可以自定义个人信息,群信息;在聊天中可以发送文字、图片、表情、语音。

        项目地址:knoci/GinBlog


项目分层

        config:配置文件

        static:前端静态资源

        docs:swagger

        setting:负责配置文件的读取

        template:html模板

        router:路由层

        controller:控制层,负责服务的转发

        logic:逻辑层,具体功能的实现

        dao:数据层,负责数据库和存储相关

        pkg:第三方库

        models:模板层,结构的定义

        middlewares:中间件

        logger:日志相关


初始化

        main函数如下:

package main

import (
	"GinBlog/controller"
	"GinBlog/dao/mysql"
	"GinBlog/dao/redis"
	"GinBlog/logger"
	"GinBlog/pkg/snowflake"
	"GinBlog/router"
	"GinBlog/setting"
	"fmt"
	"os"
)

// @title GinBlog项目接口文档
// @version 1.0
// @description Go web博客项目

// @contact.name knoci
// @contact.email [email protected]

// @host 127.0.0.1:8808
// @BasePath /api/v1

func main() {
	// 1.读取配置
	if len(os.Args) < 2 {
		fmt.Println("need config file.eg: GinBlog config.yaml")
		return
	}
	config_set := os.Args[1]
	if lenth := len(config_set); config_set[lenth-4:] == ".exe" {
		config_set = config_set[0 : lenth-4]
	}
	err := setting.Init(config_set)
	if err != nil {
		fmt.Printf("load setting failed: %v", err)
		return
	}
	// 2.加载日志
	err = logger.Init(setting.Conf.LogConfig, setting.Conf.Mode)
	if err != nil {
		fmt.Printf("load logger failed: %v", err)
		return
	}
	// 3.配置mysql
	err = mysql.Init(setting.Conf.MySQLConfig)
	if err != nil {
		fmt.Printf("init mysql failed: %v", err)
		return
	}
	defer mysql.Close()
	// 4.配置redis
	err = redis.Init(setting.Conf.RedisConfig)
	if err != nil {
		fmt.Println("init redis failed: %v", err)
		return
	}
	defer redis.Close()
	// 5.获取路由并运行服务
	if err := snowflake.Init(setting.Conf.StartTime, setting.Conf.MachineID); err != nil {
		fmt.Printf("init snowflake failed: %v", err)
		return
	}

	if err := controller.InitTrans("zh"); err != nil {
		fmt.Println("init trans failed: %v", err)
		return
	}

	r := router.InitRouter(setting.Conf.Mode)
	err = r.Run(fmt.Sprintf(":%d", setting.Conf.Port))
	if err != nil {
		fmt.Printf("run server failed: %v", err)
		return
	}
}

        一开始进来用命令行参数指令读取配置,然后去到setting下的Init()函数初始化配置,setting中还定义了嵌套结构体,用这种方法确保拿到我们每个程序需要的参数。其中WatchConfig()函数和OnConfigChange()来监视config变化实现实时更新。

        需要注意的是vipeReadInConfig()函数要读入参数到结构体,一定要打上mapstructure的tag将map[string]interface{} 类型的数据解码到 Go 的结构体中。

package setting

import (
	"fmt"

	"github.com/fsnotify/fsnotify"
	"github.com/spf13/viper"
)

var Conf = new(AppConfig)

type AppConfig struct {
	Name         string `mapstructure:"name"`
	Mode         string `mapstructure:"mode"`
	Version      string `mapstructure:"version"`
	StartTime    string `mapstructure:"start_time"`
	MachineID    int64  `mapstructure:"machine_id"`
	Port         int    `mapstructure:"port"`
	*LogConfig   `mapstructure:"log"`
	*MySQLConfig `mapstructure:"mysql"`
	*RedisConfig `mapstructure:"redis"`
}

type LogConfig struct {
	Level      string `mapstructure:"level"`
	Filename   string `mapstructure:"filename"`
	MaxSize    int    `mapstructure:"max_size"`
	MaxAge     int    `mapstructure:"max_age"`
	MaxBackups int    `mapstructure:"max_backups"`
}

type MySQLConfig struct {
	User         string `mapstructure:"user"`
	Host         string `mapstructure:"host"`
	Password     string `mapstructure:"password"`
	DbName       string `mapstructure:"dbname"`
	Port         int    `mapstructure:"port"`
	MaxOpenConns int    `mapstructure:"max_open_conns"`
	MaxIdleConns int    `mapstructure:"max_idle_conns"`
}

type RedisConfig struct {
	Host         string `mapstructure:"name"`
	Password     string `mapstructure:"password"`
	Port         int    `mapstructure:"port"`
	DB           int    `mapstructure:"db"`
	PoolSize     int    `mapstructure:"pool_size"`
	MinIdleConns int    `mapstructure:"min_idle_conns"`
}

func Init(filepath string) (err error) {
	viper.SetConfigFile(filepath)
	err = viper.ReadInConfig()
	if err != nil {
		fmt.Print("Error loading config...")
		return
	}
	err = viper.Unmarshal(Conf)
	if err != nil {
		fmt.Print("Error unmarshalling config...")
	}
	viper.WatchConfig()
	viper.OnConfigChange(func(in fsnotify.Event) {
		err = viper.Unmarshal(Conf)
		if err != nil {
			fmt.Print("Error unmarshalling config while config changing...")
			return
		}
	})
	return
}

        把LogConfig参数给到logger的Init启动日志,日志这里用了Zap库,创建了全局替换,修改了gin的日志中间件,把日志记录到zap中。

package logger

import (
	"GinBlog/setting"
	"net"
	"net/http"
	"net/http/httputil"
	"os"
	"runtime/debug"
	"strings"
	"time"

	"github.com/gin-gonic/gin"
	"github.com/natefinch/lumberjack" // lumberjack 是一个简单的日志文件滚动库
	"go.uber.org/zap"                 // zap 是一个快速的、结构化的、可靠的 Go 日志库
	"go.uber.org/zap/zapcore"         // zap 的核心包
)

// lg 是一个全局的 zap.Logger 实例
var lg *zap.Logger

// Init 初始化日志配置
func Init(cfg *setting.LogConfig, mode string) (err error) {
	// 获取日志写入器
	writeSyncer := getLogWriter(cfg.Filename, cfg.MaxSize, cfg.MaxBackups, cfg.MaxAge)
	// 获取编码器
	encoder := getEncoder()

	// 解析日志级别
	var l = new(zapcore.Level)
	err = l.UnmarshalText([]byte(cfg.Level))
	if err != nil {
		return // 如果解析失败,返回错误
	}

	// 创建 zap 核心对象
	var core zapcore.Core
	if mode == "dev" {
		// 如果是开发模式,日志同时输出到终端和文件
		consoleEncoder := zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig())
		core = zapcore.NewTee( // Tee 表示同时写入多个 Writer
			//zapcore.NewCore(encoder, writeSyncer, l),
			zapcore.NewCore(consoleEncoder, zapcore.Lock(os.Stdout), zapcore.DebugLevel),
		)
	} else {
		// 生产模式,只输出到文件
		core = zapcore.NewCore(encoder, writeSyncer, l)
	}

	// 创建 zap 日志对象
	lg = zap.New(core, zap.AddCaller()) // 添加调用者信息
	// 替换全局日志对象
	zap.ReplaceGlobals(lg)
	// 记录初始化日志
	zap.L().Info("init logger success")
	return
}

// getEncoder 创建并配置日志编码器
func getEncoder() zapcore.Encoder {
	// 使用生产环境的编码器配置
	encoderConfig := zap.NewProductionEncoderConfig()
	// 设置时间编码器
	encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
	// 设置时间字段名称
	encoderConfig.TimeKey = "time"
	// 设置日志级别编码器
	encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
	// 设置持续时间编码器
	encoderConfig.EncodeDuration = zapcore.SecondsDurationEncoder
	// 设置调用者编码器
	encoderConfig.EncodeCaller = zapcore.ShortCallerEncoder
	// 创建 JSON 编码器
	return zapcore.NewJSONEncoder(encoderConfig)
}

// getLogWriter 创建并配置日志写入器
func getLogWriter(filename string, maxSize, maxBackup, maxAge int) zapcore.WriteSyncer {
	// 使用 lumberjack 作为日志文件的写入器
	lumberJackLogger := &lumberjack.Logger{
		Filename:   filename,  // 日志文件路径
		MaxSize:    maxSize,   // 文件最大大小
		MaxBackups: maxBackup, // 最多备份文件数量
		MaxAge:     maxAge,    // 文件最长保存天数
	}
	// 将 lumberjack 写入器包装为 zap 写入器
	return zapcore.AddSync(lumberJackLogger)
}

// GinLogger 是一个 Gin 中间件,用于记录 HTTP 请求日志
func GinLogger() gin.HandlerFunc {
	return func(c *gin.Context) {
		// 记录请求开始时间
		start := time.Now()
		// 获取请求路径和查询字符串
		path := c.Request.URL.Path
		query := c.Request.URL.RawQuery
		// 继续处理请求
		c.Next()

		// 计算请求处理时间
		cost := time.Since(start)
		// 记录日志
		lg.Info(path,
			zap.Int("status", c.Writer.Status()),
			zap.String("method", c.Request.Method),
			zap.String("path", path),
			zap.String("query", query),
			zap.String("ip", c.ClientIP()),
			zap.String("user-agent", c.Request.UserAgent()),
			zap.String("errors", c.Errors.ByType(gin.ErrorTypePrivate).String()),
			zap.Duration("cost", cost),
		)
	}
}

// GinRecovery 是一个 Gin 中间件,用于捕获和记录 panic
func GinRecovery(stack bool) gin.HandlerFunc {
	return func(c *gin.Context) {
		// 使用 defer 延迟执行 panic 恢复
		defer func() {
			if err := recover(); err != nil {
				// 检查是否是连接断开导致的错误
				var brokenPipe bool
				if ne, ok := err.(*net.OpError); ok {
					if se, ok := ne.Err.(*os.SyscallError); ok {
						if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
							brokenPipe = true
						}
					}
				}

				// 记录请求信息
				httpRequest, _ := httputil.DumpRequest(c.Request, false)
				if brokenPipe {
					// 如果是连接断开,记录错误日志
					lg.Error(c.Request.URL.Path,
						zap.Any("error", err),
						zap.String("request", string(httpRequest)),
					)
					// 如果连接已断开,记录错误并终止请求
					c.Error(err.(error)) // nolint: errcheck
					c.Abort()
					return
				}

				// 如果需要记录 stack trace
				if stack {
					lg.Error("[Recovery from panic]",
						zap.Any("error", err),
						zap.String("request", string(httpRequest)),
						zap.String("stack", string(debug.Stack())),
					)
				} else {
					lg.Error("[Recovery from panic]",
						zap.Any("error", err),
						zap.String("request", string(httpRequest)),
					)
				}
				// 设置 HTTP 状态码并终止请求
				c.AbortWithStatus(http.StatusInternalServerError)
			}
		}()
		// 继续处理请求
		c.Next()
	}
}

        然后是配置mysql,用的是sqlx库,这里封装一个Close()函数允许外部关闭。

package mysql

import (
	"GinBlog/setting"
	"fmt"

	_ "github.com/go-sql-driver/mysql"
	"github.com/jmoiron/sqlx"
)

var db *sqlx.DB

// Init 初始化MySQL连接
func Init(cfg *setting.MySQLConfig) (err error) {
	// "user:password@tcp(host:port)/dbname"
	fmt.Println(cfg.Host)
	dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true&loc=Local", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DbName)
	db, err = sqlx.Connect("mysql", dsn)
	if err != nil {
		return
	}
	db.SetMaxOpenConns(cfg.MaxOpenConns)
	db.SetMaxIdleConns(cfg.MaxIdleConns)
	return
}

// Close 关闭MySQL连接
func Close() {
	_ = db.Close()
}

        接着是redis,用的是go-redis。

package redis

import (
	"GinBlog/setting"
	"fmt"

	"github.com/go-redis/redis/v8"
)

var (
	client *redis.Client
	Nil    = redis.Nil
)

func Init(cfg *setting.RedisConfig) (err error) {
	client = redis.NewClient(&redis.Options{
		Addr:         fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
		Password:     cfg.Password, // no password set
		DB:           cfg.DB,       // use default DB
		PoolSize:     cfg.PoolSize,
		MinIdleConns: cfg.MinIdleConns,
	})
	ctx := client.Context()
	_, err = client.Ping(ctx).Result()
	if err != nil {
		return err
	}
	return nil
}

func Close() {
	_ = client.Close()
}

        之后用pkg的snowflake库,初始化了一个node,GenID()接收node用雪花算法生个一个id。

package snowflake

import (
	"time"

	sf "github.com/bwmarrin/snowflake"
)

var node *sf.Node

func Init(startTime string, machineID int64) (err error) {
	var st time.Time
	st, err = time.Parse("2006-01-02", startTime)
	if err != nil {
		return
	}
	sf.Epoch = st.UnixNano() / 1000000
	node, err = sf.NewNode(machineID)
	return
}
func GenID() int64 {
	return node.Generate().Int64()
}

        还进行了一下validator的设置,这是一个参数校验的库。

package controller

import (
	"GinBlog/models"
	"fmt"
	"github.com/go-playground/validator/v10"
	"reflect"
	"strings"

	"github.com/gin-gonic/gin/binding"
	"github.com/go-playground/locales/en"
	"github.com/go-playground/locales/zh"
	ut "github.com/go-playground/universal-translator"
	enTranslations "github.com/go-playground/validator/v10/translations/en"
	zhTranslations "github.com/go-playground/validator/v10/translations/zh"
)

// 定义一个全局翻译器T
var trans ut.Translator

// InitTrans 初始化翻译器
func InitTrans(locale string) (err error) {
	// 修改gin框架中的Validator引擎属性,实现自定制
	if v, ok := binding.Validator.Engine().(*validator.Validate); ok {

		// 注册一个获取json tag的自定义方法
		v.RegisterTagNameFunc(func(fld reflect.StructField) string {
			name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
			if name == "-" {
				return ""
			}
			return name
		})

		// 为SignUpParam注册自定义校验方法
		v.RegisterStructValidation(SignUpParamStructLevelValidation, models.ParamSignUp{})

		zhT := zh.New() // 中文翻译器
		enT := en.New() // 英文翻译器

		// 第一个参数是备用(fallback)的语言环境
		// 后面的参数是应该支持的语言环境(支持多个)
		// uni := ut.New(zhT, zhT) 也是可以的
		uni := ut.New(enT, zhT, enT)

		// locale 通常取决于 http 请求头的 'Accept-Language'
		var ok bool
		// 也可以使用 uni.FindTranslator(...) 传入多个locale进行查找
		trans, ok = uni.GetTranslator(locale)
		if !ok {
			return fmt.Errorf("uni.GetTranslator(%s) failed", locale)
		}

		// 注册翻译器
		switch locale {
		case "en":
			err = enTranslations.RegisterDefaultTranslations(v, trans)
		case "zh":
			err = zhTranslations.RegisterDefaultTranslations(v, trans)
		default:
			err = enTranslations.RegisterDefaultTranslations(v, trans)
		}
		return
	}
	return
}

// removeTopStruct 去除提示信息中的结构体名称
func removeTopStruct(fields map[string]string) map[string]string {
	res := map[string]string{}
	for field, err := range fields {
		res[field[strings.Index(field, ".")+1:]] = err
	}
	return res
}

// SignUpParamStructLevelValidation 自定义SignUpParam结构体校验函数
func SignUpParamStructLevelValidation(sl validator.StructLevel) {
	su := sl.Current().Interface().(models.ParamSignUp)

	if su.Password != su.RePassword {
		// 输出错误提示信息,最后一个参数就是传递的param
		sl.ReportError(su.RePassword, "re_password", "RePassword", "eqfield", "password")
	}
}

        终于设置完配置了,然后通过InitRouter()去注册路由了。

package router

import (
	"GinBlog/controller"
	"GinBlog/docs"
	"GinBlog/logger"
	"GinBlog/middlewares"
	"github.com/gin-gonic/gin"
	swaggerFiles "github.com/swaggo/files"
	ginSwagger "github.com/swaggo/gin-swagger"
	"net/http"
)

func InitRouter(mode string) (r *gin.Engine) {
	if mode == gin.ReleaseMode {
		gin.SetMode(gin.ReleaseMode) // gin设置成发布模式
	}
	r = gin.New()
	//r.Use(logger.GinLogger(), logger.GinRecovery(true), middlewares.RateLimitMiddleware(1*time.Second, 1))
	r.Use(logger.GinLogger(), logger.GinRecovery(true))
	// 注册swagger api相关路由
	docs.SwaggerInfo.BasePath = ""
	r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))

	r.LoadHTMLFiles("./templates/index.html")
	r.Static("/static", "./static")

	r.GET("/", func(c *gin.Context) {
		c.HTML(http.StatusOK, "index.html", nil)
	})

	r.GET("/ping", func(c *gin.Context) {
		c.String(http.StatusOK, "pong")
	})

	v1 := r.Group("/api/v1")

	// 注册
	v1.POST("/signup", controller.SignUpHandler)
	// 登录
	v1.POST("/login", controller.LoginHandler)

	// 根据时间或分数获取帖子列表
	v1.GET("/posts2", controller.GetPostListHandler2)
	//v1.GET("/posts", controller.GetPostListHandler)
	v1.GET("/community", controller.CommunityHandler)
	v1.GET("/community/:id", controller.CommunityDetailHandler)
	v1.GET("/post/:id", controller.GetPostDetailHandler)

	v1.Use(middlewares.JWTAuthMiddleware()) // 应用JWT认证中间件
	{
		v1.POST("/post", controller.CreatePostHandler)

		// 投票
		v1.POST("/vote", controller.PostVoteController)
	}
	return
}

用户模块

        用到的参数模板,即models下的params.go如下:

package models

type ParamSignUp struct {
	Username   string `json:"username" binding:"required"`
	Password   string `json:"password" binding:"required"`
	RePassword string `json:"re_password" binding:"required,eqfield=Password"`
}

type ParamLogin struct {
	Username string `json:"username" binding:"required"`
	Password string `json:"password" binding:"required"`
}

type ParamVoteData struct {
	PostID string `json:"post_id" binding:"required"`
	// 赞成1反对-1取消投票0,validator的oneof限定
	Direction int8 `json:"direction,string" binding:"oneof=1 0 -1" `
}

type ParamPostList struct {
	CommunityID int64  `json:"community_id" form:"community_id"`   // 可以为空
	Page int64 `form:"page"`
	Size int64 `form:"size"`
	Order string `form:"order"`
}

type ParamCommunityPostList struct {

}

         返回状态的封装,在controller的response.go:

package controller

import (
	"net/http"

	"github.com/gin-gonic/gin"
)

/*
	{
		"code":
		"msg":
		"data":
	}
*/
type ResCode int64

const (
	CodeSuccess ResCode = 1000 + iota
	CodeInvalidParam
	CodeUserExist
	CodeUserNotExist
	CodeInvalidPassword
	CodeServerBusy
	CodeNeedLogin
	CodeInvalidToken
)

var getMsg = map[ResCode]string{
	CodeSuccess:         "success",
	CodeInvalidParam:    "请求参数错误",
	CodeUserExist:       "用户已存在",
	CodeUserNotExist:    "用户不存在",
	CodeInvalidPassword: "用户名或密码错误",
	CodeServerBusy:      "服务繁忙",
	CodeNeedLogin:       "需要登录",
	CodeInvalidToken:    "无效的token",
}

type Response struct {
	Code ResCode     `json:"code"`
	Msg  interface{} `json:"msg"`
	Data interface{} `json:"data,omitempty"` // 为空忽略
}

func ResponseSuccess(c *gin.Context, data interface{}) {
	rd := &Response{
		Code: CodeSuccess,
		Msg:  getMsg[CodeSuccess],
		Data: data,
	}
	c.JSON(http.StatusOK, rd)
}

func ResponseError(c *gin.Context, code ResCode) {
	rd := &Response{
		Code: code,
		Msg:  getMsg[code],
		Data: nil,
	}
	c.JSON(http.StatusOK, rd)
}

func ResponseErrorWithMsg(c *gin.Context, code ResCode, msg interface{}) {
	rd := &Response{
		Code: code,
		Msg:  msg,
		Data: nil,
	}
	c.JSON(http.StatusOK, rd)
}

注册

        ShouldBindJSON解析Post传来的JSON自动绑定到参数结构体上,在controller层验证一下参数是否正确,然后转到逻辑层处理具体注册逻辑

func SignUpHandler(c *gin.Context) {
	// 参数校验
	var p = new(models.ParamSignUp)
	if err := c.ShouldBindJSON(p); err != nil {
		zap.L().Error("SignUp with invalid param", zap.Error(err))
		// 判断err是不是validator.ValidationErrors类型
		fmt.Println(p.Password)
		fmt.Println(p.RePassword)
		errs, ok := err.(validator.ValidationErrors)
		if !ok {
			// 非validator.ValidationErrors类型错误直接返回
			ResponseError(c, CodeInvalidParam)
			return
		}
		ResponseErrorWithMsg(c, CodeInvalidParam, removeTopStruct(errs.Translate(trans)))
		return
	}
	// 业务处理
	if err := logic.SignUp(p); err != nil {
		zap.L().Error("logic.SignUp failed", zap.Error(err))
		if errors.Is(err, mysql.ErrUserExist) {
			ResponseError(c, CodeUserExist)
		}
		ResponseError(c, CodeServerBusy)
		return
	}
	// 返回响应
	ResponseSuccess(c, nil)
}

        在logic的SignUp中,判断是否合法,然后存入数据库。

func SignUp(p *models.ParamSignUp) (err error) {
	// 合法性判断
	if err := mysql.CheckUserExist(p.Username); err != nil {
		return err
	}
	// 保存到数据库
	user := models.User{
		UserID:   snowflake.GenID(),
		Username: p.Username,
		Password: p.Password,
	}
	err = mysql.InsertUser(&user)
	if err != nil {
		return err
	}
	return nil
}

        数据库相关的函数封装在dao的mysql,这里密码用到了md5加密。

package mysql

import (
	"GinBlog/models"
	"crypto/md5"
	"database/sql"
	"encoding/hex"
)

const salt = "knoci1337"


func InsertUser(user *models.User) (err error) {
	// 密码加密
	user.Password = encryptPassword(user.Password)
	sqlStr := `insert into user(user_id, username, password) value(?,?,?)`
	_, err = db.Exec(sqlStr, user.UserID, user.Username, user.Password)
	if err != nil {
		return
	}
	return nil
}

func CheckUserExist(username string) (err error) {
	sqlStr := `select count(user_id) from user where username = ?`
	var count int
	if err = db.Get(&count, sqlStr, username); err != nil {
		return err
	}
	if count > 0 {
		return ErrUserExist
	}
	return nil
}

func encryptPassword(password string) string {
	h := md5.New()
	h.Write([]byte(salt))
	return hex.EncodeToString(h.Sum([]byte(password)))
}

func Login(user *models.User) (err error) {
	oPassword := user.Password
	sqlStr := `select user_id, username , password from user where username = ?`
	err = db.Get(user, sqlStr, user.Username)
	if err == sql.ErrNoRows {
		return ErrUserNotExist
	}
	if err != nil {
		// 查询数据库失败
		return err
	}
	password := encryptPassword(oPassword)
	if password != user.Password {
		return ErrInvalidPassword
	}
	return
}

// GetUserById 根据id获取用户信息
func GetUserById(uid int64) (user *models.User, err error) {
	user = new(models.User)
	sqlStr := `select user_id, username from user where user_id = ?`
	err = db.Get(user, sqlStr, uid)
	return
}

登录

        controller层

func LoginHandler(c *gin.Context) {
	// 参数校验
	p := new(models.ParamLogin)
	if err := c.ShouldBindJSON(p); err != nil {
		zap.L().Error("Login with invalid param", zap.Error(err))
		errs, ok := err.(validator.ValidationErrors)
		if !ok {
			ResponseError(c, CodeInvalidParam)
			return
		}
		ResponseErrorWithMsg(c, CodeInvalidParam, removeTopStruct(errs.Translate(trans)))
		return
	}
	// 2.业务逻辑处理
	user, err := logic.Login(p)
	if err != nil {
		zap.L().Error("logic.Login failed", zap.String("username", p.Username), zap.Error(err))
		if errors.Is(err, mysql.ErrUserNotExist) {
			ResponseError(c, CodeUserNotExist)
			return
		}
		ResponseError(c, CodeInvalidPassword)
		return
	}

	// 3.返回响应
	ResponseSuccess(c, gin.H{
		"user_id":   fmt.Sprintf("%d", user.UserID), // id值大于1<<53-1  int64类型的最大值是1<<63-1
		"user_name": user.Username,
		"token":     user.Token,
	})
}

        logic层,登录成功后生成jwt的token。

func Login(p *models.ParamLogin) (*models.User,error) {
	user := &models.User{
		Username: p.Username,
		Password: p.Password,
	}
	// 传递的是指针,就能拿到user.UserID
	if err := mysql.Login(user); err != nil {
		return nil, err
	}
	// 生成JWT
	token, err := jwt.GenToken(user.UserID, user.Username)
	if err != nil {
		return nil, err
	}
	user.Token = token
	return user, err
}

        jwt的生成以及解析函数封装,在pkg目录jwt的jwt.go:

package jwt

import (
	"errors"
	"time"

	"github.com/spf13/viper"

	"github.com/dgrijalva/jwt-go"
)

var mySecret = []byte("人生不过一场梦")

// MyClaims 自定义声明结构体并内嵌jwt.StandardClaims
// jwt包自带的jwt.StandardClaims只包含了官方字段
// 我们这里需要额外记录一个username字段,所以要自定义结构体
// 如果想要保存更多信息,都可以添加到这个结构体中
type MyClaims struct {
	UserID   int64  `json:"user_id"`
	Username string `json:"username"`
	jwt.StandardClaims
}

// GenToken 生成JWT
func GenToken(userID int64, username string) (string, error) {
	// 创建一个我们自己的声明的数据
	c := MyClaims{
		userID,
		"username", // 自定义字段
		jwt.StandardClaims{
			ExpiresAt: time.Now().Add(
				time.Duration(viper.GetInt("auth.jwt_expire")) * time.Hour).Unix(), // 过期时间
			Issuer: "GinBlog", // 签发人
		},
	}
	// 使用指定的签名方法创建签名对象
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, c)
	// 使用指定的secret签名并获得完整的编码后的字符串token
	return token.SignedString(mySecret)
}

// ParseToken 解析JWT
func ParseToken(tokenString string) (*MyClaims, error) {
	// 解析token
	var mc = new(MyClaims)
	token, err := jwt.ParseWithClaims(tokenString, mc, func(token *jwt.Token) (i interface{}, err error) {
		return mySecret, nil
	})
	if err != nil {
		return nil, err
	}
	if token.Valid { // 校验token
		return mc, nil
	}
	return nil, errors.New("invalid token")
}

社区模块

        models的模板如下:

package models

import "time"

type Community struct {
	ID   int64  `json:"id" db:"community_id"`
	Name string `json:"name" db:"community_name"`
}

type CommunityDetail struct {
	ID   int64  `json:"id" db:"community_id"`
	Name         string    `json:"name" db:"community_name"`
	Introduction string    `json:"introduction" db:"introduction"`
	CreateTime   time.Time `json:"create_time" db:"create_time"`
}

所有社区

        v1.GET "/community"是获取所有社区,CommunityHandler转到逻辑层

func CommunityHandler(c *gin.Context) {
	//查询到所有community
	communities, err := logic.GetCommunityList()
	if err != nil {
		zap.L().Error("logic.GetCommunityList() failed", zap.Error(err))
		ResponseError(c, CodeServerBusy)
		return
	}
	ResponseSuccess(c, communities)
}

        logic层转到数据库,在mysql的community.go

func GetCommunityList() (data []*models.Community, err error) {
	//查找到所有的community并且返回
	data, err = mysql.GetCommunityList()
	if err != nil {
		zap.L().Error("logic.GetCommunityList() failed", zap.Error(err))
		return nil, err
	}
	return data, err
}

        数据库里用sql语句直接查询

func GetCommunityList() (data []*models.Community, err error) {
	sqlStr := "select community_id, community_name from community"
	if err := db.Select(&data, sqlStr); err != nil {
		if err == sql.ErrNoRows {
			zap.L().Warn("there is no community in db")
			err = nil
		}
	}
	return
}

指定社区

        Get请求Query一个id查询单个社区,和查询所有大差不差。

func CommunityDetailHandler(c *gin.Context) {
	// 获取社区id
	communityID := c.Param("id")
	id, err := strconv.ParseInt(communityID, 10, 64)
	if err != nil {
		ResponseError(c, CodeInvalidParam)
		return
	}
	data , err := logic.GetCommunityDetail(id)
	if err != nil {
		zap.L().Error("logic.GetCommunityDetail() failed", zap.Error(err))
		ResponseError(c, CodeServerBusy)
		return
	}
	ResponseSuccess(c, data)
}
func GetCommunityDetail(id int64) (*models.CommunityDetail, error) {
	return mysql.GetCommunityDetailByID(id)
}
func GetCommunityDetailByID(id int64) (detail *models.CommunityDetail, err error) {
	detail = new(models.CommunityDetail)
	sqlStr := "select community_id, community_name, introduction, create_time from community where community_id = ?"
	if err := db.Get(detail, sqlStr, id); err != nil {
		if err == sql.ErrNoRows {
			err = ErrInvalidID
		}
	}
	fmt.Println("%v", detail)
	return detail, err
}

帖子模块

        models的模板如下:

package models

import "time"

const (
	OderTime  = "time"
	OderScore = "score"
)

type Post struct {
	ID          int64     `json:"id,string" db:"post_id"`                            // 帖子id
	AuthorID    int64     `json:"author_id" db:"author_id"`                          // 作者id
	CommunityID int64     `json:"community_id" db:"community_id" binding:"required"` // 社区id
	Status      int32     `json:"status" db:"status"`                                // 帖子状态
	Title       string    `json:"title" db:"title" binding:"required"`               // 帖子标题
	Content     string    `json:"content" db:"content" binding:"required"`           // 帖子内容
	CreateTime  time.Time `json:"create_time" db:"create_time"`                      // 帖子创建时间
}

// ApiPostDetail 帖子详情接口的结构体
type ApiPostDetail struct {
	AuthorName       string             `json:"author_name"` // 作者
	VoteNum          int64              `json:"vote_num"`    // 投票数
	*Post                               // 嵌入帖子结构体
	*CommunityDetail `json:"community"` // 嵌入社区信息
}

顺序获取帖子

        post2实现按时间或分数顺序获取帖子,这里默认按时间

func GetPostListHandler2(c *gin.Context) {
	// 获取参数
	p := &models.ParamPostList{
		Page:  1,
		Size:  10,
		Order: models.OderTime,
	}
	if err := c.ShouldBindQuery(p); err != nil {
		zap.L().Error("GetPostListHandler2 with invalid param", zap.Error(err))
		return
	}
	// 获取帖子列表
	data, err := logic.GetPostListNew(p)
	if err != nil {
		zap.L().Error("logic GetPostList() failed", zap.Error(err))
		return
	}
	// 去redis查询id列表
	// 根据id去数据库查询帖子详情信息
	ResponseSuccess(c, data)
}

        转到logic中,由一个GetPostListNew()函数实现判断和分发

func GetPostListNew(p *models.ParamPostList) (data []*models.ApiPostDetail, err error) {
	// 根据请求参数的不同,执行不同的逻辑。
	if p.CommunityID == 0 {
		// 查所有
		data, err = GetPostList2(p)
	} else {
		// 根据社区id查询
		data, err = GetCommunityPostList(p)
	}
	if err != nil {
		zap.L().Error("GetPostListNew failed", zap.Error(err))
		return nil, err
	}
	return
}

        返回的是一个ApiPostDetail类型的切片,包含了坐着,票数,帖子信息,社区信息。

func GetPostList2(p *models.ParamPostList) (data []*models.ApiPostDetail, err error) {
	ids, err := redis.GetPostIDsInOrder(p)
	if err != nil {
		return
	}
	if len(ids) == 0 {
		zap.L().Warn("redis.GetPostIDsInOrder(p) return 0 data")
		return
	}
	posts, err := mysql.GetPostListByIDs(ids)
	if err != nil {
		return
	}
	//提前查好每篇帖子的投票数
	voteData, err := redis.GetPostVoteData(ids)
	if err != nil {
		return
	}
	for idx, post := range posts {
		// 根据作者id查询作者信息
		user, err := mysql.GetUserById(post.AuthorID)
		if err != nil {
			zap.L().Error("mysql.GetUserById(post.AuthorID) failed",
				zap.Int64("author_id", post.AuthorID),
				zap.Error(err))
			continue
		}
		// 根据社区id查询社区详细信息
		community, err := mysql.GetCommunityDetailByID(post.CommunityID)
		if err != nil {
			zap.L().Error("mysql.GetUserById(post.AuthorID) failed",
				zap.Int64("community_id", post.CommunityID),
				zap.Error(err))
			continue
		}
		postDetail := &models.ApiPostDetail{
			AuthorName: user.Username,
			VoteNum:         voteData[idx],
			Post:            post,
			CommunityDetail: community,
		}
		data = append(data, postDetail)
	}
	return
}

func GetCommunityPostList(p *models.ParamPostList) (data []*models.ApiPostDetail, err error) {
	// 2. 去redis查询id列表
	ids, err := redis.GetCommunityPostIDsInOrder(p)
	if err != nil {
		return
	}
	if len(ids) == 0 {
		zap.L().Warn("redis.GetPostIDsInOrder(p) return 0 data")
		return
	}
	zap.L().Debug("GetCommunityPostIDsInOrder", zap.Any("ids", ids))
	// 3. 根据id去MySQL数据库查询帖子详细信息
	// 返回的数据还要按照我给定的id的顺序返回
	posts, err := mysql.GetPostListByIDs(ids)
	if err != nil {
		return
	}
	zap.L().Debug("GetPostList2", zap.Any("posts", posts))
	// 提前查询好每篇帖子的投票数
	voteData, err := redis.GetPostVoteData(ids)
	if err != nil {
		return
	}

	// 将帖子的作者及分区信息查询出来填充到帖子中
	for idx, post := range posts {
		// 根据作者id查询作者信息
		user, err := mysql.GetUserById(post.AuthorID)
		if err != nil {
			zap.L().Error("mysql.GetUserById(post.AuthorID) failed",
				zap.Int64("author_id", post.AuthorID),
				zap.Error(err))
			continue
		}
		// 根据社区id查询社区详细信息
		community, err := mysql.GetCommunityDetailByID(post.CommunityID)
		if err != nil {
			zap.L().Error("mysql.GetUserById(post.AuthorID) failed",
				zap.Int64("community_id", post.CommunityID),
				zap.Error(err))
			continue
		}
		postDetail := &models.ApiPostDetail{
			AuthorName:      user.Username,
			VoteNum:         voteData[idx],
			Post:            post,
			CommunityDetail: community,
		}
		data = append(data, postDetail)
	}
	return
}

        通过redis目录下post.go的GetPostIDsInOrder()获取按时间排序好的帖子id,GetCommunityPostIDsInOrder()按社区获取帖子id,用字符串切片返回。

        GetCommunityPostIDsInOrder()用ZInterStore命令将社区的帖子集合和排序的有序集合进行交集计算,并将结果存储在key中。这里使用"MAX"作为聚合函数,意味着取两个有序集合中的最大分数。

        GetPostVoteData()函数获取每篇帖子投票顺序,使用了pipeline事务确保一致性,ZCount限定范围在1到999票,cmders是Exec结果的切片,遍历用.(*redis.IntCmd)类型断言其是整数Int型结果并获取其Val添加到data切片。

        getRedisKey()是在key.go自己封装的方法,获取key值,getIDsFromKey()是按照page和size分页获取分区的帖子。

func GetCommunityPostIDsInOrder(p *models.ParamPostList) ([]string, error) {
	orderKey := getRedisKey(KeyPostTimeZSet)
	if p.Order == models.OderScore  {
		orderKey = getRedisKey(KeyPostScoreZSet)
	}
	// 使用 zinterstore 把分区的帖子set与帖子分数的 zset 生成一个新的zset
	// 针对新的zset 按之前的逻辑取数据
	// 社区的key
	cKey := getRedisKey(KeyCommunitySetPF + strconv.Itoa(int(p.CommunityID)))
	// 利用缓存key减少zinterstore执行的次数
	key := orderKey + strconv.Itoa(int(p.CommunityID))
	ctx := context.Background()
	if client.Exists(ctx, key).Val() < 1 {
		// 不存在,需要计算
		pipeline := client.Pipeline()
		pipeline.ZInterStore(ctx, key, &redis.ZStore{
			Keys:    []string{cKey, orderKey},
			Aggregate: "MAX",
		}) // zinterstore 计算
		pipeline.Expire(ctx, key, 60*time.Second) // 设置超时时间
		_, err := pipeline.Exec(ctx)
		if err != nil {
			return nil, err
		}
	}
	// 存在的话就直接根据key查询ids
	return getIDsFormKey(key, p.Page, p.Size)
}

func GetPostIDsInOrder(p *models.ParamPostList) ([]string, error) {
	// 获取排序顺序
	key := getRedisKey(KeyPostTimeZSet)
	if p.Order == models.OderScore {
		key = getRedisKey(KeyPostScoreZSet)
	}
	return getIDsFormKey(key, p.Page, p.Size)
}

func GetPostVoteData(ids []string) (data []int64, err error){
	ctx := context.TODO()
	data = make([]int64, 0, len(ids))
	// 使用pipeline一次发送剁掉指令减少RTT
	pipeline := client.Pipeline()
	for _, id := range ids {
		key := getRedisKey(KeyPostVotedZSetPF+id)
		pipeline.ZCount(ctx, key, "1", "1")
	}
	cmders, err := pipeline.Exec(ctx)
	if err != nil {
		return nil, err
	}
	for _, cmder := range cmders {
		v := cmder.(*redis.IntCmd).Val()
		data = append(data, v)
	}
	return
}

func getIDsFormKey(key string, page, size int64) ([]string, error) {
	// 确定所有起点
	start := (page -1) * size
	end := start + size -1
	// 查询,按分数从大到小
	ctx := context.TODO()
	return  client.ZRevRange(ctx, key, start, end).Result()
}
package redis

// redis key注意使用命名空间方式方便查询和拆分
const (
	Prefix = "ginblog:"
	KeyPostTimeZSet = "post:time" // Zset 发帖时间为分数
	KeyPostScoreZSet = "post:score" // Zset 帖子评分为分数
	KeyPostVotedZSetPF = "post:voted" // zset 记录用户及投票类型,参数是post id
	KeyCommunitySetPF = "community:" // set 保存每个分区下帖子的id
)

func getRedisKey(key string) string {
	return Prefix + key
}

        在mysql中,也用了许多方法,目标都是为了获取帖子的相关信息。

// GetPostList 查询帖子列表函数
func GetPostList(page, size int64) (posts []*models.Post, err error) {
	sqlStr := `select 
	post_id, title, content, author_id, community_id, create_time
	from post
	ORDER BY create_time
	DESC
	limit ?,?
	`
	posts = make([]*models.Post, 0, 2) // 不要写成make([]*models.Post, 2)
	err = db.Select(&posts, sqlStr, (page-1)*size, size)
	return
}

// GetPostListByIDs 根据给定的id列表查询帖子数据
func GetPostListByIDs(ids []string) (postList []*models.Post, err error) {
	sqlStr := "select post_id, title, content, author_id, community_id, create_time from post where post_id in (?) order by FIND_IN_SET(post_id, ?)"
	// sqlx.In批量生成
	query, args, err := sqlx.In(sqlStr, ids, strings.Join(ids, ","))
	if err != nil {
		return nil, err
	}
	err = db.Select(&postList, query, args...) // 一定要加...
	return
}

// GetUserById 根据id获取用户信息
func GetUserById(uid int64) (user *models.User, err error) {
	user = new(models.User)
	sqlStr := `select user_id, username from user where user_id = ?`
	err = db.Get(user, sqlStr, uid)
	return
}

func GetCommunityDetailByID(id int64) (detail *models.CommunityDetail, err error) {
	detail = new(models.CommunityDetail)
	sqlStr := "select community_id, community_name, introduction, create_time from community where community_id = ?"
	if err := db.Get(detail, sqlStr, id); err != nil {
		if err == sql.ErrNoRows {
			err = ErrInvalidID
		}
	}
	fmt.Println("%v", detail)
	return detail, err
}

获取指定帖子

        Get的Query获取id参数,然后查询,逻辑比较简单

func GetPostDetailHandler(c *gin.Context) {
	// 1. 获取参数(从URL中获取帖子的id)
	pidStr := c.Param("id")
	pid, err := strconv.ParseInt(pidStr, 10, 64)
	if err != nil {
		zap.L().Error("get post detail with invalid param", zap.Error(err))
		ResponseError(c, CodeInvalidParam)
		return
	}

	// 2. 根据id取出帖子数据(查数据库)
	data, err := logic.GetPostById(pid)
	if err != nil {
		zap.L().Error("logic.GetPostById(pid) failed", zap.Error(err))
		ResponseError(c, CodeServerBusy)
		return
	}
	// 3. 返回响应
	ResponseSuccess(c, data)
}

        logic层 :

func GetPostById(pid int64) (data *models.ApiPostDetail, err error) {
	// 查询并组合我们接口想用的数据
	post, err := mysql.GetPostById(pid)
	if err != nil {
		zap.L().Error("mysql.GetPostById(pid) failed",
			zap.Int64("pid", pid),
			zap.Error(err))
		return
	}
	// 根据作者id查询作者信息
	user, err := mysql.GetUserById(post.AuthorID)
	if err != nil {
		zap.L().Error("mysql.GetUserById(post.AuthorID) failed",
			zap.Int64("author_id", post.AuthorID),
			zap.Error(err))
		return
	}
	// 根据社区id查询社区详细信息
	community, err := mysql.GetCommunityDetailByID(post.CommunityID)
	if err != nil {
		zap.L().Error("mysql.GetUserById(post.AuthorID) failed",
			zap.Int64("community_id", post.CommunityID),
			zap.Error(err))
		return
	}
	// 接口数据拼接
	data = &models.ApiPostDetail{
		AuthorName:      user.Username,
		Post:            post,
		CommunityDetail: community,
	}
	return
}

        用到的数据库查询方法 :

func GetPostById(pid int64) (post *models.Post, err error) {
	post = new(models.Post)
	sqlStr := `select
	post_id, title, content, author_id, community_id, create_time
	from post
	where post_id = ?
	`
	err = db.Get(post, sqlStr, pid)
	return
}

func GetUserById(uid int64) (user *models.User, err error) {
	user = new(models.User)
	sqlStr := `select user_id, username from user where user_id = ?`
	err = db.Get(user, sqlStr, uid)
	return
}

func GetCommunityDetailByID(id int64) (detail *models.CommunityDetail, err error) {
	detail = new(models.CommunityDetail)
	sqlStr := "select community_id, community_name, introduction, create_time from community where community_id = ?"
	if err := db.Get(detail, sqlStr, id); err != nil {
		if err == sql.ErrNoRows {
			err = ErrInvalidID
		}
	}
	fmt.Println("%v", detail)
	return detail, err
}

投票模块

        在这里,我们引入了鉴权中间件,因为没有登录是不能发帖和投票的

v1.Use(middlewares.JWTAuthMiddleware()) // 应用JWT认证中间件
	{
		v1.POST("/post", controller.CreatePostHandler)

		// 投票
		v1.POST("/vote", controller.PostVoteController)
	}

        具体实现在middlewares的auth.go中,这里我们规定把token放在头部 Authorization 中,并且以 Bearer 开头,解析成功后存入gin.Context的上下文中

package middlewares

import (
	"GinBlog/controller"
	"GinBlog/pkg/jwt"
	"strings"

	"github.com/gin-gonic/gin"
)

// JWTAuthMiddleware 基于JWT的认证中间件
func JWTAuthMiddleware() func(c *gin.Context) {
	return func(c *gin.Context) {
		// 客户端携带Token有三种方式 1.放在请求头 2.放在请求体 3.放在URI
		// 这里假设Token放在Header的Authorization中,并使用Bearer开头
		// Authorization: Bearer xxxxxxx.xxx.xxx  / X-TOKEN: xxx.xxx.xx
		// 这里的具体实现方式要依据你的实际业务情况决定
		authHeader := c.Request.Header.Get("Authorization")
		if authHeader == "" {
			controller.ResponseError(c, controller.CodeNeedLogin)
			c.Abort()
			return
		}
		// 按空格分割
		parts := strings.SplitN(authHeader, " ", 2)
		if !(len(parts) == 2 && parts[0] == "Bearer") {
			controller.ResponseError(c, controller.CodeInvalidToken)
			c.Abort()
			return
		}
		// parts[1]是获取到的tokenString,我们使用之前定义好的解析JWT的函数来解析它
		mc, err := jwt.ParseToken(parts[1])
		if err != nil {
			controller.ResponseError(c, controller.CodeInvalidToken)
			c.Abort()
			return
		}
		// 将当前请求的userID信息保存到请求的上下文c上
		c.Set(controller.CtxUserIDKey, mc.UserID)

		c.Next() // 后续的处理请求的函数中 可以用过c.Get(CtxUserIDKey) 来获取当前请求的用户信息
	}
}

发帖

func CreatePostHandler(c *gin.Context) {
	// 参数校验 ShouldBindJSON()
	p := new(models.Post)
	if err := c.ShouldBindJSON(p); err != nil {
		zap.L().Error("create post with invalid param")
		ResponseError(c, CodeInvalidParam)
		return
	}
	// 从c中取到发帖子的id
	userID, err := GetCurrenUser(c)
	if err != nil {
		ResponseError(c, CodeNeedLogin)
		return
	}
	p.AuthorID = userID
	// 创建帖子
	if err := logic.CreatePost(p); err != nil {
		zap.L().Error("logic.CreatePost failed", zap.Error(err))
		ResponseError(c, CodeServerBusy)
		return
	}
	// 返回响应
	ResponseSuccess(c, nil)
}

        logic层用雪花算法生成帖子id

func CreatePost(p *models.Post) error {
	// 生成post id
	p.ID = snowflake.GenID()
	err := mysql.CreatePost(p)
	if err != nil {
		return err
	}
	err = redis.CreatePost(p.ID, p.CommunityID)
	return err
}

        数据库mysql和redis都要存,redis中的TxPipeline与 Pipeline 类似,但确保操作的原子性。它将命令包装在 MULTI 和 EXEC 命令中,确保所有命令要么全部执行,要么全部不执行

// Mysql
func CreatePost(p *models.Post) (err error) {
	sqlStr := `insert into post(post_id, title, content, author_id, community_id) values (?, ?, ?, ?, ?)`
	_, err = db.Exec(sqlStr, p.ID, p.Title, p.Content, p.AuthorID, p.CommunityID)
	return err
}
//Redis
func CreatePost(postID, communityID int64) error {
	pipeline := client.TxPipeline()
	ctx := context.Background()
	// 帖子时间
	pipeline.ZAdd(ctx, getRedisKey(KeyPostTimeZSet), &redis.Z{
		Score:  float64(time.Now().Unix()),
		Member: postID,
	})

	// 帖子分数
	pipeline.ZAdd(ctx, getRedisKey(KeyPostScoreZSet), &redis.Z{
		Score:  float64(time.Now().Unix()),
		Member: postID,
	})
	// 更新:把帖子id加到社区的set
	cKey := getRedisKey(KeyCommunitySetPF + strconv.Itoa(int(communityID)))
	pipeline.SAdd(ctx, cKey, postID)
	_, err := pipeline.Exec(ctx)
	return err
}

投票

type ParamVoteData struct {
	PostID string `json:"post_id" binding:"required"`
	// 赞成1反对-1取消投票0,validator的oneof限定
	Direction int8 `json:"direction,string" binding:"oneof=1 0 -1" `
}

        获取投票的类型,然后获取当前都票的用户,转到logic处理

func PostVoteController(c *gin.Context) {
	// 获取参数
	p := new(models.ParamVoteData)
	if err := c.ShouldBindJSON(p); err != nil {
		errs, ok := err.(validator.ValidationErrors) // 类型断言
		if !ok {
			ResponseError(c, CodeInvalidParam)
			return
		}
		errData := removeTopStruct(errs.Translate(trans))
		ResponseErrorWithMsg(c, CodeInvalidParam, errData)
		return
	}
	// 逻辑处理
	userID, err := GetCurrenUser(c)
	if err != nil {
		ResponseError(c, CodeNeedLogin)
		return
	}
	if err := logic.PostVote(userID, p); err != nil {
		zap.L().Error("logic.PostVote() failed", zap.Error(err))
		ResponseError(c, CodeServerBusy)
		return
	}
	ResponseSuccess(c, nil)
}

        这里转到redis进行,投票的底层实现就是对排序分数的修改。日志不打也可以。

func PostVote(userID int64, p *models.ParamVoteData) error{
	zap.L().Debug("VoteForPost",
		zap.Int64("userID", userID),
		zap.String("postID", p.PostID),
		zap.Int8("direction", p.Direction))
	return redis.VoteForPost(strconv.Itoa(int(userID)), p.PostID, float64(p.Direction))
}

        首先检查帖子的投票是否已经超过一周的有效期限,如果超时,则返回错误。如果投票有效,函数会创建一个 Redis 事务管道。接着,根据用户投票的值,函数会执行不同的操作:

  1. 如果用户投票值为0,表示用户要取消对帖子的投票,函数会从特定有序集合中移除该用户的ID。
  2. 如果用户投票值非0,函数会检查用户是否已经对该帖子投过票,并且比较新旧投票值:
    • 如果新旧投票值相同,表示投票重复,函数返回错误。
    • 如果投票值不同,函数计算两者的差值,并根据差值更新帖子的得分(增加或减少),同时更新用户对该帖子的投票记录。

        最后,函数执行事务管道中的所有命令,并处理可能出现的错误。这个过程确保了投票操作的原子性,即所有更新要么同时成功,要么同时失败,从而保证了数据的一致性。

func VoteForPost(userID, postID string, value float64) error {
	// 判断投票情况
	ctx := context.Background()
	postTime := client.ZScore(ctx, getRedisKey(KeyPostTimeZSet), postID).Val()
	if float64(time.Now().Unix()) - postTime > oneWeek {
		return ErrVoteTimeExpire
	}
	// 更新分数

	pipeline := client.TxPipeline()
	// 记录用户投票
	if value == 0 {
		pipeline.ZRem(ctx, getRedisKey(KeyPostVotedZSetPF+postID), userID)
	} else {
		oldVal := client.ZScore(ctx, getRedisKey(KeyPostVotedZSetPF+postID), userID).Val() // 查询投票记录
		var op float64
		if value == oldVal {
			return ErrVoteRepeat
		}
		if value > oldVal {
			op = 1
		} else {
			op = -1
		}
		diff := math.Abs(oldVal - value)
		pipeline.ZIncrBy(ctx, getRedisKey(KeyPostScoreZSet), op*diff*scorePerVote, postID)
		pipeline.ZAdd(ctx, getRedisKey(KeyPostVotedZSetPF+postID), &redis.Z{
			Score: value,
			Member: userID,
		})
	}
	_, err := pipeline.Exec(ctx)
	return err
}

项目开发及部署

开发中使用air

        air可以让gin框架在开发中运行,对应修改自动重启。.air.conf如下:

# [Air](https://github.com/cosmtrek/air) TOML 格式的配置文件

# 工作目录
# 使用 . 或绝对路径,请注意 `tmp_dir` 目录必须在 `root` 目录下
root = "."
tmp_dir = "tmp"

[build]
# 只需要写你平常编译使用的shell命令。你也可以使用 `make`
# Windows平台示例: cmd = "go build -o ./tmp/main.exe ."
cmd = "go build -o ./tmp/main.exe ."
# 由`cmd`命令得到的二进制文件名
# Windows平台示例:bin = "tmp/main.exe"
bin = "tmp/main.exe"
# 自定义执行程序的命令,可以添加额外的编译标识例如添加 GIN_MODE=release
full_bin = "tmp/main.exe config/config.yaml"
# Windows平台示例:full_bin = "./tmp/main.exe"
# Linux平台示例:full_bin = "APP_ENV=dev APP_USER=air ./tmp/main.exe"
#full_bin = "./tmp/main.exe"
# 监听以下文件扩展名的文件.
include_ext = ["go", "tpl", "tmpl", "html"]
# 忽略这些文件扩展名或目录
exclude_dir = ["assets", "tmp", "vendor", "frontend/node_modules"]
# 监听以下指定目录的文件
include_dir = []
# 排除以下文件
exclude_file = []
# 如果文件更改过于频繁,则没有必要在每次更改时都触发构建。可以设置触发构建的延迟时间
delay = 1000 # ms
# 发生构建错误时,停止运行旧的二进制文件。
stop_on_error = true
# air的日志文件名,该日志文件放置在你的`tmp_dir`中
log = "air_errors.log"

[log]
# 显示日志时间
time = true

[color]
# 自定义每个部分显示的颜色。如果找不到颜色,使用原始的应用程序日志。
main = "magenta"
watcher = "cyan"
build = "yellow"
runner = "green"

[misc]
# 退出时删除tmp目录
clean_on_exit = true

makefile的编写

        makefile用于在linux部署,可能存在错误(因为项目并未在linux下测试过)

.PHONY: all build run gotool clean help

BINARY="GinBlog"

all: gotool build

build:
	CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags "-s -w" -o ./bin/${BINARY}

run:
	@go run ./main.go conf/config.yaml

gotool:
	go fmt ./
	go vet ./

clean:
	@if [ -f ${BINARY} ] ; then rm ${BINARY} ; fi

help:
	@echo "make - 格式化 Go 代码, 并编译生成二进制文件"
	@echo "make build - 编译 Go 代码, 生成二进制文件"
	@echo "make run - 直接运行 Go 代码"
	@echo "make clean - 移除二进制文件和 vim swap files"
	@echo "make gotool - 运行 Go 工具 'fmt' and 'vet'"

docker

        dockerfile生成image

FROM golang:alpine AS builder

# 为我们的镜像设置必要的环境变量
ENV GO111MODULE=on \
    GOPROXY=https://goproxy.cn,direct \
    CGO_ENABLED=0 \
    GOOS=linux \
    GOARCH=amd64

# 移动到工作目录:/build
WORKDIR /build

# 将代码复制到容器中
COPY . .

# 下载依赖信息
RUN go mod download

# 将我们的代码编译成二进制可执行文件 bubble
RUN go build -o ginblog .

###################
# 接下来创建一个小镜像
###################
FROM debian:stretch-slim

# 从builder镜像中把脚本拷贝到当前目录
COPY ./wait-for.sh /

# 从builder镜像中把静态文件拷贝到当前目录
COPY ./templates /templates
COPY ./static /static

# 从builder镜像中把配置文件拷贝到当前目录
COPY ./config /config

# 从builder镜像中把/dist/app 拷贝到当前目录
COPY --from=builder /build/ginblog /

EXPOSE 8808

RUN echo "" > /etc/apt/sources.list; \
    echo "deb http://mirrors.aliyun.com/debian buster main" >> /etc/apt/sources.list ; \
    echo "deb http://mirrors.aliyun.com/debian-security buster/updates main" >> /etc/apt/sources.list ; \
    echo "deb http://mirrors.aliyun.com/debian buster-updates main" >> /etc/apt/sources.list ; \
    set -eux; \
	apt-get update; \
	apt-get install -y \
		--no-install-recommends \
		netcat; \
    chmod 755 wait-for.sh

## 需要运行的命令
#ENTRYPOINT ["/ginblog", "config/config.yaml"]

        docker-compose启动

# yaml 配置
services:
  mysql:
    image: "oilrmutp57/mysql5.7:1.1"
    ports:
      - "3306:3306"
    command: "--default-authentication-plugin=mysql_native_password --init-file /data/application/init.sql"
    environment:
      MYSQL_ROOT_PASSWORD: "123"
      MYSQL_DATABASE: "ginblog"
      MYSQL_PASSWORD: "123"
    volumes:
      - ./init.sql:/data/application/init.sql
  redis:
    image: "redis:7.4.1"
    ports:
      - "6379:6379"
  ginblog:
    build: .
    command: sh -c "./wait-for.sh mysql:3306 redis:6379 -- ./ginblog ./config/config.yaml"
    depends_on:
      - mysql
      - redis
    ports:
      - "8808:8808"

        wait-for.sh

#!/bin/bash

TIMEOUT=30
QUIET=0

ADDRS=()

echoerr() {
  if [ "$QUIET" -ne 1 ]; then printf "%s\n" "$*" 1>&2; fi
}

usage() {
  exitcode="$1"
  cat << USAGE >&2
client:
  $cmdname host:port [host:port] [host:port] [-t timeout] [-- command args]
  -q | --quiet                        Do not output any status messages
  -t TIMEOUT | --timeout=timeout      Timeout in seconds, zero for no timeout
  -- COMMAND ARGS                     Execute command with args after the test finishes
USAGE
  exit "$exitcode"
}

wait_for() {
  results=()
  for addr in ${ADDRS[@]}
  do
    HOST=$(printf "%s\n" "$addr"| cut -d : -f 1)
    PORT=$(printf "%s\n" "$addr"| cut -d : -f 2)
    result=1
    for i in `seq $TIMEOUT` ; do
      nc -z "$HOST" "$PORT" > /dev/null 2>&1
      result=$?
      if [ $result -ne 0 ] ; then
        sleep 1
	continue
      fi
      break
    done
    results=(${results[@]} $result)
  done
  num=${#results[@]}
  for result in ${results[@]}
  do
    if [ $result -eq 0 ] ; then
	    num=`expr $num - 1`
    fi
  done
  if [ $num -eq 0 ] ; then
    if [ $# -gt 0 ] ; then
      exec "$@"
    fi
    exit 0
  fi
  echo "Operation timed out" >&2
  exit 1
}

while [ $# -gt 0 ]
do
  case "$1" in
    *:* )
    ADDRS=(${ADDRS[@]} $1)
    shift 1
    ;;
    -q | --quiet)
    QUIET=1
    shift 1
    ;;
    -t)
    TIMEOUT="$2"
    if [ "$TIMEOUT" = "" ]; then break; fi
    shift 2
    ;;
    --timeout=*)
    TIMEOUT="${1#*=}"
    shift 1
    ;;
    --)
    shift
    break
    ;;
    --help)
    usage 0
    ;;
    *)
    echoerr "Unknown argument: $1"
    usage 1
    ;;
  esac
done

if [ "${#ADDRS[@]}" -eq 0 ]; then
  echoerr "Error: you need to provide a host and port to test."
  usage 2
fi

wait_for "$@"

总结

        这是Gin框架的Blog小demo,基本实现了登录注册,社区,发帖,投票功能,一些拓展的功能没有实现,但是其核心对于入门Gin来说,是不错的练习。

;