本文为转载,原文:Golang 学习笔记(12)—— ORM实现
介绍
本文将利用之前所学习到的内容实现一个简单的orm,实现比较简单,没有考虑过多的设计原则,以及性能安全之类的,只是单纯的以学习为导向,做的一个练手的小工具。其中有不合理的地方还请看到的同学见谅,并指出,本人也好加以改正。
实现
首先看下完整的代码目录吧:
数据库
数据库,这里我选择mysql数据库,可以用以下sql语句创建一个测试的表:
CREATE TABLE `userinfo` (
`uid` INT(10) NOT NULL AUTO_INCREMENT,
`username` VARCHAR(64) NULL DEFAULT NULL,
`departname` VARCHAR(64) NULL DEFAULT NULL,
`created` DATE NULL DEFAULT NULL,
PRIMARY KEY (`uid`)
)
go
实体
建完表之后,我们在go中需要穿件一个与之对应的struct
type UserInfo struct{
TableName orm.TableName "userinfo"
UserName string `name:"username"`
Uid int `name:"uid"PK:"true"auto:"true"`
DepartName string `name:"departname"`
Created string `name:"created"`
}
从struct中可以看到,我们的字段分为两部分,第一个就是TableName,该字段没有具体内容,只是使用后面的tag标记其对应的数据库表名,剩余的则是与数据库表字段一一对应的字段了。
每个字段的tag则说明了其在数据库的属性。name
表示在数据库中的字段名,不写则与字段名一致,PK
表示是否为主键,若为主键则标记为"ture",默认为false,auto
表示是否为自增长,若为自增长则标记为"true",默认为false,由于我的设计比较简单,所以,先就标记这几种属性。
反射对象
既然有了实体对象,自然我们要通过反射去解析该对象。该代码写在orm/orm.go代码文件中。
package orm
import (
"fmt"
"database/sql"
"errors"
"strings"
"reflect"
)
/*
表信息
*/
type TableInfo struct{
Name string //表名
Fields []FieldInfo //表字段信息
TMMap map[string]string //表字段与实体字段名映射关系,key为表字段名,val为实体字段名
}
/*
表字段详细信息
*/
type FieldInfo struct{
Name string
IsPrimaryKey bool
IsAutoGenerate bool
Valve reflect.Value
}
/*
实体对象信息
*/
type ModelInfo struct{
TableInfo // 实体对应的表信息
TbName string // 表名称
Model interface{} //实体实例
}
//表名
type TableName string
//表名类型
var typeTableName TableName
var tableNameType reflect.Type = reflect.TypeOf(typeTableName)
//实体映射,key为表名,val为实体信息
var ModelMapping map[string]ModelInfo
/*
注册实体,每当有一个实体时,需要调用该方法注册。
注册到 ModelMapping
*/
func Register(model interface{}){
if ModelMapping == nil{
ModelMapping = make(map[string]ModelInfo)
}
tbInfo, _ := getTableInfo(model)
ModelMapping[tbInfo.Name] = ModelInfo{TbName:tbInfo.Name, Model:model}
}
/*
根据实体通过反射获取表信息
返回表信息
*/
func getTableInfo(model interface{})(tabInfo *TableInfo, err error){
defer func(){
if e := recover(); err != nil{
tabInfo = nil
err = e.(error)
}
}()
err = nil
tabInfo = &TableInfo{}
tabInfo.TMMap = make(map[string]string)
rt := reflect.TypeOf(model)
rv := reflect.ValueOf(model)
tabInfo.Name = rt.Name()
if rt.Kind() == reflect.Ptr{
rt = rt.Elem()
rv = rv.Elem()
}
//字段解析
for i, j := 0, rt.NumField(); i < j; i++{
rtf := rt.Field(i)
rvf := rv.Field(i)
if rtf.Type == tableNameType{
tabInfo.Name = string(rtf.Tag)
continue
}
if rtf.Tag == "-"{
continue
}
//解析字段的tag
var f FieldInfo
//没有tag,表字段名与实体字段ing一致
if rtf.Tag == ""{
f = FieldInfo{Name:rtf.Name, IsAutoGenerate:false, IsPrimaryKey:false, Valve:rvf}
tabInfo.TMMap[rtf.Name] = rtf.Name
}else{
strTag := string(rtf.Tag)
if strings.Index(strTag, ":") == -1{
//tag中没有":"时,表字段名与实体字段ing一致
f = FieldInfo{Name:rtf.Name, IsAutoGenerate:false, IsPrimaryKey:false, Valve:rvf}
tabInfo.TMMap[rtf.Name] = rtf.Name
}else{
//解析tag中的name值为表字段名
strName := rtf.Tag.Get("name")
if strName == ""{
strName = rtf.Name
}
//解析tag中的PK
isPk := false
strIspk := rtf.Tag.Get("PK")
if strIspk == "true"{
isPk = true
}
//解析tag中的auto
isAuto := false
strIsauto := rtf.Tag.Get("auto")
if strIsauto == "true"{
isAuto = true
}
f = FieldInfo{Name:strName, IsPrimaryKey:isPk, IsAutoGenerate:isAuto, Valve:rvf}
tabInfo.TMMap[strName] = rtf.Name
}
}
tabInfo.Fields = append(tabInfo.Fields, f)
}
return
}
/*
根据实体生成插入语句
*/
func generateInsertSql(model interface{})(string, []interface{}, *TableInfo, error){
//获取表信息
tbInfo, err := getTableInfo(model)
if err != nil{
return "", nil, nil, err
}
if len(tbInfo.Fields) == 0 {
return "", nil, nil, errors.New(tbInfo.Name + "结构体中没有字段")
}
//根据字段信息拼Sql语句,以及参数值
strSql := "insert into " + tbInfo.Name
strFileds := ""
strValues := ""
var params []interface{}
for _, v := range tbInfo.Fields{
if v.IsAutoGenerate {
continue
}
strFileds += v.Name + ","
strValues += "?,"
params = append(params, v.Valve.Interface())
}
if strFileds == ""{
return "", nil, nil, errors.New(tbInfo.Name + "结构体中没有字段,或只有自增字段")
}
strFileds = strings.TrimRight(strFileds, ",")
strValues = strings.TrimRight(strValues, ",")
strSql += " (" + strFileds + ") values(" + strValues + ")"
fmt.Println("sql: ",strSql)
fmt.Println("params: ",params)
return strSql, params, tbInfo, nil
}
/*
根据实体生成修改的sql语句
*/
func generateUpdateSql(model interface{})(string, []interface{}, error){
//获取表信息
tbInfo, err := getTableInfo(model)
if err != nil{
return "", nil, err
}
if len(tbInfo.Fields) == 0 {
return "", nil, errors.New(tbInfo.Name + "结构体中没有字段")
}
//根据字段信息拼Sql语句,以及参数值
strSql := "update " + tbInfo.Name + " set "
strFileds := ""
strWhere := ""
var p interface{}
var params []interface{}
for _, v := range tbInfo.Fields{
if v.IsAutoGenerate && !v.IsPrimaryKey{
continue
}
if v.IsPrimaryKey{
strWhere += v.Name + "=?"
p = v.Valve.Interface()
continue
}
strFileds += v.Name + "=?,"
params = append(params, v.Valve.Interface())
}
params = append(params, p)
strFileds = strings.TrimRight(strFileds, ",")
strSql += strFileds + " where " + strWhere
fmt.Println("update sql: ", strSql)
fmt.Println("update params: ", params)
return strSql, params, nil
}
/*
自动生成删除的sql语句,以主键为删除条件
*/
func generateDeleteSql(model interface{})(string, []interface{}, error){
//获取表信息
tbInfo, err := getTableInfo(model)
if err != nil{
return "", nil, err
}
//根据字段信息拼Sql语句,以及参数值
strSql := "delete from " + tbInfo.Name + " where "
var idVal interface{}
for _, v := range tbInfo.Fields{
if v.IsPrimaryKey{
strSql += v.Name + "=?"
idVal = v.Valve.Interface()
}
}
params := []interface{}{idVal}
fmt.Println("update sql: ", strSql)
fmt.Println("update params: ", params)
return strSql, params, nil
}
/*
设置自增长字段的值
*/
func setAuto(result sql.Result, tbInfo *TableInfo)(err error){
defer func(){
if e := recover(); e != nil{
err = e.(error)
}
}()
id, err := result.LastInsertId()
if id == 0{
return
}
if err != nil{
return
}
for _, v := range tbInfo.Fields{
if v.IsAutoGenerate && v.Valve.CanSet(){
v.Valve.SetInt(id)
break
}
}
return
}
这里面,我们实现了通过model实体,来生成新增,修改,删除的sql语句,以及参数。但是查询的怎么办呢?
MyRows
在orm/MyRows.go代码文件中,实现了一个自己的Rows,来处理查询:
package orm
import (
"strconv"
"reflect"
"database/sql"
)
type MyRows struct{
* sql.Rows
Values map[string]interface{} //表字段和值的映射
ColumnNames []string //表字段名集合
}
/*
获取数据
*/
func (this *MyRows)Next()bool{
bResult := this.Rows.Next()
if bResult{
//获取表字段名称集合
if this.ColumnNames == nil || len(this.ColumnNames) == 0{
this.ColumnNames, _ = this.Rows.Columns()
}
//初始化表字段和值的映射
if this.Values == nil{
this.Values = make(map[string]interface{})
}
//调用scan函数的参数
scanArgs := make([]interface{}, len(this.ColumnNames))
//scan函数的值
values := make([][]byte, len(this.ColumnNames))
for i := range values{
scanArgs[i] = &values[i]
}
this.Rows.Scan(scanArgs...)
//将结果存放到Values中
for i := 0; i < len(this.ColumnNames); i++{
this.Values[this.ColumnNames[i]] = values[i]
}
}
return bResult
}
/*
将数据映射到实体切片
tbname:U对应的数据表名
*/
func (this *MyRows)To(tbname string) ([]interface{},error){
mi := ModelMapping[tbname]
ti, _ := getTableInfo(mi.Model)
var models []interface{}
for this.Next(){
v := reflect.New(reflect.TypeOf(mi.Model).Elem()).Elem()
for k, val := range this.Values{
f := v.FieldByName(ti.TMMap[k])
var strVal string
if bt, ok := val.([]byte); ok{
strVal = string(bt)
switch f.Type().Name(){
case "int":
i, _ := strconv.ParseInt(strVal, 10, 64)
f.SetInt(i)
break
case "string":
f.SetString(strVal)
break
}
}
}
models = append(models, v.Interface())
}
return models, nil
}
该代码实现了查询结果到实体的映射
MysqlDB
orm/MysqlDB.go是实现数据库方法的代码文件。
package orm
import (
"strconv"
"reflect"
"database/sql"
)
type MyRows struct{
* sql.Rows
Values map[string]interface{} //表字段和值的映射
ColumnNames []string //表字段名集合
}
/*
获取数据
*/
func (this *MyRows)Next()bool{
bResult := this.Rows.Next()
if bResult{
//获取表字段名称集合
if this.ColumnNames == nil || len(this.ColumnNames) == 0{
this.ColumnNames, _ = this.Rows.Columns()
}
//初始化表字段和值的映射
if this.Values == nil{
this.Values = make(map[string]interface{})
}
//调用scan函数的参数
scanArgs := make([]interface{}, len(this.ColumnNames))
//scan函数的值
values := make([][]byte, len(this.ColumnNames))
for i := range values{
scanArgs[i] = &values[i]
}
this.Rows.Scan(scanArgs...)
//将结果存放到Values中
for i := 0; i < len(this.ColumnNames); i++{
this.Values[this.ColumnNames[i]] = values[i]
}
}
return bResult
}
/*
将数据映射到实体切片
tbname:U对应的数据表名
*/
func (this *MyRows)To(tbname string) ([]interface{},error){
mi := ModelMapping[tbname]
ti, _ := getTableInfo(mi.Model)
var models []interface{}
for this.Next(){
v := reflect.New(reflect.TypeOf(mi.Model).Elem()).Elem()
for k, val := range this.Values{
f := v.FieldByName(ti.TMMap[k])
var strVal string
if bt, ok := val.([]byte); ok{
strVal = string(bt)
switch f.Type().Name(){
case "int":
i, _ := strconv.ParseInt(strVal, 10, 64)
f.SetInt(i)
break
case "string":
f.SetString(strVal)
break
}
}
}
models = append(models, v.Interface())
}
return models, nil
}
详情请见注释
调用
package main
import (
"time"
"fmt"
"stu_demo/orm"
_ "github.com/go-sql-driver/mysql"
)
type UserInfo struct{
TableName orm.TableName "userinfo"
UserName string `name:"username"`
Uid int `name:"uid"PK:"true"auto:"true"`
DepartName string `name:"departname"`
Created string `name:"created"`
}
func main(){
ui := UserInfo{UserName:"CHAIN", DepartName:"TEST", Created:time.Now().String()}
orm.Register(new(UserInfo))
db, err := orm.NewDb("mysql", "root:pwd@tcp(xxx.xxx.xxx.xxx:x3306/demo?charset=utf8")
if err != nil {
fmt.Println("打开SQL时出错:", err.Error())
return
}
defer db.Close()
//插入测试
err = db.Insert(&ui)
if err != nil {
fmt.Println("插入时错误:", err.Error())
}
fmt.Println("插入成功")
//修改测试
ui.UserName = "BBBB"
err = db.Update(ui)
if err != nil {
fmt.Println("修改时错误:", err.Error())
}
fmt.Println("修改成功")
//删除测试
err = db.Delete(ui)
if err != nil {
fmt.Println("删除时错误:", err.Error())
}
fmt.Println("删除成功")
//查询测试
res, err := db.From("userinfo").
Select("username", "departname", "uid").
Where("uid__gt", 20).
Where("username", "chain").Get()
if err != nil{
fmt.Println("err: ", err.Error())
}
fmt.Println(res)
}
源码
完
转载请注明出处:
Golang 学习笔记(12)—— ORM实现
目录
上一节:Golang 学习笔记(11)—— 反射
下一节:Golang Web学习(13)—— 搭建简单的Web服务器