gorm 简单调用源码分析

前端之家收集整理的这篇文章主要介绍了gorm 简单调用源码分析前端之家小编觉得挺不错的,现在分享给大家,也给大家做个参考。

gorm 介绍:

基于golang的orm框架,相关文档:http://doc.gorm.io/

demo:

package main

import (
    "github.com/jinzhu/gorm"
    _ "github.com/jinzhu/gorm/dialects/sqlite"
)

type Product struct {
  gorm.Model
  Code string
  Price uint
}

func main() {
  db,err := gorm.Open("sqlite3","test.db")
  if err != nil {
    panic("Failed to connect database")
  }
  defer db.Close()

  // Migrate the schema
  db.AutoMigrate(&Product{})

  // Create
  db.Create(&Product{Code: "L1212",Price: 1000})

  // Read
  var product Product
  db.First(&product, 1) // find product with id 1
  db.First(&product,"code = ?","L1212") // find product with code l1212

  // Update - update product's price to 2000
  db.Model(&product).Update("Price", 2000)

  // Delete - delete product
  db.Delete(&product)
}

源码分析:

现在我们从main函数开始分析:
open实现了sql driver的注册,并且将原生sql操作封装为db对象

核心结构体:

db相关信息,每次绑定不同的value,操作对象例如product{}
type DB struct {
    Value        interface{}
    Error        error
    RowsAffected int64

    // single db
    db                sqlCommon //原生db.sql对象,包含query相关的原生方法
    blockGlobalUpdate bool
    logMode           int
    logger            logger
    search            *search //保存搜索的条件where,limit,group,比如调用db.clone()时,会指定search
    values            map[string]interface{}

    // global db
    parent        *DB
    callbacks     *Callback //当前sql绑定的函数调用
    dialect       Dialect //不同数据库适配注册sql.db
    singularTable bool
}

// 保存当前sql执行相关信息
type Scope struct {
    Search          *search // 检索条件
    Value           interface{}
    sql             string //sql
    sqlVars         []interface{}
    db              *DB //sql.db
    instanceID      string
    primaryKeyField *Field
    skipLeft        bool
    fields          *[]*Field //字段
    selectAttrs     *[]string
}

// 保存各种操作需要执行的调用链,例如create函数,需要调用creates数组中所有的函数
type Callback struct {
    creates    []*func(scope *Scope)
    updates    []*func(scope *Scope)
    deletes    []*func(scope *Scope)
    queries    []*func(scope *Scope)
    rowQueries []*func(scope *Scope)
    processors []*CallbackProcessor
}

注册db:

func Open(dialect string,args ...interface{}) (db *DB,err error) {

    // 检查传参
    var source string
    var dbsql sqlCommon

    switch value := args[0].(type) {
    case string:
        var driver = dialect
        if len(args) == 1 {
            source = value
        } else if len(args) >= 2 {
            driver = value
            source = args[1].(string)
        }
        dbsql,err = sql.Open(driver,source)
    case sqlCommon:
        dbsql = value
    }

    db = &DB{
        db:        dbsql,logger:    defaultLogger,//logger
        values:    map[string]interface{}{},callbacks: DefaultCallback,dialect:   newDialect(dialect,dbsql),}
    db.parent = db
    // 检验是否能ping
}

接下来看下create函数,demo中的create函数往product中增加记录

func (s *DB) Create(value interface{}) *DB {
    scope := s.NewScope(value) // value是product对象 
    //回调creates数组中的所有函数,然后将db返回
    return scope.callCallbacks(s.parent.callbacks.creates).db
}

// NewScope create a scope for current operation
func (s *DB) NewScope(value interface{}) *Scope {
    dbClone := s.clone() // 从当前db对象s拷贝出新db
    dbClone.Value = value // 绑定value为product对象
    return &Scope{db: dbClone,Search: dbClone.search.clone(),Value: value} //返回新的scope对象
}

那creates数组中的函数应该有哪些,回看open函数中并没有直接定义,那么还有一种初始化方法,就是包内init方法,我们可以看下gorm包中底下的所有的.go文件。果然在callback_create.go文件中,定义了init方法,里面对一些回调函数进行append:

package gorm

// Define callbacks for creating
func init() {
    DefaultCallback.Create().Register("gorm:begin_transaction",beginTransactionCallback)
    DefaultCallback.Create().Register("gorm:before_create",beforeCreateCallback)
    DefaultCallback.Create().Register("gorm:save_before_associations",saveBeforeAssociationsCallback)
    DefaultCallback.Create().Register("gorm:update_time_stamp",updateTimeStampForCreateCallback)
    DefaultCallback.Create().Register("gorm:create",createCallback)
    DefaultCallback.Create().Register("gorm:force_reload_after_create",forceReloadAfterCreateCallback)
    DefaultCallback.Create().Register("gorm:save_after_associations",saveAfterAssociationsCallback)
    DefaultCallback.Create().Register("gorm:after_create",afterCreateCallback)
    DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction",commitOrRollbackTransactionCallback)
}

我们看看Callback.Create().Register()里面的逻辑:

type CallbackProcessor struct {
    name      string              // current callback's name
    before    string              // register current callback before a callback
    after     string              // register current callback after a callback
    replace   bool                // replace callbacks with same name
    remove    bool                // delete callbacks with same name
    kind      string              // callback type: create,update,delete,query,row_query
    processor *func(scope *Scope) // callback handler
    parent    *Callback
}

// 定义了kind为create
func (c *Callback) Create() *CallbackProcessor {
    return &CallbackProcessor{kind: "create",parent: c}
}

