gorm 介绍:
基于golang的orm框架,相关文档:http://doc.gorm.io/
demo:
package main
import (
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/sqlite"
)
type Product struct {
gorm.Model
Code string
Price uint
}
func main() {
db,err := gorm.Open("sqlite3","test.db")
if err != nil {
panic("Failed to connect database")
}
defer db.Close()
// Migrate the schema
db.AutoMigrate(&Product{})
// Create
db.Create(&Product{Code: "L1212",Price: 1000})
// Read
var product Product
db.First(&product, 1) // find product with id 1
db.First(&product,"code = ?","L1212") // find product with code l1212
// Update - update product's price to 2000
db.Model(&product).Update("Price", 2000)
// Delete - delete product
db.Delete(&product)
}
源码分析:
现在我们从main函数开始分析:
open实现了sql driver的注册,并且将原生sql操作封装为db对象
核心结构体:
db相关信息,每次绑定不同的value,操作对象例如product{}
type DB struct {
Value interface{}
Error error
RowsAffected int64
// single db
db sqlCommon //原生db.sql对象,包含query相关的原生方法
blockGlobalUpdate bool
logMode int
logger logger
search *search //保存搜索的条件where,limit,group,比如调用db.clone()时,会指定search
values map[string]interface{}
// global db
parent *DB
callbacks *Callback //当前sql绑定的函数调用链
dialect Dialect //不同数据库适配注册sql.db
singularTable bool
}
// 保存当前sql执行相关信息
type Scope struct {
Search *search // 检索条件
Value interface{}
sql string //sql
sqlVars []interface{}
db *DB //sql.db
instanceID string
primaryKeyField *Field
skipLeft bool
fields *[]*Field //字段
selectAttrs *[]string
}
// 保存各种操作需要执行的调用链,例如create函数,需要调用creates数组中所有的函数
type Callback struct {
creates []*func(scope *Scope)
updates []*func(scope *Scope)
deletes []*func(scope *Scope)
queries []*func(scope *Scope)
rowQueries []*func(scope *Scope)
processors []*CallbackProcessor
}
注册db:
func Open(dialect string,args ...interface{}) (db *DB,err error) {
// 检查传参
var source string
var dbsql sqlCommon
switch value := args[0].(type) {
case string:
var driver = dialect
if len(args) == 1 {
source = value
} else if len(args) >= 2 {
driver = value
source = args[1].(string)
}
dbsql,err = sql.Open(driver,source)
case sqlCommon:
dbsql = value
}
db = &DB{
db: dbsql,logger: defaultLogger,//logger
values: map[string]interface{}{},callbacks: DefaultCallback,dialect: newDialect(dialect,dbsql),}
db.parent = db
// 检验是否能ping
}
接下来看下create函数,demo中的create函数往product中增加记录
func (s *DB) Create(value interface{}) *DB {
scope := s.NewScope(value) // value是product对象
//回调creates数组中的所有函数,然后将db返回
return scope.callCallbacks(s.parent.callbacks.creates).db
}
// NewScope create a scope for current operation
func (s *DB) NewScope(value interface{}) *Scope {
dbClone := s.clone() // 从当前db对象s拷贝出新db
dbClone.Value = value // 绑定value为product对象
return &Scope{db: dbClone,Search: dbClone.search.clone(),Value: value} //返回新的scope对象
}
那creates数组中的函数应该有哪些,回看open函数中并没有直接定义,那么还有一种初始化方法,就是包内init方法,我们可以看下gorm包中底下的所有的.go文件。果然在callback_create.go文件中,定义了init方法,里面对一些回调函数进行append:
package gorm
// Define callbacks for creating
func init() {
DefaultCallback.Create().Register("gorm:begin_transaction",beginTransactionCallback)
DefaultCallback.Create().Register("gorm:before_create",beforeCreateCallback)
DefaultCallback.Create().Register("gorm:save_before_associations",saveBeforeAssociationsCallback)
DefaultCallback.Create().Register("gorm:update_time_stamp",updateTimeStampForCreateCallback)
DefaultCallback.Create().Register("gorm:create",createCallback)
DefaultCallback.Create().Register("gorm:force_reload_after_create",forceReloadAfterCreateCallback)
DefaultCallback.Create().Register("gorm:save_after_associations",saveAfterAssociationsCallback)
DefaultCallback.Create().Register("gorm:after_create",afterCreateCallback)
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction",commitOrRollbackTransactionCallback)
}
我们看看Callback.Create().Register()里面的逻辑:
type CallbackProcessor struct {
name string // current callback's name
before string // register current callback before a callback
after string // register current callback after a callback
replace bool // replace callbacks with same name
remove bool // delete callbacks with same name
kind string // callback type: create,update,delete,query,row_query
processor *func(scope *Scope) // callback handler
parent *Callback
}
// 定义了kind为create
func (c *Callback) Create() *CallbackProcessor {
return &CallbackProcessor{kind: "create",parent: c}
}
// Register a new callback,refer `Callbacks.Create`
func (cp *CallbackProcessor) Register(callbackName string,callback func(scope *Scope)) {
// 如果为查询类型
if cp.kind == "row_query" {
if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" {
log.Printf("Registing RowQuery callback %v without specify order with Before(),After(),applying Before('gorm:row_query') by default for compatibility...\n",callbackName)
cp.before = "gorm:row_query"
}
}
// 设置callbackName,callback函数
cp.name = callbackName
cp.processor = &callback
cp.parent.processors = append(cp.parent.processors,cp)
// append 到create数组
cp.parent.reorder()
}
// reorder all registered processors,and reset CRUD callbacks
func (c *Callback) reorder() {
var creates,updates,deletes,queries,rowQueries []*CallbackProcessor
for _,processor := range c.processors {
if processor.name != "" {
switch processor.kind {
case "create":
// append 至creates数组
creates = append(creates,processor)
case "update":
updates = append(updates,processor)
case "delete":
deletes = append(deletes,processor)
case "query":
queries = append(queries,processor)
case "row_query":
rowQueries = append(rowQueries,processor)
}
}
}
c.creates = sortProcessors(creates)
c.updates = sortProcessors(updates)
c.deletes = sortProcessors(deletes)
c.queries = sortProcessors(queries)
c.rowQueries = sortProcessors(rowQueries)
}
creates初始化已经大体了解了,那看回数组本身,它拥有以下方法:
DefaultCallback.Create().Register("gorm:begin_transaction",commitOrRollbackTransactionCallback)
开始事务,create前检查,更新时间戳,create后检查,事务提交或回滚等等。
事务代码:
func beginTransactionCallback(scope *Scope) {
scope.Begin()
}
func commitOrRollbackTransactionCallback(scope *Scope) {
scope.CommitOrRollback()
}
// Begin start a transaction
func (scope *Scope) Begin() *Scope {
// 使用scope中的sqldb进行操作
if db,ok := scope.sqlDB().(sqlDb); ok {
if tx,err := db.Begin(); err == nil {
scope.db.db = interface{}(tx).(sqlCommon)
scope.InstanceSet("gorm:started_transaction",true)
}
}
return scope
}
// CommitOrRollback commit current transaction if no error happened,otherwise will rollback it
func (scope *Scope) CommitOrRollback() *Scope {
if _,ok := scope.InstanceGet("gorm:started_transaction"); ok {
if db,ok := scope.db.db.(sqlTx); ok {
if scope.HasError() {
db.Rollback()
} else {
scope.Err(db.Commit())
}
scope.db.db = scope.db.parent.db
}
}
return scope
}
创建代码:
func createCallback(scope *Scope) {
if !scope.HasError() {
defer scope.trace(NowFunc())
var (
columns,placeholders []string
blankColumnsWithDefaultValue []string
)
for _,field := range scope.Fields() {
if scope.changeableField(field) {
if field.IsNormal {
if field.IsBlank && field.HasDefaultValue {
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue,scope.Quote(field.DBName))
scope.InstanceSet("gorm:blank_columns_with_default_value",blankColumnsWithDefaultValue)
} else if !field.IsPrimaryKey || !field.IsBlank {
columns = append(columns,scope.Quote(field.DBName))
placeholders = append(placeholders,scope.AddToVars(field.Field.Interface()))
}
} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
for _,foreignKey := range field.Relationship.ForeignDBNames {
if foreignField,ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
columns = append(columns,scope.Quote(foreignField.DBName))
placeholders = append(placeholders,scope.AddToVars(foreignField.Field.Interface()))
}
}
}
}
}
var (
returningColumn = "*"
quotedTableName = scope.QuotedTableName()
primaryField = scope.PrimaryField()
extraOption string
)
if str,ok := scope.Get("gorm:insert_option"); ok {
extraOption = fmt.Sprint(str)
}
if primaryField != nil {
returningColumn = scope.Quote(primaryField.DBName)
}
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName,returningColumn)
// 拼接执行sql
if len(columns) == 0 {
scope.Raw(fmt.Sprintf(
"INSERT INTO %v %v%v%v",quotedTableName,scope.Dialect().DefaultValueStr(),addExtraSpaceIfExist(extraOption),addExtraSpaceIfExist(lastInsertIDReturningSuffix),))
} else {
scope.Raw(fmt.Sprintf(
"INSERT INTO %v (%v) VALUES (%v)%v%v",scope.QuotedTableName(),strings.Join(columns,","),strings.Join(placeholders,))
}
// execute create sql
if lastInsertIDReturningSuffix == "" || primaryField == nil {
if result,err := scope.sqlDB().Exec(scope.sql,scope.sqlVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected,_ = result.RowsAffected()
// set primary value to primary field
if primaryField != nil && primaryField.IsBlank {
if primaryValue,err := result.LastInsertId(); scope.Err(err) == nil {
scope.Err(primaryField.Set(primaryValue))
}
}
}
} else {
if primaryField.Field.CanAddr() {
if err := scope.sqlDB().QueryRow(scope.sql,scope.sqlVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
scope.db.RowsAffected = 1
}
} else {
scope.Err(ErrUnaddressable)
}
}
}
}
简单crub操作的调用流程分析如上。