refactor(engine): 将辅助函数提取到独立文件并更新引用
- 将类型转换、文档操作、比较等辅助函数移动到 helpers.go 文件 - 更新 aggregate_helpers.go 中的函数调用使用新的公共辅助函数 - 更新 operators.go 中的比较函数使用新的公共辅助函数 - 更新 type_conversion.go 中的类型转换函数使用新的公共辅助函数 - 添加导出版本的辅助函数供其他包使用 - 保持向后兼容性,确保现有功能正常工作
This commit is contained in:
parent
d0b5e956c4
commit
948877c15b
|
|
@ -1,5 +1,8 @@
|
|||
package engine
|
||||
|
||||
// aggregate_helpers.go - 聚合辅助函数
|
||||
// 使用 helpers.go 中的公共辅助函数
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
|
@ -27,7 +30,7 @@ func (e *AggregationEngine) concat(operand interface{}, data map[string]interfac
|
|||
if str, ok := item.(string); ok {
|
||||
result += str
|
||||
} else {
|
||||
result += toString(item)
|
||||
result += FormatValueToString(item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
|
@ -40,8 +43,8 @@ func (e *AggregationEngine) substr(operand interface{}, data map[string]interfac
|
|||
return ""
|
||||
}
|
||||
|
||||
str := e.getFieldValueStr(types.Document{Data: data}, arr[0])
|
||||
start := int(toFloat64(arr[1]))
|
||||
str := GetFieldValueStr(types.Document{Data: data}, arr[0])
|
||||
start := int(ToFloat64(arr[1]))
|
||||
|
||||
if start < 0 {
|
||||
start = 0
|
||||
|
|
@ -52,7 +55,7 @@ func (e *AggregationEngine) substr(operand interface{}, data map[string]interfac
|
|||
|
||||
end := len(str)
|
||||
if len(arr) > 2 {
|
||||
length := int(toFloat64(arr[2]))
|
||||
length := int(ToFloat64(arr[2]))
|
||||
if length > 0 {
|
||||
end = start + length
|
||||
if end > len(str) {
|
||||
|
|
@ -73,7 +76,7 @@ func (e *AggregationEngine) add(operand interface{}, data map[string]interface{}
|
|||
|
||||
sum := 0.0
|
||||
for _, item := range arr {
|
||||
sum += toFloat64(e.evaluateExpression(data, item))
|
||||
sum += ToFloat64(e.evaluateExpression(data, item))
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
|
@ -87,7 +90,7 @@ func (e *AggregationEngine) multiply(operand interface{}, data map[string]interf
|
|||
|
||||
product := 1.0
|
||||
for _, item := range arr {
|
||||
product *= toFloat64(e.evaluateExpression(data, item))
|
||||
product *= ToFloat64(e.evaluateExpression(data, item))
|
||||
}
|
||||
return product
|
||||
}
|
||||
|
|
@ -99,8 +102,8 @@ func (e *AggregationEngine) divide(operand interface{}, data map[string]interfac
|
|||
return 0
|
||||
}
|
||||
|
||||
dividend := toFloat64(e.evaluateExpression(data, arr[0]))
|
||||
divisor := toFloat64(e.evaluateExpression(data, arr[1]))
|
||||
dividend := ToFloat64(e.evaluateExpression(data, arr[0]))
|
||||
divisor := ToFloat64(e.evaluateExpression(data, arr[1]))
|
||||
|
||||
if divisor == 0 {
|
||||
return 0
|
||||
|
|
@ -132,14 +135,14 @@ func (e *AggregationEngine) cond(operand interface{}, data map[string]interface{
|
|||
elseCond, ok3 := op["else"]
|
||||
|
||||
if ok1 && ok2 && ok3 {
|
||||
if isTrue(e.evaluateExpression(data, ifCond)) {
|
||||
if IsTrueValue(e.evaluateExpression(data, ifCond)) {
|
||||
return thenCond
|
||||
}
|
||||
return elseCond
|
||||
}
|
||||
case []interface{}:
|
||||
if len(op) >= 3 {
|
||||
if isTrue(e.evaluateExpression(data, op[0])) {
|
||||
if IsTrueValue(e.evaluateExpression(data, op[0])) {
|
||||
return op[1]
|
||||
}
|
||||
return op[2]
|
||||
|
|
@ -167,7 +170,7 @@ func (e *AggregationEngine) switchExpr(operand interface{}, data map[string]inte
|
|||
caseRaw, _ := branch["case"]
|
||||
thenRaw, _ := branch["then"]
|
||||
|
||||
if isTrueValue(e.evaluateExpression(data, caseRaw)) {
|
||||
if IsTrueValue(e.evaluateExpression(data, caseRaw)) {
|
||||
return e.evaluateExpression(data, thenRaw)
|
||||
}
|
||||
}
|
||||
|
|
@ -175,13 +178,9 @@ func (e *AggregationEngine) switchExpr(operand interface{}, data map[string]inte
|
|||
return defaultVal
|
||||
}
|
||||
|
||||
// getFieldValueStr 获取字段值的字符串形式
|
||||
// getFieldValueStr 获取字段值的字符串形式(已移到 helpers.go,此处为向后兼容)
|
||||
func (e *AggregationEngine) getFieldValueStr(doc types.Document, field interface{}) string {
|
||||
val := e.getFieldValue(doc, field)
|
||||
if str, ok := val.(string); ok {
|
||||
return str
|
||||
}
|
||||
return toString(val)
|
||||
return GetFieldValueStr(doc, field)
|
||||
}
|
||||
|
||||
// executeAddFields 执行 $addFields / $set 阶段
|
||||
|
|
@ -193,7 +192,7 @@ func (e *AggregationEngine) executeAddFields(spec interface{}, docs []types.Docu
|
|||
|
||||
var results []types.Document
|
||||
for _, doc := range docs {
|
||||
newData := deepCopyMap(doc.Data)
|
||||
newData := DeepCopyMap(doc.Data)
|
||||
for field, expr := range fields {
|
||||
newData[field] = e.evaluateExpression(newData, expr)
|
||||
}
|
||||
|
|
@ -225,9 +224,9 @@ func (e *AggregationEngine) executeUnset(spec interface{}, docs []types.Document
|
|||
|
||||
var results []types.Document
|
||||
for _, doc := range docs {
|
||||
newData := deepCopyMap(doc.Data)
|
||||
newData := DeepCopyMap(doc.Data)
|
||||
for _, field := range fields {
|
||||
removeNestedValue(newData, field)
|
||||
RemoveNestedValue(newData, field)
|
||||
}
|
||||
results = append(results, types.Document{
|
||||
ID: doc.ID,
|
||||
|
|
@ -280,7 +279,7 @@ func (e *AggregationEngine) executeSample(spec interface{}, docs []types.Documen
|
|||
switch s := spec.(type) {
|
||||
case map[string]interface{}:
|
||||
if sizeVal, ok := s["size"]; ok {
|
||||
size = int(toFloat64(sizeVal))
|
||||
size = int(ToFloat64(sizeVal))
|
||||
}
|
||||
case float64:
|
||||
size = int(s)
|
||||
|
|
@ -317,7 +316,7 @@ func (e *AggregationEngine) executeBucket(spec interface{}, docs []types.Documen
|
|||
// 转换边界为 float64 数组
|
||||
boundaries := make([]float64, 0, len(boundariesRaw))
|
||||
for _, b := range boundariesRaw {
|
||||
boundaries = append(boundaries, toFloat64(b))
|
||||
boundaries = append(boundaries, ToFloat64(b))
|
||||
}
|
||||
|
||||
// 创建桶
|
||||
|
|
@ -332,7 +331,7 @@ func (e *AggregationEngine) executeBucket(spec interface{}, docs []types.Documen
|
|||
|
||||
// 分组
|
||||
for _, doc := range docs {
|
||||
value := toFloat64(getNestedValue(doc.Data, groupBy))
|
||||
value := ToFloat64(GetNestedValue(doc.Data, groupBy))
|
||||
|
||||
bucketName := ""
|
||||
for i := 0; i < len(boundaries)-1; i++ {
|
||||
|
|
@ -384,7 +383,7 @@ func (e *AggregationEngine) ExecutePipeline(docs []types.Document, pipeline []ty
|
|||
|
||||
// abs 绝对值
|
||||
func (e *AggregationEngine) abs(operand interface{}, data map[string]interface{}) float64 {
|
||||
val := toFloat64(e.evaluateExpression(data, operand))
|
||||
val := ToFloat64(e.evaluateExpression(data, operand))
|
||||
if val < 0 {
|
||||
return -val
|
||||
}
|
||||
|
|
@ -393,13 +392,13 @@ func (e *AggregationEngine) abs(operand interface{}, data map[string]interface{}
|
|||
|
||||
// ceil 向上取整
|
||||
func (e *AggregationEngine) ceil(operand interface{}, data map[string]interface{}) float64 {
|
||||
val := toFloat64(e.evaluateExpression(data, operand))
|
||||
val := ToFloat64(e.evaluateExpression(data, operand))
|
||||
return math.Ceil(val)
|
||||
}
|
||||
|
||||
// floor 向下取整
|
||||
func (e *AggregationEngine) floor(operand interface{}, data map[string]interface{}) float64 {
|
||||
val := toFloat64(e.evaluateExpression(data, operand))
|
||||
val := ToFloat64(e.evaluateExpression(data, operand))
|
||||
return math.Floor(val)
|
||||
}
|
||||
|
||||
|
|
@ -410,24 +409,23 @@ func (e *AggregationEngine) round(operand interface{}, data map[string]interface
|
|||
|
||||
switch op := operand.(type) {
|
||||
case []interface{}:
|
||||
value = toFloat64(e.evaluateExpression(data, op[0]))
|
||||
value = ToFloat64(e.evaluateExpression(data, op[0]))
|
||||
if len(op) > 1 {
|
||||
precision = int(toFloat64(op[1]))
|
||||
precision = int(ToFloat64(op[1]))
|
||||
} else {
|
||||
precision = 0
|
||||
}
|
||||
default:
|
||||
value = toFloat64(e.evaluateExpression(data, op))
|
||||
value = ToFloat64(e.evaluateExpression(data, op))
|
||||
precision = 0
|
||||
}
|
||||
|
||||
multiplier := math.Pow(10, float64(precision))
|
||||
return math.Round(value*multiplier) / multiplier
|
||||
return RoundToPrecision(value, precision)
|
||||
}
|
||||
|
||||
// sqrt 平方根
|
||||
func (e *AggregationEngine) sqrt(operand interface{}, data map[string]interface{}) float64 {
|
||||
val := toFloat64(e.evaluateExpression(data, operand))
|
||||
val := ToFloat64(e.evaluateExpression(data, operand))
|
||||
return math.Sqrt(val)
|
||||
}
|
||||
|
||||
|
|
@ -438,9 +436,9 @@ func (e *AggregationEngine) subtract(operand interface{}, data map[string]interf
|
|||
return 0
|
||||
}
|
||||
|
||||
result := toFloat64(e.evaluateExpression(data, arr[0]))
|
||||
result := ToFloat64(e.evaluateExpression(data, arr[0]))
|
||||
for i := 1; i < len(arr); i++ {
|
||||
result -= toFloat64(e.evaluateExpression(data, arr[i]))
|
||||
result -= ToFloat64(e.evaluateExpression(data, arr[i]))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
|
@ -452,8 +450,8 @@ func (e *AggregationEngine) pow(operand interface{}, data map[string]interface{}
|
|||
return 0
|
||||
}
|
||||
|
||||
base := toFloat64(e.evaluateExpression(data, arr[0]))
|
||||
exp := toFloat64(e.evaluateExpression(data, arr[1]))
|
||||
base := ToFloat64(e.evaluateExpression(data, arr[0]))
|
||||
exp := ToFloat64(e.evaluateExpression(data, arr[1]))
|
||||
return math.Pow(base, exp)
|
||||
}
|
||||
|
||||
|
|
@ -467,15 +465,15 @@ func (e *AggregationEngine) trim(operand interface{}, data map[string]interface{
|
|||
switch op := operand.(type) {
|
||||
case map[string]interface{}:
|
||||
if in, ok := op["input"]; ok {
|
||||
input = e.getFieldValueStr(types.Document{Data: data}, in)
|
||||
input = GetFieldValueStr(types.Document{Data: data}, in)
|
||||
}
|
||||
if c, ok := op["characters"]; ok {
|
||||
chars = c.(string)
|
||||
}
|
||||
case string:
|
||||
input = e.getFieldValueStr(types.Document{Data: data}, op)
|
||||
input = GetFieldValueStr(types.Document{Data: data}, op)
|
||||
default:
|
||||
input = toString(operand)
|
||||
input = FormatValueToString(operand)
|
||||
}
|
||||
|
||||
return strings.Trim(input, chars)
|
||||
|
|
@ -483,13 +481,13 @@ func (e *AggregationEngine) trim(operand interface{}, data map[string]interface{
|
|||
|
||||
// ltrim 去除左侧空格
|
||||
func (e *AggregationEngine) ltrim(operand interface{}, data map[string]interface{}) string {
|
||||
input := e.getFieldValueStr(types.Document{Data: data}, operand)
|
||||
input := GetFieldValueStr(types.Document{Data: data}, operand)
|
||||
return strings.TrimLeft(input, " ")
|
||||
}
|
||||
|
||||
// rtrim 去除右侧空格
|
||||
func (e *AggregationEngine) rtrim(operand interface{}, data map[string]interface{}) string {
|
||||
input := e.getFieldValueStr(types.Document{Data: data}, operand)
|
||||
input := GetFieldValueStr(types.Document{Data: data}, operand)
|
||||
return strings.TrimRight(input, " ")
|
||||
}
|
||||
|
||||
|
|
@ -500,7 +498,7 @@ func (e *AggregationEngine) split(operand interface{}, data map[string]interface
|
|||
return nil
|
||||
}
|
||||
|
||||
input := e.getFieldValueStr(types.Document{Data: data}, arr[0])
|
||||
input := GetFieldValueStr(types.Document{Data: data}, arr[0])
|
||||
delimiter := arr[1].(string)
|
||||
|
||||
parts := strings.Split(input, delimiter)
|
||||
|
|
@ -518,11 +516,11 @@ func (e *AggregationEngine) replaceAll(operand interface{}, data map[string]inte
|
|||
return ""
|
||||
}
|
||||
|
||||
input := e.getFieldValueStr(types.Document{Data: data}, spec["input"])
|
||||
input := GetFieldValueStr(types.Document{Data: data}, spec["input"])
|
||||
find := spec["find"].(string)
|
||||
replacement := ""
|
||||
if rep, ok := spec["replacement"]; ok {
|
||||
replacement = toString(rep)
|
||||
replacement = FormatValueToString(rep)
|
||||
}
|
||||
|
||||
return strings.ReplaceAll(input, find, replacement)
|
||||
|
|
@ -535,8 +533,8 @@ func (e *AggregationEngine) strcasecmp(operand interface{}, data map[string]inte
|
|||
return 0
|
||||
}
|
||||
|
||||
str1 := strings.ToLower(e.getFieldValueStr(types.Document{Data: data}, arr[0]))
|
||||
str2 := strings.ToLower(e.getFieldValueStr(types.Document{Data: data}, arr[1]))
|
||||
str1 := strings.ToLower(GetFieldValueStr(types.Document{Data: data}, arr[0]))
|
||||
str2 := strings.ToLower(GetFieldValueStr(types.Document{Data: data}, arr[1]))
|
||||
|
||||
if str1 < str2 {
|
||||
return -1
|
||||
|
|
@ -573,7 +571,7 @@ func (e *AggregationEngine) filter(operand interface{}, data map[string]interfac
|
|||
}
|
||||
tempData["$$"+as] = item
|
||||
|
||||
if isTrue(e.evaluateExpression(tempData, condRaw)) {
|
||||
if IsTrueValue(e.evaluateExpression(tempData, condRaw)) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
|
|
@ -638,9 +636,9 @@ func (e *AggregationEngine) slice(operand interface{}, data map[string]interface
|
|||
case []interface{}:
|
||||
if len(op) >= 2 {
|
||||
arr = e.toArray(op[0])
|
||||
skip = int(toFloat64(op[1]))
|
||||
skip = int(ToFloat64(op[1]))
|
||||
if len(op) > 2 {
|
||||
limit = int(toFloat64(op[2]))
|
||||
limit = int(ToFloat64(op[2]))
|
||||
} else {
|
||||
limit = len(arr) - skip
|
||||
}
|
||||
|
|
@ -702,7 +700,7 @@ func (e *AggregationEngine) objectToArray(operand interface{}, data map[string]i
|
|||
|
||||
// ========== 辅助函数 ==========
|
||||
|
||||
// toArray 将值转换为数组
|
||||
// toArray 将值转换为数组(保持向后兼容)
|
||||
func (e *AggregationEngine) toArray(value interface{}) []interface{} {
|
||||
switch v := value.(type) {
|
||||
case []interface{}:
|
||||
|
|
@ -724,7 +722,7 @@ func (e *AggregationEngine) boolAnd(operand interface{}, data map[string]interfa
|
|||
}
|
||||
|
||||
for _, item := range arr {
|
||||
if !isTrue(e.evaluateExpression(data, item)) {
|
||||
if !IsTrueValue(e.evaluateExpression(data, item)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
@ -739,7 +737,7 @@ func (e *AggregationEngine) boolOr(operand interface{}, data map[string]interfac
|
|||
}
|
||||
|
||||
for _, item := range arr {
|
||||
if isTrue(e.evaluateExpression(data, item)) {
|
||||
if IsTrueValue(e.evaluateExpression(data, item)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
|
@ -748,5 +746,5 @@ func (e *AggregationEngine) boolOr(operand interface{}, data map[string]interfac
|
|||
|
||||
// boolNot 布尔非
|
||||
func (e *AggregationEngine) boolNot(operand interface{}, data map[string]interface{}) bool {
|
||||
return !isTrue(e.evaluateExpression(data, operand))
|
||||
return !IsTrueValue(e.evaluateExpression(data, operand))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,477 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.kingecg.top/kingecg/gomog/pkg/types"
|
||||
)
|
||||
|
||||
// ========== 类型转换辅助函数 ==========
|
||||
|
||||
// ToFloat64 将任意值转换为 float64(导出版本)
|
||||
func ToFloat64(v interface{}) float64 {
|
||||
return toFloat64(v)
|
||||
}
|
||||
|
||||
// toFloat64 将值转换为 float64
|
||||
func toFloat64(v interface{}) float64 {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return float64(val)
|
||||
case int8:
|
||||
return float64(val)
|
||||
case int16:
|
||||
return float64(val)
|
||||
case int32:
|
||||
return float64(val)
|
||||
case int64:
|
||||
return float64(val)
|
||||
case uint:
|
||||
return float64(val)
|
||||
case uint8:
|
||||
return float64(val)
|
||||
case uint16:
|
||||
return float64(val)
|
||||
case uint32:
|
||||
return float64(val)
|
||||
case uint64:
|
||||
return float64(val)
|
||||
case float32:
|
||||
return float64(val)
|
||||
case float64:
|
||||
return val
|
||||
case string:
|
||||
if num, err := strconv.ParseFloat(val, 64); err == nil {
|
||||
return num
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// ToInt64 将任意值转换为 int64(导出版本)
|
||||
func ToInt64(v interface{}) int64 {
|
||||
return toInt64(v)
|
||||
}
|
||||
|
||||
// toInt64 将值转换为 int64
|
||||
func toInt64(v interface{}) int64 {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return int64(val)
|
||||
case int8:
|
||||
return int64(val)
|
||||
case int16:
|
||||
return int64(val)
|
||||
case int32:
|
||||
return int64(val)
|
||||
case int64:
|
||||
return val
|
||||
case uint:
|
||||
return int64(val)
|
||||
case uint8:
|
||||
return int64(val)
|
||||
case uint16:
|
||||
return int64(val)
|
||||
case uint32:
|
||||
return int64(val)
|
||||
case uint64:
|
||||
return int64(val)
|
||||
case float32:
|
||||
return int64(val)
|
||||
case float64:
|
||||
return int64(val)
|
||||
case string:
|
||||
if num, err := strconv.ParseInt(val, 10, 64); err == nil {
|
||||
return num
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// FormatValueToString 将任意值格式化为字符串(导出版本)
|
||||
func FormatValueToString(value interface{}) string {
|
||||
return formatValueToString(value)
|
||||
}
|
||||
|
||||
// formatValueToString 将任意值格式化为字符串
|
||||
func formatValueToString(value interface{}) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case bool:
|
||||
return strconv.FormatBool(v)
|
||||
case int, int8, int16, int32, int64:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(v), 'g', -1, 32)
|
||||
case float64:
|
||||
return strconv.FormatFloat(v, 'g', -1, 64)
|
||||
case time.Time:
|
||||
return v.Format(time.RFC3339)
|
||||
case []interface{}:
|
||||
result := "["
|
||||
for i, item := range v {
|
||||
if i > 0 {
|
||||
result += ","
|
||||
}
|
||||
result += formatValueToString(item)
|
||||
}
|
||||
result += "]"
|
||||
return result
|
||||
case map[string]interface{}:
|
||||
result := "{"
|
||||
first := true
|
||||
for k, val := range v {
|
||||
if !first {
|
||||
result += ","
|
||||
}
|
||||
result += fmt.Sprintf("%s:%v", k, val)
|
||||
first = false
|
||||
}
|
||||
result += "}"
|
||||
return result
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// IsTrueValue 检查值是否为真值(导出版本)
|
||||
// 注意:内部使用的 isTrueValue/isTrue 已在 query.go 和 aggregate.go 中定义
|
||||
func IsTrueValue(v interface{}) bool {
|
||||
// 使用统一的转换逻辑
|
||||
if v == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch val := v.(type) {
|
||||
case bool:
|
||||
return val
|
||||
case int, int8, int16, int32, int64:
|
||||
return ToInt64(v) != 0
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
return ToInt64(v) != 0
|
||||
case float32, float64:
|
||||
return ToFloat64(v) != 0
|
||||
case string:
|
||||
return val != "" && val != "0" && strings.ToLower(val) != "false"
|
||||
case []interface{}:
|
||||
return len(val) > 0
|
||||
case map[string]interface{}:
|
||||
return len(val) > 0
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 文档操作辅助函数 ==========
|
||||
|
||||
// GetNestedValue 从嵌套 map 中获取值
|
||||
func GetNestedValue(data map[string]interface{}, field string) interface{} {
|
||||
if field == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Split(field, ".")
|
||||
current := data
|
||||
|
||||
for i, part := range parts {
|
||||
if i == len(parts)-1 {
|
||||
return current[part]
|
||||
}
|
||||
|
||||
if next, ok := current[part].(map[string]interface{}); ok {
|
||||
current = next
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetNestedValue 设置嵌套 map 中的值
|
||||
func SetNestedValue(data map[string]interface{}, field string, value interface{}) {
|
||||
if field == "" {
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.Split(field, ".")
|
||||
current := data
|
||||
|
||||
for i, part := range parts {
|
||||
if i == len(parts)-1 {
|
||||
current[part] = value
|
||||
return
|
||||
}
|
||||
|
||||
if next, ok := current[part].(map[string]interface{}); ok {
|
||||
current = next
|
||||
} else {
|
||||
newMap := make(map[string]interface{})
|
||||
current[part] = newMap
|
||||
current = newMap
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveNestedValue 移除嵌套 map 中的值
|
||||
func RemoveNestedValue(data map[string]interface{}, field string) {
|
||||
if field == "" {
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.Split(field, ".")
|
||||
current := data
|
||||
|
||||
for i, part := range parts {
|
||||
if i == len(parts)-1 {
|
||||
delete(current, part)
|
||||
return
|
||||
}
|
||||
|
||||
if next, ok := current[part].(map[string]interface{}); ok {
|
||||
current = next
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DeepCopyMap 深度复制 map
|
||||
func DeepCopyMap(src map[string]interface{}) map[string]interface{} {
|
||||
dst := make(map[string]interface{})
|
||||
for k, v := range src {
|
||||
dst[k] = deepCopyValue(v)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// deepCopyValue 深度复制值
|
||||
func deepCopyValue(v interface{}) interface{} {
|
||||
switch val := v.(type) {
|
||||
case map[string]interface{}:
|
||||
return DeepCopyMap(val)
|
||||
case []interface{}:
|
||||
arr := make([]interface{}, len(val))
|
||||
for i, item := range val {
|
||||
arr[i] = deepCopyValue(item)
|
||||
}
|
||||
return arr
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 比较辅助函数 ==========
|
||||
|
||||
// CompareEq 相等比较(导出版本)
|
||||
func CompareEq(a, b interface{}) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if isComplexType(a) || isComplexType(b) {
|
||||
return reflect.DeepEqual(a, b)
|
||||
}
|
||||
|
||||
return normalizeValue(a) == normalizeValue(b)
|
||||
}
|
||||
|
||||
// CompareNumbers 比较两个数值,返回 -1/0/1(导出版本)
|
||||
func CompareNumbers(a, b interface{}) int {
|
||||
return compareNumbers(a, b)
|
||||
}
|
||||
|
||||
// compareNumbers 比较两个数值
|
||||
func compareNumbers(a, b interface{}) int {
|
||||
numA := toFloat64(a)
|
||||
numB := toFloat64(b)
|
||||
|
||||
if numA < numB {
|
||||
return -1
|
||||
} else if numA > numB {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// IsComplexType 检查是否是复杂类型(导出版本)
|
||||
func IsComplexType(v interface{}) bool {
|
||||
return isComplexType(v)
|
||||
}
|
||||
|
||||
// isComplexType 检查是否是复杂类型
|
||||
func isComplexType(v interface{}) bool {
|
||||
switch v.(type) {
|
||||
case []interface{}:
|
||||
return true
|
||||
case map[string]interface{}:
|
||||
return true
|
||||
case map[interface{}]interface{}:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// NormalizeValue 标准化值用于比较(导出版本)
|
||||
func NormalizeValue(v interface{}) interface{} {
|
||||
return normalizeValue(v)
|
||||
}
|
||||
|
||||
// normalizeValue 标准化值
|
||||
func normalizeValue(v interface{}) interface{} {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch val := v.(type) {
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
|
||||
return toFloat64(v)
|
||||
case string:
|
||||
if num, err := strconv.ParseFloat(val, 64); err == nil {
|
||||
return num
|
||||
}
|
||||
return strings.ToLower(val)
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// ========== 数组辅助函数 ==========
|
||||
|
||||
// ContainsElement 检查数组是否包含指定元素
|
||||
func ContainsElement(arr []interface{}, element interface{}) bool {
|
||||
for _, item := range arr {
|
||||
if compareEq(item, element) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ContainsAllElements 检查数组是否包含所有指定元素
|
||||
func ContainsAllElements(arr []interface{}, elements []interface{}) bool {
|
||||
for _, elem := range elements {
|
||||
if !ContainsElement(arr, elem) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ArrayIntersection 计算数组交集
|
||||
func ArrayIntersection(a, b []interface{}) []interface{} {
|
||||
result := make([]interface{}, 0)
|
||||
for _, item := range a {
|
||||
if ContainsElement(b, item) && !ContainsElement(result, item) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ArrayUnion 计算数组并集
|
||||
func ArrayUnion(a, b []interface{}) []interface{} {
|
||||
result := make([]interface{}, len(a))
|
||||
copy(result, a)
|
||||
for _, item := range b {
|
||||
if !ContainsElement(result, item) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ========== 正则表达式辅助函数 ==========
|
||||
|
||||
// MatchRegex 正则表达式匹配
|
||||
func MatchRegex(value interface{}, pattern interface{}) bool {
|
||||
str, ok := value.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
patternStr, ok := pattern.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
matched, _ := regexp.MatchString(patternStr, str)
|
||||
return matched
|
||||
}
|
||||
|
||||
// ========== 类型检查辅助函数 ==========
|
||||
|
||||
// CheckType 检查值的类型
|
||||
func CheckType(value interface{}, typeName string) bool {
|
||||
if value == nil {
|
||||
return typeName == "null"
|
||||
}
|
||||
|
||||
var actualType string
|
||||
switch reflect.TypeOf(value).Kind() {
|
||||
case reflect.String:
|
||||
actualType = "string"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
actualType = "int"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
actualType = "double"
|
||||
case reflect.Bool:
|
||||
actualType = "bool"
|
||||
case reflect.Slice, reflect.Array:
|
||||
actualType = "array"
|
||||
case reflect.Map:
|
||||
actualType = "object"
|
||||
}
|
||||
|
||||
return actualType == typeName
|
||||
}
|
||||
|
||||
// ========== 数学辅助函数 ==========
|
||||
|
||||
// RoundToPrecision 四舍五入到指定精度
|
||||
func RoundToPrecision(value float64, precision int) float64 {
|
||||
multiplier := math.Pow(10, float64(precision))
|
||||
return math.Round(value*multiplier) / multiplier
|
||||
}
|
||||
|
||||
// ========== 文档辅助函数 ==========
|
||||
|
||||
// GetFieldValue 从文档中获取字段值
|
||||
func GetFieldValue(doc types.Document, field interface{}) interface{} {
|
||||
switch f := field.(type) {
|
||||
case string:
|
||||
if strings.HasPrefix(f, "$") {
|
||||
return GetNestedValue(doc.Data, f[1:])
|
||||
}
|
||||
return GetNestedValue(doc.Data, f)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetFieldValueStr 从文档中获取字段值的字符串形式
|
||||
func GetFieldValueStr(doc types.Document, field interface{}) string {
|
||||
val := GetFieldValue(doc, field)
|
||||
if str, ok := val.(string); ok {
|
||||
return str
|
||||
}
|
||||
return toString(val)
|
||||
}
|
||||
|
|
@ -1,111 +1,31 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
// operators.go - 查询操作符实现
|
||||
// 使用 helpers.go 中的公共辅助函数
|
||||
|
||||
// compareEq 相等比较
|
||||
func compareEq(a, b interface{}) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 对于 slice、map 等复杂类型,使用 reflect.DeepEqual
|
||||
if isComplexType(a) || isComplexType(b) {
|
||||
return reflect.DeepEqual(a, b)
|
||||
}
|
||||
|
||||
// 类型转换后比较
|
||||
return normalizeValue(a) == normalizeValue(b)
|
||||
}
|
||||
|
||||
// isComplexType 检查是否是复杂类型(slice、map 等)
|
||||
func isComplexType(v interface{}) bool {
|
||||
switch v.(type) {
|
||||
case []interface{}:
|
||||
return true
|
||||
case map[string]interface{}:
|
||||
return true
|
||||
case map[interface{}]interface{}:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return CompareEq(a, b)
|
||||
}
|
||||
|
||||
// compareGt 大于比较
|
||||
func compareGt(a, b interface{}) bool {
|
||||
return compareNumbers(a, b) > 0
|
||||
return CompareNumbers(a, b) > 0
|
||||
}
|
||||
|
||||
// compareGte 大于等于比较
|
||||
func compareGte(a, b interface{}) bool {
|
||||
return compareNumbers(a, b) >= 0
|
||||
return CompareNumbers(a, b) >= 0
|
||||
}
|
||||
|
||||
// compareLt 小于比较
|
||||
func compareLt(a, b interface{}) bool {
|
||||
return compareNumbers(a, b) < 0
|
||||
return CompareNumbers(a, b) < 0
|
||||
}
|
||||
|
||||
// compareLte 小于等于比较
|
||||
func compareLte(a, b interface{}) bool {
|
||||
return compareNumbers(a, b) <= 0
|
||||
}
|
||||
|
||||
// compareNumbers 比较两个数值,返回 -1/0/1
|
||||
func compareNumbers(a, b interface{}) int {
|
||||
numA := toFloat64(a)
|
||||
numB := toFloat64(b)
|
||||
|
||||
if numA < numB {
|
||||
return -1
|
||||
} else if numA > numB {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// toFloat64 将值转换为 float64
|
||||
func toFloat64(v interface{}) float64 {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return float64(val)
|
||||
case int8:
|
||||
return float64(val)
|
||||
case int16:
|
||||
return float64(val)
|
||||
case int32:
|
||||
return float64(val)
|
||||
case int64:
|
||||
return float64(val)
|
||||
case uint:
|
||||
return float64(val)
|
||||
case uint8:
|
||||
return float64(val)
|
||||
case uint16:
|
||||
return float64(val)
|
||||
case uint32:
|
||||
return float64(val)
|
||||
case uint64:
|
||||
return float64(val)
|
||||
case float32:
|
||||
return float64(val)
|
||||
case float64:
|
||||
return val
|
||||
case string:
|
||||
// 尝试解析字符串为数字
|
||||
if num, err := strconv.ParseFloat(val, 64); err == nil {
|
||||
return num
|
||||
}
|
||||
}
|
||||
return 0
|
||||
return CompareNumbers(a, b) <= 0
|
||||
}
|
||||
|
||||
// compareIn 检查值是否在数组中
|
||||
|
|
@ -115,28 +35,12 @@ func compareIn(value interface{}, operand interface{}) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
for _, item := range arr {
|
||||
if compareEq(value, item) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return ContainsElement(arr, value)
|
||||
}
|
||||
|
||||
// compareRegex 正则表达式匹配
|
||||
func compareRegex(value interface{}, operand interface{}) bool {
|
||||
str, ok := value.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
pattern, ok := operand.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
matched, _ := regexp.MatchString(pattern, str)
|
||||
return matched
|
||||
return MatchRegex(value, operand)
|
||||
}
|
||||
|
||||
// compareType 类型检查
|
||||
|
|
@ -145,30 +49,12 @@ func compareType(value interface{}, operand interface{}) bool {
|
|||
return operand == "null"
|
||||
}
|
||||
|
||||
var typeName string
|
||||
switch reflect.TypeOf(value).Kind() {
|
||||
case reflect.String:
|
||||
typeName = "string"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
typeName = "int"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
typeName = "double"
|
||||
case reflect.Bool:
|
||||
typeName = "bool"
|
||||
case reflect.Slice, reflect.Array:
|
||||
typeName = "array"
|
||||
case reflect.Map:
|
||||
typeName = "object"
|
||||
}
|
||||
|
||||
// 支持字符串或数组形式的类型检查
|
||||
switch op := operand.(type) {
|
||||
case string:
|
||||
return typeName == op
|
||||
return CheckType(value, op)
|
||||
case []interface{}:
|
||||
for _, t := range op {
|
||||
if ts, ok := t.(string); ok && typeName == ts {
|
||||
if ts, ok := t.(string); ok && CheckType(value, ts) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
|
@ -189,20 +75,7 @@ func compareAll(value interface{}, operand interface{}) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
for _, req := range required {
|
||||
found := false
|
||||
for _, item := range arr {
|
||||
if compareEq(item, req) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
return ContainsAllElements(arr, required)
|
||||
}
|
||||
|
||||
// compareElemMatch 数组元素匹配
|
||||
|
|
@ -252,7 +125,7 @@ func compareSize(value interface{}, operand interface{}) bool {
|
|||
|
||||
// compareMod 模运算:value % divisor == remainder
|
||||
func compareMod(value interface{}, operand interface{}) bool {
|
||||
num := toFloat64(value)
|
||||
num := ToFloat64(value)
|
||||
|
||||
var divisor, remainder float64
|
||||
switch op := operand.(type) {
|
||||
|
|
@ -260,8 +133,8 @@ func compareMod(value interface{}, operand interface{}) bool {
|
|||
if len(op) != 2 {
|
||||
return false
|
||||
}
|
||||
divisor = toFloat64(op[0])
|
||||
remainder = toFloat64(op[1])
|
||||
divisor = ToFloat64(op[0])
|
||||
remainder = ToFloat64(op[1])
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
|
@ -282,84 +155,28 @@ func compareMod(value interface{}, operand interface{}) bool {
|
|||
|
||||
// compareBitsAllClear 位运算:所有指定位都为 0
|
||||
func compareBitsAllClear(value interface{}, operand interface{}) bool {
|
||||
num := toInt64(value)
|
||||
mask := toInt64(operand)
|
||||
num := ToInt64(value)
|
||||
mask := ToInt64(operand)
|
||||
return (num & mask) == 0
|
||||
}
|
||||
|
||||
// compareBitsAllSet 位运算:所有指定位都为 1
|
||||
func compareBitsAllSet(value interface{}, operand interface{}) bool {
|
||||
num := toInt64(value)
|
||||
mask := toInt64(operand)
|
||||
num := ToInt64(value)
|
||||
mask := ToInt64(operand)
|
||||
return (num & mask) == mask
|
||||
}
|
||||
|
||||
// compareBitsAnyClear 位运算:任意指定位为 0
|
||||
func compareBitsAnyClear(value interface{}, operand interface{}) bool {
|
||||
num := toInt64(value)
|
||||
mask := toInt64(operand)
|
||||
num := ToInt64(value)
|
||||
mask := ToInt64(operand)
|
||||
return (num & mask) != mask
|
||||
}
|
||||
|
||||
// compareBitsAnySet 位运算:任意指定位为 1
|
||||
func compareBitsAnySet(value interface{}, operand interface{}) bool {
|
||||
num := toInt64(value)
|
||||
mask := toInt64(operand)
|
||||
num := ToInt64(value)
|
||||
mask := ToInt64(operand)
|
||||
return (num & mask) != 0
|
||||
}
|
||||
|
||||
// toInt64 将值转换为 int64
|
||||
func toInt64(v interface{}) int64 {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return int64(val)
|
||||
case int8:
|
||||
return int64(val)
|
||||
case int16:
|
||||
return int64(val)
|
||||
case int32:
|
||||
return int64(val)
|
||||
case int64:
|
||||
return val
|
||||
case uint:
|
||||
return int64(val)
|
||||
case uint8:
|
||||
return int64(val)
|
||||
case uint16:
|
||||
return int64(val)
|
||||
case uint32:
|
||||
return int64(val)
|
||||
case uint64:
|
||||
return int64(val)
|
||||
case float32:
|
||||
return int64(val)
|
||||
case float64:
|
||||
return int64(val)
|
||||
case string:
|
||||
if num, err := strconv.ParseInt(val, 10, 64); err == nil {
|
||||
return num
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// normalizeValue 标准化值用于比较
|
||||
func normalizeValue(v interface{}) interface{} {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 处理数字类型
|
||||
switch val := v.(type) {
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
|
||||
return toFloat64(v)
|
||||
case string:
|
||||
// 尝试将字符串解析为数字
|
||||
if num, err := strconv.ParseFloat(val, 64); err == nil {
|
||||
return num
|
||||
}
|
||||
return strings.ToLower(val)
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,39 +1,36 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
// type_conversion.go - 类型转换操作符实现
|
||||
// 使用 helpers.go 中的公共辅助函数
|
||||
|
||||
// toString 转换为字符串
|
||||
func (e *AggregationEngine) toString(operand interface{}, data map[string]interface{}) string {
|
||||
val := e.evaluateExpression(data, operand)
|
||||
return formatValueToString(val)
|
||||
return FormatValueToString(val)
|
||||
}
|
||||
|
||||
// toInt 转换为整数 (int32)
|
||||
func (e *AggregationEngine) toInt(operand interface{}, data map[string]interface{}) int32 {
|
||||
val := e.evaluateExpression(data, operand)
|
||||
return int32(toInt64(val))
|
||||
return int32(ToInt64(val))
|
||||
}
|
||||
|
||||
// toLong 转换为长整数 (int64)
|
||||
func (e *AggregationEngine) toLong(operand interface{}, data map[string]interface{}) int64 {
|
||||
val := e.evaluateExpression(data, operand)
|
||||
return toInt64(val)
|
||||
return ToInt64(val)
|
||||
}
|
||||
|
||||
// toDouble 转换为浮点数 (double)
|
||||
func (e *AggregationEngine) toDouble(operand interface{}, data map[string]interface{}) float64 {
|
||||
val := e.evaluateExpression(data, operand)
|
||||
return toFloat64(val)
|
||||
return ToFloat64(val)
|
||||
}
|
||||
|
||||
// toBool 转换为布尔值
|
||||
func (e *AggregationEngine) toBool(operand interface{}, data map[string]interface{}) bool {
|
||||
val := e.evaluateExpression(data, operand)
|
||||
return isTrueValue(val)
|
||||
return IsTrueValue(val)
|
||||
}
|
||||
|
||||
// toDocument 转换为文档(对象)
|
||||
|
|
@ -53,53 +50,3 @@ func (e *AggregationEngine) toDocument(operand interface{}, data map[string]inte
|
|||
// 其他情况返回空对象(MongoDB 行为)
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
|
||||
// formatValueToString 将任意值格式化为字符串
|
||||
func formatValueToString(value interface{}) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case bool:
|
||||
return strconv.FormatBool(v)
|
||||
case int, int8, int16, int32, int64:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(v), 'g', -1, 32)
|
||||
case float64:
|
||||
return strconv.FormatFloat(v, 'g', -1, 64)
|
||||
case time.Time:
|
||||
return v.Format(time.RFC3339)
|
||||
case []interface{}:
|
||||
// 数组转为 JSON 风格字符串
|
||||
result := "["
|
||||
for i, item := range v {
|
||||
if i > 0 {
|
||||
result += ","
|
||||
}
|
||||
result += formatValueToString(item)
|
||||
}
|
||||
result += "]"
|
||||
return result
|
||||
case map[string]interface{}:
|
||||
// 对象转为 JSON 风格字符串(简化版)
|
||||
result := "{"
|
||||
first := true
|
||||
for k, val := range v {
|
||||
if !first {
|
||||
result += ","
|
||||
}
|
||||
result += fmt.Sprintf("%s:%v", k, val)
|
||||
first = false
|
||||
}
|
||||
result += "}"
|
||||
return result
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue