diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f4b907f --- /dev/null +++ b/go.mod @@ -0,0 +1,7 @@ +module github.com/twiglab/sqlt + +require ( + github.com/jmoiron/sqlx v1.2.0 + github.com/lib/pq v1.0.0 + google.golang.org/appengine v1.4.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7258bec --- /dev/null +++ b/go.sum @@ -0,0 +1,13 @@ +github.com/go-sql-driver/mysql v1.4.0 h1:7LxgVwFb2hIQtMm87NdgAVfXjnt4OePseqT1tKx+opk= +github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA= +github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= +github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mattn/go-sqlite3 v1.9.0 h1:pDRiWfl+++eC2FEFRy6jXmQlvp4Yh3z1MJKg4UeYM/4= +github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= diff --git a/internal/mapper/base.go b/internal/mapper/base.go new file mode 100644 index 0000000..c486e64 --- /dev/null +++ b/internal/mapper/base.go @@ -0,0 +1,54 @@ +package mapper + +import ( + "errors" + "reflect" +) + +type ColScanner interface { + Columns() ([]string, error) + Scan(dest ...interface{}) error + Err() error +} + +func MapScan(r ColScanner, dest map[string]interface{}) error { + columns, err := r.Columns() + if err != nil { + return err + } + + values := make([]interface{}, len(columns)) + for i := range values { + values[i] = new(interface{}) + } + + err = r.Scan(values...) + if err != nil { + return err + } + + for i, column := range columns { + dest[column] = *(values[i].(*interface{})) + } + + return r.Err() +} + +func StructScan(rows ColScanner, dest interface{}) (err error) { + destValue := reflect.ValueOf(dest) + elemType := destValue.Type() + + if elemType.Kind() != reflect.Ptr { + return errors.New("slice elem must ptr ") + } + + rowMap := make(map[string]interface{}) + if err = MapScan(rows, rowMap); err != nil { + return + } + if err = MapperMap(rowMap, dest); err != nil { + return + } + + return +} diff --git a/internal/mapper/convert.go b/internal/mapper/convert.go new file mode 100644 index 0000000..3478760 --- /dev/null +++ b/internal/mapper/convert.go @@ -0,0 +1,233 @@ +package mapper + +import ( + "fmt" + "math/big" + "reflect" + "strconv" + "time" +) + +// Convert is the target string +type Convert string + +// Set string +func (f *Convert) Set(v string) { + if v != "" { + *f = Convert(v) + } else { + f.Clear() + } +} + +// Clear string +func (f *Convert) Clear() { + *f = Convert(0x1E) +} + +// Exist check string exist +func (f Convert) Exist() bool { + return string(f) != string(0x1E) +} + +// Bool string to bool +func (f Convert) Bool() (bool, error) { + return strconv.ParseBool(f.String()) +} + +// Float32 string to float32 +func (f Convert) Float32() (float32, error) { + v, err := strconv.ParseFloat(f.String(), 32) + return float32(v), err +} + +// Float64 string to float64 +func (f Convert) Float64() (float64, error) { + return strconv.ParseFloat(f.String(), 64) +} + +// Int string to int +func (f Convert) Int() (int, error) { + v, err := strconv.ParseInt(f.String(), 10, 32) + return int(v), err +} + +// Int8 string to int8 +func (f Convert) Int8() (int8, error) { + v, err := strconv.ParseInt(f.String(), 10, 8) + return int8(v), err +} + +// Int16 string to int16 +func (f Convert) Int16() (int16, error) { + v, err := strconv.ParseInt(f.String(), 10, 16) + return int16(v), err +} + +// Int32 string to int32 +func (f Convert) Int32() (int32, error) { + v, err := strconv.ParseInt(f.String(), 10, 32) + return int32(v), err +} + +// Int64 string to int64 +func (f Convert) Int64() (int64, error) { + v, err := strconv.ParseInt(f.String(), 10, 64) + if err != nil { + i := new(big.Int) + ni, ok := i.SetString(f.String(), 10) // octal + if !ok { + return v, err + } + return ni.Int64(), nil + } + return v, err +} + +// Uint string to uint +func (f Convert) Uint() (uint, error) { + v, err := strconv.ParseUint(f.String(), 10, 32) + return uint(v), err +} + +// Uint8 string to uint8 +func (f Convert) Uint8() (uint8, error) { + v, err := strconv.ParseUint(f.String(), 10, 8) + return uint8(v), err +} + +// Uint16 string to uint16 +func (f Convert) Uint16() (uint16, error) { + v, err := strconv.ParseUint(f.String(), 10, 16) + return uint16(v), err +} + +// Uint32 string to uint32 +func (f Convert) Uint32() (uint32, error) { + v, err := strconv.ParseUint(f.String(), 10, 32) + return uint32(v), err +} + +// Uint64 string to uint64 +func (f Convert) Uint64() (uint64, error) { + v, err := strconv.ParseUint(f.String(), 10, 64) + if err != nil { + i := new(big.Int) + ni, ok := i.SetString(f.String(), 10) + if !ok { + return v, err + } + return ni.Uint64(), nil + } + return v, err +} + +// String string to string +func (f Convert) String() string { + if f.Exist() { + return string(f) + } + return "" +} + +// ToString interface to string +func ToString(value interface{}, args ...int) (s string) { + switch v := value.(type) { + case bool: + s = strconv.FormatBool(v) + case float32: + s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32)) + case float64: + s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64)) + case int: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int8: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int16: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int32: + s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10)) + case int64: + s = strconv.FormatInt(v, argInt(args).Get(0, 10)) + case uint: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint8: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint16: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint32: + s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10)) + case uint64: + s = strconv.FormatUint(v, argInt(args).Get(0, 10)) + case string: + s = v + case []byte: + s = string(v) + default: + s = fmt.Sprintf("%v", v) + } + return s +} + +// ToInt64 interface to int64 +func ToInt64(value interface{}) (d int64) { + val := reflect.ValueOf(value) + switch value.(type) { + case int, int8, int16, int32, int64: + d = val.Int() + case uint, uint8, uint16, uint32, uint64: + d = int64(val.Uint()) + default: + panic(fmt.Errorf("ToInt64 need numeric not `%T`", value)) + } + return +} + +type argInt []int + +// get int by index from int slice +func (a argInt) Get(i int, args ...int) (r int) { + if i >= 0 && i < len(a) { + r = a[i] + } + if len(args) > 0 { + r = args[0] + } + return +} + +// TimeToUnix transform time to Unix time, the number of seconds elapsed +func TimeToUnix(t time.Time) int64 { + return t.Unix() +} + +// UnixToTime transform Unix time to local Time +func UnixToTime(tt int64) time.Time { + return time.Unix(tt, 0) +} + +// TimeToUnixLocation transform time to Unix time with time location +// location like "Asia/Shanghai" +func TimeToUnixLocation(t time.Time, location string) (int64, error) { + timeStr := t.Format("2006-01-02 15:04:05") + loc, err := time.LoadLocation(location) + if err != nil { + return 0, err + } + tt, err := time.ParseInLocation("2006-01-02 15:04:05", timeStr, loc) + if err != nil { + return 0, err + } + return tt.Unix(), err +} + +// UnixToTimeLocation transform Unix time to local Time with time location +// location like "Asia/Shanghai" +func UnixToTimeLocation(tt int64, location string) (time.Time, error) { + loc, err := time.LoadLocation(location) + if err != nil { + return time.Now(), err + } + time.Local = loc + return time.Unix(tt, 0), nil +} diff --git a/internal/mapper/jsontime.go b/internal/mapper/jsontime.go new file mode 100644 index 0000000..de06b19 --- /dev/null +++ b/internal/mapper/jsontime.go @@ -0,0 +1,35 @@ +package mapper + +import "time" + +type JSONTime time.Time + +var ( + timeJSONFormat = "2006-01-02 15:04:05" +) + +func SetTimeJSONFormat(format string) { + timeJSONFormat = format +} + +func GetTimeJSONFormat() string { + return timeJSONFormat +} + +func (t *JSONTime) UnmarshalJSON(data []byte) (err error) { + now, err := time.ParseInLocation(`"`+timeJSONFormat+`"`, string(data), time.Local) + *t = JSONTime(now) + return +} + +func (t JSONTime) MarshalJSON() ([]byte, error) { + b := make([]byte, 0, len(timeJSONFormat)+2) + b = append(b, '"') + b = time.Time(t).AppendFormat(b, timeJSONFormat) + b = append(b, '"') + return b, nil +} + +func (t JSONTime) String() string { + return time.Time(t).Format(timeJSONFormat) +} diff --git a/internal/mapper/mapper.go b/internal/mapper/mapper.go new file mode 100644 index 0000000..119ad9d --- /dev/null +++ b/internal/mapper/mapper.go @@ -0,0 +1,504 @@ +package mapper + +import ( + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" + "sync" + "time" +) + +var ( + ZeroValue reflect.Value + fieldNameMap sync.Map + registerMap sync.Map + enabledTypeChecking bool + enabledMapperStructField bool + enabledAutoTypeConvert bool + timeType = reflect.TypeOf(time.Now()) + jsonTimeType = reflect.TypeOf(JSONTime(time.Now())) +) + +const ( + packageVersion = "0.6" + mapperTagKey = "mapper" + jsonTagKey = "json" + sqltTagKey = "sqlt" + sqlxTagKey = "db" // for sqlx + xormTagKey = "xorm" // for xorm + IgnoreTagValue = "-" + nameConnector = "_" + formatTime = "15:04:05" + formatDate = "2006-01-02" + formatDateTime = "2006-01-02 15:04:05" +) + +func init() { + ZeroValue = reflect.Value{} + enabledTypeChecking = false + enabledMapperStructField = true + enabledAutoTypeConvert = true +} + +func PackageVersion() string { + return packageVersion +} + +// SetEnabledTypeChecking set enabled flag for TypeChecking +// if set true, the field type will be checked for consistency during mapping +// default is false +func SetEnabledTypeChecking(isEnabled bool) { + enabledTypeChecking = isEnabled +} + +// SetEnabledAutoTypeConvert set enabled flag for auto type convert +// if set true, field will auto convert in Time and Unix +// default is true +func SetEnabledAutoTypeConvert(isEnabled bool) { + enabledAutoTypeConvert = isEnabled +} + +// SetEnabledMapperStructField set enabled flag for MapperStructField +// if set true, the reflect.Struct field will auto mapper +// must follow premises: +// 1. fromField and toField type must be reflect.Struct and not time.Time +// 2. fromField and toField must be not same type +// default is enabled +func SetEnabledMapperStructField(isEnabled bool) { + enabledMapperStructField = isEnabled +} + +// Register register struct to init Map +func Register(obj interface{}) error { + objValue := reflect.ValueOf(obj) + if objValue == ZeroValue { + return errors.New("no exists this value") + } + + return registerValue(objValue.Elem()) +} + +// registerValue register Value to init Map +func registerValue(objValue reflect.Value) error { + regValue := objValue + if objValue == ZeroValue { + return errors.New("no exists this value") + } + + if regValue.Type().Kind() == reflect.Ptr { + regValue = regValue.Elem() + } + + typeName := regValue.Type().String() + for i := 0; i < regValue.NumField(); i++ { + mapFieldName := typeName + nameConnector + GetFieldName(regValue, i) + realFieldName := regValue.Type().Field(i).Name + fieldNameMap.Store(mapFieldName, realFieldName) + } + + //store register flag + registerMap.Store(typeName, nil) + return nil +} + +// GetTypeName get type name +func GetTypeName(obj interface{}) string { + object := reflect.ValueOf(obj) + return object.String() +} + +// CheckExistsField check field is exists by name +func CheckExistsField(elem reflect.Value, fieldName string) (realFieldName string, exists bool) { + typeName := elem.Type().String() + fileKey := typeName + nameConnector + fieldName + realName, isOk := fieldNameMap.Load(fileKey) + + if !isOk { + return "", isOk + } else { + return realName.(string), isOk + } + +} + +// GetFieldName get fieldName with ElemValue and index +// if config tag string, return tag value +func GetFieldName(objElem reflect.Value, index int) string { + fieldName := "" + field := objElem.Type().Field(index) + tag := getStructTag(field) + if tag != "" { + fieldName = tag + } else { + fieldName = field.Name + } + return fieldName +} + +// MapperMap mapper and set value from map to object +// support auto register struct +// now support field type: +// 1.reflect.Bool +// 2.reflect.String +// 3.reflect.Int8\16\32\64 +// 4.reflect.Uint8\16\32\64 +// 5.reflect.Float32\64 +// 6.time.Time +func MapperMap(fromMap map[string]interface{}, toObj interface{}) error { + toElem := reflect.ValueOf(toObj).Elem() + if toElem == ZeroValue { + return errors.New("to obj is not legal value") + } + //check register flag + //if not register, register it + if !checkIsRegister(toElem) { + Register(toObj) + } + for k, v := range fromMap { + fieldName := k + //check field is exists + realFieldName, exists := CheckExistsField(toElem, fieldName) + if !exists { + continue + } + fieldInfo, exists := toElem.Type().FieldByName(realFieldName) + if !exists { + continue + } + fieldKind := fieldInfo.Type.Kind() + fieldValue := toElem.FieldByName(realFieldName) + setFieldValue(fieldValue, fieldKind, v) + } + return nil +} + +// MapperMapSlice mapper from map[string]map[string]interface{} to a slice of any type's ptr +// toSlice must be a slice of any type's ptr. +func MapperMapSlice(fromMaps map[string]map[string]interface{}, toSlice interface{}) error { + var err error + toValue := reflect.ValueOf(toSlice) + if toValue.Kind() != reflect.Ptr { + return errors.New("toSlice must pointer of slice") + } + if toValue.IsNil() { + return errors.New("toSlice must not nil pointer") + } + + toElemType := reflect.TypeOf(toSlice).Elem().Elem() + if toElemType.Kind() != reflect.Ptr { + return errors.New("slice elem must ptr ") + } + + direct := reflect.Indirect(toValue) + //3 elem parse: 1.[]*type 2.*type 3.type + toElemType = toElemType.Elem() + for _, v := range fromMaps { + elem := reflect.New(toElemType) + err = MapperMap(v, elem.Interface()) + if err == nil { + direct.Set(reflect.Append(direct, elem)) + } + } + return err +} + +// MapToJson mapper from map[string]interface{} to json []byte +func MapToJson(fromMap map[string]interface{}) ([]byte, error) { + json, err := json.Marshal(fromMap) + if err != nil { + return nil, err + } + return json, nil +} + +// JsonToMap mapper from json []byte to map[string]interface{} +func JsonToMap(body []byte, toMap *map[string]interface{}) error { + err := json.Unmarshal(body, toMap) + return err +} + +// Mapper mapper and set value from struct fromObj to toObj +// not support auto register struct +func Mapper(fromObj, toObj interface{}) error { + fromElem := reflect.ValueOf(fromObj).Elem() + toElem := reflect.ValueOf(toObj).Elem() + if fromElem == ZeroValue { + return errors.New("from obj is not legal value") + } + if toElem == ZeroValue { + return errors.New("to obj is not legal value") + } + return elemMapper(fromElem, toElem) +} + +// MapperSlice mapper from slice of struct to a slice of any type +// fromSlice and toSlice must be a slice of any type. +func MapperSlice(fromSlice, toSlice interface{}) error { + var err error + toValue := reflect.ValueOf(toSlice) + if toValue.Kind() != reflect.Ptr { + return errors.New("toSlice must pointer of slice") + } + if toValue.IsNil() { + return errors.New("toSlice must not nil pointer") + } + + elemType := reflect.TypeOf(toSlice).Elem().Elem() + if elemType.Kind() != reflect.Ptr { + return errors.New("slice elem must ptr ") + } + + direct := reflect.Indirect(toValue) + //3 elem parse: 1.[]*type 2.*type 3.type + elemType = elemType.Elem() + + fromElems := convertToSlice(fromSlice) + for _, v := range fromElems { + elem := reflect.New(elemType) + err = elemMapper(reflect.ValueOf(v).Elem(), elem.Elem()) + if err == nil { + direct.Set(reflect.Append(direct, elem)) + } + } + return err +} + +// Mapper mapper and set value from struct fromObj to toObj +// support auto register struct +func AutoMapper(fromObj, toObj interface{}) error { + return Mapper(fromObj, toObj) +} + +func elemMapper(fromElem, toElem reflect.Value) error { + //check register flag + //if not register, register it + if !checkIsRegister(fromElem) { + registerValue(fromElem) + } + if !checkIsRegister(toElem) { + registerValue(toElem) + } + + for i := 0; i < fromElem.NumField(); i++ { + fromFieldInfo := fromElem.Field(i) + fieldName := GetFieldName(fromElem, i) + //check field is exists + realFieldName, exists := CheckExistsField(toElem, fieldName) + if !exists { + continue + } + + toFieldInfo := toElem.FieldByName(realFieldName) + //check field is same type + if enabledTypeChecking { + if fromFieldInfo.Kind() != toFieldInfo.Kind() { + continue + } + } + + if enabledMapperStructField && + toFieldInfo.Kind() == reflect.Struct && fromFieldInfo.Kind() == reflect.Struct && + toFieldInfo.Type() != fromFieldInfo.Type() && + !isTimeField(toFieldInfo) && !isTimeField(fromFieldInfo) { + x := reflect.New(toFieldInfo.Type()).Elem() + err := elemMapper(fromFieldInfo, x) + if err != nil { + fmt.Println("auto mapper field", fromFieldInfo, "=>", toFieldInfo, "error", err.Error()) + } else { + toFieldInfo.Set(x) + } + } else { + isSet := false + if enabledAutoTypeConvert { + if isTimeField(fromFieldInfo) && toFieldInfo.Kind() == reflect.Int64 { + fromTime := fromFieldInfo.Interface().(time.Time) + toFieldInfo.Set(reflect.ValueOf(TimeToUnix(fromTime))) + isSet = true + } else if isTimeField(toFieldInfo) && fromFieldInfo.Kind() == reflect.Int64 { + fromTime := fromFieldInfo.Interface().(int64) + toFieldInfo.Set(reflect.ValueOf(UnixToTime(fromTime))) + isSet = true + } + } + if !isSet { + toFieldInfo.Set(fromFieldInfo) + } + } + + } + return nil +} + +func setFieldValue(fieldValue reflect.Value, fieldKind reflect.Kind, value interface{}) error { + switch fieldKind { + case reflect.Bool: + if value == nil { + fieldValue.SetBool(false) + } else if v, ok := value.(bool); ok { + fieldValue.SetBool(v) + } else { + v, _ := Convert(ToString(value)).Bool() + fieldValue.SetBool(v) + } + + case reflect.String: + if value == nil { + fieldValue.SetString("") + } else { + fieldValue.SetString(ToString(value)) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if value == nil { + fieldValue.SetInt(0) + } else { + val := reflect.ValueOf(value) + switch val.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + fieldValue.SetInt(val.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + fieldValue.SetInt(int64(val.Uint())) + default: + v, _ := Convert(ToString(value)).Int64() + fieldValue.SetInt(v) + } + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if value == nil { + fieldValue.SetUint(0) + } else { + val := reflect.ValueOf(value) + switch val.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + fieldValue.SetUint(uint64(val.Int())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + fieldValue.SetUint(val.Uint()) + default: + v, _ := Convert(ToString(value)).Uint64() + fieldValue.SetUint(v) + } + } + case reflect.Float64, reflect.Float32: + if value == nil { + fieldValue.SetFloat(0) + } else { + val := reflect.ValueOf(value) + switch val.Kind() { + case reflect.Float64: + fieldValue.SetFloat(val.Float()) + default: + v, _ := Convert(ToString(value)).Float64() + fieldValue.SetFloat(v) + } + } + case reflect.Struct: + if value == nil { + fieldValue.Set(reflect.Zero(fieldValue.Type())) + } else if isTimeField(fieldValue) { + var timeString string + if fieldValue.Type() == timeType { + timeString = "" + fieldValue.Set(reflect.ValueOf(value)) + } + if fieldValue.Type() == jsonTimeType { + timeString = "" + fieldValue.Set(reflect.ValueOf(JSONTime(value.(time.Time)))) + } + switch d := value.(type) { + case []byte: + timeString = string(d) + case string: + timeString = d + case int64: + if enabledAutoTypeConvert { + //try to transform Unix time to local Time + t, err := UnixToTimeLocation(value.(int64), time.UTC.String()) + if err != nil { + return err + } + fieldValue.Set(reflect.ValueOf(t)) + } + } + if timeString != "" { + if len(timeString) >= 19 { + //满足yyyy-MM-dd HH:mm:ss格式 + timeString = timeString[:19] + t, err := time.ParseInLocation(formatDateTime, timeString, time.UTC) + if err == nil { + t = t.In(time.UTC) + fieldValue.Set(reflect.ValueOf(t)) + } + } else if len(timeString) >= 10 { + //满足yyyy-MM-dd格式 + timeString = timeString[:10] + t, err := time.ParseInLocation(formatDate, timeString, time.UTC) + if err == nil { + fieldValue.Set(reflect.ValueOf(t)) + } + } + } + } + default: + if reflect.ValueOf(value).Type() == fieldValue.Type() { + fieldValue.Set(reflect.ValueOf(value)) + } + } + + return nil +} + +func isTimeField(fieldValue reflect.Value) bool { + if _, ok := fieldValue.Interface().(time.Time); ok { + return true + } + if _, ok := fieldValue.Interface().(JSONTime); ok { + return true + } + return false +} + +func getStructTag(field reflect.StructField) string { + tagValue := "" + //1.check mapperTagKey + tagValue = field.Tag.Get(mapperTagKey) + if checkTagValidity(tagValue) { + return tagValue + } + + //2.check jsonTagKey + tagValue = field.Tag.Get(jsonTagKey) + if checkTagValidity(tagValue) { + // support more tag property, as json tag omitempty 2018-07-13 + return strings.Split(tagValue, ",")[0] + } + + return "" +} + +func checkTagValidity(tagValue string) bool { + if tagValue != "" && tagValue != IgnoreTagValue { + return true + } + return false +} + +func checkIsRegister(objElem reflect.Value) bool { + typeName := objElem.Type().String() + _, isOk := registerMap.Load(typeName) + return isOk +} + +//convert slice interface{} to []interface{} +func convertToSlice(arr interface{}) []interface{} { + v := reflect.ValueOf(arr) + if v.Kind() != reflect.Slice { + panic("toslice arr not slice") + } + l := v.Len() + ret := make([]interface{}, l) + for i := 0; i < l; i++ { + ret[i] = v.Index(i).Interface() + } + return ret +} diff --git a/internal/mapper/reflectx.go b/internal/mapper/reflectx.go new file mode 100644 index 0000000..22ab591 --- /dev/null +++ b/internal/mapper/reflectx.go @@ -0,0 +1,33 @@ +package mapper + +import ( + "fmt" + "reflect" +) + +// Deref is Indirect for reflect.Types +func Deref(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +// BaseType get baseType from reflect.Type and check is same with expected +func BaseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { + t = Deref(t) + if t.Kind() != expected { + return nil, fmt.Errorf("expected %s but got %s", expected, t.Kind()) + } + return t, nil +} + +// GetSliceType get slice's elem type from slice +func GetSliceType(slice reflect.Value) reflect.Type { + isPtr := slice.Kind() == reflect.Ptr + elemType := slice.Type().Elem() + if isPtr { + elemType = elemType.Elem() + } + return elemType +} diff --git a/sqlt.go b/sqlt.go index f39eb28..ad80ad5 100644 --- a/sqlt.go +++ b/sqlt.go @@ -15,7 +15,6 @@ type Rows interface { Scan(...interface{}) error MapScan(map[string]interface{}) error StructScan(interface{}) error - SliceScan() ([]interface{}, error) ColumnTypes() ([]*sql.ColumnType, error) Columns() ([]string, error) diff --git a/tpl.go b/tpl.go index 30ed792..8418c65 100644 --- a/tpl.go +++ b/tpl.go @@ -37,6 +37,10 @@ func (t *SqlTemplate) MakeSql(id string, param interface{}) (string, error) { return sb.String(), err } +func (t *SqlTemplate) SetDebug(b bool) { + t.Debug = b +} + type Maker interface { MakeSql(string, interface{}) (string, error) } diff --git a/util.go b/util.go index c3b07ea..14eb2e9 100644 --- a/util.go +++ b/util.go @@ -5,6 +5,7 @@ import ( "database/sql" "github.com/jmoiron/sqlx" + "github.com/twiglab/sqlt/internal/mapper" ) var DefaultTxOptions *sql.TxOptions = NewTxOptions(sql.LevelDefault, false) @@ -43,7 +44,7 @@ func query(ctx context.Context, ext sqltExecer, id string, data interface{}, h R return err } defer rows.Close() - return h.Extract(rows) + return h.Extract(&Rs{Rows: rows}) } func exec(ctx context.Context, ext sqltExecer, id string, data interface{}) (r sql.Result, e error) { @@ -120,3 +121,15 @@ func Commit(t TxEnd) error { func Rollback(t TxEnd) error { return t.TRollback() } + +type Rs struct { + *sqlx.Rows +} + +func (r *Rs) MapScan(m map[string]interface{}) error { + return mapper.MapScan(r, m) +} + +func (r *Rs) StructScan(dist interface{}) error { + return mapper.StructScan(r, dist) +}