目录
项目分析
基于Gin框架的IM即时通讯小demo,实现了用户注册,添加好友,创建和加入群聊;好友聊天,群聊天;可以自定义个人信息,群信息;在聊天中可以发送文字、图片、表情、语音。
项目地址:knoci/GinBlog
项目分层
config:配置文件
static:前端静态资源
docs:swagger
setting:负责配置文件的读取
template:html模板
router:路由层
controller:控制层,负责服务的转发
logic:逻辑层,具体功能的实现
dao:数据层,负责数据库和存储相关
pkg:第三方库
models:模板层,结构的定义
middlewares:中间件
logger:日志相关
初始化
main函数如下:
package main
import (
"GinBlog/controller"
"GinBlog/dao/mysql"
"GinBlog/dao/redis"
"GinBlog/logger"
"GinBlog/pkg/snowflake"
"GinBlog/router"
"GinBlog/setting"
"fmt"
"os"
)
// @title GinBlog项目接口文档
// @version 1.0
// @description Go web博客项目
// @contact.name knoci
// @contact.email [email protected]
// @host 127.0.0.1:8808
// @BasePath /api/v1
func main() {
// 1.读取配置
if len(os.Args) < 2 {
fmt.Println("need config file.eg: GinBlog config.yaml")
return
}
config_set := os.Args[1]
if lenth := len(config_set); config_set[lenth-4:] == ".exe" {
config_set = config_set[0 : lenth-4]
}
err := setting.Init(config_set)
if err != nil {
fmt.Printf("load setting failed: %v", err)
return
}
// 2.加载日志
err = logger.Init(setting.Conf.LogConfig, setting.Conf.Mode)
if err != nil {
fmt.Printf("load logger failed: %v", err)
return
}
// 3.配置mysql
err = mysql.Init(setting.Conf.MySQLConfig)
if err != nil {
fmt.Printf("init mysql failed: %v", err)
return
}
defer mysql.Close()
// 4.配置redis
err = redis.Init(setting.Conf.RedisConfig)
if err != nil {
fmt.Println("init redis failed: %v", err)
return
}
defer redis.Close()
// 5.获取路由并运行服务
if err := snowflake.Init(setting.Conf.StartTime, setting.Conf.MachineID); err != nil {
fmt.Printf("init snowflake failed: %v", err)
return
}
if err := controller.InitTrans("zh"); err != nil {
fmt.Println("init trans failed: %v", err)
return
}
r := router.InitRouter(setting.Conf.Mode)
err = r.Run(fmt.Sprintf(":%d", setting.Conf.Port))
if err != nil {
fmt.Printf("run server failed: %v", err)
return
}
}
一开始进来用命令行参数指令读取配置,然后去到setting下的Init()函数初始化配置,setting中还定义了嵌套结构体,用这种方法确保拿到我们每个程序需要的参数。其中WatchConfig()函数和OnConfigChange()来监视config变化实现实时更新。
需要注意的是vipeReadInConfig()函数要读入参数到结构体,一定要打上mapstructure的tag将map[string]interface{}
类型的数据解码到 Go 的结构体中。
package setting
import (
"fmt"
"github.com/fsnotify/fsnotify"
"github.com/spf13/viper"
)
var Conf = new(AppConfig)
type AppConfig struct {
Name string `mapstructure:"name"`
Mode string `mapstructure:"mode"`
Version string `mapstructure:"version"`
StartTime string `mapstructure:"start_time"`
MachineID int64 `mapstructure:"machine_id"`
Port int `mapstructure:"port"`
*LogConfig `mapstructure:"log"`
*MySQLConfig `mapstructure:"mysql"`
*RedisConfig `mapstructure:"redis"`
}
type LogConfig struct {
Level string `mapstructure:"level"`
Filename string `mapstructure:"filename"`
MaxSize int `mapstructure:"max_size"`
MaxAge int `mapstructure:"max_age"`
MaxBackups int `mapstructure:"max_backups"`
}
type MySQLConfig struct {
User string `mapstructure:"user"`
Host string `mapstructure:"host"`
Password string `mapstructure:"password"`
DbName string `mapstructure:"dbname"`
Port int `mapstructure:"port"`
MaxOpenConns int `mapstructure:"max_open_conns"`
MaxIdleConns int `mapstructure:"max_idle_conns"`
}
type RedisConfig struct {
Host string `mapstructure:"name"`
Password string `mapstructure:"password"`
Port int `mapstructure:"port"`
DB int `mapstructure:"db"`
PoolSize int `mapstructure:"pool_size"`
MinIdleConns int `mapstructure:"min_idle_conns"`
}
func Init(filepath string) (err error) {
viper.SetConfigFile(filepath)
err = viper.ReadInConfig()
if err != nil {
fmt.Print("Error loading config...")
return
}
err = viper.Unmarshal(Conf)
if err != nil {
fmt.Print("Error unmarshalling config...")
}
viper.WatchConfig()
viper.OnConfigChange(func(in fsnotify.Event) {
err = viper.Unmarshal(Conf)
if err != nil {
fmt.Print("Error unmarshalling config while config changing...")
return
}
})
return
}
把LogConfig参数给到logger的Init启动日志,日志这里用了Zap库,创建了全局替换,修改了gin的日志中间件,把日志记录到zap中。
package logger
import (
"GinBlog/setting"
"net"
"net/http"
"net/http/httputil"
"os"
"runtime/debug"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/natefinch/lumberjack" // lumberjack 是一个简单的日志文件滚动库
"go.uber.org/zap" // zap 是一个快速的、结构化的、可靠的 Go 日志库
"go.uber.org/zap/zapcore" // zap 的核心包
)
// lg 是一个全局的 zap.Logger 实例
var lg *zap.Logger
// Init 初始化日志配置
func Init(cfg *setting.LogConfig, mode string) (err error) {
// 获取日志写入器
writeSyncer := getLogWriter(cfg.Filename, cfg.MaxSize, cfg.MaxBackups, cfg.MaxAge)
// 获取编码器
encoder := getEncoder()
// 解析日志级别
var l = new(zapcore.Level)
err = l.UnmarshalText([]byte(cfg.Level))
if err != nil {
return // 如果解析失败,返回错误
}
// 创建 zap 核心对象
var core zapcore.Core
if mode == "dev" {
// 如果是开发模式,日志同时输出到终端和文件
consoleEncoder := zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig())
core = zapcore.NewTee( // Tee 表示同时写入多个 Writer
//zapcore.NewCore(encoder, writeSyncer, l),
zapcore.NewCore(consoleEncoder, zapcore.Lock(os.Stdout), zapcore.DebugLevel),
)
} else {
// 生产模式,只输出到文件
core = zapcore.NewCore(encoder, writeSyncer, l)
}
// 创建 zap 日志对象
lg = zap.New(core, zap.AddCaller()) // 添加调用者信息
// 替换全局日志对象
zap.ReplaceGlobals(lg)
// 记录初始化日志
zap.L().Info("init logger success")
return
}
// getEncoder 创建并配置日志编码器
func getEncoder() zapcore.Encoder {
// 使用生产环境的编码器配置
encoderConfig := zap.NewProductionEncoderConfig()
// 设置时间编码器
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
// 设置时间字段名称
encoderConfig.TimeKey = "time"
// 设置日志级别编码器
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
// 设置持续时间编码器
encoderConfig.EncodeDuration = zapcore.SecondsDurationEncoder
// 设置调用者编码器
encoderConfig.EncodeCaller = zapcore.ShortCallerEncoder
// 创建 JSON 编码器
return zapcore.NewJSONEncoder(encoderConfig)
}
// getLogWriter 创建并配置日志写入器
func getLogWriter(filename string, maxSize, maxBackup, maxAge int) zapcore.WriteSyncer {
// 使用 lumberjack 作为日志文件的写入器
lumberJackLogger := &lumberjack.Logger{
Filename: filename, // 日志文件路径
MaxSize: maxSize, // 文件最大大小
MaxBackups: maxBackup, // 最多备份文件数量
MaxAge: maxAge, // 文件最长保存天数
}
// 将 lumberjack 写入器包装为 zap 写入器
return zapcore.AddSync(lumberJackLogger)
}
// GinLogger 是一个 Gin 中间件,用于记录 HTTP 请求日志
func GinLogger() gin.HandlerFunc {
return func(c *gin.Context) {
// 记录请求开始时间
start := time.Now()
// 获取请求路径和查询字符串
path := c.Request.URL.Path
query := c.Request.URL.RawQuery
// 继续处理请求
c.Next()
// 计算请求处理时间
cost := time.Since(start)
// 记录日志
lg.Info(path,
zap.Int("status", c.Writer.Status()),
zap.String("method", c.Request.Method),
zap.String("path", path),
zap.String("query", query),
zap.String("ip", c.ClientIP()),
zap.String("user-agent", c.Request.UserAgent()),
zap.String("errors", c.Errors.ByType(gin.ErrorTypePrivate).String()),
zap.Duration("cost", cost),
)
}
}
// GinRecovery 是一个 Gin 中间件,用于捕获和记录 panic
func GinRecovery(stack bool) gin.HandlerFunc {
return func(c *gin.Context) {
// 使用 defer 延迟执行 panic 恢复
defer func() {
if err := recover(); err != nil {
// 检查是否是连接断开导致的错误
var brokenPipe bool
if ne, ok := err.(*net.OpError); ok {
if se, ok := ne.Err.(*os.SyscallError); ok {
if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
brokenPipe = true
}
}
}
// 记录请求信息
httpRequest, _ := httputil.DumpRequest(c.Request, false)
if brokenPipe {
// 如果是连接断开,记录错误日志
lg.Error(c.Request.URL.Path,
zap.Any("error", err),
zap.String("request", string(httpRequest)),
)
// 如果连接已断开,记录错误并终止请求
c.Error(err.(error)) // nolint: errcheck
c.Abort()
return
}
// 如果需要记录 stack trace
if stack {
lg.Error("[Recovery from panic]",
zap.Any("error", err),
zap.String("request", string(httpRequest)),
zap.String("stack", string(debug.Stack())),
)
} else {
lg.Error("[Recovery from panic]",
zap.Any("error", err),
zap.String("request", string(httpRequest)),
)
}
// 设置 HTTP 状态码并终止请求
c.AbortWithStatus(http.StatusInternalServerError)
}
}()
// 继续处理请求
c.Next()
}
}
然后是配置mysql,用的是sqlx库,这里封装一个Close()函数允许外部关闭。
package mysql
import (
"GinBlog/setting"
"fmt"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
)
var db *sqlx.DB
// Init 初始化MySQL连接
func Init(cfg *setting.MySQLConfig) (err error) {
// "user:password@tcp(host:port)/dbname"
fmt.Println(cfg.Host)
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true&loc=Local", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DbName)
db, err = sqlx.Connect("mysql", dsn)
if err != nil {
return
}
db.SetMaxOpenConns(cfg.MaxOpenConns)
db.SetMaxIdleConns(cfg.MaxIdleConns)
return
}
// Close 关闭MySQL连接
func Close() {
_ = db.Close()
}
接着是redis,用的是go-redis。
package redis
import (
"GinBlog/setting"
"fmt"
"github.com/go-redis/redis/v8"
)
var (
client *redis.Client
Nil = redis.Nil
)
func Init(cfg *setting.RedisConfig) (err error) {
client = redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Password: cfg.Password, // no password set
DB: cfg.DB, // use default DB
PoolSize: cfg.PoolSize,
MinIdleConns: cfg.MinIdleConns,
})
ctx := client.Context()
_, err = client.Ping(ctx).Result()
if err != nil {
return err
}
return nil
}
func Close() {
_ = client.Close()
}
之后用pkg的snowflake库,初始化了一个node,GenID()接收node用雪花算法生个一个id。
package snowflake
import (
"time"
sf "github.com/bwmarrin/snowflake"
)
var node *sf.Node
func Init(startTime string, machineID int64) (err error) {
var st time.Time
st, err = time.Parse("2006-01-02", startTime)
if err != nil {
return
}
sf.Epoch = st.UnixNano() / 1000000
node, err = sf.NewNode(machineID)
return
}
func GenID() int64 {
return node.Generate().Int64()
}
还进行了一下validator的设置,这是一个参数校验的库。
package controller
import (
"GinBlog/models"
"fmt"
"github.com/go-playground/validator/v10"
"reflect"
"strings"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/locales/en"
"github.com/go-playground/locales/zh"
ut "github.com/go-playground/universal-translator"
enTranslations "github.com/go-playground/validator/v10/translations/en"
zhTranslations "github.com/go-playground/validator/v10/translations/zh"
)
// 定义一个全局翻译器T
var trans ut.Translator
// InitTrans 初始化翻译器
func InitTrans(locale string) (err error) {
// 修改gin框架中的Validator引擎属性,实现自定制
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
// 注册一个获取json tag的自定义方法
v.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
// 为SignUpParam注册自定义校验方法
v.RegisterStructValidation(SignUpParamStructLevelValidation, models.ParamSignUp{})
zhT := zh.New() // 中文翻译器
enT := en.New() // 英文翻译器
// 第一个参数是备用(fallback)的语言环境
// 后面的参数是应该支持的语言环境(支持多个)
// uni := ut.New(zhT, zhT) 也是可以的
uni := ut.New(enT, zhT, enT)
// locale 通常取决于 http 请求头的 'Accept-Language'
var ok bool
// 也可以使用 uni.FindTranslator(...) 传入多个locale进行查找
trans, ok = uni.GetTranslator(locale)
if !ok {
return fmt.Errorf("uni.GetTranslator(%s) failed", locale)
}
// 注册翻译器
switch locale {
case "en":
err = enTranslations.RegisterDefaultTranslations(v, trans)
case "zh":
err = zhTranslations.RegisterDefaultTranslations(v, trans)
default:
err = enTranslations.RegisterDefaultTranslations(v, trans)
}
return
}
return
}
// removeTopStruct 去除提示信息中的结构体名称
func removeTopStruct(fields map[string]string) map[string]string {
res := map[string]string{}
for field, err := range fields {
res[field[strings.Index(field, ".")+1:]] = err
}
return res
}
// SignUpParamStructLevelValidation 自定义SignUpParam结构体校验函数
func SignUpParamStructLevelValidation(sl validator.StructLevel) {
su := sl.Current().Interface().(models.ParamSignUp)
if su.Password != su.RePassword {
// 输出错误提示信息,最后一个参数就是传递的param
sl.ReportError(su.RePassword, "re_password", "RePassword", "eqfield", "password")
}
}
终于设置完配置了,然后通过InitRouter()去注册路由了。
package router
import (
"GinBlog/controller"
"GinBlog/docs"
"GinBlog/logger"
"GinBlog/middlewares"
"github.com/gin-gonic/gin"
swaggerFiles "github.com/swaggo/files"
ginSwagger "github.com/swaggo/gin-swagger"
"net/http"
)
func InitRouter(mode string) (r *gin.Engine) {
if mode == gin.ReleaseMode {
gin.SetMode(gin.ReleaseMode) // gin设置成发布模式
}
r = gin.New()
//r.Use(logger.GinLogger(), logger.GinRecovery(true), middlewares.RateLimitMiddleware(1*time.Second, 1))
r.Use(logger.GinLogger(), logger.GinRecovery(true))
// 注册swagger api相关路由
docs.SwaggerInfo.BasePath = ""
r.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
r.LoadHTMLFiles("./templates/index.html")
r.Static("/static", "./static")
r.GET("/", func(c *gin.Context) {
c.HTML(http.StatusOK, "index.html", nil)
})
r.GET("/ping", func(c *gin.Context) {
c.String(http.StatusOK, "pong")
})
v1 := r.Group("/api/v1")
// 注册
v1.POST("/signup", controller.SignUpHandler)
// 登录
v1.POST("/login", controller.LoginHandler)
// 根据时间或分数获取帖子列表
v1.GET("/posts2", controller.GetPostListHandler2)
//v1.GET("/posts", controller.GetPostListHandler)
v1.GET("/community", controller.CommunityHandler)
v1.GET("/community/:id", controller.CommunityDetailHandler)
v1.GET("/post/:id", controller.GetPostDetailHandler)
v1.Use(middlewares.JWTAuthMiddleware()) // 应用JWT认证中间件
{
v1.POST("/post", controller.CreatePostHandler)
// 投票
v1.POST("/vote", controller.PostVoteController)
}
return
}
用户模块
用到的参数模板,即models下的params.go如下:
package models
type ParamSignUp struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
RePassword string `json:"re_password" binding:"required,eqfield=Password"`
}
type ParamLogin struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}
type ParamVoteData struct {
PostID string `json:"post_id" binding:"required"`
// 赞成1反对-1取消投票0,validator的oneof限定
Direction int8 `json:"direction,string" binding:"oneof=1 0 -1" `
}
type ParamPostList struct {
CommunityID int64 `json:"community_id" form:"community_id"` // 可以为空
Page int64 `form:"page"`
Size int64 `form:"size"`
Order string `form:"order"`
}
type ParamCommunityPostList struct {
}
返回状态的封装,在controller的response.go:
package controller
import (
"net/http"
"github.com/gin-gonic/gin"
)
/*
{
"code":
"msg":
"data":
}
*/
type ResCode int64
const (
CodeSuccess ResCode = 1000 + iota
CodeInvalidParam
CodeUserExist
CodeUserNotExist
CodeInvalidPassword
CodeServerBusy
CodeNeedLogin
CodeInvalidToken
)
var getMsg = map[ResCode]string{
CodeSuccess: "success",
CodeInvalidParam: "请求参数错误",
CodeUserExist: "用户已存在",
CodeUserNotExist: "用户不存在",
CodeInvalidPassword: "用户名或密码错误",
CodeServerBusy: "服务繁忙",
CodeNeedLogin: "需要登录",
CodeInvalidToken: "无效的token",
}
type Response struct {
Code ResCode `json:"code"`
Msg interface{} `json:"msg"`
Data interface{} `json:"data,omitempty"` // 为空忽略
}
func ResponseSuccess(c *gin.Context, data interface{}) {
rd := &Response{
Code: CodeSuccess,
Msg: getMsg[CodeSuccess],
Data: data,
}
c.JSON(http.StatusOK, rd)
}
func ResponseError(c *gin.Context, code ResCode) {
rd := &Response{
Code: code,
Msg: getMsg[code],
Data: nil,
}
c.JSON(http.StatusOK, rd)
}
func ResponseErrorWithMsg(c *gin.Context, code ResCode, msg interface{}) {
rd := &Response{
Code: code,
Msg: msg,
Data: nil,
}
c.JSON(http.StatusOK, rd)
}
注册
ShouldBindJSON解析Post传来的JSON自动绑定到参数结构体上,在controller层验证一下参数是否正确,然后转到逻辑层处理具体注册逻辑
func SignUpHandler(c *gin.Context) {
// 参数校验
var p = new(models.ParamSignUp)
if err := c.ShouldBindJSON(p); err != nil {
zap.L().Error("SignUp with invalid param", zap.Error(err))
// 判断err是不是validator.ValidationErrors类型
fmt.Println(p.Password)
fmt.Println(p.RePassword)
errs, ok := err.(validator.ValidationErrors)
if !ok {
// 非validator.ValidationErrors类型错误直接返回
ResponseError(c, CodeInvalidParam)
return
}
ResponseErrorWithMsg(c, CodeInvalidParam, removeTopStruct(errs.Translate(trans)))
return
}
// 业务处理
if err := logic.SignUp(p); err != nil {
zap.L().Error("logic.SignUp failed", zap.Error(err))
if errors.Is(err, mysql.ErrUserExist) {
ResponseError(c, CodeUserExist)
}
ResponseError(c, CodeServerBusy)
return
}
// 返回响应
ResponseSuccess(c, nil)
}
在logic的SignUp中,判断是否合法,然后存入数据库。
func SignUp(p *models.ParamSignUp) (err error) {
// 合法性判断
if err := mysql.CheckUserExist(p.Username); err != nil {
return err
}
// 保存到数据库
user := models.User{
UserID: snowflake.GenID(),
Username: p.Username,
Password: p.Password,
}
err = mysql.InsertUser(&user)
if err != nil {
return err
}
return nil
}
数据库相关的函数封装在dao的mysql,这里密码用到了md5加密。
package mysql
import (
"GinBlog/models"
"crypto/md5"
"database/sql"
"encoding/hex"
)
const salt = "knoci1337"
func InsertUser(user *models.User) (err error) {
// 密码加密
user.Password = encryptPassword(user.Password)
sqlStr := `insert into user(user_id, username, password) value(?,?,?)`
_, err = db.Exec(sqlStr, user.UserID, user.Username, user.Password)
if err != nil {
return
}
return nil
}
func CheckUserExist(username string) (err error) {
sqlStr := `select count(user_id) from user where username = ?`
var count int
if err = db.Get(&count, sqlStr, username); err != nil {
return err
}
if count > 0 {
return ErrUserExist
}
return nil
}
func encryptPassword(password string) string {
h := md5.New()
h.Write([]byte(salt))
return hex.EncodeToString(h.Sum([]byte(password)))
}
func Login(user *models.User) (err error) {
oPassword := user.Password
sqlStr := `select user_id, username , password from user where username = ?`
err = db.Get(user, sqlStr, user.Username)
if err == sql.ErrNoRows {
return ErrUserNotExist
}
if err != nil {
// 查询数据库失败
return err
}
password := encryptPassword(oPassword)
if password != user.Password {
return ErrInvalidPassword
}
return
}
// GetUserById 根据id获取用户信息
func GetUserById(uid int64) (user *models.User, err error) {
user = new(models.User)
sqlStr := `select user_id, username from user where user_id = ?`
err = db.Get(user, sqlStr, uid)
return
}
登录
controller层
func LoginHandler(c *gin.Context) {
// 参数校验
p := new(models.ParamLogin)
if err := c.ShouldBindJSON(p); err != nil {
zap.L().Error("Login with invalid param", zap.Error(err))
errs, ok := err.(validator.ValidationErrors)
if !ok {
ResponseError(c, CodeInvalidParam)
return
}
ResponseErrorWithMsg(c, CodeInvalidParam, removeTopStruct(errs.Translate(trans)))
return
}
// 2.业务逻辑处理
user, err := logic.Login(p)
if err != nil {
zap.L().Error("logic.Login failed", zap.String("username", p.Username), zap.Error(err))
if errors.Is(err, mysql.ErrUserNotExist) {
ResponseError(c, CodeUserNotExist)
return
}
ResponseError(c, CodeInvalidPassword)
return
}
// 3.返回响应
ResponseSuccess(c, gin.H{
"user_id": fmt.Sprintf("%d", user.UserID), // id值大于1<<53-1 int64类型的最大值是1<<63-1
"user_name": user.Username,
"token": user.Token,
})
}
logic层,登录成功后生成jwt的token。
func Login(p *models.ParamLogin) (*models.User,error) {
user := &models.User{
Username: p.Username,
Password: p.Password,
}
// 传递的是指针,就能拿到user.UserID
if err := mysql.Login(user); err != nil {
return nil, err
}
// 生成JWT
token, err := jwt.GenToken(user.UserID, user.Username)
if err != nil {
return nil, err
}
user.Token = token
return user, err
}
jwt的生成以及解析函数封装,在pkg目录jwt的jwt.go:
package jwt
import (
"errors"
"time"
"github.com/spf13/viper"
"github.com/dgrijalva/jwt-go"
)
var mySecret = []byte("人生不过一场梦")
// MyClaims 自定义声明结构体并内嵌jwt.StandardClaims
// jwt包自带的jwt.StandardClaims只包含了官方字段
// 我们这里需要额外记录一个username字段,所以要自定义结构体
// 如果想要保存更多信息,都可以添加到这个结构体中
type MyClaims struct {
UserID int64 `json:"user_id"`
Username string `json:"username"`
jwt.StandardClaims
}
// GenToken 生成JWT
func GenToken(userID int64, username string) (string, error) {
// 创建一个我们自己的声明的数据
c := MyClaims{
userID,
"username", // 自定义字段
jwt.StandardClaims{
ExpiresAt: time.Now().Add(
time.Duration(viper.GetInt("auth.jwt_expire")) * time.Hour).Unix(), // 过期时间
Issuer: "GinBlog", // 签发人
},
}
// 使用指定的签名方法创建签名对象
token := jwt.NewWithClaims(jwt.SigningMethodHS256, c)
// 使用指定的secret签名并获得完整的编码后的字符串token
return token.SignedString(mySecret)
}
// ParseToken 解析JWT
func ParseToken(tokenString string) (*MyClaims, error) {
// 解析token
var mc = new(MyClaims)
token, err := jwt.ParseWithClaims(tokenString, mc, func(token *jwt.Token) (i interface{}, err error) {
return mySecret, nil
})
if err != nil {
return nil, err
}
if token.Valid { // 校验token
return mc, nil
}
return nil, errors.New("invalid token")
}
社区模块
models的模板如下:
package models
import "time"
type Community struct {
ID int64 `json:"id" db:"community_id"`
Name string `json:"name" db:"community_name"`
}
type CommunityDetail struct {
ID int64 `json:"id" db:"community_id"`
Name string `json:"name" db:"community_name"`
Introduction string `json:"introduction" db:"introduction"`
CreateTime time.Time `json:"create_time" db:"create_time"`
}
所有社区
v1.GET "/community"是获取所有社区,CommunityHandler转到逻辑层
func CommunityHandler(c *gin.Context) {
//查询到所有community
communities, err := logic.GetCommunityList()
if err != nil {
zap.L().Error("logic.GetCommunityList() failed", zap.Error(err))
ResponseError(c, CodeServerBusy)
return
}
ResponseSuccess(c, communities)
}
logic层转到数据库,在mysql的community.go
func GetCommunityList() (data []*models.Community, err error) {
//查找到所有的community并且返回
data, err = mysql.GetCommunityList()
if err != nil {
zap.L().Error("logic.GetCommunityList() failed", zap.Error(err))
return nil, err
}
return data, err
}
数据库里用sql语句直接查询
func GetCommunityList() (data []*models.Community, err error) {
sqlStr := "select community_id, community_name from community"
if err := db.Select(&data, sqlStr); err != nil {
if err == sql.ErrNoRows {
zap.L().Warn("there is no community in db")
err = nil
}
}
return
}
指定社区
Get请求Query一个id查询单个社区,和查询所有大差不差。
func CommunityDetailHandler(c *gin.Context) {
// 获取社区id
communityID := c.Param("id")
id, err := strconv.ParseInt(communityID, 10, 64)
if err != nil {
ResponseError(c, CodeInvalidParam)
return
}
data , err := logic.GetCommunityDetail(id)
if err != nil {
zap.L().Error("logic.GetCommunityDetail() failed", zap.Error(err))
ResponseError(c, CodeServerBusy)
return
}
ResponseSuccess(c, data)
}
func GetCommunityDetail(id int64) (*models.CommunityDetail, error) {
return mysql.GetCommunityDetailByID(id)
}
func GetCommunityDetailByID(id int64) (detail *models.CommunityDetail, err error) {
detail = new(models.CommunityDetail)
sqlStr := "select community_id, community_name, introduction, create_time from community where community_id = ?"
if err := db.Get(detail, sqlStr, id); err != nil {
if err == sql.ErrNoRows {
err = ErrInvalidID
}
}
fmt.Println("%v", detail)
return detail, err
}
帖子模块
models的模板如下:
package models
import "time"
const (
OderTime = "time"
OderScore = "score"
)
type Post struct {
ID int64 `json:"id,string" db:"post_id"` // 帖子id
AuthorID int64 `json:"author_id" db:"author_id"` // 作者id
CommunityID int64 `json:"community_id" db:"community_id" binding:"required"` // 社区id
Status int32 `json:"status" db:"status"` // 帖子状态
Title string `json:"title" db:"title" binding:"required"` // 帖子标题
Content string `json:"content" db:"content" binding:"required"` // 帖子内容
CreateTime time.Time `json:"create_time" db:"create_time"` // 帖子创建时间
}
// ApiPostDetail 帖子详情接口的结构体
type ApiPostDetail struct {
AuthorName string `json:"author_name"` // 作者
VoteNum int64 `json:"vote_num"` // 投票数
*Post // 嵌入帖子结构体
*CommunityDetail `json:"community"` // 嵌入社区信息
}
顺序获取帖子
post2实现按时间或分数顺序获取帖子,这里默认按时间
func GetPostListHandler2(c *gin.Context) {
// 获取参数
p := &models.ParamPostList{
Page: 1,
Size: 10,
Order: models.OderTime,
}
if err := c.ShouldBindQuery(p); err != nil {
zap.L().Error("GetPostListHandler2 with invalid param", zap.Error(err))
return
}
// 获取帖子列表
data, err := logic.GetPostListNew(p)
if err != nil {
zap.L().Error("logic GetPostList() failed", zap.Error(err))
return
}
// 去redis查询id列表
// 根据id去数据库查询帖子详情信息
ResponseSuccess(c, data)
}
转到logic中,由一个GetPostListNew()函数实现判断和分发
func GetPostListNew(p *models.ParamPostList) (data []*models.ApiPostDetail, err error) {
// 根据请求参数的不同,执行不同的逻辑。
if p.CommunityID == 0 {
// 查所有
data, err = GetPostList2(p)
} else {
// 根据社区id查询
data, err = GetCommunityPostList(p)
}
if err != nil {
zap.L().Error("GetPostListNew failed", zap.Error(err))
return nil, err
}
return
}
返回的是一个ApiPostDetail类型的切片,包含了坐着,票数,帖子信息,社区信息。
func GetPostList2(p *models.ParamPostList) (data []*models.ApiPostDetail, err error) {
ids, err := redis.GetPostIDsInOrder(p)
if err != nil {
return
}
if len(ids) == 0 {
zap.L().Warn("redis.GetPostIDsInOrder(p) return 0 data")
return
}
posts, err := mysql.GetPostListByIDs(ids)
if err != nil {
return
}
//提前查好每篇帖子的投票数
voteData, err := redis.GetPostVoteData(ids)
if err != nil {
return
}
for idx, post := range posts {
// 根据作者id查询作者信息
user, err := mysql.GetUserById(post.AuthorID)
if err != nil {
zap.L().Error("mysql.GetUserById(post.AuthorID) failed",
zap.Int64("author_id", post.AuthorID),
zap.Error(err))
continue
}
// 根据社区id查询社区详细信息
community, err := mysql.GetCommunityDetailByID(post.CommunityID)
if err != nil {
zap.L().Error("mysql.GetUserById(post.AuthorID) failed",
zap.Int64("community_id", post.CommunityID),
zap.Error(err))
continue
}
postDetail := &models.ApiPostDetail{
AuthorName: user.Username,
VoteNum: voteData[idx],
Post: post,
CommunityDetail: community,
}
data = append(data, postDetail)
}
return
}
func GetCommunityPostList(p *models.ParamPostList) (data []*models.ApiPostDetail, err error) {
// 2. 去redis查询id列表
ids, err := redis.GetCommunityPostIDsInOrder(p)
if err != nil {
return
}
if len(ids) == 0 {
zap.L().Warn("redis.GetPostIDsInOrder(p) return 0 data")
return
}
zap.L().Debug("GetCommunityPostIDsInOrder", zap.Any("ids", ids))
// 3. 根据id去MySQL数据库查询帖子详细信息
// 返回的数据还要按照我给定的id的顺序返回
posts, err := mysql.GetPostListByIDs(ids)
if err != nil {
return
}
zap.L().Debug("GetPostList2", zap.Any("posts", posts))
// 提前查询好每篇帖子的投票数
voteData, err := redis.GetPostVoteData(ids)
if err != nil {
return
}
// 将帖子的作者及分区信息查询出来填充到帖子中
for idx, post := range posts {
// 根据作者id查询作者信息
user, err := mysql.GetUserById(post.AuthorID)
if err != nil {
zap.L().Error("mysql.GetUserById(post.AuthorID) failed",
zap.Int64("author_id", post.AuthorID),
zap.Error(err))
continue
}
// 根据社区id查询社区详细信息
community, err := mysql.GetCommunityDetailByID(post.CommunityID)
if err != nil {
zap.L().Error("mysql.GetUserById(post.AuthorID) failed",
zap.Int64("community_id", post.CommunityID),
zap.Error(err))
continue
}
postDetail := &models.ApiPostDetail{
AuthorName: user.Username,
VoteNum: voteData[idx],
Post: post,
CommunityDetail: community,
}
data = append(data, postDetail)
}
return
}
通过redis目录下post.go的GetPostIDsInOrder()获取按时间排序好的帖子id,GetCommunityPostIDsInOrder()按社区获取帖子id,用字符串切片返回。
GetCommunityPostIDsInOrder()用ZInterStore
命令将社区的帖子集合和排序的有序集合进行交集计算,并将结果存储在key
中。这里使用"MAX"
作为聚合函数,意味着取两个有序集合中的最大分数。
GetPostVoteData()函数获取每篇帖子投票顺序,使用了pipeline事务确保一致性,ZCount限定范围在1到999票,cmders是Exec结果的切片,遍历用.(*redis.IntCmd)类型断言其是整数Int型结果并获取其Val添加到data切片。
getRedisKey()是在key.go自己封装的方法,获取key值,getIDsFromKey()是按照page和size分页获取分区的帖子。
func GetCommunityPostIDsInOrder(p *models.ParamPostList) ([]string, error) {
orderKey := getRedisKey(KeyPostTimeZSet)
if p.Order == models.OderScore {
orderKey = getRedisKey(KeyPostScoreZSet)
}
// 使用 zinterstore 把分区的帖子set与帖子分数的 zset 生成一个新的zset
// 针对新的zset 按之前的逻辑取数据
// 社区的key
cKey := getRedisKey(KeyCommunitySetPF + strconv.Itoa(int(p.CommunityID)))
// 利用缓存key减少zinterstore执行的次数
key := orderKey + strconv.Itoa(int(p.CommunityID))
ctx := context.Background()
if client.Exists(ctx, key).Val() < 1 {
// 不存在,需要计算
pipeline := client.Pipeline()
pipeline.ZInterStore(ctx, key, &redis.ZStore{
Keys: []string{cKey, orderKey},
Aggregate: "MAX",
}) // zinterstore 计算
pipeline.Expire(ctx, key, 60*time.Second) // 设置超时时间
_, err := pipeline.Exec(ctx)
if err != nil {
return nil, err
}
}
// 存在的话就直接根据key查询ids
return getIDsFormKey(key, p.Page, p.Size)
}
func GetPostIDsInOrder(p *models.ParamPostList) ([]string, error) {
// 获取排序顺序
key := getRedisKey(KeyPostTimeZSet)
if p.Order == models.OderScore {
key = getRedisKey(KeyPostScoreZSet)
}
return getIDsFormKey(key, p.Page, p.Size)
}
func GetPostVoteData(ids []string) (data []int64, err error){
ctx := context.TODO()
data = make([]int64, 0, len(ids))
// 使用pipeline一次发送剁掉指令减少RTT
pipeline := client.Pipeline()
for _, id := range ids {
key := getRedisKey(KeyPostVotedZSetPF+id)
pipeline.ZCount(ctx, key, "1", "1")
}
cmders, err := pipeline.Exec(ctx)
if err != nil {
return nil, err
}
for _, cmder := range cmders {
v := cmder.(*redis.IntCmd).Val()
data = append(data, v)
}
return
}
func getIDsFormKey(key string, page, size int64) ([]string, error) {
// 确定所有起点
start := (page -1) * size
end := start + size -1
// 查询,按分数从大到小
ctx := context.TODO()
return client.ZRevRange(ctx, key, start, end).Result()
}
package redis
// redis key注意使用命名空间方式方便查询和拆分
const (
Prefix = "ginblog:"
KeyPostTimeZSet = "post:time" // Zset 发帖时间为分数
KeyPostScoreZSet = "post:score" // Zset 帖子评分为分数
KeyPostVotedZSetPF = "post:voted" // zset 记录用户及投票类型,参数是post id
KeyCommunitySetPF = "community:" // set 保存每个分区下帖子的id
)
func getRedisKey(key string) string {
return Prefix + key
}
在mysql中,也用了许多方法,目标都是为了获取帖子的相关信息。
// GetPostList 查询帖子列表函数
func GetPostList(page, size int64) (posts []*models.Post, err error) {
sqlStr := `select
post_id, title, content, author_id, community_id, create_time
from post
ORDER BY create_time
DESC
limit ?,?
`
posts = make([]*models.Post, 0, 2) // 不要写成make([]*models.Post, 2)
err = db.Select(&posts, sqlStr, (page-1)*size, size)
return
}
// GetPostListByIDs 根据给定的id列表查询帖子数据
func GetPostListByIDs(ids []string) (postList []*models.Post, err error) {
sqlStr := "select post_id, title, content, author_id, community_id, create_time from post where post_id in (?) order by FIND_IN_SET(post_id, ?)"
// sqlx.In批量生成
query, args, err := sqlx.In(sqlStr, ids, strings.Join(ids, ","))
if err != nil {
return nil, err
}
err = db.Select(&postList, query, args...) // 一定要加...
return
}
// GetUserById 根据id获取用户信息
func GetUserById(uid int64) (user *models.User, err error) {
user = new(models.User)
sqlStr := `select user_id, username from user where user_id = ?`
err = db.Get(user, sqlStr, uid)
return
}
func GetCommunityDetailByID(id int64) (detail *models.CommunityDetail, err error) {
detail = new(models.CommunityDetail)
sqlStr := "select community_id, community_name, introduction, create_time from community where community_id = ?"
if err := db.Get(detail, sqlStr, id); err != nil {
if err == sql.ErrNoRows {
err = ErrInvalidID
}
}
fmt.Println("%v", detail)
return detail, err
}
获取指定帖子
Get的Query获取id参数,然后查询,逻辑比较简单
func GetPostDetailHandler(c *gin.Context) {
// 1. 获取参数(从URL中获取帖子的id)
pidStr := c.Param("id")
pid, err := strconv.ParseInt(pidStr, 10, 64)
if err != nil {
zap.L().Error("get post detail with invalid param", zap.Error(err))
ResponseError(c, CodeInvalidParam)
return
}
// 2. 根据id取出帖子数据(查数据库)
data, err := logic.GetPostById(pid)
if err != nil {
zap.L().Error("logic.GetPostById(pid) failed", zap.Error(err))
ResponseError(c, CodeServerBusy)
return
}
// 3. 返回响应
ResponseSuccess(c, data)
}
logic层 :
func GetPostById(pid int64) (data *models.ApiPostDetail, err error) {
// 查询并组合我们接口想用的数据
post, err := mysql.GetPostById(pid)
if err != nil {
zap.L().Error("mysql.GetPostById(pid) failed",
zap.Int64("pid", pid),
zap.Error(err))
return
}
// 根据作者id查询作者信息
user, err := mysql.GetUserById(post.AuthorID)
if err != nil {
zap.L().Error("mysql.GetUserById(post.AuthorID) failed",
zap.Int64("author_id", post.AuthorID),
zap.Error(err))
return
}
// 根据社区id查询社区详细信息
community, err := mysql.GetCommunityDetailByID(post.CommunityID)
if err != nil {
zap.L().Error("mysql.GetUserById(post.AuthorID) failed",
zap.Int64("community_id", post.CommunityID),
zap.Error(err))
return
}
// 接口数据拼接
data = &models.ApiPostDetail{
AuthorName: user.Username,
Post: post,
CommunityDetail: community,
}
return
}
用到的数据库查询方法 :
func GetPostById(pid int64) (post *models.Post, err error) {
post = new(models.Post)
sqlStr := `select
post_id, title, content, author_id, community_id, create_time
from post
where post_id = ?
`
err = db.Get(post, sqlStr, pid)
return
}
func GetUserById(uid int64) (user *models.User, err error) {
user = new(models.User)
sqlStr := `select user_id, username from user where user_id = ?`
err = db.Get(user, sqlStr, uid)
return
}
func GetCommunityDetailByID(id int64) (detail *models.CommunityDetail, err error) {
detail = new(models.CommunityDetail)
sqlStr := "select community_id, community_name, introduction, create_time from community where community_id = ?"
if err := db.Get(detail, sqlStr, id); err != nil {
if err == sql.ErrNoRows {
err = ErrInvalidID
}
}
fmt.Println("%v", detail)
return detail, err
}
投票模块
在这里,我们引入了鉴权中间件,因为没有登录是不能发帖和投票的
v1.Use(middlewares.JWTAuthMiddleware()) // 应用JWT认证中间件
{
v1.POST("/post", controller.CreatePostHandler)
// 投票
v1.POST("/vote", controller.PostVoteController)
}
具体实现在middlewares的auth.go中,这里我们规定把token放在头部 Authorization 中,并且以 Bearer 开头,解析成功后存入gin.Context的上下文中
package middlewares
import (
"GinBlog/controller"
"GinBlog/pkg/jwt"
"strings"
"github.com/gin-gonic/gin"
)
// JWTAuthMiddleware 基于JWT的认证中间件
func JWTAuthMiddleware() func(c *gin.Context) {
return func(c *gin.Context) {
// 客户端携带Token有三种方式 1.放在请求头 2.放在请求体 3.放在URI
// 这里假设Token放在Header的Authorization中,并使用Bearer开头
// Authorization: Bearer xxxxxxx.xxx.xxx / X-TOKEN: xxx.xxx.xx
// 这里的具体实现方式要依据你的实际业务情况决定
authHeader := c.Request.Header.Get("Authorization")
if authHeader == "" {
controller.ResponseError(c, controller.CodeNeedLogin)
c.Abort()
return
}
// 按空格分割
parts := strings.SplitN(authHeader, " ", 2)
if !(len(parts) == 2 && parts[0] == "Bearer") {
controller.ResponseError(c, controller.CodeInvalidToken)
c.Abort()
return
}
// parts[1]是获取到的tokenString,我们使用之前定义好的解析JWT的函数来解析它
mc, err := jwt.ParseToken(parts[1])
if err != nil {
controller.ResponseError(c, controller.CodeInvalidToken)
c.Abort()
return
}
// 将当前请求的userID信息保存到请求的上下文c上
c.Set(controller.CtxUserIDKey, mc.UserID)
c.Next() // 后续的处理请求的函数中 可以用过c.Get(CtxUserIDKey) 来获取当前请求的用户信息
}
}
发帖
func CreatePostHandler(c *gin.Context) {
// 参数校验 ShouldBindJSON()
p := new(models.Post)
if err := c.ShouldBindJSON(p); err != nil {
zap.L().Error("create post with invalid param")
ResponseError(c, CodeInvalidParam)
return
}
// 从c中取到发帖子的id
userID, err := GetCurrenUser(c)
if err != nil {
ResponseError(c, CodeNeedLogin)
return
}
p.AuthorID = userID
// 创建帖子
if err := logic.CreatePost(p); err != nil {
zap.L().Error("logic.CreatePost failed", zap.Error(err))
ResponseError(c, CodeServerBusy)
return
}
// 返回响应
ResponseSuccess(c, nil)
}
logic层用雪花算法生成帖子id
func CreatePost(p *models.Post) error {
// 生成post id
p.ID = snowflake.GenID()
err := mysql.CreatePost(p)
if err != nil {
return err
}
err = redis.CreatePost(p.ID, p.CommunityID)
return err
}
数据库mysql和redis都要存,redis中的TxPipeline
与 Pipeline
类似,但确保操作的原子性。它将命令包装在 MULTI
和 EXEC
命令中,确保所有命令要么全部执行,要么全部不执行
// Mysql
func CreatePost(p *models.Post) (err error) {
sqlStr := `insert into post(post_id, title, content, author_id, community_id) values (?, ?, ?, ?, ?)`
_, err = db.Exec(sqlStr, p.ID, p.Title, p.Content, p.AuthorID, p.CommunityID)
return err
}
//Redis
func CreatePost(postID, communityID int64) error {
pipeline := client.TxPipeline()
ctx := context.Background()
// 帖子时间
pipeline.ZAdd(ctx, getRedisKey(KeyPostTimeZSet), &redis.Z{
Score: float64(time.Now().Unix()),
Member: postID,
})
// 帖子分数
pipeline.ZAdd(ctx, getRedisKey(KeyPostScoreZSet), &redis.Z{
Score: float64(time.Now().Unix()),
Member: postID,
})
// 更新:把帖子id加到社区的set
cKey := getRedisKey(KeyCommunitySetPF + strconv.Itoa(int(communityID)))
pipeline.SAdd(ctx, cKey, postID)
_, err := pipeline.Exec(ctx)
return err
}
投票
type ParamVoteData struct {
PostID string `json:"post_id" binding:"required"`
// 赞成1反对-1取消投票0,validator的oneof限定
Direction int8 `json:"direction,string" binding:"oneof=1 0 -1" `
}
获取投票的类型,然后获取当前都票的用户,转到logic处理
func PostVoteController(c *gin.Context) {
// 获取参数
p := new(models.ParamVoteData)
if err := c.ShouldBindJSON(p); err != nil {
errs, ok := err.(validator.ValidationErrors) // 类型断言
if !ok {
ResponseError(c, CodeInvalidParam)
return
}
errData := removeTopStruct(errs.Translate(trans))
ResponseErrorWithMsg(c, CodeInvalidParam, errData)
return
}
// 逻辑处理
userID, err := GetCurrenUser(c)
if err != nil {
ResponseError(c, CodeNeedLogin)
return
}
if err := logic.PostVote(userID, p); err != nil {
zap.L().Error("logic.PostVote() failed", zap.Error(err))
ResponseError(c, CodeServerBusy)
return
}
ResponseSuccess(c, nil)
}
这里转到redis进行,投票的底层实现就是对排序分数的修改。日志不打也可以。
func PostVote(userID int64, p *models.ParamVoteData) error{
zap.L().Debug("VoteForPost",
zap.Int64("userID", userID),
zap.String("postID", p.PostID),
zap.Int8("direction", p.Direction))
return redis.VoteForPost(strconv.Itoa(int(userID)), p.PostID, float64(p.Direction))
}
首先检查帖子的投票是否已经超过一周的有效期限,如果超时,则返回错误。如果投票有效,函数会创建一个 Redis 事务管道。接着,根据用户投票的值,函数会执行不同的操作:
- 如果用户投票值为0,表示用户要取消对帖子的投票,函数会从特定有序集合中移除该用户的ID。
- 如果用户投票值非0,函数会检查用户是否已经对该帖子投过票,并且比较新旧投票值:
- 如果新旧投票值相同,表示投票重复,函数返回错误。
- 如果投票值不同,函数计算两者的差值,并根据差值更新帖子的得分(增加或减少),同时更新用户对该帖子的投票记录。
最后,函数执行事务管道中的所有命令,并处理可能出现的错误。这个过程确保了投票操作的原子性,即所有更新要么同时成功,要么同时失败,从而保证了数据的一致性。
func VoteForPost(userID, postID string, value float64) error {
// 判断投票情况
ctx := context.Background()
postTime := client.ZScore(ctx, getRedisKey(KeyPostTimeZSet), postID).Val()
if float64(time.Now().Unix()) - postTime > oneWeek {
return ErrVoteTimeExpire
}
// 更新分数
pipeline := client.TxPipeline()
// 记录用户投票
if value == 0 {
pipeline.ZRem(ctx, getRedisKey(KeyPostVotedZSetPF+postID), userID)
} else {
oldVal := client.ZScore(ctx, getRedisKey(KeyPostVotedZSetPF+postID), userID).Val() // 查询投票记录
var op float64
if value == oldVal {
return ErrVoteRepeat
}
if value > oldVal {
op = 1
} else {
op = -1
}
diff := math.Abs(oldVal - value)
pipeline.ZIncrBy(ctx, getRedisKey(KeyPostScoreZSet), op*diff*scorePerVote, postID)
pipeline.ZAdd(ctx, getRedisKey(KeyPostVotedZSetPF+postID), &redis.Z{
Score: value,
Member: userID,
})
}
_, err := pipeline.Exec(ctx)
return err
}
项目开发及部署
开发中使用air
air可以让gin框架在开发中运行,对应修改自动重启。.air.conf如下:
# [Air](https://github.com/cosmtrek/air) TOML 格式的配置文件
# 工作目录
# 使用 . 或绝对路径,请注意 `tmp_dir` 目录必须在 `root` 目录下
root = "."
tmp_dir = "tmp"
[build]
# 只需要写你平常编译使用的shell命令。你也可以使用 `make`
# Windows平台示例: cmd = "go build -o ./tmp/main.exe ."
cmd = "go build -o ./tmp/main.exe ."
# 由`cmd`命令得到的二进制文件名
# Windows平台示例:bin = "tmp/main.exe"
bin = "tmp/main.exe"
# 自定义执行程序的命令,可以添加额外的编译标识例如添加 GIN_MODE=release
full_bin = "tmp/main.exe config/config.yaml"
# Windows平台示例:full_bin = "./tmp/main.exe"
# Linux平台示例:full_bin = "APP_ENV=dev APP_USER=air ./tmp/main.exe"
#full_bin = "./tmp/main.exe"
# 监听以下文件扩展名的文件.
include_ext = ["go", "tpl", "tmpl", "html"]
# 忽略这些文件扩展名或目录
exclude_dir = ["assets", "tmp", "vendor", "frontend/node_modules"]
# 监听以下指定目录的文件
include_dir = []
# 排除以下文件
exclude_file = []
# 如果文件更改过于频繁,则没有必要在每次更改时都触发构建。可以设置触发构建的延迟时间
delay = 1000 # ms
# 发生构建错误时,停止运行旧的二进制文件。
stop_on_error = true
# air的日志文件名,该日志文件放置在你的`tmp_dir`中
log = "air_errors.log"
[log]
# 显示日志时间
time = true
[color]
# 自定义每个部分显示的颜色。如果找不到颜色,使用原始的应用程序日志。
main = "magenta"
watcher = "cyan"
build = "yellow"
runner = "green"
[misc]
# 退出时删除tmp目录
clean_on_exit = true
makefile的编写
makefile用于在linux部署,可能存在错误(因为项目并未在linux下测试过)
.PHONY: all build run gotool clean help
BINARY="GinBlog"
all: gotool build
build:
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags "-s -w" -o ./bin/${BINARY}
run:
@go run ./main.go conf/config.yaml
gotool:
go fmt ./
go vet ./
clean:
@if [ -f ${BINARY} ] ; then rm ${BINARY} ; fi
help:
@echo "make - 格式化 Go 代码, 并编译生成二进制文件"
@echo "make build - 编译 Go 代码, 生成二进制文件"
@echo "make run - 直接运行 Go 代码"
@echo "make clean - 移除二进制文件和 vim swap files"
@echo "make gotool - 运行 Go 工具 'fmt' and 'vet'"
docker
dockerfile生成image
FROM golang:alpine AS builder
# 为我们的镜像设置必要的环境变量
ENV GO111MODULE=on \
GOPROXY=https://goproxy.cn,direct \
CGO_ENABLED=0 \
GOOS=linux \
GOARCH=amd64
# 移动到工作目录:/build
WORKDIR /build
# 将代码复制到容器中
COPY . .
# 下载依赖信息
RUN go mod download
# 将我们的代码编译成二进制可执行文件 bubble
RUN go build -o ginblog .
###################
# 接下来创建一个小镜像
###################
FROM debian:stretch-slim
# 从builder镜像中把脚本拷贝到当前目录
COPY ./wait-for.sh /
# 从builder镜像中把静态文件拷贝到当前目录
COPY ./templates /templates
COPY ./static /static
# 从builder镜像中把配置文件拷贝到当前目录
COPY ./config /config
# 从builder镜像中把/dist/app 拷贝到当前目录
COPY --from=builder /build/ginblog /
EXPOSE 8808
RUN echo "" > /etc/apt/sources.list; \
echo "deb http://mirrors.aliyun.com/debian buster main" >> /etc/apt/sources.list ; \
echo "deb http://mirrors.aliyun.com/debian-security buster/updates main" >> /etc/apt/sources.list ; \
echo "deb http://mirrors.aliyun.com/debian buster-updates main" >> /etc/apt/sources.list ; \
set -eux; \
apt-get update; \
apt-get install -y \
--no-install-recommends \
netcat; \
chmod 755 wait-for.sh
## 需要运行的命令
#ENTRYPOINT ["/ginblog", "config/config.yaml"]
docker-compose启动
# yaml 配置
services:
mysql:
image: "oilrmutp57/mysql5.7:1.1"
ports:
- "3306:3306"
command: "--default-authentication-plugin=mysql_native_password --init-file /data/application/init.sql"
environment:
MYSQL_ROOT_PASSWORD: "123"
MYSQL_DATABASE: "ginblog"
MYSQL_PASSWORD: "123"
volumes:
- ./init.sql:/data/application/init.sql
redis:
image: "redis:7.4.1"
ports:
- "6379:6379"
ginblog:
build: .
command: sh -c "./wait-for.sh mysql:3306 redis:6379 -- ./ginblog ./config/config.yaml"
depends_on:
- mysql
- redis
ports:
- "8808:8808"
wait-for.sh
#!/bin/bash
TIMEOUT=30
QUIET=0
ADDRS=()
echoerr() {
if [ "$QUIET" -ne 1 ]; then printf "%s\n" "$*" 1>&2; fi
}
usage() {
exitcode="$1"
cat << USAGE >&2
client:
$cmdname host:port [host:port] [host:port] [-t timeout] [-- command args]
-q | --quiet Do not output any status messages
-t TIMEOUT | --timeout=timeout Timeout in seconds, zero for no timeout
-- COMMAND ARGS Execute command with args after the test finishes
USAGE
exit "$exitcode"
}
wait_for() {
results=()
for addr in ${ADDRS[@]}
do
HOST=$(printf "%s\n" "$addr"| cut -d : -f 1)
PORT=$(printf "%s\n" "$addr"| cut -d : -f 2)
result=1
for i in `seq $TIMEOUT` ; do
nc -z "$HOST" "$PORT" > /dev/null 2>&1
result=$?
if [ $result -ne 0 ] ; then
sleep 1
continue
fi
break
done
results=(${results[@]} $result)
done
num=${#results[@]}
for result in ${results[@]}
do
if [ $result -eq 0 ] ; then
num=`expr $num - 1`
fi
done
if [ $num -eq 0 ] ; then
if [ $# -gt 0 ] ; then
exec "$@"
fi
exit 0
fi
echo "Operation timed out" >&2
exit 1
}
while [ $# -gt 0 ]
do
case "$1" in
*:* )
ADDRS=(${ADDRS[@]} $1)
shift 1
;;
-q | --quiet)
QUIET=1
shift 1
;;
-t)
TIMEOUT="$2"
if [ "$TIMEOUT" = "" ]; then break; fi
shift 2
;;
--timeout=*)
TIMEOUT="${1#*=}"
shift 1
;;
--)
shift
break
;;
--help)
usage 0
;;
*)
echoerr "Unknown argument: $1"
usage 1
;;
esac
done
if [ "${#ADDRS[@]}" -eq 0 ]; then
echoerr "Error: you need to provide a host and port to test."
usage 2
fi
wait_for "$@"
总结
这是Gin框架的Blog小demo,基本实现了登录注册,社区,发帖,投票功能,一些拓展的功能没有实现,但是其核心对于入门Gin来说,是不错的练习。