49

利用golang反射实现一个迷你orm

 6 years ago
source link: https://studygolang.com/articles/14803?amp%3Butm_medium=referral
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

github地址

作为静态语言,golang稍显笨拙,还好go的标准包 reflect (反射)包弥补了这点不足,它提供了一系列强大的API,能够根据执行过程中对象的类型来改变程序控制流。本文将通过设计并实现一个简易的mysql orm来学习它,要求读者了解 mysql 基本知识,并且跟我一样至少已经接触golang两到三个月。

orm这个概念相信同学们都非常熟悉,尤其是写过 rails 的同学,对 active_record 的强大肯定深有体会(得益于的 method_missingdefine_method 方法,少写了海量代码),所以对orm我就不过多介绍了。本文要实现的orm只提供基本的 CRUD (增删改查)和 transaction (事务)功能,核心代码控制在300行左右。

如果想手把手照着写,需要先做一些准备工作。

准备工作

在本地mysql里 create database orm_db ,然后再 create 一张 user 表,结构如下:

CREATE TABLE `user` (
  `id` int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT '自增主键',
  `age` smallint(10) unsigned NOT NULL DEFAULT 0 COMMENT '年龄',
  `first_name` varchar(45) NOT NULL DEFAULT '' COMMENT '姓',
  `last_name` varchar(45) NOT NULL DEFAULT '' COMMENT '名',
  `email` varchar(45) NOT NULL DEFAULT '' COMMENT '邮箱地址',
  `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
  `updated_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
  PRIMARY KEY (`id`),
  KEY `idx_email` (`email`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='用户表';

同时,golang代码里定义一个与之对应的 struct

type User struct {
    ID        int64     `json:"id"`         // 自增主键
    Age       int64     `json:"age"`        // 年龄
    FirstName string    `json:"first_name"` // 姓
    LastName  string    `json:"last_name"`  // 名
    Email     string    `json:"email"`      // 邮箱地址
    CreatedAt time.Time `json:"created_at"` // 创建时间
    UpdatedAt time.Time `json:"updated_at"` // 更新时间
}

与mysql交互需要用到一个go标准包和一个 驱动 ,代码 import 如下:

package orm

import (
    "database/sql"
    
    //register driver
    _ "github.com/go-sql-driver/mysql"
)

首先按照 database 维度建立连接,写一个可以返回mysql连接的函数:

//Connect db by dsn e.g. "user:password@tcp(127.0.0.1:3306)/dbname"
func Connect(dsn string) (*sql.DB, error) {
    conn, err := sql.Open("mysql", dsn)
    if err != nil {
        return nil, err
    }
    //设置连接池
    conn.SetMaxOpenConns(100)
    conn.SetMaxIdleConns(10)
    conn.SetConnMaxLifetime(10 * time.Minute)
    return conn, conn.Ping()
}

设计一个 struct 用于实现orm(go不是面向对象的语言,没有 class ):

//Query will build a sql
type Query struct {
    db      *sql.DB
    table   string
}

最后将通过 Query 拼接出sql语句与mysql交互,所以写一个绑定函数:

//Table bind db and table
func Table(db *sql.DB, tableName string) func() *Query {
    return func() *Query {
        return &Query{
            db:    db,
            table: tableName,
        }
    }
}

返回值是一个闭包函数,这样使用时直接调用这个闭包函数就可以获取一个绑定好的database和table的 Query ,比如现在有数据库 orm_dbuser 表:

//全局变量ormDB和users
ormDB, _ := Connect("user:password@tcp(127.0.0.1:3306)/orm_db")
users := Table(ormDB, "user")
//调用
users().Insert(...)

准备工作到此完成,下面进入正题。

Insert方法

首先分析一下标准 insert 语句:

insert into user (first_name, last_name) values ('Tom', 'Cat'), ('Tom', 'Cruise')

把sql语句中变化的部分抽象出来,其实就是 key (字段)和 value (值),那么orm里的 Insert 方法原型就有了,如下,参数是struct或者map,因为它们都能提供键值对:

//Insert in can be *User, []*User, map[string]interface{}
func (q *Query) Insert(in interface{}) (int64, error) {
    var keys, values []string
    v := reflect.ValueOf(in)
    //剥离指针
    for v.Kind() == reflect.Ptr {
        v = v.Elem()
    }
    switch v.Kind() {
    case reflect.Struct:
        keys, values = sKV(v)
    case reflect.Map:
        keys, values = mKV(v)
    case reflect.Slice:
        for i := 0; i < v.Len(); i++ {
            //Kind是切片时,可以用Index()方法遍历
            sv := v.Index(i)
            for sv.Kind() == reflect.Ptr || sv.Kind() == reflect.Interface {
                sv = sv.Elem()
            }
            //切片元素不是struct或者指针,报错
            if sv.Kind() != reflect.Struct {
                return 0, errors.New("method Insert error: in slice is not structs")
            }
            //keys只保存一次就行,因为后面的都一样了
            if len(keys) == 0 {
                keys, values = sKV(sv)
                continue
            }
            _, val := sKV(sv)
            values = append(values, val...)
        }
    default:
        return 0, errors.New("method Insert error: type error")
    }
    //todo
    //...
}

参数 in 可以是一个 User (前文定义好的结构体)实例的指针(或者指针集合),也可以是一个map,这两个结构都可以提供键值对,我们通过反射来分析它的 类型 ,然后根据类型执行相应的逻辑。

reflect包里的有两个重要结构 TypeValue ,Type是一个接口,定义了所有类型相关的api,reflect里的 *rtype 实现了这个接口,通过reflect.TypeOf函数可以获取任何传入值的 *rtype 。Value是一个struct,通过reflect.ValueOf函数获取,它在 *rtype 的基础上又封装了传入值的unsafe.Pointer类型的 地址 以及这个值的 元数据

在Type和Value之上还有一个 Kind ,它代表传入值的 原始类型 ,比如:

type myInt int
var i myInt
t := reflect.TypeOf(i)
k := t.Kind()

t是myInt,而k是int,Type和Kind是不同的,这一点要注意区分。

如果Type的Kind是指针、接口、切片、map等复合类型,可以调用Elem()方法获取基类型。

如果Value的Kind是指针、接口,可以调用Elem()方法获取实际值。

Value上还定义了一个 Interface() 方法,它是ValueOf()方法的反操作。

有了上面这些反射方法,我们可以封装一个 sKV() 函数,它专门处理struct类型的值,获取key(取json tag)和value:

func sKV(v reflect.Value) ([]string, []string) {
    var keys, values []string
    t := v.Type()
    for n := 0; n < t.NumField(); n++ {
        tf := t.Field(n)
        vf := v.Field(n)
        //忽略非导出字段
        if tf.Anonymous {
            continue
        }
        //忽略无效、零值字段
        if !vf.IsValid() || reflect.DeepEqual(vf.Interface(), reflect.Zero(vf.Type()).Interface()) {
            continue
        }
        for vf.Type().Kind() == reflect.Ptr {
            vf = vf.Elem()
        }
        //有时候根据需求会组合struct,这里处理下,支持获取嵌套的struct tag和value
        //如果字段值是time类型之外的struct,递归获取keys和values
        if vf.Kind() == reflect.Struct && tf.Type.Name() != "Time" {
            cKeys, cValues := sKV(vf)
            keys = append(keys, cKeys...)
            values = append(values, cValues...)
            continue
        }
        //根据字段的json tag获取key,忽略无tag字段
        key := strings.Split(tf.Tag.Get("json"), ",")[0]
        if key == "" {
            continue
        }
        value := format(vf)
        if value != "" {
            keys = append(keys, key)
            values = append(values, value)
        }
    }
    return keys, values
}

sKV() 函数里需要格式化字符串,那么定义一个 format() 函数。

time.Time 类型怎么转化成各种数据库的时间类型我有点拿不准,所以需要对比时间类型的值时,一律用unxi时间戳,感觉比较省事不会出错:

func format(v reflect.Value) string {
    //断言出time类型直接转unix时间戳
    if t, ok := v.Interface().(time.Time); ok {
        return fmt.Sprintf("FROM_UNIXTIME(%d)", t.Unix())
    }
    switch v.Kind() {
    case reflect.String:
        return fmt.Sprintf(`'%s'`, v.Interface())
    case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
        return fmt.Sprintf(`%d`, v.Interface())
    case reflect.Float32, reflect.Float64:
        return fmt.Sprintf(`%f`, v.Interface())
    //如果是切片类型,遍历元素,递归格式化成"(, , , )"形式
    case reflect.Slice:
        var values []string
        for i := 0; i < v.Len(); i++ {
            values = append(values, format(v.Index(i)))
        }
        return fmt.Sprintf(`(%s)`, strings.Join(values, ","))
    //接口类型剥一层递归
    case reflect.Interface:
        return format(v.Elem())
    }
    return ""
}

map类型处理起来和struct不同,所以我们再定义一个 mKV() 函数,目的和sKV()一样,都是获取键值对:

func mKV(v reflect.Value) ([]string, []string) {
    var keys, values []string
    //获取map的key组成的切片
    mapKeys := v.MapKeys()
    for _, key := range mapKeys {
        value := format(v.MapIndex(key))
        if value != "" {
            values = append(values, value)
            keys = append(keys, key.Interface().(string))
        }
    }
    return keys, values
}

利用sKV()和mKV()函数取到键值对后,就得到了insert语句中的变化部分,补全Insert()方法的 todo 部分:

//Insert in can be User, *User, []User, []*User, map[string]interface{}
func (q *Query) Insert(in interface{}) (int64, error) {
    //already done
    kl := len(keys)
    vl := len(values)
    if kl == 0 || vl == 0 {
        return 0, errors.New("method Insert error: no data")
    }
    var insertValue string
    //插入多条记录时需要用","拼接一下values
    if kl < vl {
        var tmpValues []string
        for kl <= vl {
            if kl%(len(keys)) == 0 {
                tmpValues = append(tmpValues, fmt.Sprintf("(%s)", strings.Join(values[kl-len(keys):kl], ",")))
            }
            kl++
        }
        insertValue = strings.Join(tmpValues, ",")
    } else {
        insertValue = fmt.Sprintf("(%s)", strings.Join(values, ","))
    }
    query := fmt.Sprintf(`insert into %s (%s) values %s`, q.table, strings.Join(keys, ","), insertValue)
    log.Printf("insert sql: %s", query)
    st, err := q.DB.Prepare(query)
    if err != nil {
        return 0, err
    }
    result, err := st.Exec()
    if err != nil {
        return 0, err
    }
    return result.LastInsertId()
}

原理很简单,利用反射分析参数,取键值对,然后拼接sql语句,再通过mysql驱动入库。

调用示例:

user1 := &User{
    Age:       30,
    FirstName: "Tom",
    LastName:  "Cat",
}
user2 := User{
    Age:       30,
    FirstName: "Tom",
    LastName:  "Curise",
}
user3 := User{
    Age:       30,
    FirstName: "Tom",
    LastName:  "Hanks",
}
user4 := map[string]interface{}{
    "age":        30,
    "first_name": "Tom",
    "last_name":  "Zzy",
}
users().Insert([]interface{}{user1, user2})
users().Insert(user3)
users().Insert(user4)

增删改查的 部分到此完成,因为查询语句非常复杂多变,所以有了数据后,先进行

Select方法

先分析一下标准 select 语句

select id, age from user where first_name = 'Tom' and last_name = 'Cat'

可见sql语句的变量部分是 select 后面的字段和 where 后面的键值对,所以我们需要一个 Where() 来方法构造查询条件,并且需要一个 Select() 方法最后执行查询,最终形成一个链式调用效果:

var user []User
users().Where(?).WhereNot(?).Limit(100).Offset(100).Order("id desc").Only("id", "age").Select(&user)

所以需要改造Query如下,增加属性用于暂存链式调用中添加的值:

//Query will build a sql
type Query struct {
    db     *sql.DB
    table  string
    wheres []string
    only   []string
    limit  string
    offset string
    order  string
    errs   []string
}

为Query添加Where()方法,支持struct和map参数,同时支持传如同 "age > 10" 形式的字符串:

//Where args can be string, User, *User, map[string]interface{}
func (q *Query) Where(wheres ...interface{}) *Query {
    for _, w := range wheres {
        v := reflect.ValueOf(w)
        for v.Kind() == reflect.Ptr {
            v = v.Elem()
        }
        switch v.Kind() {
        case reflect.String:
            q.wheres = append(q.wheres, w.(string))
        case reflect.Struct:
            //todo
        case reflect.Map:
            //todo
        default:
            q.errs = append(q.errs, "method Where error: type error")
        }
    }
    return q
}

但是考虑到后面还会实现一个 WhereNot() 方法,所以把公共逻辑抽到一个 where() 函数里,并且直接复用之前的sKV()、mKv()函数获取键值对:

func where(eq bool, w interface{}) (string, error) {
    var keys, values []string
    v := reflect.ValueOf(w)
    for v.Kind() == reflect.Ptr {
        v = v.Elem()
    }
    switch v.Kind() {
    case reflect.String:
        return w.(string), nil
    case reflect.Struct:
        keys, values = sKV(v)
    case reflect.Map:
        keys, values = mKV(v)
    default:
        return "", errors.New("method Where error: type error")
    }
    if len(keys) != len(values) {
        return "", errors.New("method Where error: len(keys) not equal len(values))")
    }
    var wheres []string
    //之前的format()函数里,已经将切片类型值处理成"( , , ,)“形式
    for idx, key := range keys {
        if eq {
            if strings.HasPrefix(values[idx], "(") && strings.HasSuffix(values[idx], ")") {
                wheres = append(wheres, fmt.Sprintf("%s in %s", key, values[idx]))
                continue
            }
            wheres = append(wheres, fmt.Sprintf("%s = %s", key, values[idx]))
            continue
        }
        if strings.HasPrefix(values[idx], "(") && strings.HasSuffix(values[idx], ")") {
            wheres = append(wheres, fmt.Sprintf("%s not in %s", key, values[idx]))
            continue
        }
        wheres = append(wheres, fmt.Sprintf("%s != %s", key, values[idx]))
    }
    return strings.Join(wheres, " and "), nil
}

Where()方法最终变成:

//Where args can be string, User, *User, map[string]interface{}
func (q *Query) Where(wheres ...interface{}) *Query {
    for _, w := range wheres {
        str, err := where(true, w)
        q.wheres = append(q.wheres, str)
        if err != nil {
            //因为需要达到链式调用的效果,所以把错误都搜集起来,最后再处理
            q.errs = append(q.errs, err.Error())
        }
    }
    return q
}

WhereNot()把调用where()的第一个参数改成false就行了,不贴代码了。

Limit()Offset()Order()Only() 这几个方法也很简单:

//Limit .
func (q *Query) Limit(limit int) *Query {
    if limit <= 0 {
        q.errs = append(q.errs, "Limit error")
        return q
    }
    q.limit = fmt.Sprintf("limit %d", limit)
    return q
}

//Offset .
func (q *Query) Offset(offset int) *Query {
    if offset <= 0 {
        q.errs = append(q.errs, "Offset error")
        return q
    }
    q.offset = fmt.Sprintf("offset %d", offset)
    return q
}

//Order .
func (q *Query) Order(ord string) *Query {
    q.order = fmt.Sprintf("order by %s", ord)
    return q
}

//Only .
func (q *Query) Only(columns ...string) *Query {
    q.only = append(q.only, columns...)
    return q
}

有了上面这些条件之后,我们可以写一个 toSQL() 方法,把Query的属性组装成一条sql语句:

func (q *Query) toSQL() string {
    var where string
    if len(q.wheres) > 0 {
        where = fmt.Sprintf(`where %s`, strings.Join(q.wheres, " and "))
    }
    sqlStr := fmt.Sprintf(`select %s from %s %s %s %s %s`, strings.Join(q.only, ","), q.table, where, q.order, q.limit, q.offset)
    log.Printf("select sql: %s", sqlStr)
    return sqlStr
}

有了sql语句我们就可以查询数据了,但是想查一个表的全部字段时,为了方便,只需要传入对应的 struct ,比如 user 表对应的 User ,我们就直接分析这个struct,取它的tag作为查询字段,而不需要再调用Only()方法指定字段。

另外,因为golang中的参数传递全都是值传递,要修改传入值,必须传值的指针,这里要注意一点:

var user User
users.Select(&user)
var userPtr *User
users.Select(user)

这两种声明方式是不同的,后者只声明了一个指针类型,是错误的。

综上,我们首先为Select()方法做一下的参数检查,确保传入值是一个正确的指针,并确保only属性有值:

//Select dest must be a ptr, e.g. *user, *[]user, *[]*user, *map, *[]map, *int, *[]int
func (q *Query) Select(dest interface{}) error {
    if len(q.errs) != 0 {
        return errors.New(strings.Join(q.errs, "
"))
    }
    t := reflect.TypeOf(dest)
    v := reflect.ValueOf(dest)
    typeErr := errors.New("method Select error: type error")
    if t.Kind() != reflect.Ptr {
        return typeErr
    }
    //如果是用 var userPtr *User 方式声明的变量,则不可取址
    if !v.Elem().CanAddr() {
        return typeErr
    }
    t = t.Elem()
    v = v.Elem()
    //如果only此时仍然为空,说明Only()方法未被调用,我们从struct上取tag填充
    if len(q.only) == 0 {
        switch t.Kind() {
        case reflect.Struct:
            if t.Name() != "Time" {
                q.only = sK(v)
            }
        case reflect.Slice:
            //获取切片的基本类型给一个局部变量
            t := t.Elem()
            if t.Kind() == reflect.Ptr {
                t = t.Elem()
            }
            if t.Kind() == reflect.Struct {
                if t.Name() != "Time" {
                    q.only = sK(reflect.Zero(t))
                }
            }
        }
    }
    if len(q.only) == 0 {
        return errors.New("method Select error: type error, no columns to select")
    }
    if t.Kind() != reflect.Slice {
        q.limit = "limit 1"
    }
    //todo
}

这里只取struct的tag,不取value,我们定义一个新的sK()函数:

func sK(v reflect.Value) []string {
    var keys []string
    t := v.Type()
    for n := 0; n < t.NumField(); n++ {
        tf := t.Field(n)
        vf := v.Field(n)
        //忽略非导出字段
        if tf.Anonymous {
            continue
        }
        for vf.Type().Kind() == reflect.Ptr {
            vf = vf.Elem()
        }
        //如果字段值是time类型之外的struct,递归获取keys
        if vf.Kind() == reflect.Struct && tf.Type.Name() != "Time" {
            keys = append(keys, sK(vf)...)
            continue
        }
        //根据字段的json tag获取key,忽略无tag字段
        key := strings.Split(tf.Tag.Get("json"), ",")[0]
        if key == "" {
            continue
        }
        keys = append(keys, key)
    }
    return keys
}

现在sql语句已经完备了,可以执行最后的取值步骤了。

我们根据传入Select()的指针的基类型生成实际数据,对其取址后交给sql包的 Scan() 方法填充,然后 Set() 回去,所以这里需要一个 address() 函数用于取址:

func address(dest reflect.Value, columns []string) []interface{} {
    dest = dest.Elem()
    t := dest.Type()
    addrs := make([]interface{}, 0)
    switch t.Kind() {
    case reflect.Struct:
        for n := 0; n < t.NumField(); n++ {
            tf := t.Field(n)
            vf := dest.Field(n)
            if tf.Anonymous {
                continue
            }
            for vf.Type().Kind() == reflect.Ptr {
                vf = vf.Elem()
            }
            //如果字段值是time类型之外的struct,递归取址
            if vf.Kind() == reflect.Struct && tf.Type.Name() != "Time" {
                nVf := reflect.New(vf.Type())
                vf.Set(nVf.Elem())
                addrs = append(addrs, address(nVf, columns)...)
                continue
            }
            column := strings.Split(tf.Tag.Get("json"), ",")[0]
            if column == "" {
                continue
            }
            //只取选定的字段的地址
            for _, col := range columns {
                if col == column {
                    addrs = append(addrs, vf.Addr().Interface())
                    break
                }
            }
        }
    default:
        addrs = append(addrs, dest.Addr().Interface())
    }
    return addrs
}

Value.Addr() 函数可用于取址,前提是 Value.CanAddr() 返回true。

relfect.New() 可以根据 Typenew 出一个 Value ,这个Value是一个 指针 ,它的基值是可以取址的,把它的基值 Set() 到目标值上,就达到了根据Type从无到有生成对应值的目的。

因为map不能用new()函数生成,所以需要写一个用于生成map的函数 setMap()

//map的value类型必须是interface{},因为无类型信息,所以mysql驱动会返回一个字节切片,需要自行用[]byte断言
func (q *Query) setMap(rows *sql.Rows, t reflect.Type) (reflect.Value, error) {
    if t.Elem().Kind() != reflect.Interface {
        return reflect.ValueOf(nil), errors.New("method setMap error: type error, must be map[string]interface{}")
    }
    m := reflect.MakeMap(t)
    addrs := make([]interface{}, len(q.only))
    for idx := range q.only {
        addrs[idx] = new(interface{})
    }
    if err := rows.Scan(addrs...); err != nil {
        return reflect.ValueOf(nil), err
    }
    for idx, column := range q.only {
        //从指针剥出interface{},再剥出实际值
        m.SetMapIndex(reflect.ValueOf(column), reflect.ValueOf(addrs[idx]).Elem().Elem())
    }
    return m, nil
}

reflect.MakeMap()make() 作用差不多,它接受一个 Kindreflect.MapType 作为参数,生成一个对应类型的map。

对于其它适用于 new 的类型,写一个通用的函数 setElem() 处理:

//适用于基类型和struct
func (q *Query) setElem(rows *sql.Rows, t reflect.Type) (reflect.Value, error) {
    addrsErr := errors.New("method setElem error: columns not match addresses")
    dest := reflect.New(t)
    addrs := address(dest, q.only)
    if len(q.only) != len(addrs) {
        return reflect.ValueOf(nil), addrsErr
    }
    if err := rows.Scan(addrs...); err != nil {
        return reflect.ValueOf(nil), err
    }
    return dest, nil
}

这些函数完成后,就可以着手完善Select()里的todo部分了:

//already done
rows, err := q.DB.Query(q.toSQL())
    if err != nil {
        return err
    }
    switch t.Kind() {
    case reflect.Slice:
        dt := t.Elem()
        for dt.Kind() == reflect.Ptr {
            dt = dt.Elem()
        }
        sl := reflect.MakeSlice(t, 0, 0)
        for rows.Next() {
            var destination reflect.Value
            if dt.Kind() == reflect.Map {
                destination, err = q.setMap(rows, dt)
            } else {
                destination, err = q.setElem(rows, dt)
            }
            if err != nil {
                return err
            }
            //区分切片元素是否指针
            switch t.Elem().Kind() {
            case reflect.Ptr, reflect.Map:
                sl = reflect.Append(sl, destination)
            default:
                sl = reflect.Append(sl, destination.Elem())
            }
        }
        v.Set(sl)
        return nil
    case reflect.Map:
        for rows.Next() {
            m, err := q.setMap(rows, t)
            if err != nil {
                return err
            }
            v.Set(m)
        }
        return nil
    default:
        for rows.Next() {
            destination, err := q.setElem(rows, t)
            if err != nil {
                return err
            }
            v.Set(destination.Elem())
        }
    }
    return nil

至此,Select()方法就大功告成了,部分调用方式示例:

var user User
users()
.Where("first_name = 'Tom'", map[string]interface{}{
    "id": []int{1, 2, 3, 4},
})
.WhereNot(&User{LastName: "Cat"})
.Only("last_name")
.Select(&user)

var userMore []User
users().Where("first_name = 'Tom'").Order("id desc").Select(&userMore)
var userMoreP []*User
users().Where("first_name = 'Tom'").Select(&userMoreP)
var lastName string
users().Where(&User{FirstName: "Tom"}).Only("last_name").Select(&lastName)
var lastNames []string
users().Where(map[string]interface{}{
    "first_name": "Tom",
}).Only("last_name").Select(&lastNames)
var userM map[string]interface{}
users().Where(&User{FirstName: "Tom"}).Only("last_name").Select(&userM)
var userMS []map[string]interface{}
users().Where("age > 10").Only("last_name", "age").Limit(100).Select(&userMS)

Update方法

分析update sql语句:

update user set first_name = "z", last_name = "zy" where first_name = "Tom" and last_name = "Curise"

比较简单,直接复用之前写的sKV()和mKV()函数:

//Update src can be *user, user, map[string]interface{}, string
func (q *Query) Update(src interface{}) (int64, error) {
    if len(q.errs) != 0 {
        return 0, errors.New(strings.Join(q.errs, "
"))
    }
    v := reflect.ValueOf(src)
    for v.Kind() == reflect.Ptr {
        v = v.Elem()
    }
    var toBeUpdated, where string
    var keys, values []string
    switch v.Kind() {
    case reflect.String:
        toBeUpdated = src.(string)
    case reflect.Struct:
        keys, values = sKV(v)
    case reflect.Map:
        keys, values = mKV(v)
    default:
        return 0, errors.New("method Update error: type error")
    }
    if toBeUpdated == "" {
        if len(keys) != len(values) {
            return 0, errors.New("method Update error: keys not match values")
        }
        var kvs []string
        for idx, key := range keys {
            kvs = append(kvs, fmt.Sprintf("%s = %s", key, values[idx]))
        }
        toBeUpdated = strings.Join(kvs, ",")
    }
    if len(q.wheres) > 0 {
        where = fmt.Sprintf(`where %s`, strings.Join(q.wheres, " and "))
    }
    query := fmt.Sprintf("update %s set %s %s", q.table, toBeUpdated, where)
    st, err := q.DB.Prepare(query)
    if err != nil {
        return 0, err
    }
    result, err := st.Exec()
    if err != nil {
        return 0, err
    }
    return result.RowsAffected()
}

调用方式:

u1 := "age = 100"
u2 := map[string]interface{}{
    "age":        100,
    "first_name": "z",
    "last_name":  "zy",
}
u3 := &User{
    Age:       100,
    FirstName: "z",
    LastName:  "zy",
}
_, _ = users().Where("age > 10").Update(u1)
_, _ = users().Where("age > 10").Update(u2)
_, _ = users().Where("age > 10").Update(u3)

Delete方法

这个最简单,没啥好说的:

//Delete no args
func (q *Query) Delete() (int64, error) {
    if len(q.errs) != 0 {
        return 0, errors.New(strings.Join(q.errs, "
"))
    }
    var where string
    if len(q.wheres) > 0 {
        where = fmt.Sprintf(`where %s`, strings.Join(q.wheres, " and "))
    }
    st, err := q.DB.Prepare(fmt.Sprintf(`delete from %s %s`, q.table, where))
    if err != nil {
        return 0, err
    }
    result, err := st.Exec()
    if err != nil {
        return 0, err
    }
    return result.RowsAffected()
}

删除id为1,2,3,4,并且age大于10的用户的调用方式:

w := map[string]interface{}{
    "id": []int{1, 2, 3, 4},
}
_, _ = users().Where(w, "age > 10").Delete()

最后,写一个简单的事务处理函数 Transaction()

Transaction函数

事务有三个关键动作 beginrollbackcommit

begin后,要求所有操作要不全部成功,要不全部失败,所以我们要检查所有error,一旦出现错误就rollback,并且还要 recover 程序的panic,发现panic时也要rollback,直到最后确保无错,才能commit。

调用 *sql.DB.Begin() 方法后,我们会得到一个事务具柄,事务内的mysql交互都要通过它来进行,它也实现了 Query()Prepare() 等方法。

所以我们定义一个接口:

//Dba *sql.DB or *sql.Tx
type Dba interface {
    Query(string, ...interface{}) (*sql.Rows, error)
    Prepare(string) (*sql.Stmt, error)
}

然后把 Query 结构体的 DB 属性的类型改成这个接口:

//Query will build a sql
type Query struct {
    DB     Dba
    ...
}

同时, 改造 Table() 函数:

//Table bind db and table
func Table(db *sql.DB, tableName string) func(...Dba) *Query {
    return func(tx ...Dba) *Query {
        if len(tx) == 1 {
            return &Query{
                DB:    tx[0],
                table: tableName,
            }
        }
        return &Query{
            DB:    db,
            table: tableName,
        }
    }
}

这样我们就可以有选择性的和mysql进行普通交互或者事务交互。

然后把 Transaction() 函数写成这样:

//Transaction .
func Transaction(db *sql.DB, f func(Dba) error) (err error) {
    tx, err := db.Begin()
    if err != nil {
        return err
    }
    defer func() {
        p := recover()
        if err != nil {
            if rerr := tx.Rollback(); rerr != nil {
                panic(rerr)
            }
            return
        }
        if p != nil {
            if rerr := tx.Rollback(); rerr != nil {
                panic(rerr)
            }
            err = fmt.Errorf("function Transaction error: %v", p)
            return
        }
        if cerr := tx.Commit(); cerr != nil {
            panic(cerr)
        }
    }()
    err = f(tx)
    return err
}

第二个参数是一个接受事务具柄,返回error的函数,我们将需要事务的操作全部封装在这个函数里,就能抓到所有的panic和error。

调用方式示例:

unc doTx() error {
    ormDB, err := Connect("root@tcp(127.0.0.1:3306)/orm_db?parseTime=true&loc=Local")
    if err != nil {
        panic(err)
    }
    users := Table(ormDB, "user")
    args := something()
    //利用闭包传递变量
    f := func(tx Dba) error {
        var id int
        //select语句无需在事务具柄上进行
        if err := users().Where(args).Select(&id); err != nil {
            return err
        }
        //增删改需要在事务上进行
        if _, err = users(tx).Insert(args); err != nil {
            return err
        }
        if _, err = users(tx).Update(args); err != nil {
            return err
        }
        if _, err = users(tx).Where(args).Delete(); err != nil {
            return err
        }
        return nil
    }
    //开始事务
    if err := Transaction(ormDB, f); err != nil {
        return err
    }
    return nil
}

到此,这个迷你orm的增删改查和事务功能全部都实现了,代码大概600行,比我预想的多了一倍。

后记

golang的反射虽然强大(其实并不,没有ruby的元编程那么方便),但还是比较烦琐的,而且类型不对时动不动就panic,使用的时候要尽量检查一下Kind。


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK