refactor(engine): 将辅助函数提取到独立文件并更新引用

- 将类型转换、文档操作、比较等辅助函数移动到 helpers.go 文件
- 更新 aggregate_helpers.go 中的函数调用使用新的公共辅助函数
- 更新 operators.go 中的比较函数使用新的公共辅助函数
- 更新 type_conversion.go 中的类型转换函数使用新的公共辅助函数
- 添加导出版本的辅助函数供其他包使用
- 保持向后兼容性,确保现有功能正常工作
This commit is contained in:
kingecg 2026-03-14 18:48:36 +08:00
parent d0b5e956c4
commit 948877c15b
4 changed files with 558 additions and 319 deletions

View File

@ -1,5 +1,8 @@
package engine package engine
// aggregate_helpers.go - 聚合辅助函数
// 使用 helpers.go 中的公共辅助函数
import ( import (
"fmt" "fmt"
"math" "math"
@ -27,7 +30,7 @@ func (e *AggregationEngine) concat(operand interface{}, data map[string]interfac
if str, ok := item.(string); ok { if str, ok := item.(string); ok {
result += str result += str
} else { } else {
result += toString(item) result += FormatValueToString(item)
} }
} }
return result return result
@ -40,8 +43,8 @@ func (e *AggregationEngine) substr(operand interface{}, data map[string]interfac
return "" return ""
} }
str := e.getFieldValueStr(types.Document{Data: data}, arr[0]) str := GetFieldValueStr(types.Document{Data: data}, arr[0])
start := int(toFloat64(arr[1])) start := int(ToFloat64(arr[1]))
if start < 0 { if start < 0 {
start = 0 start = 0
@ -52,7 +55,7 @@ func (e *AggregationEngine) substr(operand interface{}, data map[string]interfac
end := len(str) end := len(str)
if len(arr) > 2 { if len(arr) > 2 {
length := int(toFloat64(arr[2])) length := int(ToFloat64(arr[2]))
if length > 0 { if length > 0 {
end = start + length end = start + length
if end > len(str) { if end > len(str) {
@ -73,7 +76,7 @@ func (e *AggregationEngine) add(operand interface{}, data map[string]interface{}
sum := 0.0 sum := 0.0
for _, item := range arr { for _, item := range arr {
sum += toFloat64(e.evaluateExpression(data, item)) sum += ToFloat64(e.evaluateExpression(data, item))
} }
return sum return sum
} }
@ -87,7 +90,7 @@ func (e *AggregationEngine) multiply(operand interface{}, data map[string]interf
product := 1.0 product := 1.0
for _, item := range arr { for _, item := range arr {
product *= toFloat64(e.evaluateExpression(data, item)) product *= ToFloat64(e.evaluateExpression(data, item))
} }
return product return product
} }
@ -99,8 +102,8 @@ func (e *AggregationEngine) divide(operand interface{}, data map[string]interfac
return 0 return 0
} }
dividend := toFloat64(e.evaluateExpression(data, arr[0])) dividend := ToFloat64(e.evaluateExpression(data, arr[0]))
divisor := toFloat64(e.evaluateExpression(data, arr[1])) divisor := ToFloat64(e.evaluateExpression(data, arr[1]))
if divisor == 0 { if divisor == 0 {
return 0 return 0
@ -132,14 +135,14 @@ func (e *AggregationEngine) cond(operand interface{}, data map[string]interface{
elseCond, ok3 := op["else"] elseCond, ok3 := op["else"]
if ok1 && ok2 && ok3 { if ok1 && ok2 && ok3 {
if isTrue(e.evaluateExpression(data, ifCond)) { if IsTrueValue(e.evaluateExpression(data, ifCond)) {
return thenCond return thenCond
} }
return elseCond return elseCond
} }
case []interface{}: case []interface{}:
if len(op) >= 3 { if len(op) >= 3 {
if isTrue(e.evaluateExpression(data, op[0])) { if IsTrueValue(e.evaluateExpression(data, op[0])) {
return op[1] return op[1]
} }
return op[2] return op[2]
@ -167,7 +170,7 @@ func (e *AggregationEngine) switchExpr(operand interface{}, data map[string]inte
caseRaw, _ := branch["case"] caseRaw, _ := branch["case"]
thenRaw, _ := branch["then"] thenRaw, _ := branch["then"]
if isTrueValue(e.evaluateExpression(data, caseRaw)) { if IsTrueValue(e.evaluateExpression(data, caseRaw)) {
return e.evaluateExpression(data, thenRaw) return e.evaluateExpression(data, thenRaw)
} }
} }
@ -175,13 +178,9 @@ func (e *AggregationEngine) switchExpr(operand interface{}, data map[string]inte
return defaultVal return defaultVal
} }
// getFieldValueStr 获取字段值的字符串形式 // getFieldValueStr 获取字段值的字符串形式(已移到 helpers.go此处为向后兼容
func (e *AggregationEngine) getFieldValueStr(doc types.Document, field interface{}) string { func (e *AggregationEngine) getFieldValueStr(doc types.Document, field interface{}) string {
val := e.getFieldValue(doc, field) return GetFieldValueStr(doc, field)
if str, ok := val.(string); ok {
return str
}
return toString(val)
} }
// executeAddFields 执行 $addFields / $set 阶段 // executeAddFields 执行 $addFields / $set 阶段
@ -193,7 +192,7 @@ func (e *AggregationEngine) executeAddFields(spec interface{}, docs []types.Docu
var results []types.Document var results []types.Document
for _, doc := range docs { for _, doc := range docs {
newData := deepCopyMap(doc.Data) newData := DeepCopyMap(doc.Data)
for field, expr := range fields { for field, expr := range fields {
newData[field] = e.evaluateExpression(newData, expr) newData[field] = e.evaluateExpression(newData, expr)
} }
@ -225,9 +224,9 @@ func (e *AggregationEngine) executeUnset(spec interface{}, docs []types.Document
var results []types.Document var results []types.Document
for _, doc := range docs { for _, doc := range docs {
newData := deepCopyMap(doc.Data) newData := DeepCopyMap(doc.Data)
for _, field := range fields { for _, field := range fields {
removeNestedValue(newData, field) RemoveNestedValue(newData, field)
} }
results = append(results, types.Document{ results = append(results, types.Document{
ID: doc.ID, ID: doc.ID,
@ -280,7 +279,7 @@ func (e *AggregationEngine) executeSample(spec interface{}, docs []types.Documen
switch s := spec.(type) { switch s := spec.(type) {
case map[string]interface{}: case map[string]interface{}:
if sizeVal, ok := s["size"]; ok { if sizeVal, ok := s["size"]; ok {
size = int(toFloat64(sizeVal)) size = int(ToFloat64(sizeVal))
} }
case float64: case float64:
size = int(s) size = int(s)
@ -317,7 +316,7 @@ func (e *AggregationEngine) executeBucket(spec interface{}, docs []types.Documen
// 转换边界为 float64 数组 // 转换边界为 float64 数组
boundaries := make([]float64, 0, len(boundariesRaw)) boundaries := make([]float64, 0, len(boundariesRaw))
for _, b := range boundariesRaw { for _, b := range boundariesRaw {
boundaries = append(boundaries, toFloat64(b)) boundaries = append(boundaries, ToFloat64(b))
} }
// 创建桶 // 创建桶
@ -332,7 +331,7 @@ func (e *AggregationEngine) executeBucket(spec interface{}, docs []types.Documen
// 分组 // 分组
for _, doc := range docs { for _, doc := range docs {
value := toFloat64(getNestedValue(doc.Data, groupBy)) value := ToFloat64(GetNestedValue(doc.Data, groupBy))
bucketName := "" bucketName := ""
for i := 0; i < len(boundaries)-1; i++ { for i := 0; i < len(boundaries)-1; i++ {
@ -384,7 +383,7 @@ func (e *AggregationEngine) ExecutePipeline(docs []types.Document, pipeline []ty
// abs 绝对值 // abs 绝对值
func (e *AggregationEngine) abs(operand interface{}, data map[string]interface{}) float64 { func (e *AggregationEngine) abs(operand interface{}, data map[string]interface{}) float64 {
val := toFloat64(e.evaluateExpression(data, operand)) val := ToFloat64(e.evaluateExpression(data, operand))
if val < 0 { if val < 0 {
return -val return -val
} }
@ -393,13 +392,13 @@ func (e *AggregationEngine) abs(operand interface{}, data map[string]interface{}
// ceil 向上取整 // ceil 向上取整
func (e *AggregationEngine) ceil(operand interface{}, data map[string]interface{}) float64 { func (e *AggregationEngine) ceil(operand interface{}, data map[string]interface{}) float64 {
val := toFloat64(e.evaluateExpression(data, operand)) val := ToFloat64(e.evaluateExpression(data, operand))
return math.Ceil(val) return math.Ceil(val)
} }
// floor 向下取整 // floor 向下取整
func (e *AggregationEngine) floor(operand interface{}, data map[string]interface{}) float64 { func (e *AggregationEngine) floor(operand interface{}, data map[string]interface{}) float64 {
val := toFloat64(e.evaluateExpression(data, operand)) val := ToFloat64(e.evaluateExpression(data, operand))
return math.Floor(val) return math.Floor(val)
} }
@ -410,24 +409,23 @@ func (e *AggregationEngine) round(operand interface{}, data map[string]interface
switch op := operand.(type) { switch op := operand.(type) {
case []interface{}: case []interface{}:
value = toFloat64(e.evaluateExpression(data, op[0])) value = ToFloat64(e.evaluateExpression(data, op[0]))
if len(op) > 1 { if len(op) > 1 {
precision = int(toFloat64(op[1])) precision = int(ToFloat64(op[1]))
} else { } else {
precision = 0 precision = 0
} }
default: default:
value = toFloat64(e.evaluateExpression(data, op)) value = ToFloat64(e.evaluateExpression(data, op))
precision = 0 precision = 0
} }
multiplier := math.Pow(10, float64(precision)) return RoundToPrecision(value, precision)
return math.Round(value*multiplier) / multiplier
} }
// sqrt 平方根 // sqrt 平方根
func (e *AggregationEngine) sqrt(operand interface{}, data map[string]interface{}) float64 { func (e *AggregationEngine) sqrt(operand interface{}, data map[string]interface{}) float64 {
val := toFloat64(e.evaluateExpression(data, operand)) val := ToFloat64(e.evaluateExpression(data, operand))
return math.Sqrt(val) return math.Sqrt(val)
} }
@ -438,9 +436,9 @@ func (e *AggregationEngine) subtract(operand interface{}, data map[string]interf
return 0 return 0
} }
result := toFloat64(e.evaluateExpression(data, arr[0])) result := ToFloat64(e.evaluateExpression(data, arr[0]))
for i := 1; i < len(arr); i++ { for i := 1; i < len(arr); i++ {
result -= toFloat64(e.evaluateExpression(data, arr[i])) result -= ToFloat64(e.evaluateExpression(data, arr[i]))
} }
return result return result
} }
@ -452,8 +450,8 @@ func (e *AggregationEngine) pow(operand interface{}, data map[string]interface{}
return 0 return 0
} }
base := toFloat64(e.evaluateExpression(data, arr[0])) base := ToFloat64(e.evaluateExpression(data, arr[0]))
exp := toFloat64(e.evaluateExpression(data, arr[1])) exp := ToFloat64(e.evaluateExpression(data, arr[1]))
return math.Pow(base, exp) return math.Pow(base, exp)
} }
@ -467,15 +465,15 @@ func (e *AggregationEngine) trim(operand interface{}, data map[string]interface{
switch op := operand.(type) { switch op := operand.(type) {
case map[string]interface{}: case map[string]interface{}:
if in, ok := op["input"]; ok { if in, ok := op["input"]; ok {
input = e.getFieldValueStr(types.Document{Data: data}, in) input = GetFieldValueStr(types.Document{Data: data}, in)
} }
if c, ok := op["characters"]; ok { if c, ok := op["characters"]; ok {
chars = c.(string) chars = c.(string)
} }
case string: case string:
input = e.getFieldValueStr(types.Document{Data: data}, op) input = GetFieldValueStr(types.Document{Data: data}, op)
default: default:
input = toString(operand) input = FormatValueToString(operand)
} }
return strings.Trim(input, chars) return strings.Trim(input, chars)
@ -483,13 +481,13 @@ func (e *AggregationEngine) trim(operand interface{}, data map[string]interface{
// ltrim 去除左侧空格 // ltrim 去除左侧空格
func (e *AggregationEngine) ltrim(operand interface{}, data map[string]interface{}) string { func (e *AggregationEngine) ltrim(operand interface{}, data map[string]interface{}) string {
input := e.getFieldValueStr(types.Document{Data: data}, operand) input := GetFieldValueStr(types.Document{Data: data}, operand)
return strings.TrimLeft(input, " ") return strings.TrimLeft(input, " ")
} }
// rtrim 去除右侧空格 // rtrim 去除右侧空格
func (e *AggregationEngine) rtrim(operand interface{}, data map[string]interface{}) string { func (e *AggregationEngine) rtrim(operand interface{}, data map[string]interface{}) string {
input := e.getFieldValueStr(types.Document{Data: data}, operand) input := GetFieldValueStr(types.Document{Data: data}, operand)
return strings.TrimRight(input, " ") return strings.TrimRight(input, " ")
} }
@ -500,7 +498,7 @@ func (e *AggregationEngine) split(operand interface{}, data map[string]interface
return nil return nil
} }
input := e.getFieldValueStr(types.Document{Data: data}, arr[0]) input := GetFieldValueStr(types.Document{Data: data}, arr[0])
delimiter := arr[1].(string) delimiter := arr[1].(string)
parts := strings.Split(input, delimiter) parts := strings.Split(input, delimiter)
@ -518,11 +516,11 @@ func (e *AggregationEngine) replaceAll(operand interface{}, data map[string]inte
return "" return ""
} }
input := e.getFieldValueStr(types.Document{Data: data}, spec["input"]) input := GetFieldValueStr(types.Document{Data: data}, spec["input"])
find := spec["find"].(string) find := spec["find"].(string)
replacement := "" replacement := ""
if rep, ok := spec["replacement"]; ok { if rep, ok := spec["replacement"]; ok {
replacement = toString(rep) replacement = FormatValueToString(rep)
} }
return strings.ReplaceAll(input, find, replacement) return strings.ReplaceAll(input, find, replacement)
@ -535,8 +533,8 @@ func (e *AggregationEngine) strcasecmp(operand interface{}, data map[string]inte
return 0 return 0
} }
str1 := strings.ToLower(e.getFieldValueStr(types.Document{Data: data}, arr[0])) str1 := strings.ToLower(GetFieldValueStr(types.Document{Data: data}, arr[0]))
str2 := strings.ToLower(e.getFieldValueStr(types.Document{Data: data}, arr[1])) str2 := strings.ToLower(GetFieldValueStr(types.Document{Data: data}, arr[1]))
if str1 < str2 { if str1 < str2 {
return -1 return -1
@ -573,7 +571,7 @@ func (e *AggregationEngine) filter(operand interface{}, data map[string]interfac
} }
tempData["$$"+as] = item tempData["$$"+as] = item
if isTrue(e.evaluateExpression(tempData, condRaw)) { if IsTrueValue(e.evaluateExpression(tempData, condRaw)) {
result = append(result, item) result = append(result, item)
} }
} }
@ -638,9 +636,9 @@ func (e *AggregationEngine) slice(operand interface{}, data map[string]interface
case []interface{}: case []interface{}:
if len(op) >= 2 { if len(op) >= 2 {
arr = e.toArray(op[0]) arr = e.toArray(op[0])
skip = int(toFloat64(op[1])) skip = int(ToFloat64(op[1]))
if len(op) > 2 { if len(op) > 2 {
limit = int(toFloat64(op[2])) limit = int(ToFloat64(op[2]))
} else { } else {
limit = len(arr) - skip limit = len(arr) - skip
} }
@ -702,7 +700,7 @@ func (e *AggregationEngine) objectToArray(operand interface{}, data map[string]i
// ========== 辅助函数 ========== // ========== 辅助函数 ==========
// toArray 将值转换为数组 // toArray 将值转换为数组(保持向后兼容)
func (e *AggregationEngine) toArray(value interface{}) []interface{} { func (e *AggregationEngine) toArray(value interface{}) []interface{} {
switch v := value.(type) { switch v := value.(type) {
case []interface{}: case []interface{}:
@ -724,7 +722,7 @@ func (e *AggregationEngine) boolAnd(operand interface{}, data map[string]interfa
} }
for _, item := range arr { for _, item := range arr {
if !isTrue(e.evaluateExpression(data, item)) { if !IsTrueValue(e.evaluateExpression(data, item)) {
return false return false
} }
} }
@ -739,7 +737,7 @@ func (e *AggregationEngine) boolOr(operand interface{}, data map[string]interfac
} }
for _, item := range arr { for _, item := range arr {
if isTrue(e.evaluateExpression(data, item)) { if IsTrueValue(e.evaluateExpression(data, item)) {
return true return true
} }
} }
@ -748,5 +746,5 @@ func (e *AggregationEngine) boolOr(operand interface{}, data map[string]interfac
// boolNot 布尔非 // boolNot 布尔非
func (e *AggregationEngine) boolNot(operand interface{}, data map[string]interface{}) bool { func (e *AggregationEngine) boolNot(operand interface{}, data map[string]interface{}) bool {
return !isTrue(e.evaluateExpression(data, operand)) return !IsTrueValue(e.evaluateExpression(data, operand))
} }

477
internal/engine/helpers.go Normal file
View File

@ -0,0 +1,477 @@
package engine
import (
"fmt"
"math"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"git.kingecg.top/kingecg/gomog/pkg/types"
)
// ========== 类型转换辅助函数 ==========
// ToFloat64 将任意值转换为 float64导出版本
func ToFloat64(v interface{}) float64 {
return toFloat64(v)
}
// toFloat64 将值转换为 float64
func toFloat64(v interface{}) float64 {
switch val := v.(type) {
case int:
return float64(val)
case int8:
return float64(val)
case int16:
return float64(val)
case int32:
return float64(val)
case int64:
return float64(val)
case uint:
return float64(val)
case uint8:
return float64(val)
case uint16:
return float64(val)
case uint32:
return float64(val)
case uint64:
return float64(val)
case float32:
return float64(val)
case float64:
return val
case string:
if num, err := strconv.ParseFloat(val, 64); err == nil {
return num
}
}
return 0
}
// ToInt64 将任意值转换为 int64导出版本
func ToInt64(v interface{}) int64 {
return toInt64(v)
}
// toInt64 将值转换为 int64
func toInt64(v interface{}) int64 {
switch val := v.(type) {
case int:
return int64(val)
case int8:
return int64(val)
case int16:
return int64(val)
case int32:
return int64(val)
case int64:
return val
case uint:
return int64(val)
case uint8:
return int64(val)
case uint16:
return int64(val)
case uint32:
return int64(val)
case uint64:
return int64(val)
case float32:
return int64(val)
case float64:
return int64(val)
case string:
if num, err := strconv.ParseInt(val, 10, 64); err == nil {
return num
}
}
return 0
}
// FormatValueToString 将任意值格式化为字符串(导出版本)
func FormatValueToString(value interface{}) string {
return formatValueToString(value)
}
// formatValueToString 将任意值格式化为字符串
func formatValueToString(value interface{}) string {
if value == nil {
return ""
}
switch v := value.(type) {
case string:
return v
case bool:
return strconv.FormatBool(v)
case int, int8, int16, int32, int64:
return fmt.Sprintf("%d", v)
case uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", v)
case float32:
return strconv.FormatFloat(float64(v), 'g', -1, 32)
case float64:
return strconv.FormatFloat(v, 'g', -1, 64)
case time.Time:
return v.Format(time.RFC3339)
case []interface{}:
result := "["
for i, item := range v {
if i > 0 {
result += ","
}
result += formatValueToString(item)
}
result += "]"
return result
case map[string]interface{}:
result := "{"
first := true
for k, val := range v {
if !first {
result += ","
}
result += fmt.Sprintf("%s:%v", k, val)
first = false
}
result += "}"
return result
default:
return fmt.Sprintf("%v", v)
}
}
// IsTrueValue 检查值是否为真值(导出版本)
// 注意:内部使用的 isTrueValue/isTrue 已在 query.go 和 aggregate.go 中定义
func IsTrueValue(v interface{}) bool {
// 使用统一的转换逻辑
if v == nil {
return false
}
switch val := v.(type) {
case bool:
return val
case int, int8, int16, int32, int64:
return ToInt64(v) != 0
case uint, uint8, uint16, uint32, uint64:
return ToInt64(v) != 0
case float32, float64:
return ToFloat64(v) != 0
case string:
return val != "" && val != "0" && strings.ToLower(val) != "false"
case []interface{}:
return len(val) > 0
case map[string]interface{}:
return len(val) > 0
default:
return true
}
}
// ========== 文档操作辅助函数 ==========
// GetNestedValue 从嵌套 map 中获取值
func GetNestedValue(data map[string]interface{}, field string) interface{} {
if field == "" {
return nil
}
parts := strings.Split(field, ".")
current := data
for i, part := range parts {
if i == len(parts)-1 {
return current[part]
}
if next, ok := current[part].(map[string]interface{}); ok {
current = next
} else {
return nil
}
}
return nil
}
// SetNestedValue 设置嵌套 map 中的值
func SetNestedValue(data map[string]interface{}, field string, value interface{}) {
if field == "" {
return
}
parts := strings.Split(field, ".")
current := data
for i, part := range parts {
if i == len(parts)-1 {
current[part] = value
return
}
if next, ok := current[part].(map[string]interface{}); ok {
current = next
} else {
newMap := make(map[string]interface{})
current[part] = newMap
current = newMap
}
}
}
// RemoveNestedValue 移除嵌套 map 中的值
func RemoveNestedValue(data map[string]interface{}, field string) {
if field == "" {
return
}
parts := strings.Split(field, ".")
current := data
for i, part := range parts {
if i == len(parts)-1 {
delete(current, part)
return
}
if next, ok := current[part].(map[string]interface{}); ok {
current = next
} else {
return
}
}
}
// DeepCopyMap 深度复制 map
func DeepCopyMap(src map[string]interface{}) map[string]interface{} {
dst := make(map[string]interface{})
for k, v := range src {
dst[k] = deepCopyValue(v)
}
return dst
}
// deepCopyValue 深度复制值
func deepCopyValue(v interface{}) interface{} {
switch val := v.(type) {
case map[string]interface{}:
return DeepCopyMap(val)
case []interface{}:
arr := make([]interface{}, len(val))
for i, item := range val {
arr[i] = deepCopyValue(item)
}
return arr
default:
return v
}
}
// ========== 比较辅助函数 ==========
// CompareEq 相等比较(导出版本)
func CompareEq(a, b interface{}) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
if isComplexType(a) || isComplexType(b) {
return reflect.DeepEqual(a, b)
}
return normalizeValue(a) == normalizeValue(b)
}
// CompareNumbers 比较两个数值,返回 -1/0/1导出版本
func CompareNumbers(a, b interface{}) int {
return compareNumbers(a, b)
}
// compareNumbers 比较两个数值
func compareNumbers(a, b interface{}) int {
numA := toFloat64(a)
numB := toFloat64(b)
if numA < numB {
return -1
} else if numA > numB {
return 1
}
return 0
}
// IsComplexType 检查是否是复杂类型(导出版本)
func IsComplexType(v interface{}) bool {
return isComplexType(v)
}
// isComplexType 检查是否是复杂类型
func isComplexType(v interface{}) bool {
switch v.(type) {
case []interface{}:
return true
case map[string]interface{}:
return true
case map[interface{}]interface{}:
return true
default:
return false
}
}
// NormalizeValue 标准化值用于比较(导出版本)
func NormalizeValue(v interface{}) interface{} {
return normalizeValue(v)
}
// normalizeValue 标准化值
func normalizeValue(v interface{}) interface{} {
if v == nil {
return nil
}
switch val := v.(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
return toFloat64(v)
case string:
if num, err := strconv.ParseFloat(val, 64); err == nil {
return num
}
return strings.ToLower(val)
}
return v
}
// ========== 数组辅助函数 ==========
// ContainsElement 检查数组是否包含指定元素
func ContainsElement(arr []interface{}, element interface{}) bool {
for _, item := range arr {
if compareEq(item, element) {
return true
}
}
return false
}
// ContainsAllElements 检查数组是否包含所有指定元素
func ContainsAllElements(arr []interface{}, elements []interface{}) bool {
for _, elem := range elements {
if !ContainsElement(arr, elem) {
return false
}
}
return true
}
// ArrayIntersection 计算数组交集
func ArrayIntersection(a, b []interface{}) []interface{} {
result := make([]interface{}, 0)
for _, item := range a {
if ContainsElement(b, item) && !ContainsElement(result, item) {
result = append(result, item)
}
}
return result
}
// ArrayUnion 计算数组并集
func ArrayUnion(a, b []interface{}) []interface{} {
result := make([]interface{}, len(a))
copy(result, a)
for _, item := range b {
if !ContainsElement(result, item) {
result = append(result, item)
}
}
return result
}
// ========== 正则表达式辅助函数 ==========
// MatchRegex 正则表达式匹配
func MatchRegex(value interface{}, pattern interface{}) bool {
str, ok := value.(string)
if !ok {
return false
}
patternStr, ok := pattern.(string)
if !ok {
return false
}
matched, _ := regexp.MatchString(patternStr, str)
return matched
}
// ========== 类型检查辅助函数 ==========
// CheckType 检查值的类型
func CheckType(value interface{}, typeName string) bool {
if value == nil {
return typeName == "null"
}
var actualType string
switch reflect.TypeOf(value).Kind() {
case reflect.String:
actualType = "string"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
actualType = "int"
case reflect.Float32, reflect.Float64:
actualType = "double"
case reflect.Bool:
actualType = "bool"
case reflect.Slice, reflect.Array:
actualType = "array"
case reflect.Map:
actualType = "object"
}
return actualType == typeName
}
// ========== 数学辅助函数 ==========
// RoundToPrecision 四舍五入到指定精度
func RoundToPrecision(value float64, precision int) float64 {
multiplier := math.Pow(10, float64(precision))
return math.Round(value*multiplier) / multiplier
}
// ========== 文档辅助函数 ==========
// GetFieldValue 从文档中获取字段值
func GetFieldValue(doc types.Document, field interface{}) interface{} {
switch f := field.(type) {
case string:
if strings.HasPrefix(f, "$") {
return GetNestedValue(doc.Data, f[1:])
}
return GetNestedValue(doc.Data, f)
default:
return nil
}
}
// GetFieldValueStr 从文档中获取字段值的字符串形式
func GetFieldValueStr(doc types.Document, field interface{}) string {
val := GetFieldValue(doc, field)
if str, ok := val.(string); ok {
return str
}
return toString(val)
}

View File

@ -1,111 +1,31 @@
package engine package engine
import ( // operators.go - 查询操作符实现
"reflect" // 使用 helpers.go 中的公共辅助函数
"regexp"
"strconv"
"strings"
)
// compareEq 相等比较 // compareEq 相等比较
func compareEq(a, b interface{}) bool { func compareEq(a, b interface{}) bool {
if a == nil && b == nil { return CompareEq(a, b)
return true
}
if a == nil || b == nil {
return false
}
// 对于 slice、map 等复杂类型,使用 reflect.DeepEqual
if isComplexType(a) || isComplexType(b) {
return reflect.DeepEqual(a, b)
}
// 类型转换后比较
return normalizeValue(a) == normalizeValue(b)
}
// isComplexType 检查是否是复杂类型slice、map 等)
func isComplexType(v interface{}) bool {
switch v.(type) {
case []interface{}:
return true
case map[string]interface{}:
return true
case map[interface{}]interface{}:
return true
default:
return false
}
} }
// compareGt 大于比较 // compareGt 大于比较
func compareGt(a, b interface{}) bool { func compareGt(a, b interface{}) bool {
return compareNumbers(a, b) > 0 return CompareNumbers(a, b) > 0
} }
// compareGte 大于等于比较 // compareGte 大于等于比较
func compareGte(a, b interface{}) bool { func compareGte(a, b interface{}) bool {
return compareNumbers(a, b) >= 0 return CompareNumbers(a, b) >= 0
} }
// compareLt 小于比较 // compareLt 小于比较
func compareLt(a, b interface{}) bool { func compareLt(a, b interface{}) bool {
return compareNumbers(a, b) < 0 return CompareNumbers(a, b) < 0
} }
// compareLte 小于等于比较 // compareLte 小于等于比较
func compareLte(a, b interface{}) bool { func compareLte(a, b interface{}) bool {
return compareNumbers(a, b) <= 0 return CompareNumbers(a, b) <= 0
}
// compareNumbers 比较两个数值,返回 -1/0/1
func compareNumbers(a, b interface{}) int {
numA := toFloat64(a)
numB := toFloat64(b)
if numA < numB {
return -1
} else if numA > numB {
return 1
}
return 0
}
// toFloat64 将值转换为 float64
func toFloat64(v interface{}) float64 {
switch val := v.(type) {
case int:
return float64(val)
case int8:
return float64(val)
case int16:
return float64(val)
case int32:
return float64(val)
case int64:
return float64(val)
case uint:
return float64(val)
case uint8:
return float64(val)
case uint16:
return float64(val)
case uint32:
return float64(val)
case uint64:
return float64(val)
case float32:
return float64(val)
case float64:
return val
case string:
// 尝试解析字符串为数字
if num, err := strconv.ParseFloat(val, 64); err == nil {
return num
}
}
return 0
} }
// compareIn 检查值是否在数组中 // compareIn 检查值是否在数组中
@ -115,28 +35,12 @@ func compareIn(value interface{}, operand interface{}) bool {
return false return false
} }
for _, item := range arr { return ContainsElement(arr, value)
if compareEq(value, item) {
return true
}
}
return false
} }
// compareRegex 正则表达式匹配 // compareRegex 正则表达式匹配
func compareRegex(value interface{}, operand interface{}) bool { func compareRegex(value interface{}, operand interface{}) bool {
str, ok := value.(string) return MatchRegex(value, operand)
if !ok {
return false
}
pattern, ok := operand.(string)
if !ok {
return false
}
matched, _ := regexp.MatchString(pattern, str)
return matched
} }
// compareType 类型检查 // compareType 类型检查
@ -145,30 +49,12 @@ func compareType(value interface{}, operand interface{}) bool {
return operand == "null" return operand == "null"
} }
var typeName string
switch reflect.TypeOf(value).Kind() {
case reflect.String:
typeName = "string"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
typeName = "int"
case reflect.Float32, reflect.Float64:
typeName = "double"
case reflect.Bool:
typeName = "bool"
case reflect.Slice, reflect.Array:
typeName = "array"
case reflect.Map:
typeName = "object"
}
// 支持字符串或数组形式的类型检查
switch op := operand.(type) { switch op := operand.(type) {
case string: case string:
return typeName == op return CheckType(value, op)
case []interface{}: case []interface{}:
for _, t := range op { for _, t := range op {
if ts, ok := t.(string); ok && typeName == ts { if ts, ok := t.(string); ok && CheckType(value, ts) {
return true return true
} }
} }
@ -189,20 +75,7 @@ func compareAll(value interface{}, operand interface{}) bool {
return false return false
} }
for _, req := range required { return ContainsAllElements(arr, required)
found := false
for _, item := range arr {
if compareEq(item, req) {
found = true
break
}
}
if !found {
return false
}
}
return true
} }
// compareElemMatch 数组元素匹配 // compareElemMatch 数组元素匹配
@ -252,7 +125,7 @@ func compareSize(value interface{}, operand interface{}) bool {
// compareMod 模运算value % divisor == remainder // compareMod 模运算value % divisor == remainder
func compareMod(value interface{}, operand interface{}) bool { func compareMod(value interface{}, operand interface{}) bool {
num := toFloat64(value) num := ToFloat64(value)
var divisor, remainder float64 var divisor, remainder float64
switch op := operand.(type) { switch op := operand.(type) {
@ -260,8 +133,8 @@ func compareMod(value interface{}, operand interface{}) bool {
if len(op) != 2 { if len(op) != 2 {
return false return false
} }
divisor = toFloat64(op[0]) divisor = ToFloat64(op[0])
remainder = toFloat64(op[1]) remainder = ToFloat64(op[1])
default: default:
return false return false
} }
@ -282,84 +155,28 @@ func compareMod(value interface{}, operand interface{}) bool {
// compareBitsAllClear 位运算:所有指定位都为 0 // compareBitsAllClear 位运算:所有指定位都为 0
func compareBitsAllClear(value interface{}, operand interface{}) bool { func compareBitsAllClear(value interface{}, operand interface{}) bool {
num := toInt64(value) num := ToInt64(value)
mask := toInt64(operand) mask := ToInt64(operand)
return (num & mask) == 0 return (num & mask) == 0
} }
// compareBitsAllSet 位运算:所有指定位都为 1 // compareBitsAllSet 位运算:所有指定位都为 1
func compareBitsAllSet(value interface{}, operand interface{}) bool { func compareBitsAllSet(value interface{}, operand interface{}) bool {
num := toInt64(value) num := ToInt64(value)
mask := toInt64(operand) mask := ToInt64(operand)
return (num & mask) == mask return (num & mask) == mask
} }
// compareBitsAnyClear 位运算:任意指定位为 0 // compareBitsAnyClear 位运算:任意指定位为 0
func compareBitsAnyClear(value interface{}, operand interface{}) bool { func compareBitsAnyClear(value interface{}, operand interface{}) bool {
num := toInt64(value) num := ToInt64(value)
mask := toInt64(operand) mask := ToInt64(operand)
return (num & mask) != mask return (num & mask) != mask
} }
// compareBitsAnySet 位运算:任意指定位为 1 // compareBitsAnySet 位运算:任意指定位为 1
func compareBitsAnySet(value interface{}, operand interface{}) bool { func compareBitsAnySet(value interface{}, operand interface{}) bool {
num := toInt64(value) num := ToInt64(value)
mask := toInt64(operand) mask := ToInt64(operand)
return (num & mask) != 0 return (num & mask) != 0
} }
// toInt64 将值转换为 int64
func toInt64(v interface{}) int64 {
switch val := v.(type) {
case int:
return int64(val)
case int8:
return int64(val)
case int16:
return int64(val)
case int32:
return int64(val)
case int64:
return val
case uint:
return int64(val)
case uint8:
return int64(val)
case uint16:
return int64(val)
case uint32:
return int64(val)
case uint64:
return int64(val)
case float32:
return int64(val)
case float64:
return int64(val)
case string:
if num, err := strconv.ParseInt(val, 10, 64); err == nil {
return num
}
}
return 0
}
// normalizeValue 标准化值用于比较
func normalizeValue(v interface{}) interface{} {
if v == nil {
return nil
}
// 处理数字类型
switch val := v.(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
return toFloat64(v)
case string:
// 尝试将字符串解析为数字
if num, err := strconv.ParseFloat(val, 64); err == nil {
return num
}
return strings.ToLower(val)
}
return v
}

View File

@ -1,39 +1,36 @@
package engine package engine
import ( // type_conversion.go - 类型转换操作符实现
"fmt" // 使用 helpers.go 中的公共辅助函数
"strconv"
"time"
)
// toString 转换为字符串 // toString 转换为字符串
func (e *AggregationEngine) toString(operand interface{}, data map[string]interface{}) string { func (e *AggregationEngine) toString(operand interface{}, data map[string]interface{}) string {
val := e.evaluateExpression(data, operand) val := e.evaluateExpression(data, operand)
return formatValueToString(val) return FormatValueToString(val)
} }
// toInt 转换为整数 (int32) // toInt 转换为整数 (int32)
func (e *AggregationEngine) toInt(operand interface{}, data map[string]interface{}) int32 { func (e *AggregationEngine) toInt(operand interface{}, data map[string]interface{}) int32 {
val := e.evaluateExpression(data, operand) val := e.evaluateExpression(data, operand)
return int32(toInt64(val)) return int32(ToInt64(val))
} }
// toLong 转换为长整数 (int64) // toLong 转换为长整数 (int64)
func (e *AggregationEngine) toLong(operand interface{}, data map[string]interface{}) int64 { func (e *AggregationEngine) toLong(operand interface{}, data map[string]interface{}) int64 {
val := e.evaluateExpression(data, operand) val := e.evaluateExpression(data, operand)
return toInt64(val) return ToInt64(val)
} }
// toDouble 转换为浮点数 (double) // toDouble 转换为浮点数 (double)
func (e *AggregationEngine) toDouble(operand interface{}, data map[string]interface{}) float64 { func (e *AggregationEngine) toDouble(operand interface{}, data map[string]interface{}) float64 {
val := e.evaluateExpression(data, operand) val := e.evaluateExpression(data, operand)
return toFloat64(val) return ToFloat64(val)
} }
// toBool 转换为布尔值 // toBool 转换为布尔值
func (e *AggregationEngine) toBool(operand interface{}, data map[string]interface{}) bool { func (e *AggregationEngine) toBool(operand interface{}, data map[string]interface{}) bool {
val := e.evaluateExpression(data, operand) val := e.evaluateExpression(data, operand)
return isTrueValue(val) return IsTrueValue(val)
} }
// toDocument 转换为文档(对象) // toDocument 转换为文档(对象)
@ -53,53 +50,3 @@ func (e *AggregationEngine) toDocument(operand interface{}, data map[string]inte
// 其他情况返回空对象MongoDB 行为) // 其他情况返回空对象MongoDB 行为)
return map[string]interface{}{} return map[string]interface{}{}
} }
// formatValueToString 将任意值格式化为字符串
func formatValueToString(value interface{}) string {
if value == nil {
return ""
}
switch v := value.(type) {
case string:
return v
case bool:
return strconv.FormatBool(v)
case int, int8, int16, int32, int64:
return fmt.Sprintf("%d", v)
case uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", v)
case float32:
return strconv.FormatFloat(float64(v), 'g', -1, 32)
case float64:
return strconv.FormatFloat(v, 'g', -1, 64)
case time.Time:
return v.Format(time.RFC3339)
case []interface{}:
// 数组转为 JSON 风格字符串
result := "["
for i, item := range v {
if i > 0 {
result += ","
}
result += formatValueToString(item)
}
result += "]"
return result
case map[string]interface{}:
// 对象转为 JSON 风格字符串(简化版)
result := "{"
first := true
for k, val := range v {
if !first {
result += ","
}
result += fmt.Sprintf("%s:%v", k, val)
first = false
}
result += "}"
return result
default:
return fmt.Sprintf("%v", v)
}
}