refactor(engine): 将辅助函数提取到独立文件并更新引用

- 将类型转换、文档操作、比较等辅助函数移动到 helpers.go 文件
- 更新 aggregate_helpers.go 中的函数调用使用新的公共辅助函数
- 更新 operators.go 中的比较函数使用新的公共辅助函数
- 更新 type_conversion.go 中的类型转换函数使用新的公共辅助函数
- 添加导出版本的辅助函数供其他包使用
- 保持向后兼容性,确保现有功能正常工作
This commit is contained in:
kingecg 2026-03-14 18:48:36 +08:00
parent d0b5e956c4
commit 948877c15b
4 changed files with 558 additions and 319 deletions

View File

@ -1,5 +1,8 @@
package engine
// aggregate_helpers.go - 聚合辅助函数
// 使用 helpers.go 中的公共辅助函数
import (
"fmt"
"math"
@ -27,7 +30,7 @@ func (e *AggregationEngine) concat(operand interface{}, data map[string]interfac
if str, ok := item.(string); ok {
result += str
} else {
result += toString(item)
result += FormatValueToString(item)
}
}
return result
@ -40,8 +43,8 @@ func (e *AggregationEngine) substr(operand interface{}, data map[string]interfac
return ""
}
str := e.getFieldValueStr(types.Document{Data: data}, arr[0])
start := int(toFloat64(arr[1]))
str := GetFieldValueStr(types.Document{Data: data}, arr[0])
start := int(ToFloat64(arr[1]))
if start < 0 {
start = 0
@ -52,7 +55,7 @@ func (e *AggregationEngine) substr(operand interface{}, data map[string]interfac
end := len(str)
if len(arr) > 2 {
length := int(toFloat64(arr[2]))
length := int(ToFloat64(arr[2]))
if length > 0 {
end = start + length
if end > len(str) {
@ -73,7 +76,7 @@ func (e *AggregationEngine) add(operand interface{}, data map[string]interface{}
sum := 0.0
for _, item := range arr {
sum += toFloat64(e.evaluateExpression(data, item))
sum += ToFloat64(e.evaluateExpression(data, item))
}
return sum
}
@ -87,7 +90,7 @@ func (e *AggregationEngine) multiply(operand interface{}, data map[string]interf
product := 1.0
for _, item := range arr {
product *= toFloat64(e.evaluateExpression(data, item))
product *= ToFloat64(e.evaluateExpression(data, item))
}
return product
}
@ -99,8 +102,8 @@ func (e *AggregationEngine) divide(operand interface{}, data map[string]interfac
return 0
}
dividend := toFloat64(e.evaluateExpression(data, arr[0]))
divisor := toFloat64(e.evaluateExpression(data, arr[1]))
dividend := ToFloat64(e.evaluateExpression(data, arr[0]))
divisor := ToFloat64(e.evaluateExpression(data, arr[1]))
if divisor == 0 {
return 0
@ -132,14 +135,14 @@ func (e *AggregationEngine) cond(operand interface{}, data map[string]interface{
elseCond, ok3 := op["else"]
if ok1 && ok2 && ok3 {
if isTrue(e.evaluateExpression(data, ifCond)) {
if IsTrueValue(e.evaluateExpression(data, ifCond)) {
return thenCond
}
return elseCond
}
case []interface{}:
if len(op) >= 3 {
if isTrue(e.evaluateExpression(data, op[0])) {
if IsTrueValue(e.evaluateExpression(data, op[0])) {
return op[1]
}
return op[2]
@ -167,7 +170,7 @@ func (e *AggregationEngine) switchExpr(operand interface{}, data map[string]inte
caseRaw, _ := branch["case"]
thenRaw, _ := branch["then"]
if isTrueValue(e.evaluateExpression(data, caseRaw)) {
if IsTrueValue(e.evaluateExpression(data, caseRaw)) {
return e.evaluateExpression(data, thenRaw)
}
}
@ -175,13 +178,9 @@ func (e *AggregationEngine) switchExpr(operand interface{}, data map[string]inte
return defaultVal
}
// getFieldValueStr 获取字段值的字符串形式
// getFieldValueStr 获取字段值的字符串形式(已移到 helpers.go此处为向后兼容
func (e *AggregationEngine) getFieldValueStr(doc types.Document, field interface{}) string {
val := e.getFieldValue(doc, field)
if str, ok := val.(string); ok {
return str
}
return toString(val)
return GetFieldValueStr(doc, field)
}
// executeAddFields 执行 $addFields / $set 阶段
@ -193,7 +192,7 @@ func (e *AggregationEngine) executeAddFields(spec interface{}, docs []types.Docu
var results []types.Document
for _, doc := range docs {
newData := deepCopyMap(doc.Data)
newData := DeepCopyMap(doc.Data)
for field, expr := range fields {
newData[field] = e.evaluateExpression(newData, expr)
}
@ -225,9 +224,9 @@ func (e *AggregationEngine) executeUnset(spec interface{}, docs []types.Document
var results []types.Document
for _, doc := range docs {
newData := deepCopyMap(doc.Data)
newData := DeepCopyMap(doc.Data)
for _, field := range fields {
removeNestedValue(newData, field)
RemoveNestedValue(newData, field)
}
results = append(results, types.Document{
ID: doc.ID,
@ -280,7 +279,7 @@ func (e *AggregationEngine) executeSample(spec interface{}, docs []types.Documen
switch s := spec.(type) {
case map[string]interface{}:
if sizeVal, ok := s["size"]; ok {
size = int(toFloat64(sizeVal))
size = int(ToFloat64(sizeVal))
}
case float64:
size = int(s)
@ -317,7 +316,7 @@ func (e *AggregationEngine) executeBucket(spec interface{}, docs []types.Documen
// 转换边界为 float64 数组
boundaries := make([]float64, 0, len(boundariesRaw))
for _, b := range boundariesRaw {
boundaries = append(boundaries, toFloat64(b))
boundaries = append(boundaries, ToFloat64(b))
}
// 创建桶
@ -332,7 +331,7 @@ func (e *AggregationEngine) executeBucket(spec interface{}, docs []types.Documen
// 分组
for _, doc := range docs {
value := toFloat64(getNestedValue(doc.Data, groupBy))
value := ToFloat64(GetNestedValue(doc.Data, groupBy))
bucketName := ""
for i := 0; i < len(boundaries)-1; i++ {
@ -384,7 +383,7 @@ func (e *AggregationEngine) ExecutePipeline(docs []types.Document, pipeline []ty
// abs 绝对值
func (e *AggregationEngine) abs(operand interface{}, data map[string]interface{}) float64 {
val := toFloat64(e.evaluateExpression(data, operand))
val := ToFloat64(e.evaluateExpression(data, operand))
if val < 0 {
return -val
}
@ -393,13 +392,13 @@ func (e *AggregationEngine) abs(operand interface{}, data map[string]interface{}
// ceil 向上取整
func (e *AggregationEngine) ceil(operand interface{}, data map[string]interface{}) float64 {
val := toFloat64(e.evaluateExpression(data, operand))
val := ToFloat64(e.evaluateExpression(data, operand))
return math.Ceil(val)
}
// floor 向下取整
func (e *AggregationEngine) floor(operand interface{}, data map[string]interface{}) float64 {
val := toFloat64(e.evaluateExpression(data, operand))
val := ToFloat64(e.evaluateExpression(data, operand))
return math.Floor(val)
}
@ -410,24 +409,23 @@ func (e *AggregationEngine) round(operand interface{}, data map[string]interface
switch op := operand.(type) {
case []interface{}:
value = toFloat64(e.evaluateExpression(data, op[0]))
value = ToFloat64(e.evaluateExpression(data, op[0]))
if len(op) > 1 {
precision = int(toFloat64(op[1]))
precision = int(ToFloat64(op[1]))
} else {
precision = 0
}
default:
value = toFloat64(e.evaluateExpression(data, op))
value = ToFloat64(e.evaluateExpression(data, op))
precision = 0
}
multiplier := math.Pow(10, float64(precision))
return math.Round(value*multiplier) / multiplier
return RoundToPrecision(value, precision)
}
// sqrt 平方根
func (e *AggregationEngine) sqrt(operand interface{}, data map[string]interface{}) float64 {
val := toFloat64(e.evaluateExpression(data, operand))
val := ToFloat64(e.evaluateExpression(data, operand))
return math.Sqrt(val)
}
@ -438,9 +436,9 @@ func (e *AggregationEngine) subtract(operand interface{}, data map[string]interf
return 0
}
result := toFloat64(e.evaluateExpression(data, arr[0]))
result := ToFloat64(e.evaluateExpression(data, arr[0]))
for i := 1; i < len(arr); i++ {
result -= toFloat64(e.evaluateExpression(data, arr[i]))
result -= ToFloat64(e.evaluateExpression(data, arr[i]))
}
return result
}
@ -452,8 +450,8 @@ func (e *AggregationEngine) pow(operand interface{}, data map[string]interface{}
return 0
}
base := toFloat64(e.evaluateExpression(data, arr[0]))
exp := toFloat64(e.evaluateExpression(data, arr[1]))
base := ToFloat64(e.evaluateExpression(data, arr[0]))
exp := ToFloat64(e.evaluateExpression(data, arr[1]))
return math.Pow(base, exp)
}
@ -467,15 +465,15 @@ func (e *AggregationEngine) trim(operand interface{}, data map[string]interface{
switch op := operand.(type) {
case map[string]interface{}:
if in, ok := op["input"]; ok {
input = e.getFieldValueStr(types.Document{Data: data}, in)
input = GetFieldValueStr(types.Document{Data: data}, in)
}
if c, ok := op["characters"]; ok {
chars = c.(string)
}
case string:
input = e.getFieldValueStr(types.Document{Data: data}, op)
input = GetFieldValueStr(types.Document{Data: data}, op)
default:
input = toString(operand)
input = FormatValueToString(operand)
}
return strings.Trim(input, chars)
@ -483,13 +481,13 @@ func (e *AggregationEngine) trim(operand interface{}, data map[string]interface{
// ltrim 去除左侧空格
func (e *AggregationEngine) ltrim(operand interface{}, data map[string]interface{}) string {
input := e.getFieldValueStr(types.Document{Data: data}, operand)
input := GetFieldValueStr(types.Document{Data: data}, operand)
return strings.TrimLeft(input, " ")
}
// rtrim 去除右侧空格
func (e *AggregationEngine) rtrim(operand interface{}, data map[string]interface{}) string {
input := e.getFieldValueStr(types.Document{Data: data}, operand)
input := GetFieldValueStr(types.Document{Data: data}, operand)
return strings.TrimRight(input, " ")
}
@ -500,7 +498,7 @@ func (e *AggregationEngine) split(operand interface{}, data map[string]interface
return nil
}
input := e.getFieldValueStr(types.Document{Data: data}, arr[0])
input := GetFieldValueStr(types.Document{Data: data}, arr[0])
delimiter := arr[1].(string)
parts := strings.Split(input, delimiter)
@ -518,11 +516,11 @@ func (e *AggregationEngine) replaceAll(operand interface{}, data map[string]inte
return ""
}
input := e.getFieldValueStr(types.Document{Data: data}, spec["input"])
input := GetFieldValueStr(types.Document{Data: data}, spec["input"])
find := spec["find"].(string)
replacement := ""
if rep, ok := spec["replacement"]; ok {
replacement = toString(rep)
replacement = FormatValueToString(rep)
}
return strings.ReplaceAll(input, find, replacement)
@ -535,8 +533,8 @@ func (e *AggregationEngine) strcasecmp(operand interface{}, data map[string]inte
return 0
}
str1 := strings.ToLower(e.getFieldValueStr(types.Document{Data: data}, arr[0]))
str2 := strings.ToLower(e.getFieldValueStr(types.Document{Data: data}, arr[1]))
str1 := strings.ToLower(GetFieldValueStr(types.Document{Data: data}, arr[0]))
str2 := strings.ToLower(GetFieldValueStr(types.Document{Data: data}, arr[1]))
if str1 < str2 {
return -1
@ -573,7 +571,7 @@ func (e *AggregationEngine) filter(operand interface{}, data map[string]interfac
}
tempData["$$"+as] = item
if isTrue(e.evaluateExpression(tempData, condRaw)) {
if IsTrueValue(e.evaluateExpression(tempData, condRaw)) {
result = append(result, item)
}
}
@ -638,9 +636,9 @@ func (e *AggregationEngine) slice(operand interface{}, data map[string]interface
case []interface{}:
if len(op) >= 2 {
arr = e.toArray(op[0])
skip = int(toFloat64(op[1]))
skip = int(ToFloat64(op[1]))
if len(op) > 2 {
limit = int(toFloat64(op[2]))
limit = int(ToFloat64(op[2]))
} else {
limit = len(arr) - skip
}
@ -702,7 +700,7 @@ func (e *AggregationEngine) objectToArray(operand interface{}, data map[string]i
// ========== 辅助函数 ==========
// toArray 将值转换为数组
// toArray 将值转换为数组(保持向后兼容)
func (e *AggregationEngine) toArray(value interface{}) []interface{} {
switch v := value.(type) {
case []interface{}:
@ -724,7 +722,7 @@ func (e *AggregationEngine) boolAnd(operand interface{}, data map[string]interfa
}
for _, item := range arr {
if !isTrue(e.evaluateExpression(data, item)) {
if !IsTrueValue(e.evaluateExpression(data, item)) {
return false
}
}
@ -739,7 +737,7 @@ func (e *AggregationEngine) boolOr(operand interface{}, data map[string]interfac
}
for _, item := range arr {
if isTrue(e.evaluateExpression(data, item)) {
if IsTrueValue(e.evaluateExpression(data, item)) {
return true
}
}
@ -748,5 +746,5 @@ func (e *AggregationEngine) boolOr(operand interface{}, data map[string]interfac
// boolNot 布尔非
func (e *AggregationEngine) boolNot(operand interface{}, data map[string]interface{}) bool {
return !isTrue(e.evaluateExpression(data, operand))
return !IsTrueValue(e.evaluateExpression(data, operand))
}

477
internal/engine/helpers.go Normal file
View File

@ -0,0 +1,477 @@
package engine
import (
"fmt"
"math"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"git.kingecg.top/kingecg/gomog/pkg/types"
)
// ========== 类型转换辅助函数 ==========
// ToFloat64 将任意值转换为 float64导出版本
func ToFloat64(v interface{}) float64 {
return toFloat64(v)
}
// toFloat64 将值转换为 float64
func toFloat64(v interface{}) float64 {
switch val := v.(type) {
case int:
return float64(val)
case int8:
return float64(val)
case int16:
return float64(val)
case int32:
return float64(val)
case int64:
return float64(val)
case uint:
return float64(val)
case uint8:
return float64(val)
case uint16:
return float64(val)
case uint32:
return float64(val)
case uint64:
return float64(val)
case float32:
return float64(val)
case float64:
return val
case string:
if num, err := strconv.ParseFloat(val, 64); err == nil {
return num
}
}
return 0
}
// ToInt64 将任意值转换为 int64导出版本
func ToInt64(v interface{}) int64 {
return toInt64(v)
}
// toInt64 将值转换为 int64
func toInt64(v interface{}) int64 {
switch val := v.(type) {
case int:
return int64(val)
case int8:
return int64(val)
case int16:
return int64(val)
case int32:
return int64(val)
case int64:
return val
case uint:
return int64(val)
case uint8:
return int64(val)
case uint16:
return int64(val)
case uint32:
return int64(val)
case uint64:
return int64(val)
case float32:
return int64(val)
case float64:
return int64(val)
case string:
if num, err := strconv.ParseInt(val, 10, 64); err == nil {
return num
}
}
return 0
}
// FormatValueToString 将任意值格式化为字符串(导出版本)
func FormatValueToString(value interface{}) string {
return formatValueToString(value)
}
// formatValueToString 将任意值格式化为字符串
func formatValueToString(value interface{}) string {
if value == nil {
return ""
}
switch v := value.(type) {
case string:
return v
case bool:
return strconv.FormatBool(v)
case int, int8, int16, int32, int64:
return fmt.Sprintf("%d", v)
case uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", v)
case float32:
return strconv.FormatFloat(float64(v), 'g', -1, 32)
case float64:
return strconv.FormatFloat(v, 'g', -1, 64)
case time.Time:
return v.Format(time.RFC3339)
case []interface{}:
result := "["
for i, item := range v {
if i > 0 {
result += ","
}
result += formatValueToString(item)
}
result += "]"
return result
case map[string]interface{}:
result := "{"
first := true
for k, val := range v {
if !first {
result += ","
}
result += fmt.Sprintf("%s:%v", k, val)
first = false
}
result += "}"
return result
default:
return fmt.Sprintf("%v", v)
}
}
// IsTrueValue 检查值是否为真值(导出版本)
// 注意:内部使用的 isTrueValue/isTrue 已在 query.go 和 aggregate.go 中定义
func IsTrueValue(v interface{}) bool {
// 使用统一的转换逻辑
if v == nil {
return false
}
switch val := v.(type) {
case bool:
return val
case int, int8, int16, int32, int64:
return ToInt64(v) != 0
case uint, uint8, uint16, uint32, uint64:
return ToInt64(v) != 0
case float32, float64:
return ToFloat64(v) != 0
case string:
return val != "" && val != "0" && strings.ToLower(val) != "false"
case []interface{}:
return len(val) > 0
case map[string]interface{}:
return len(val) > 0
default:
return true
}
}
// ========== 文档操作辅助函数 ==========
// GetNestedValue 从嵌套 map 中获取值
func GetNestedValue(data map[string]interface{}, field string) interface{} {
if field == "" {
return nil
}
parts := strings.Split(field, ".")
current := data
for i, part := range parts {
if i == len(parts)-1 {
return current[part]
}
if next, ok := current[part].(map[string]interface{}); ok {
current = next
} else {
return nil
}
}
return nil
}
// SetNestedValue 设置嵌套 map 中的值
func SetNestedValue(data map[string]interface{}, field string, value interface{}) {
if field == "" {
return
}
parts := strings.Split(field, ".")
current := data
for i, part := range parts {
if i == len(parts)-1 {
current[part] = value
return
}
if next, ok := current[part].(map[string]interface{}); ok {
current = next
} else {
newMap := make(map[string]interface{})
current[part] = newMap
current = newMap
}
}
}
// RemoveNestedValue 移除嵌套 map 中的值
func RemoveNestedValue(data map[string]interface{}, field string) {
if field == "" {
return
}
parts := strings.Split(field, ".")
current := data
for i, part := range parts {
if i == len(parts)-1 {
delete(current, part)
return
}
if next, ok := current[part].(map[string]interface{}); ok {
current = next
} else {
return
}
}
}
// DeepCopyMap 深度复制 map
func DeepCopyMap(src map[string]interface{}) map[string]interface{} {
dst := make(map[string]interface{})
for k, v := range src {
dst[k] = deepCopyValue(v)
}
return dst
}
// deepCopyValue 深度复制值
func deepCopyValue(v interface{}) interface{} {
switch val := v.(type) {
case map[string]interface{}:
return DeepCopyMap(val)
case []interface{}:
arr := make([]interface{}, len(val))
for i, item := range val {
arr[i] = deepCopyValue(item)
}
return arr
default:
return v
}
}
// ========== 比较辅助函数 ==========
// CompareEq 相等比较(导出版本)
func CompareEq(a, b interface{}) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
if isComplexType(a) || isComplexType(b) {
return reflect.DeepEqual(a, b)
}
return normalizeValue(a) == normalizeValue(b)
}
// CompareNumbers 比较两个数值,返回 -1/0/1导出版本
func CompareNumbers(a, b interface{}) int {
return compareNumbers(a, b)
}
// compareNumbers 比较两个数值
func compareNumbers(a, b interface{}) int {
numA := toFloat64(a)
numB := toFloat64(b)
if numA < numB {
return -1
} else if numA > numB {
return 1
}
return 0
}
// IsComplexType 检查是否是复杂类型(导出版本)
func IsComplexType(v interface{}) bool {
return isComplexType(v)
}
// isComplexType 检查是否是复杂类型
func isComplexType(v interface{}) bool {
switch v.(type) {
case []interface{}:
return true
case map[string]interface{}:
return true
case map[interface{}]interface{}:
return true
default:
return false
}
}
// NormalizeValue 标准化值用于比较(导出版本)
func NormalizeValue(v interface{}) interface{} {
return normalizeValue(v)
}
// normalizeValue 标准化值
func normalizeValue(v interface{}) interface{} {
if v == nil {
return nil
}
switch val := v.(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
return toFloat64(v)
case string:
if num, err := strconv.ParseFloat(val, 64); err == nil {
return num
}
return strings.ToLower(val)
}
return v
}
// ========== 数组辅助函数 ==========
// ContainsElement 检查数组是否包含指定元素
func ContainsElement(arr []interface{}, element interface{}) bool {
for _, item := range arr {
if compareEq(item, element) {
return true
}
}
return false
}
// ContainsAllElements 检查数组是否包含所有指定元素
func ContainsAllElements(arr []interface{}, elements []interface{}) bool {
for _, elem := range elements {
if !ContainsElement(arr, elem) {
return false
}
}
return true
}
// ArrayIntersection 计算数组交集
func ArrayIntersection(a, b []interface{}) []interface{} {
result := make([]interface{}, 0)
for _, item := range a {
if ContainsElement(b, item) && !ContainsElement(result, item) {
result = append(result, item)
}
}
return result
}
// ArrayUnion 计算数组并集
func ArrayUnion(a, b []interface{}) []interface{} {
result := make([]interface{}, len(a))
copy(result, a)
for _, item := range b {
if !ContainsElement(result, item) {
result = append(result, item)
}
}
return result
}
// ========== 正则表达式辅助函数 ==========
// MatchRegex 正则表达式匹配
func MatchRegex(value interface{}, pattern interface{}) bool {
str, ok := value.(string)
if !ok {
return false
}
patternStr, ok := pattern.(string)
if !ok {
return false
}
matched, _ := regexp.MatchString(patternStr, str)
return matched
}
// ========== 类型检查辅助函数 ==========
// CheckType 检查值的类型
func CheckType(value interface{}, typeName string) bool {
if value == nil {
return typeName == "null"
}
var actualType string
switch reflect.TypeOf(value).Kind() {
case reflect.String:
actualType = "string"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
actualType = "int"
case reflect.Float32, reflect.Float64:
actualType = "double"
case reflect.Bool:
actualType = "bool"
case reflect.Slice, reflect.Array:
actualType = "array"
case reflect.Map:
actualType = "object"
}
return actualType == typeName
}
// ========== 数学辅助函数 ==========
// RoundToPrecision 四舍五入到指定精度
func RoundToPrecision(value float64, precision int) float64 {
multiplier := math.Pow(10, float64(precision))
return math.Round(value*multiplier) / multiplier
}
// ========== 文档辅助函数 ==========
// GetFieldValue 从文档中获取字段值
func GetFieldValue(doc types.Document, field interface{}) interface{} {
switch f := field.(type) {
case string:
if strings.HasPrefix(f, "$") {
return GetNestedValue(doc.Data, f[1:])
}
return GetNestedValue(doc.Data, f)
default:
return nil
}
}
// GetFieldValueStr 从文档中获取字段值的字符串形式
func GetFieldValueStr(doc types.Document, field interface{}) string {
val := GetFieldValue(doc, field)
if str, ok := val.(string); ok {
return str
}
return toString(val)
}

View File

@ -1,111 +1,31 @@
package engine
import (
"reflect"
"regexp"
"strconv"
"strings"
)
// operators.go - 查询操作符实现
// 使用 helpers.go 中的公共辅助函数
// compareEq 相等比较
func compareEq(a, b interface{}) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
// 对于 slice、map 等复杂类型,使用 reflect.DeepEqual
if isComplexType(a) || isComplexType(b) {
return reflect.DeepEqual(a, b)
}
// 类型转换后比较
return normalizeValue(a) == normalizeValue(b)
}
// isComplexType 检查是否是复杂类型slice、map 等)
func isComplexType(v interface{}) bool {
switch v.(type) {
case []interface{}:
return true
case map[string]interface{}:
return true
case map[interface{}]interface{}:
return true
default:
return false
}
return CompareEq(a, b)
}
// compareGt 大于比较
func compareGt(a, b interface{}) bool {
return compareNumbers(a, b) > 0
return CompareNumbers(a, b) > 0
}
// compareGte 大于等于比较
func compareGte(a, b interface{}) bool {
return compareNumbers(a, b) >= 0
return CompareNumbers(a, b) >= 0
}
// compareLt 小于比较
func compareLt(a, b interface{}) bool {
return compareNumbers(a, b) < 0
return CompareNumbers(a, b) < 0
}
// compareLte 小于等于比较
func compareLte(a, b interface{}) bool {
return compareNumbers(a, b) <= 0
}
// compareNumbers 比较两个数值,返回 -1/0/1
func compareNumbers(a, b interface{}) int {
numA := toFloat64(a)
numB := toFloat64(b)
if numA < numB {
return -1
} else if numA > numB {
return 1
}
return 0
}
// toFloat64 将值转换为 float64
func toFloat64(v interface{}) float64 {
switch val := v.(type) {
case int:
return float64(val)
case int8:
return float64(val)
case int16:
return float64(val)
case int32:
return float64(val)
case int64:
return float64(val)
case uint:
return float64(val)
case uint8:
return float64(val)
case uint16:
return float64(val)
case uint32:
return float64(val)
case uint64:
return float64(val)
case float32:
return float64(val)
case float64:
return val
case string:
// 尝试解析字符串为数字
if num, err := strconv.ParseFloat(val, 64); err == nil {
return num
}
}
return 0
return CompareNumbers(a, b) <= 0
}
// compareIn 检查值是否在数组中
@ -115,28 +35,12 @@ func compareIn(value interface{}, operand interface{}) bool {
return false
}
for _, item := range arr {
if compareEq(value, item) {
return true
}
}
return false
return ContainsElement(arr, value)
}
// compareRegex 正则表达式匹配
func compareRegex(value interface{}, operand interface{}) bool {
str, ok := value.(string)
if !ok {
return false
}
pattern, ok := operand.(string)
if !ok {
return false
}
matched, _ := regexp.MatchString(pattern, str)
return matched
return MatchRegex(value, operand)
}
// compareType 类型检查
@ -145,30 +49,12 @@ func compareType(value interface{}, operand interface{}) bool {
return operand == "null"
}
var typeName string
switch reflect.TypeOf(value).Kind() {
case reflect.String:
typeName = "string"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
typeName = "int"
case reflect.Float32, reflect.Float64:
typeName = "double"
case reflect.Bool:
typeName = "bool"
case reflect.Slice, reflect.Array:
typeName = "array"
case reflect.Map:
typeName = "object"
}
// 支持字符串或数组形式的类型检查
switch op := operand.(type) {
case string:
return typeName == op
return CheckType(value, op)
case []interface{}:
for _, t := range op {
if ts, ok := t.(string); ok && typeName == ts {
if ts, ok := t.(string); ok && CheckType(value, ts) {
return true
}
}
@ -189,20 +75,7 @@ func compareAll(value interface{}, operand interface{}) bool {
return false
}
for _, req := range required {
found := false
for _, item := range arr {
if compareEq(item, req) {
found = true
break
}
}
if !found {
return false
}
}
return true
return ContainsAllElements(arr, required)
}
// compareElemMatch 数组元素匹配
@ -252,7 +125,7 @@ func compareSize(value interface{}, operand interface{}) bool {
// compareMod 模运算value % divisor == remainder
func compareMod(value interface{}, operand interface{}) bool {
num := toFloat64(value)
num := ToFloat64(value)
var divisor, remainder float64
switch op := operand.(type) {
@ -260,8 +133,8 @@ func compareMod(value interface{}, operand interface{}) bool {
if len(op) != 2 {
return false
}
divisor = toFloat64(op[0])
remainder = toFloat64(op[1])
divisor = ToFloat64(op[0])
remainder = ToFloat64(op[1])
default:
return false
}
@ -282,84 +155,28 @@ func compareMod(value interface{}, operand interface{}) bool {
// compareBitsAllClear 位运算:所有指定位都为 0
func compareBitsAllClear(value interface{}, operand interface{}) bool {
num := toInt64(value)
mask := toInt64(operand)
num := ToInt64(value)
mask := ToInt64(operand)
return (num & mask) == 0
}
// compareBitsAllSet 位运算:所有指定位都为 1
func compareBitsAllSet(value interface{}, operand interface{}) bool {
num := toInt64(value)
mask := toInt64(operand)
num := ToInt64(value)
mask := ToInt64(operand)
return (num & mask) == mask
}
// compareBitsAnyClear 位运算:任意指定位为 0
func compareBitsAnyClear(value interface{}, operand interface{}) bool {
num := toInt64(value)
mask := toInt64(operand)
num := ToInt64(value)
mask := ToInt64(operand)
return (num & mask) != mask
}
// compareBitsAnySet 位运算:任意指定位为 1
func compareBitsAnySet(value interface{}, operand interface{}) bool {
num := toInt64(value)
mask := toInt64(operand)
num := ToInt64(value)
mask := ToInt64(operand)
return (num & mask) != 0
}
// toInt64 将值转换为 int64
func toInt64(v interface{}) int64 {
switch val := v.(type) {
case int:
return int64(val)
case int8:
return int64(val)
case int16:
return int64(val)
case int32:
return int64(val)
case int64:
return val
case uint:
return int64(val)
case uint8:
return int64(val)
case uint16:
return int64(val)
case uint32:
return int64(val)
case uint64:
return int64(val)
case float32:
return int64(val)
case float64:
return int64(val)
case string:
if num, err := strconv.ParseInt(val, 10, 64); err == nil {
return num
}
}
return 0
}
// normalizeValue 标准化值用于比较
func normalizeValue(v interface{}) interface{} {
if v == nil {
return nil
}
// 处理数字类型
switch val := v.(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
return toFloat64(v)
case string:
// 尝试将字符串解析为数字
if num, err := strconv.ParseFloat(val, 64); err == nil {
return num
}
return strings.ToLower(val)
}
return v
}

View File

@ -1,39 +1,36 @@
package engine
import (
"fmt"
"strconv"
"time"
)
// type_conversion.go - 类型转换操作符实现
// 使用 helpers.go 中的公共辅助函数
// toString 转换为字符串
func (e *AggregationEngine) toString(operand interface{}, data map[string]interface{}) string {
val := e.evaluateExpression(data, operand)
return formatValueToString(val)
return FormatValueToString(val)
}
// toInt 转换为整数 (int32)
func (e *AggregationEngine) toInt(operand interface{}, data map[string]interface{}) int32 {
val := e.evaluateExpression(data, operand)
return int32(toInt64(val))
return int32(ToInt64(val))
}
// toLong 转换为长整数 (int64)
func (e *AggregationEngine) toLong(operand interface{}, data map[string]interface{}) int64 {
val := e.evaluateExpression(data, operand)
return toInt64(val)
return ToInt64(val)
}
// toDouble 转换为浮点数 (double)
func (e *AggregationEngine) toDouble(operand interface{}, data map[string]interface{}) float64 {
val := e.evaluateExpression(data, operand)
return toFloat64(val)
return ToFloat64(val)
}
// toBool 转换为布尔值
func (e *AggregationEngine) toBool(operand interface{}, data map[string]interface{}) bool {
val := e.evaluateExpression(data, operand)
return isTrueValue(val)
return IsTrueValue(val)
}
// toDocument 转换为文档(对象)
@ -53,53 +50,3 @@ func (e *AggregationEngine) toDocument(operand interface{}, data map[string]inte
// 其他情况返回空对象MongoDB 行为)
return map[string]interface{}{}
}
// formatValueToString 将任意值格式化为字符串
func formatValueToString(value interface{}) string {
if value == nil {
return ""
}
switch v := value.(type) {
case string:
return v
case bool:
return strconv.FormatBool(v)
case int, int8, int16, int32, int64:
return fmt.Sprintf("%d", v)
case uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", v)
case float32:
return strconv.FormatFloat(float64(v), 'g', -1, 32)
case float64:
return strconv.FormatFloat(v, 'g', -1, 64)
case time.Time:
return v.Format(time.RFC3339)
case []interface{}:
// 数组转为 JSON 风格字符串
result := "["
for i, item := range v {
if i > 0 {
result += ","
}
result += formatValueToString(item)
}
result += "]"
return result
case map[string]interface{}:
// 对象转为 JSON 风格字符串(简化版)
result := "{"
first := true
for k, val := range v {
if !first {
result += ","
}
result += fmt.Sprintf("%s:%v", k, val)
first = false
}
result += "}"
return result
default:
return fmt.Sprintf("%v", v)
}
}