// Register a new callback,refer `Callbacks.Create`
func (cp *CallbackProcessor) Register(callbackName string,callback func(scope *Scope)) {
    // 如果为查询类型
    if cp.kind == "row_query" {
        if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
            log.Printf("Registing RowQuery callback %v without specify order with Before(),After(),applying Before('gorm:row_query') by default for compatibility...\n",callbackName)
            cp.before = "gorm:row_query"
        }
    }

    // 设置callbackName,callback函数
    cp.name = callbackName
    cp.processor = &callback
    cp.parent.processors = append(cp.parent.processors,cp)

    // append 到create数组
    cp.parent.reorder()
}

// reorder all registered processors,and reset CRUD callbacks
func (c *Callback) reorder() {
    var creates,updates,deletes,queries,rowQueries []*CallbackProcessor

    for _,processor := range c.processors {
        if processor.name != "" {
            switch processor.kind {
            case "create":
                // append 至creates数组
                creates = append(creates,processor)
            case "update":
                updates = append(updates,processor)
            case "delete":
                deletes = append(deletes,processor)
            case "query":
                queries = append(queries,processor)
            case "row_query":
                rowQueries = append(rowQueries,processor)
            }
        }
    }

    c.creates = sortProcessors(creates)
    c.updates = sortProcessors(updates)
    c.deletes = sortProcessors(deletes)
    c.queries = sortProcessors(queries)
    c.rowQueries = sortProcessors(rowQueries)
}

creates初始化已经大体了解了,那看回数组本身,它拥有以下方法

DefaultCallback.Create().Register("gorm:begin_transaction",commitOrRollbackTransactionCallback)

开始事务,create前检查,更新时间戳,create后检查,事务提交或回滚等等。
事务代码

func beginTransactionCallback(scope *Scope) {
    scope.Begin()
}

func commitOrRollbackTransactionCallback(scope *Scope) {
    scope.CommitOrRollback()
}

// Begin start a transaction
func (scope *Scope) Begin() *Scope {
    // 使用scope中的sqldb进行操作
    if db,ok := scope.sqlDB().(sqlDb); ok {
        if tx,err := db.Begin(); err == nil {
            scope.db.db = interface{}(tx).(sqlCommon)
            scope.InstanceSet("gorm:started_transaction",true)
        }
    }
    return scope
}

// CommitOrRollback commit current transaction if no error happened,otherwise will rollback it
func (scope *Scope) CommitOrRollback() *Scope {
    if _,ok := scope.InstanceGet("gorm:started_transaction"); ok {
        if db,ok := scope.db.db.(sqlTx); ok {
            if scope.HasError() {
                db.Rollback()
            } else {
                scope.Err(db.Commit())
            }
            scope.db.db = scope.db.parent.db
        }
    }
    return scope
}

创建代码

func createCallback(scope *Scope) {
    if !scope.HasError() {
        defer scope.trace(NowFunc())

        var (
            columns,placeholders        []string
            blankColumnsWithDefaultValue []string
        )

        for _,field := range scope.Fields() {
            if scope.changeableField(field) {
                if field.IsNormal {
                    if field.IsBlank && field.HasDefaultValue {
                        blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue,scope.Quote(field.DBName))
                        scope.InstanceSet("gorm:blank_columns_with_default_value",blankColumnsWithDefaultValue)
                    } else if !field.IsPrimaryKey || !field.IsBlank {
                        columns = append(columns,scope.Quote(field.DBName))
                        placeholders = append(placeholders,scope.AddToVars(field.Field.Interface()))
                    }
                } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
                    for _,foreignKey := range field.Relationship.ForeignDBNames {
                        if foreignField,ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
                            columns = append(columns,scope.Quote(foreignField.DBName))
                            placeholders = append(placeholders,scope.AddToVars(foreignField.Field.Interface()))
                        }
                    }
                }
            }
        }

        var (
            returningColumn = "*"
            quotedTableName = scope.QuotedTableName()
            primaryField    = scope.PrimaryField()
            extraOption     string
        )

        if str,ok := scope.Get("gorm:insert_option"); ok {
            extraOption = fmt.Sprint(str)
        }

        if primaryField != nil {
            returningColumn = scope.Quote(primaryField.DBName)
        }

        lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName,returningColumn)

        // 拼接执行sql
        if len(columns) == 0 {
            scope.Raw(fmt.Sprintf(
                "INSERT INTO %v %v%v%v",quotedTableName,scope.Dialect().DefaultValueStr(),addExtraSpaceIfExist(extraOption),addExtraSpaceIfExist(lastInsertIDReturningSuffix),))
        } else {
            scope.Raw(fmt.Sprintf(
                "INSERT INTO %v (%v) VALUES (%v)%v%v",scope.QuotedTableName(),strings.Join(columns,","),strings.Join(placeholders,))
        }

        // execute create sql
        if lastInsertIDReturningSuffix == "" || primaryField == nil {
            if result,err := scope.sqlDB().Exec(scope.sql,scope.sqlVars...); scope.Err(err) == nil {
                // set rows affected count
                scope.db.RowsAffected,_ = result.RowsAffected()

                // set primary value to primary field
                if primaryField != nil && primaryField.IsBlank {
                    if primaryValue,err := result.LastInsertId(); scope.Err(err) == nil {
                        scope.Err(primaryField.Set(primaryValue))
                    }
                }
            }
        } else {
            if primaryField.Field.CanAddr() {
                if err := scope.sqlDB().QueryRow(scope.sql,scope.sqlVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
                    primaryField.IsBlank = false
                    scope.db.RowsAffected = 1
                }
            } else {
                scope.Err(ErrUnaddressable)
            }
        }
    }
}

简单crub操作的调用流程分析如上。

猜你在找的Go相关文章