Bootstrap

gorm的upsert操作不同字段

场景:

“INSERT INTO … ON DUPLICATE KEY UPDATE”的应用,在 UPDATE 时不能更新字段 f_create_uid 和 f_create_time 的值,而必须更新 f_update_uid 和 f_update_time 的值。关键点在于指定 UPDATE 不更新的字段列表,实现依赖 gorm 的 tag,但如果 struct 的 field 名同表的 field 名,这没有此依赖。

示例表:

DROP TABLE IF EXISTS TableA;
CREATE TABLE TableA (
    f_id INT UNSIGNED NOT NULL AUTO_INCREMENT,
    f_name VARCHAR(25) NOT NULL,
    f_address VARCHAR(100) NOT NULL,
    f_create_uid VARCHAR(20) NOT NULL,
    f_update_uid VARCHAR(20) NOT NULL,
    f_create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
    f_update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
    PRIMARY KEY (f_id),
    UNIQUE KEY (f_name)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4;

使用 (sql2struct)[https://github.com/eyjian/sql2struct] 工具生成的 struct:

type TableA struct {
    Id uint32 `gorm:"column:f_id;primaryKey;autoIncrement" json:"id" db:"f_id" form:"id"`
    Name string `gorm:"column:f_name" json:"name" db:"f_name" form:"name"`
    Address string `gorm:"column:f_address" json:"address" db:"f_address" form:"address"`
    CreateUid string `gorm:"column:f_create_uid" json:"create_uid" db:"f_create_uid" form:"create_uid"`
    UpdateUid string `gorm:"column:f_update_uid" json:"update_uid" db:"f_update_uid" form:"update_uid"`
    CreateTime time.Time `gorm:"column:f_create_time" json:"create_time" db:"f_create_time" form:"create_time"`
    UpdateTime time.Time `gorm:"column:f_update_time" json:"update_time" db:"f_update_time" form:"update_time"`
}

表的“INSERT INTO … ON DUPLICATE KEY UPDATE”操作:

package main

import (
    "flag"
    "fmt"
    "gorm.io/driver/mysql"
    "gorm.io/gorm"
    "gorm.io/gorm/clause"
    "os"
    "reflect"
    "strings"
    "time"
)

/*
DROP TABLE IF EXISTS TableA;
CREATE TABLE TableA (
    f_id INT UNSIGNED NOT NULL AUTO_INCREMENT,
    f_name VARCHAR(25) NOT NULL,
    f_address VARCHAR(100) NOT NULL,
    f_create_uid VARCHAR(20) NOT NULL,
    f_update_uid VARCHAR(20) NOT NULL,
    f_create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
    f_update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
    PRIMARY KEY (f_id),
    UNIQUE KEY (f_name)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4;
*/

// TableA 本例依赖 gorm 的 tag
type TableA struct {
    Id         uint32    `gorm:"column:f_id;primaryKey;autoIncrement" json:"id" db:"f_id" form:"id"`
    Name       string    `gorm:"column:f_name" json:"name" db:"f_name" form:"name"`
    Address    string    `gorm:"column:f_address" json:"address" db:"f_address" form:"address"`
    CreateUid  string    `gorm:"column:f_create_uid" json:"create_uid" db:"f_create_uid" form:"create_uid"`
    UpdateUid  string    `gorm:"column:f_update_uid" json:"update_uid" db:"f_update_uid" form:"update_uid"`
    CreateTime time.Time `gorm:"column:f_create_time" json:"create_time" db:"f_create_time" form:"create_time"`
    UpdateTime time.Time `gorm:"column:f_update_time" json:"update_time" db:"f_update_time" form:"update_time"`
}

func (t *TableA) TableName() string {
    return "TableA"
}

var (
    dsn = flag.String("dsn", "", "dbuser:dbpassword@tcp(dbhost:dbport)/dbname?charset=utf8mb4&parseTime=True&loc=Local")
)

func main() {
    // 命令行参数解析
    flag.Parse()

    // 命令行参数检查
    if *dsn == "" {
        fmt.Printf("parameter[-dsn] is not set\n")
        flag.Usage()
        os.Exit(1)
    }

    // 连接数据库
    tx, err := gorm.Open(mysql.Open(*dsn), &gorm.Config{})
    if err != nil {
        fmt.Printf("open mysql error: %s\n", err.Error())
        os.Exit(1)
    }
    tx = tx.Debug() //.Table("TableA") // 这里不能有 Table 调用,否则会导致 Omit 影响后面的 Find,导致 Find 不是全部的字段,即不是”SELECT * FROM“

    // 获取 DsFundBatchStudent 结构体的所有字段
    typeOfModel := reflect.TypeOf((*TableA)(nil)).Elem()
    numFields := typeOfModel.NumField()           // 取得总的字段数
    updateColumns := make([]string, 0, numFields) // 存放所有需更新的字段

    // 遍历所有字段,排除不需要更新的字段
    for i := 0; i < numFields; i++ {
        field := typeOfModel.Field(i)
        fieldName := field.Name
        if fieldName != "Id" && fieldName != "CreateUid" && fieldName != "CreateTime" && fieldName != "UpdateTime" {
            // 使用结构体标签(tag)获取字段的列名
            columnName := field.Tag.Get("gorm")
            if strings.HasPrefix(columnName, "column:") {
                columnName = strings.TrimPrefix(columnName, "column:")
            }
            updateColumns = append(updateColumns, columnName)
        }
    }
    fmt.Printf("updateColumns=>%+v\n", updateColumns)

    aa10 := []*TableA{
        &TableA{
            Name:      "zhangsan",
            Address:   "shenzhen",
            CreateUid: "lisi",
            UpdateUid: "lisi",
        },
        &TableA{
            Name:      "wangwu",
            Address:   "nanjing",
            CreateUid: "lisi",
            UpdateUid: "lisi",
        },
    }
    // 插入时忽略的字段
    db := tx
    db = db.Omit("f_id", "f_update_time", "f_create_time")
    db = db.Clauses(clause.OnConflict{
        //Columns: []clause.Column{{Name: "f_id"}}, // 指定冲突检查的列
        DoUpdates: clause.AssignmentColumns(updateColumns), // 指定更新的列
    })
    err = db.Create(aa10).Error
    if err != nil {
        fmt.Printf("create aa1 error: %s\n", err.Error())
        os.Exit(1)
    }

    var a11 []*TableA
    db = tx // .Table("TableA") // 这里的 Table 调用可有可无
    err = db.Order("f_name").Find(&a11).Error
    if err != nil {
        fmt.Printf("find aa11 error: %s\n", err.Error())
        os.Exit(1)
    }
    for _, a110 := range a11 {
        fmt.Printf("%+v\n", *a110)
    }

    aa20 := []*TableA{
        &TableA{
            Name:      "zhangsan",
            Address:   "hangzhou",
            CreateUid: "xiaoming",
            UpdateUid: "xiaoming",
        },
        &TableA{
            Name:      "zhouba",
            Address:   "guangzhou",
            CreateUid: "xiaoming",
            UpdateUid: "xiaoming",
        },
    }
    // 插入时忽略的字段
    db = tx
    db = db.Omit("f_id", "f_update_time", "f_create_time")
    db = db.Clauses(clause.OnConflict{
        //Columns: []clause.Column{{Name: "f_id"}}, // 指定冲突检查的列
        DoUpdates: clause.AssignmentColumns(updateColumns), // 指定更新的列
    })
    err = db.Create(aa20).Error
    if err != nil {
        fmt.Printf("create aa2 error: %s\n", err.Error())
        os.Exit(1)
    }

    var a21 []*TableA
    db = tx //.Table("TableA") // 这里的 Table 调用可有可无
    err = db.Order("f_name").Find(&a21).Error
    if err != nil {
        fmt.Printf("find a21 error: %s\n", err.Error())
        os.Exit(1)
    }
    for _, a210 := range a11 {
        fmt.Printf("%+v\n", *a210)
    }

    // 当本程序反复执行时,由于只有 zhangsan 的 UpdateUid 的交替变化,
    // 故只会观察到 zhangsan 一行的 f_update_time 在一直变化,
    // wangwu 和 zhouba 执行一次后保持不再变化,因为没有任何字段值发生变化。
}
;