From 948877c15b5ffece42e1a10bb589d2e22ce9b1fd Mon Sep 17 00:00:00 2001 From: kingecg Date: Sat, 14 Mar 2026 18:48:36 +0800 Subject: [PATCH] =?UTF-8?q?refactor(engine):=20=E5=B0=86=E8=BE=85=E5=8A=A9?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E6=8F=90=E5=8F=96=E5=88=B0=E7=8B=AC=E7=AB=8B?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=B9=B6=E6=9B=B4=E6=96=B0=E5=BC=95=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将类型转换、文档操作、比较等辅助函数移动到 helpers.go 文件 - 更新 aggregate_helpers.go 中的函数调用使用新的公共辅助函数 - 更新 operators.go 中的比较函数使用新的公共辅助函数 - 更新 type_conversion.go 中的类型转换函数使用新的公共辅助函数 - 添加导出版本的辅助函数供其他包使用 - 保持向后兼容性,确保现有功能正常工作 --- internal/engine/aggregate_helpers.go | 104 +++--- internal/engine/helpers.go | 477 +++++++++++++++++++++++++++ internal/engine/operators.go | 229 ++----------- internal/engine/type_conversion.go | 67 +--- 4 files changed, 558 insertions(+), 319 deletions(-) create mode 100644 internal/engine/helpers.go diff --git a/internal/engine/aggregate_helpers.go b/internal/engine/aggregate_helpers.go index 5025cb6..bb51b96 100644 --- a/internal/engine/aggregate_helpers.go +++ b/internal/engine/aggregate_helpers.go @@ -1,5 +1,8 @@ package engine +// aggregate_helpers.go - 聚合辅助函数 +// 使用 helpers.go 中的公共辅助函数 + import ( "fmt" "math" @@ -27,7 +30,7 @@ func (e *AggregationEngine) concat(operand interface{}, data map[string]interfac if str, ok := item.(string); ok { result += str } else { - result += toString(item) + result += FormatValueToString(item) } } return result @@ -40,8 +43,8 @@ func (e *AggregationEngine) substr(operand interface{}, data map[string]interfac return "" } - str := e.getFieldValueStr(types.Document{Data: data}, arr[0]) - start := int(toFloat64(arr[1])) + str := GetFieldValueStr(types.Document{Data: data}, arr[0]) + start := int(ToFloat64(arr[1])) if start < 0 { start = 0 @@ -52,7 +55,7 @@ func (e *AggregationEngine) substr(operand interface{}, data map[string]interfac end := len(str) if len(arr) > 2 { - length := int(toFloat64(arr[2])) + length := int(ToFloat64(arr[2])) if length > 0 { end = start + length if end > len(str) { @@ -73,7 +76,7 @@ func (e *AggregationEngine) add(operand interface{}, data map[string]interface{} sum := 0.0 for _, item := range arr { - sum += toFloat64(e.evaluateExpression(data, item)) + sum += ToFloat64(e.evaluateExpression(data, item)) } return sum } @@ -87,7 +90,7 @@ func (e *AggregationEngine) multiply(operand interface{}, data map[string]interf product := 1.0 for _, item := range arr { - product *= toFloat64(e.evaluateExpression(data, item)) + product *= ToFloat64(e.evaluateExpression(data, item)) } return product } @@ -99,8 +102,8 @@ func (e *AggregationEngine) divide(operand interface{}, data map[string]interfac return 0 } - dividend := toFloat64(e.evaluateExpression(data, arr[0])) - divisor := toFloat64(e.evaluateExpression(data, arr[1])) + dividend := ToFloat64(e.evaluateExpression(data, arr[0])) + divisor := ToFloat64(e.evaluateExpression(data, arr[1])) if divisor == 0 { return 0 @@ -132,14 +135,14 @@ func (e *AggregationEngine) cond(operand interface{}, data map[string]interface{ elseCond, ok3 := op["else"] if ok1 && ok2 && ok3 { - if isTrue(e.evaluateExpression(data, ifCond)) { + if IsTrueValue(e.evaluateExpression(data, ifCond)) { return thenCond } return elseCond } case []interface{}: if len(op) >= 3 { - if isTrue(e.evaluateExpression(data, op[0])) { + if IsTrueValue(e.evaluateExpression(data, op[0])) { return op[1] } return op[2] @@ -167,7 +170,7 @@ func (e *AggregationEngine) switchExpr(operand interface{}, data map[string]inte caseRaw, _ := branch["case"] thenRaw, _ := branch["then"] - if isTrueValue(e.evaluateExpression(data, caseRaw)) { + if IsTrueValue(e.evaluateExpression(data, caseRaw)) { return e.evaluateExpression(data, thenRaw) } } @@ -175,13 +178,9 @@ func (e *AggregationEngine) switchExpr(operand interface{}, data map[string]inte return defaultVal } -// getFieldValueStr 获取字段值的字符串形式 +// getFieldValueStr 获取字段值的字符串形式(已移到 helpers.go,此处为向后兼容) func (e *AggregationEngine) getFieldValueStr(doc types.Document, field interface{}) string { - val := e.getFieldValue(doc, field) - if str, ok := val.(string); ok { - return str - } - return toString(val) + return GetFieldValueStr(doc, field) } // executeAddFields 执行 $addFields / $set 阶段 @@ -193,7 +192,7 @@ func (e *AggregationEngine) executeAddFields(spec interface{}, docs []types.Docu var results []types.Document for _, doc := range docs { - newData := deepCopyMap(doc.Data) + newData := DeepCopyMap(doc.Data) for field, expr := range fields { newData[field] = e.evaluateExpression(newData, expr) } @@ -225,9 +224,9 @@ func (e *AggregationEngine) executeUnset(spec interface{}, docs []types.Document var results []types.Document for _, doc := range docs { - newData := deepCopyMap(doc.Data) + newData := DeepCopyMap(doc.Data) for _, field := range fields { - removeNestedValue(newData, field) + RemoveNestedValue(newData, field) } results = append(results, types.Document{ ID: doc.ID, @@ -280,7 +279,7 @@ func (e *AggregationEngine) executeSample(spec interface{}, docs []types.Documen switch s := spec.(type) { case map[string]interface{}: if sizeVal, ok := s["size"]; ok { - size = int(toFloat64(sizeVal)) + size = int(ToFloat64(sizeVal)) } case float64: size = int(s) @@ -317,7 +316,7 @@ func (e *AggregationEngine) executeBucket(spec interface{}, docs []types.Documen // 转换边界为 float64 数组 boundaries := make([]float64, 0, len(boundariesRaw)) for _, b := range boundariesRaw { - boundaries = append(boundaries, toFloat64(b)) + boundaries = append(boundaries, ToFloat64(b)) } // 创建桶 @@ -332,7 +331,7 @@ func (e *AggregationEngine) executeBucket(spec interface{}, docs []types.Documen // 分组 for _, doc := range docs { - value := toFloat64(getNestedValue(doc.Data, groupBy)) + value := ToFloat64(GetNestedValue(doc.Data, groupBy)) bucketName := "" for i := 0; i < len(boundaries)-1; i++ { @@ -384,7 +383,7 @@ func (e *AggregationEngine) ExecutePipeline(docs []types.Document, pipeline []ty // abs 绝对值 func (e *AggregationEngine) abs(operand interface{}, data map[string]interface{}) float64 { - val := toFloat64(e.evaluateExpression(data, operand)) + val := ToFloat64(e.evaluateExpression(data, operand)) if val < 0 { return -val } @@ -393,13 +392,13 @@ func (e *AggregationEngine) abs(operand interface{}, data map[string]interface{} // ceil 向上取整 func (e *AggregationEngine) ceil(operand interface{}, data map[string]interface{}) float64 { - val := toFloat64(e.evaluateExpression(data, operand)) + val := ToFloat64(e.evaluateExpression(data, operand)) return math.Ceil(val) } // floor 向下取整 func (e *AggregationEngine) floor(operand interface{}, data map[string]interface{}) float64 { - val := toFloat64(e.evaluateExpression(data, operand)) + val := ToFloat64(e.evaluateExpression(data, operand)) return math.Floor(val) } @@ -410,24 +409,23 @@ func (e *AggregationEngine) round(operand interface{}, data map[string]interface switch op := operand.(type) { case []interface{}: - value = toFloat64(e.evaluateExpression(data, op[0])) + value = ToFloat64(e.evaluateExpression(data, op[0])) if len(op) > 1 { - precision = int(toFloat64(op[1])) + precision = int(ToFloat64(op[1])) } else { precision = 0 } default: - value = toFloat64(e.evaluateExpression(data, op)) + value = ToFloat64(e.evaluateExpression(data, op)) precision = 0 } - multiplier := math.Pow(10, float64(precision)) - return math.Round(value*multiplier) / multiplier + return RoundToPrecision(value, precision) } // sqrt 平方根 func (e *AggregationEngine) sqrt(operand interface{}, data map[string]interface{}) float64 { - val := toFloat64(e.evaluateExpression(data, operand)) + val := ToFloat64(e.evaluateExpression(data, operand)) return math.Sqrt(val) } @@ -438,9 +436,9 @@ func (e *AggregationEngine) subtract(operand interface{}, data map[string]interf return 0 } - result := toFloat64(e.evaluateExpression(data, arr[0])) + result := ToFloat64(e.evaluateExpression(data, arr[0])) for i := 1; i < len(arr); i++ { - result -= toFloat64(e.evaluateExpression(data, arr[i])) + result -= ToFloat64(e.evaluateExpression(data, arr[i])) } return result } @@ -452,8 +450,8 @@ func (e *AggregationEngine) pow(operand interface{}, data map[string]interface{} return 0 } - base := toFloat64(e.evaluateExpression(data, arr[0])) - exp := toFloat64(e.evaluateExpression(data, arr[1])) + base := ToFloat64(e.evaluateExpression(data, arr[0])) + exp := ToFloat64(e.evaluateExpression(data, arr[1])) return math.Pow(base, exp) } @@ -467,15 +465,15 @@ func (e *AggregationEngine) trim(operand interface{}, data map[string]interface{ switch op := operand.(type) { case map[string]interface{}: if in, ok := op["input"]; ok { - input = e.getFieldValueStr(types.Document{Data: data}, in) + input = GetFieldValueStr(types.Document{Data: data}, in) } if c, ok := op["characters"]; ok { chars = c.(string) } case string: - input = e.getFieldValueStr(types.Document{Data: data}, op) + input = GetFieldValueStr(types.Document{Data: data}, op) default: - input = toString(operand) + input = FormatValueToString(operand) } return strings.Trim(input, chars) @@ -483,13 +481,13 @@ func (e *AggregationEngine) trim(operand interface{}, data map[string]interface{ // ltrim 去除左侧空格 func (e *AggregationEngine) ltrim(operand interface{}, data map[string]interface{}) string { - input := e.getFieldValueStr(types.Document{Data: data}, operand) + input := GetFieldValueStr(types.Document{Data: data}, operand) return strings.TrimLeft(input, " ") } // rtrim 去除右侧空格 func (e *AggregationEngine) rtrim(operand interface{}, data map[string]interface{}) string { - input := e.getFieldValueStr(types.Document{Data: data}, operand) + input := GetFieldValueStr(types.Document{Data: data}, operand) return strings.TrimRight(input, " ") } @@ -500,7 +498,7 @@ func (e *AggregationEngine) split(operand interface{}, data map[string]interface return nil } - input := e.getFieldValueStr(types.Document{Data: data}, arr[0]) + input := GetFieldValueStr(types.Document{Data: data}, arr[0]) delimiter := arr[1].(string) parts := strings.Split(input, delimiter) @@ -518,11 +516,11 @@ func (e *AggregationEngine) replaceAll(operand interface{}, data map[string]inte return "" } - input := e.getFieldValueStr(types.Document{Data: data}, spec["input"]) + input := GetFieldValueStr(types.Document{Data: data}, spec["input"]) find := spec["find"].(string) replacement := "" if rep, ok := spec["replacement"]; ok { - replacement = toString(rep) + replacement = FormatValueToString(rep) } return strings.ReplaceAll(input, find, replacement) @@ -535,8 +533,8 @@ func (e *AggregationEngine) strcasecmp(operand interface{}, data map[string]inte return 0 } - str1 := strings.ToLower(e.getFieldValueStr(types.Document{Data: data}, arr[0])) - str2 := strings.ToLower(e.getFieldValueStr(types.Document{Data: data}, arr[1])) + str1 := strings.ToLower(GetFieldValueStr(types.Document{Data: data}, arr[0])) + str2 := strings.ToLower(GetFieldValueStr(types.Document{Data: data}, arr[1])) if str1 < str2 { return -1 @@ -573,7 +571,7 @@ func (e *AggregationEngine) filter(operand interface{}, data map[string]interfac } tempData["$$"+as] = item - if isTrue(e.evaluateExpression(tempData, condRaw)) { + if IsTrueValue(e.evaluateExpression(tempData, condRaw)) { result = append(result, item) } } @@ -638,9 +636,9 @@ func (e *AggregationEngine) slice(operand interface{}, data map[string]interface case []interface{}: if len(op) >= 2 { arr = e.toArray(op[0]) - skip = int(toFloat64(op[1])) + skip = int(ToFloat64(op[1])) if len(op) > 2 { - limit = int(toFloat64(op[2])) + limit = int(ToFloat64(op[2])) } else { limit = len(arr) - skip } @@ -702,7 +700,7 @@ func (e *AggregationEngine) objectToArray(operand interface{}, data map[string]i // ========== 辅助函数 ========== -// toArray 将值转换为数组 +// toArray 将值转换为数组(保持向后兼容) func (e *AggregationEngine) toArray(value interface{}) []interface{} { switch v := value.(type) { case []interface{}: @@ -724,7 +722,7 @@ func (e *AggregationEngine) boolAnd(operand interface{}, data map[string]interfa } for _, item := range arr { - if !isTrue(e.evaluateExpression(data, item)) { + if !IsTrueValue(e.evaluateExpression(data, item)) { return false } } @@ -739,7 +737,7 @@ func (e *AggregationEngine) boolOr(operand interface{}, data map[string]interfac } for _, item := range arr { - if isTrue(e.evaluateExpression(data, item)) { + if IsTrueValue(e.evaluateExpression(data, item)) { return true } } @@ -748,5 +746,5 @@ func (e *AggregationEngine) boolOr(operand interface{}, data map[string]interfac // boolNot 布尔非 func (e *AggregationEngine) boolNot(operand interface{}, data map[string]interface{}) bool { - return !isTrue(e.evaluateExpression(data, operand)) + return !IsTrueValue(e.evaluateExpression(data, operand)) } diff --git a/internal/engine/helpers.go b/internal/engine/helpers.go new file mode 100644 index 0000000..d53b312 --- /dev/null +++ b/internal/engine/helpers.go @@ -0,0 +1,477 @@ +package engine + +import ( + "fmt" + "math" + "reflect" + "regexp" + "strconv" + "strings" + "time" + + "git.kingecg.top/kingecg/gomog/pkg/types" +) + +// ========== 类型转换辅助函数 ========== + +// ToFloat64 将任意值转换为 float64(导出版本) +func ToFloat64(v interface{}) float64 { + return toFloat64(v) +} + +// toFloat64 将值转换为 float64 +func toFloat64(v interface{}) float64 { + switch val := v.(type) { + case int: + return float64(val) + case int8: + return float64(val) + case int16: + return float64(val) + case int32: + return float64(val) + case int64: + return float64(val) + case uint: + return float64(val) + case uint8: + return float64(val) + case uint16: + return float64(val) + case uint32: + return float64(val) + case uint64: + return float64(val) + case float32: + return float64(val) + case float64: + return val + case string: + if num, err := strconv.ParseFloat(val, 64); err == nil { + return num + } + } + return 0 +} + +// ToInt64 将任意值转换为 int64(导出版本) +func ToInt64(v interface{}) int64 { + return toInt64(v) +} + +// toInt64 将值转换为 int64 +func toInt64(v interface{}) int64 { + switch val := v.(type) { + case int: + return int64(val) + case int8: + return int64(val) + case int16: + return int64(val) + case int32: + return int64(val) + case int64: + return val + case uint: + return int64(val) + case uint8: + return int64(val) + case uint16: + return int64(val) + case uint32: + return int64(val) + case uint64: + return int64(val) + case float32: + return int64(val) + case float64: + return int64(val) + case string: + if num, err := strconv.ParseInt(val, 10, 64); err == nil { + return num + } + } + return 0 +} + +// FormatValueToString 将任意值格式化为字符串(导出版本) +func FormatValueToString(value interface{}) string { + return formatValueToString(value) +} + +// formatValueToString 将任意值格式化为字符串 +func formatValueToString(value interface{}) string { + if value == nil { + return "" + } + + switch v := value.(type) { + case string: + return v + case bool: + return strconv.FormatBool(v) + case int, int8, int16, int32, int64: + return fmt.Sprintf("%d", v) + case uint, uint8, uint16, uint32, uint64: + return fmt.Sprintf("%d", v) + case float32: + return strconv.FormatFloat(float64(v), 'g', -1, 32) + case float64: + return strconv.FormatFloat(v, 'g', -1, 64) + case time.Time: + return v.Format(time.RFC3339) + case []interface{}: + result := "[" + for i, item := range v { + if i > 0 { + result += "," + } + result += formatValueToString(item) + } + result += "]" + return result + case map[string]interface{}: + result := "{" + first := true + for k, val := range v { + if !first { + result += "," + } + result += fmt.Sprintf("%s:%v", k, val) + first = false + } + result += "}" + return result + default: + return fmt.Sprintf("%v", v) + } +} + +// IsTrueValue 检查值是否为真值(导出版本) +// 注意:内部使用的 isTrueValue/isTrue 已在 query.go 和 aggregate.go 中定义 +func IsTrueValue(v interface{}) bool { + // 使用统一的转换逻辑 + if v == nil { + return false + } + + switch val := v.(type) { + case bool: + return val + case int, int8, int16, int32, int64: + return ToInt64(v) != 0 + case uint, uint8, uint16, uint32, uint64: + return ToInt64(v) != 0 + case float32, float64: + return ToFloat64(v) != 0 + case string: + return val != "" && val != "0" && strings.ToLower(val) != "false" + case []interface{}: + return len(val) > 0 + case map[string]interface{}: + return len(val) > 0 + default: + return true + } +} + +// ========== 文档操作辅助函数 ========== + +// GetNestedValue 从嵌套 map 中获取值 +func GetNestedValue(data map[string]interface{}, field string) interface{} { + if field == "" { + return nil + } + + parts := strings.Split(field, ".") + current := data + + for i, part := range parts { + if i == len(parts)-1 { + return current[part] + } + + if next, ok := current[part].(map[string]interface{}); ok { + current = next + } else { + return nil + } + } + + return nil +} + +// SetNestedValue 设置嵌套 map 中的值 +func SetNestedValue(data map[string]interface{}, field string, value interface{}) { + if field == "" { + return + } + + parts := strings.Split(field, ".") + current := data + + for i, part := range parts { + if i == len(parts)-1 { + current[part] = value + return + } + + if next, ok := current[part].(map[string]interface{}); ok { + current = next + } else { + newMap := make(map[string]interface{}) + current[part] = newMap + current = newMap + } + } +} + +// RemoveNestedValue 移除嵌套 map 中的值 +func RemoveNestedValue(data map[string]interface{}, field string) { + if field == "" { + return + } + + parts := strings.Split(field, ".") + current := data + + for i, part := range parts { + if i == len(parts)-1 { + delete(current, part) + return + } + + if next, ok := current[part].(map[string]interface{}); ok { + current = next + } else { + return + } + } +} + +// DeepCopyMap 深度复制 map +func DeepCopyMap(src map[string]interface{}) map[string]interface{} { + dst := make(map[string]interface{}) + for k, v := range src { + dst[k] = deepCopyValue(v) + } + return dst +} + +// deepCopyValue 深度复制值 +func deepCopyValue(v interface{}) interface{} { + switch val := v.(type) { + case map[string]interface{}: + return DeepCopyMap(val) + case []interface{}: + arr := make([]interface{}, len(val)) + for i, item := range val { + arr[i] = deepCopyValue(item) + } + return arr + default: + return v + } +} + +// ========== 比较辅助函数 ========== + +// CompareEq 相等比较(导出版本) +func CompareEq(a, b interface{}) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + + if isComplexType(a) || isComplexType(b) { + return reflect.DeepEqual(a, b) + } + + return normalizeValue(a) == normalizeValue(b) +} + +// CompareNumbers 比较两个数值,返回 -1/0/1(导出版本) +func CompareNumbers(a, b interface{}) int { + return compareNumbers(a, b) +} + +// compareNumbers 比较两个数值 +func compareNumbers(a, b interface{}) int { + numA := toFloat64(a) + numB := toFloat64(b) + + if numA < numB { + return -1 + } else if numA > numB { + return 1 + } + return 0 +} + +// IsComplexType 检查是否是复杂类型(导出版本) +func IsComplexType(v interface{}) bool { + return isComplexType(v) +} + +// isComplexType 检查是否是复杂类型 +func isComplexType(v interface{}) bool { + switch v.(type) { + case []interface{}: + return true + case map[string]interface{}: + return true + case map[interface{}]interface{}: + return true + default: + return false + } +} + +// NormalizeValue 标准化值用于比较(导出版本) +func NormalizeValue(v interface{}) interface{} { + return normalizeValue(v) +} + +// normalizeValue 标准化值 +func normalizeValue(v interface{}) interface{} { + if v == nil { + return nil + } + + switch val := v.(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: + return toFloat64(v) + case string: + if num, err := strconv.ParseFloat(val, 64); err == nil { + return num + } + return strings.ToLower(val) + } + + return v +} + +// ========== 数组辅助函数 ========== + +// ContainsElement 检查数组是否包含指定元素 +func ContainsElement(arr []interface{}, element interface{}) bool { + for _, item := range arr { + if compareEq(item, element) { + return true + } + } + return false +} + +// ContainsAllElements 检查数组是否包含所有指定元素 +func ContainsAllElements(arr []interface{}, elements []interface{}) bool { + for _, elem := range elements { + if !ContainsElement(arr, elem) { + return false + } + } + return true +} + +// ArrayIntersection 计算数组交集 +func ArrayIntersection(a, b []interface{}) []interface{} { + result := make([]interface{}, 0) + for _, item := range a { + if ContainsElement(b, item) && !ContainsElement(result, item) { + result = append(result, item) + } + } + return result +} + +// ArrayUnion 计算数组并集 +func ArrayUnion(a, b []interface{}) []interface{} { + result := make([]interface{}, len(a)) + copy(result, a) + for _, item := range b { + if !ContainsElement(result, item) { + result = append(result, item) + } + } + return result +} + +// ========== 正则表达式辅助函数 ========== + +// MatchRegex 正则表达式匹配 +func MatchRegex(value interface{}, pattern interface{}) bool { + str, ok := value.(string) + if !ok { + return false + } + + patternStr, ok := pattern.(string) + if !ok { + return false + } + + matched, _ := regexp.MatchString(patternStr, str) + return matched +} + +// ========== 类型检查辅助函数 ========== + +// CheckType 检查值的类型 +func CheckType(value interface{}, typeName string) bool { + if value == nil { + return typeName == "null" + } + + var actualType string + switch reflect.TypeOf(value).Kind() { + case reflect.String: + actualType = "string" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + actualType = "int" + case reflect.Float32, reflect.Float64: + actualType = "double" + case reflect.Bool: + actualType = "bool" + case reflect.Slice, reflect.Array: + actualType = "array" + case reflect.Map: + actualType = "object" + } + + return actualType == typeName +} + +// ========== 数学辅助函数 ========== + +// RoundToPrecision 四舍五入到指定精度 +func RoundToPrecision(value float64, precision int) float64 { + multiplier := math.Pow(10, float64(precision)) + return math.Round(value*multiplier) / multiplier +} + +// ========== 文档辅助函数 ========== + +// GetFieldValue 从文档中获取字段值 +func GetFieldValue(doc types.Document, field interface{}) interface{} { + switch f := field.(type) { + case string: + if strings.HasPrefix(f, "$") { + return GetNestedValue(doc.Data, f[1:]) + } + return GetNestedValue(doc.Data, f) + default: + return nil + } +} + +// GetFieldValueStr 从文档中获取字段值的字符串形式 +func GetFieldValueStr(doc types.Document, field interface{}) string { + val := GetFieldValue(doc, field) + if str, ok := val.(string); ok { + return str + } + return toString(val) +} diff --git a/internal/engine/operators.go b/internal/engine/operators.go index a07c310..1b8e563 100644 --- a/internal/engine/operators.go +++ b/internal/engine/operators.go @@ -1,111 +1,31 @@ package engine -import ( - "reflect" - "regexp" - "strconv" - "strings" -) +// operators.go - 查询操作符实现 +// 使用 helpers.go 中的公共辅助函数 // compareEq 相等比较 func compareEq(a, b interface{}) bool { - if a == nil && b == nil { - return true - } - if a == nil || b == nil { - return false - } - - // 对于 slice、map 等复杂类型,使用 reflect.DeepEqual - if isComplexType(a) || isComplexType(b) { - return reflect.DeepEqual(a, b) - } - - // 类型转换后比较 - return normalizeValue(a) == normalizeValue(b) -} - -// isComplexType 检查是否是复杂类型(slice、map 等) -func isComplexType(v interface{}) bool { - switch v.(type) { - case []interface{}: - return true - case map[string]interface{}: - return true - case map[interface{}]interface{}: - return true - default: - return false - } + return CompareEq(a, b) } // compareGt 大于比较 func compareGt(a, b interface{}) bool { - return compareNumbers(a, b) > 0 + return CompareNumbers(a, b) > 0 } // compareGte 大于等于比较 func compareGte(a, b interface{}) bool { - return compareNumbers(a, b) >= 0 + return CompareNumbers(a, b) >= 0 } // compareLt 小于比较 func compareLt(a, b interface{}) bool { - return compareNumbers(a, b) < 0 + return CompareNumbers(a, b) < 0 } // compareLte 小于等于比较 func compareLte(a, b interface{}) bool { - return compareNumbers(a, b) <= 0 -} - -// compareNumbers 比较两个数值,返回 -1/0/1 -func compareNumbers(a, b interface{}) int { - numA := toFloat64(a) - numB := toFloat64(b) - - if numA < numB { - return -1 - } else if numA > numB { - return 1 - } - return 0 -} - -// toFloat64 将值转换为 float64 -func toFloat64(v interface{}) float64 { - switch val := v.(type) { - case int: - return float64(val) - case int8: - return float64(val) - case int16: - return float64(val) - case int32: - return float64(val) - case int64: - return float64(val) - case uint: - return float64(val) - case uint8: - return float64(val) - case uint16: - return float64(val) - case uint32: - return float64(val) - case uint64: - return float64(val) - case float32: - return float64(val) - case float64: - return val - case string: - // 尝试解析字符串为数字 - if num, err := strconv.ParseFloat(val, 64); err == nil { - return num - } - } - return 0 + return CompareNumbers(a, b) <= 0 } // compareIn 检查值是否在数组中 @@ -115,28 +35,12 @@ func compareIn(value interface{}, operand interface{}) bool { return false } - for _, item := range arr { - if compareEq(value, item) { - return true - } - } - return false + return ContainsElement(arr, value) } // compareRegex 正则表达式匹配 func compareRegex(value interface{}, operand interface{}) bool { - str, ok := value.(string) - if !ok { - return false - } - - pattern, ok := operand.(string) - if !ok { - return false - } - - matched, _ := regexp.MatchString(pattern, str) - return matched + return MatchRegex(value, operand) } // compareType 类型检查 @@ -145,30 +49,12 @@ func compareType(value interface{}, operand interface{}) bool { return operand == "null" } - var typeName string - switch reflect.TypeOf(value).Kind() { - case reflect.String: - typeName = "string" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - typeName = "int" - case reflect.Float32, reflect.Float64: - typeName = "double" - case reflect.Bool: - typeName = "bool" - case reflect.Slice, reflect.Array: - typeName = "array" - case reflect.Map: - typeName = "object" - } - - // 支持字符串或数组形式的类型检查 switch op := operand.(type) { case string: - return typeName == op + return CheckType(value, op) case []interface{}: for _, t := range op { - if ts, ok := t.(string); ok && typeName == ts { + if ts, ok := t.(string); ok && CheckType(value, ts) { return true } } @@ -189,20 +75,7 @@ func compareAll(value interface{}, operand interface{}) bool { return false } - for _, req := range required { - found := false - for _, item := range arr { - if compareEq(item, req) { - found = true - break - } - } - if !found { - return false - } - } - - return true + return ContainsAllElements(arr, required) } // compareElemMatch 数组元素匹配 @@ -252,7 +125,7 @@ func compareSize(value interface{}, operand interface{}) bool { // compareMod 模运算:value % divisor == remainder func compareMod(value interface{}, operand interface{}) bool { - num := toFloat64(value) + num := ToFloat64(value) var divisor, remainder float64 switch op := operand.(type) { @@ -260,8 +133,8 @@ func compareMod(value interface{}, operand interface{}) bool { if len(op) != 2 { return false } - divisor = toFloat64(op[0]) - remainder = toFloat64(op[1]) + divisor = ToFloat64(op[0]) + remainder = ToFloat64(op[1]) default: return false } @@ -282,84 +155,28 @@ func compareMod(value interface{}, operand interface{}) bool { // compareBitsAllClear 位运算:所有指定位都为 0 func compareBitsAllClear(value interface{}, operand interface{}) bool { - num := toInt64(value) - mask := toInt64(operand) + num := ToInt64(value) + mask := ToInt64(operand) return (num & mask) == 0 } // compareBitsAllSet 位运算:所有指定位都为 1 func compareBitsAllSet(value interface{}, operand interface{}) bool { - num := toInt64(value) - mask := toInt64(operand) + num := ToInt64(value) + mask := ToInt64(operand) return (num & mask) == mask } // compareBitsAnyClear 位运算:任意指定位为 0 func compareBitsAnyClear(value interface{}, operand interface{}) bool { - num := toInt64(value) - mask := toInt64(operand) + num := ToInt64(value) + mask := ToInt64(operand) return (num & mask) != mask } // compareBitsAnySet 位运算:任意指定位为 1 func compareBitsAnySet(value interface{}, operand interface{}) bool { - num := toInt64(value) - mask := toInt64(operand) + num := ToInt64(value) + mask := ToInt64(operand) return (num & mask) != 0 } - -// toInt64 将值转换为 int64 -func toInt64(v interface{}) int64 { - switch val := v.(type) { - case int: - return int64(val) - case int8: - return int64(val) - case int16: - return int64(val) - case int32: - return int64(val) - case int64: - return val - case uint: - return int64(val) - case uint8: - return int64(val) - case uint16: - return int64(val) - case uint32: - return int64(val) - case uint64: - return int64(val) - case float32: - return int64(val) - case float64: - return int64(val) - case string: - if num, err := strconv.ParseInt(val, 10, 64); err == nil { - return num - } - } - return 0 -} - -// normalizeValue 标准化值用于比较 -func normalizeValue(v interface{}) interface{} { - if v == nil { - return nil - } - - // 处理数字类型 - switch val := v.(type) { - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: - return toFloat64(v) - case string: - // 尝试将字符串解析为数字 - if num, err := strconv.ParseFloat(val, 64); err == nil { - return num - } - return strings.ToLower(val) - } - - return v -} diff --git a/internal/engine/type_conversion.go b/internal/engine/type_conversion.go index f39e980..58211b1 100644 --- a/internal/engine/type_conversion.go +++ b/internal/engine/type_conversion.go @@ -1,39 +1,36 @@ package engine -import ( - "fmt" - "strconv" - "time" -) +// type_conversion.go - 类型转换操作符实现 +// 使用 helpers.go 中的公共辅助函数 // toString 转换为字符串 func (e *AggregationEngine) toString(operand interface{}, data map[string]interface{}) string { val := e.evaluateExpression(data, operand) - return formatValueToString(val) + return FormatValueToString(val) } // toInt 转换为整数 (int32) func (e *AggregationEngine) toInt(operand interface{}, data map[string]interface{}) int32 { val := e.evaluateExpression(data, operand) - return int32(toInt64(val)) + return int32(ToInt64(val)) } // toLong 转换为长整数 (int64) func (e *AggregationEngine) toLong(operand interface{}, data map[string]interface{}) int64 { val := e.evaluateExpression(data, operand) - return toInt64(val) + return ToInt64(val) } // toDouble 转换为浮点数 (double) func (e *AggregationEngine) toDouble(operand interface{}, data map[string]interface{}) float64 { val := e.evaluateExpression(data, operand) - return toFloat64(val) + return ToFloat64(val) } // toBool 转换为布尔值 func (e *AggregationEngine) toBool(operand interface{}, data map[string]interface{}) bool { val := e.evaluateExpression(data, operand) - return isTrueValue(val) + return IsTrueValue(val) } // toDocument 转换为文档(对象) @@ -53,53 +50,3 @@ func (e *AggregationEngine) toDocument(operand interface{}, data map[string]inte // 其他情况返回空对象(MongoDB 行为) return map[string]interface{}{} } - -// formatValueToString 将任意值格式化为字符串 -func formatValueToString(value interface{}) string { - if value == nil { - return "" - } - - switch v := value.(type) { - case string: - return v - case bool: - return strconv.FormatBool(v) - case int, int8, int16, int32, int64: - return fmt.Sprintf("%d", v) - case uint, uint8, uint16, uint32, uint64: - return fmt.Sprintf("%d", v) - case float32: - return strconv.FormatFloat(float64(v), 'g', -1, 32) - case float64: - return strconv.FormatFloat(v, 'g', -1, 64) - case time.Time: - return v.Format(time.RFC3339) - case []interface{}: - // 数组转为 JSON 风格字符串 - result := "[" - for i, item := range v { - if i > 0 { - result += "," - } - result += formatValueToString(item) - } - result += "]" - return result - case map[string]interface{}: - // 对象转为 JSON 风格字符串(简化版) - result := "{" - first := true - for k, val := range v { - if !first { - result += "," - } - result += fmt.Sprintf("%s:%v", k, val) - first = false - } - result += "}" - return result - default: - return fmt.Sprintf("%v", v) - } -}