4

Gorm源码学习-创建行记录 - Amos01

 1 year ago
source link: https://www.cnblogs.com/amos01/p/16913681.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Gorm源码学习系列

此文是Gorm源码学习系列的第二篇,主要梳理下通过Gorm创建表的流程。

2. 创建行记录代码示例

gorm提供了以下几个接口来创建行记录

  • 一次创建一行 func (db *DB) Create(value interface{}) (tx *DB)
  • 批量创建 func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB)
  • 数据库不存在主键时创建,存在时更新 func (db *DB) Save(value interface{}) (tx *DB)

详细请看教程及源码finisher_api.go,这里使用func (db *DB) Create(value interface{}) (tx *DB)来说明创建行记录等大致流程。

2.1 声明模型

type Stu struct { ID int64 `gorm:"column:id; primary_key" json:"id"` Age int64 `gorm:"column:age;"` Height int64 `gorm:"column:height;"` Weight int64 `gorm:"column:weight;"`} // 设置表名func (Stu) TableName() string { return "t_student"}

模型代码的主要用途如下,

  • 申明的表中有哪些列及每列的名称、特性等,如gorm标签指定每个字断对于的表的列名
  • 通过实现Tabler接口指定了固定的表名,接口定义如下
type Tabler interface { TableName() string}

关于模型定义中更多的约定和约束等,请看教程

出于分表等业务场景,我们并不希望固定模型等表名,gorm提供了func (db *DB) Table(name string, args ...interface{}) (tx *DB)等方法

来动态指定表名,详情请看教程

2.2 创建行

func main() { // 数据库连接, 具体查看https://www.cnblogs.com/amos01/p/16890747.html 连接数据库代码示例 db, _ := dbOpen() // 打开调试模式、会打印DML db = db.Debug() stu := &Stu{ Age: 18, Height: 185, Weight: 70, } db = db.Create(stu) fmt.Printf("Error:%v ID:%v RowsAffected:%v\n", db.Error, stu.ID, db.RowsAffected)}

代码输出如下

$ go run main.go2022/12/11 14:59:59 /Users/zbw/workspace/test/main.go:33[1.910ms] [rows:1] INSERT INTO `t_student` (`age`,`height`,`weight`) VALUES (18,185,70)Error:<nil> ID:1027 RowsAffected:1

从代码输出可以看,行记录的ID为1027,连接数据库查询,结果如下。

mysql> select * from t_student where id = 1027\G*************************** 1. row *************************** id: 1027 age: 18height: 185weight: 701 row in set (0.01 sec)

因此,我们带着以下问题来梳理下Gorm创建行记录的流程

  • 如何从model到DML语句的
  • 如何将ID写入到model的 

3. 从Model到DML

func (db *DB) Create(value interface{}) (tx *DB)的实现如下

// Create inserts value, returning the inserted data's primary key in value's idfunc (db *DB) Create(value interface{}) (tx *DB) { if db.CreateBatchSize > 0 { return db.CreateInBatches(value, db.CreateBatchSize) } tx = db.getInstance() tx.Statement.Dest = value return tx.callbacks.Create().Execute(tx)}

func (p *processor) Execute(db *DB) *DB 的实现比较长,具体代码见github

总结下来,做了两件主要的事情,

  • 解析model获取表名、每列的定义等
  • 执行钩子函数以及创建行函数

3.1 数据结构理解

  • gorm.Statement

查看gorm.Statement代码

// Statement statementtype Statement struct { *DB TableExpr *clause.Expr Table string // 表名 Model interface{} // model定义 Unscoped bool Dest interface{} // model的另外一种表达,如map ReflectValue reflect.Value Clauses map[string]clause.Clause BuildClauses []string Distinct bool Selects []string // selected columns Omits []string // omit columns Joins []join Preloads map[string][]interface{} Settings sync.Map ConnPool ConnPool // 数据库连接 Schema *schema.Schema // 表结构化信息 Context context.Context RaiseErrorOnNotFound bool SkipHooks bool SQL strings.Builder // 最终的DML语句 Vars []interface{} // DML语句的参数值 CurDestIndex int // 批量创建/更新时,gorm当前操作的数组/slice的下标 attrs []interface{} assigns []interface{} scopes []func(*DB) *DB}
  • schema.Schem

查看schema.Schema代码

type Schema struct { Name string ModelType reflect.Type Table string // 表名 PrioritizedPrimaryField *Field DBNames []string // 表每列的名字 PrimaryFields []*Field PrimaryFieldDBNames []string // 表的主键列明 Fields []*Field // gorm自定义的model每个字短 FieldsByName map[string]*Field FieldsByDBName map[string]*Field FieldsWithDefaultDBValue []*Field // fields with default value assigned by database Relationships Relationships CreateClauses []clause.Interface // 创建行的子句 QueryClauses []clause.Interface UpdateClauses []clause.Interface DeleteClauses []clause.Interface BeforeCreate, AfterCreate bool BeforeUpdate, AfterUpdate bool BeforeDelete, AfterDelete bool BeforeSave, AfterSave bool AfterFind bool err error initialized chan struct{} namer Namer cacheStore *sync.Map}
  • schema.Field

查看schema.Field代码

// Field is the representation of model schema's fieldtype Field struct { Name string // model的字段名 DBName string // 对应表的列名 BindNames []string DataType DataType GORMDataType DataType PrimaryKey bool AutoIncrement bool AutoIncrementIncrement int64 Creatable bool Updatable bool Readable bool AutoCreateTime TimeType AutoUpdateTime TimeType HasDefaultValue bool DefaultValue string DefaultValueInterface interface{} NotNull bool Unique bool Comment string Size int Precision int Scale int IgnoreMigration bool FieldType reflect.Type // 反射类型 IndirectFieldType reflect.Type // 反射类型 StructField reflect.StructField // model字段信息 Tag reflect.StructTag // tag TagSettings map[string]string Schema *Schema EmbeddedSchema *Schema OwnerSchema *Schema ReflectValueOf func(context.Context, reflect.Value) reflect.Value // 通过反射获取该字段的反射对象 ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool) // 通过反射获取该字段的值 get方法 Set func(context.Context, reflect.Value, interface{}) error // 通过反射设置该字段的值 set方法 Serializer SerializerInterface NewValuePool FieldNewValuePool}
  • clause.Interfaceclause.Clause

gorm定义了多种clause,包括

查看clause.Interface代码

// Interface clause interfacetype Interface interface { Name() string Build(Builder) MergeClause(*Clause)}

查看clause.Clause代码

// Clausetype Clause struct { Name string // WHERE BeforeExpression Expression AfterNameExpression Expression AfterExpression Expression Expression Expression Builder ClauseBuilder}

3.2 解析Model

通过调用stmt.Parse(stmt.Model)进行model解析

stmt.Parse(stmt.Model)会调用到函数func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) 进行解析。

详细代码见schema.go,下面列举重要的几个点。

  • 通过反射判断dest interface{}是否为reflect.Struct
  • 通过接口获取表名,其中stu实现了Tabler接口
// 获取表名 modelValue := reflect.New(modelType) tableName := namer.TableName(modelType.Name()) if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { tableName = tabler.TableName(namer) } if en, ok := namer.(embeddedNamer); ok { tableName = en.Table } if specialTableName != "" && specialTableName != tableName { tableName = specialTableName }
  • 解析model每个字段
// 通过反射获取每个字段for i := 0; i < modelType.NumField(); i++ { if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { // 解析每个字段 if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil { schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...) } else { schema.Fields = append(schema.Fields, field) } }}
  • 放到map方便查找,并且通过func (field *Field) setupValuerAndSetter()初始化每个Field的ReflectValueOfValueOfSet方法。
    for _, field := range schema.Fields { if field.DBName == "" && field.DataType != "" { field.DBName = namer.ColumnName(schema.Table, field.Name) } if field.DBName != "" { // nonexistence or shortest path or first appear prioritized if has permission if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { if _, ok := schema.FieldsByDBName[field.DBName]; !ok { schema.DBNames = append(schema.DBNames, field.DBName) } // gorm tag字段到field的映射 schema.FieldsByDBName[field.DBName] = field // model 字段到field的映射 schema.FieldsByName[field.Name] = field if v != nil && v.PrimaryKey { for idx, f := range schema.PrimaryFields { if f == v { schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) } } } // 主键 if field.PrimaryKey { schema.PrimaryFields = append(schema.PrimaryFields, field) } } } if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { schema.FieldsByName[field.Name] = field } // 挂载字段的set方法和get方法 field.setupValuerAndSetter()}

值得一提的是,每个model解析后的结果是一致,可以将结果解析的结构缓存下来,并且通过chan来解决并发的问题。

解析model之后,通过process获取到钩子函数及创建行的函数,具体代码见Github

for _, f := range p.fns { f(db) }

3.3 执行钩子函数及创建行的函数

创建行的函数及对应的钩子函数位于create.go

  • 创建行记录
if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.Build(db.Statement.BuildClauses...)}

 这里插入两个clause.Clause,分别为clause.Insert以及clause.Values,然后调用这两种clause.Clausebuild方法生成SQL语句。

首先,看下ConvertToCreateValues的实现,这里只截取部分代码

values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}// 获取每一列的名字for _, db := range stmt.Schema.DBNames { if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) { values.Columns = append(values.Columns, clause.Column{Name: db}) } }} // 获取每一列对应的值switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: case reflect.Struct: values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] // func (field *Field) setupValuerAndSetter() 挂载的方法 if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { if field.DefaultValueInterface != nil { values.Values[0][idx] = field.DefaultValueInterface stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface)) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } else if field.AutoUpdateTime > 0 && updateTrackTime { stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } }

通过ConvertToCreateValues获取了每一列的名称及对应的值。

接下来,看clause.ClauseSQL语句的过程。

遍历加入clause,此时分别为clause.Insert以及clause.Values

// Build build sql with clauses namesfunc (stmt *Statement) Build(clauses ...string) { var firstClauseWritten bool for _, name := range clauses { if c, ok := stmt.Clauses[name]; ok { // 代码有删减 c.Build(stmt) } }}

接着调用func (c Clause) Build(builder Builder)

// Build build clause func (c Clause) Build(builder Builder) { // 有删减 // c为clause.Insert以及clause.Values if c.Name != "" { // builder写入 INSERT 或者 VALUES builder.WriteString(c.Name) builder.WriteByte(' ') } // 通过clause.Insert以及clause.Values的MergeClause函数,c.Expression为clause.Insert以及clause.Values // 因此,这里调用clause.Insert或者clause.Values的Build的方法 c.Expression.Build(builder)}

接下来分别看clause.Insert以及clause.Values

// Build build insert clausefunc (insert Insert) Build(builder Builder) { // builder写入INTO,此时builder为INSERT INTO builder.WriteString("INTO ") // builder写入表名 builder.WriteQuoted(currentTable)}

从调用的链路可以得出,这里builderstmt *Statement,并且currentTable类型为clause.Table,因此

// WriteQuoted write quoted valuefunc (stmt *Statement) WriteQuoted(value interface{}) { stmt.QuoteTo(&stmt.SQL, value)} // QuoteTo write quoted value to writer 有删减func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { write := func(raw bool, str string) { // mysql驱动Dialector stmt.DB.Dialector.QuoteTo(writer, str) } switch v := field.(type) { case clause.Table: write(v.Raw, stmt.Table) }}

至此,builder已经拼装出INSERT INTO `t_student` ,解析来再看clause.Valuesbuild方法

// Build build from clausefunc (values Values) Build(builder Builder) { if len(values.Columns) > 0 { builder.WriteByte('(') for idx, column := range values.Columns { if idx > 0 { builder.WriteByte(',') } builder.WriteQuoted(column) } builder.WriteByte(')') builder.WriteString(" VALUES ") for idx, value := range values.Values { if idx > 0 { builder.WriteByte(',') } builder.WriteByte('(') builder.AddVar(builder, value...) builder.WriteByte(')') } } else { builder.WriteString("DEFAULT VALUES") }}

func (values Values) Build(builder Builder)取出所有列名和列对应的值

最终builder拼装成例子的完整SQL语句INSERT INTO `t_student` (`age`,`height`,`weight`) VALUES (18,185,70)

有了SQL语句,就可以执行了

result, err := db.Statement.ConnPool.ExecContext( db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,)

通过前一面学习,db.Statement.ConnPool的值为sql.DB,实际执行的函数为func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error)

至此,从Model到DML到流程已经完成。

4. 将ID写入到model的 

看返回参数sql.Result,因此通过LastInsertId() (int64, error)可以获取到插入行的ID值。

// A Result summarizes an executed SQL command.type Result interface { // LastInsertId returns the integer generated by the database // in response to a command. Typically this will be from an // "auto increment" column when inserting a new row. Not all // databases support this feature, and the syntax of such // statements varies. LastInsertId() (int64, error) // RowsAffected returns the number of rows affected by an // update, insert, or delete. Not every database or database // driver may support this. RowsAffected() (int64, error)}

获取到刚插入的行ID值,再通过反射写入model的ID字段即可。

db.RowsAffected, _ = result.RowsAffected()if db.RowsAffected != 0 && db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { insertID, err := result.LastInsertId() switch db.Statement.ReflectValue.Kind() { case reflect.Struct: _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) if isZero { // 通过反射更新ID db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) } }}

使用反射解析Model,获得每个成员对应的表的列名、值等信息。

定义SQL各个关键词如INSERTVALUESFROMDELETE的结构体,并实现clause.Interface接口

进而对SQL语句的构造进行抽象封装。


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK