```
feat(database): 添加分页查询功能并完善适配器实现 - 在DatabaseAdapter接口中新增FindPage方法用于分页查询 - 实现PageResult结构体包含文档列表、是否有更多数据和总数 - 在BaseAdapter、DM8Adapter、PostgresAdapter和SQLiteAdapter中实现分页查询 - SQLite适配器现在正确检查集合是否存在和列出集合 - 调整CollectionExists方法返回nil而不是ErrNotImplemented refactor(engine): 重构内存存储初始化策略 - 修改Initialize方法改为懒加载模式,不再一次性加载所有数据 - 添加Collection结构体的新字段:pageSize、loadedAll、totalCount - 实现LoadCollectionPage方法支持按页加载数据 - 添加LoadEntireCollection和LazyLoadDocument方法 - 实现DocumentIterator用于文档遍历 feat(engine): 添加流式聚合执行功能 - 新增StreamAggregationOptions配置流式聚合参数 - 实现StreamExecute方法提供流式聚合能力 - 添加缓冲区大小、并发控制等选项 example: 添加流式聚合示例程序 - 创建stream_aggregate_example.go演示流式聚合用法 - 包含完整的测试数据创建和聚合管道执行流程 - 展示如何处理批量结果和错误通道 chore(config): 更新服务器TCP端口配置 - 将TCP监听地址从:27017更改为:28017 ```
This commit is contained in:
parent
bcda1398fb
commit
2841e31d84
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
server:
|
||||
http_addr: ":8080"
|
||||
tcp_addr: ":27017"
|
||||
tcp_addr: ":28017"
|
||||
mode: "dev"
|
||||
|
||||
database:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,127 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"git.kingecg.top/kingecg/gomog/internal/engine"
|
||||
"git.kingecg.top/kingecg/gomog/pkg/types"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建内存存储和流式聚合引擎
|
||||
store := engine.NewMemoryStore(nil)
|
||||
aggEngine := engine.NewStreamAggregationEngine(store)
|
||||
|
||||
// 创建测试数据
|
||||
collection := "test_stream"
|
||||
testDocs := make(map[string]types.Document)
|
||||
for i := 0; i < 1000; i++ {
|
||||
testDocs[fmt.Sprintf("doc%d", i)] = types.Document{
|
||||
ID: fmt.Sprintf("doc%d", i),
|
||||
Data: map[string]interface{}{
|
||||
"name": fmt.Sprintf("User%d", i),
|
||||
"age": 20 + (i % 50),
|
||||
"score": float64(50 + (i % 50)),
|
||||
"status": map[string]interface{}{"active": i%2 == 0},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 插入测试数据
|
||||
var docs []types.Document
|
||||
for _, doc := range testDocs {
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
|
||||
if err := store.InsertMany(collection, docs); err != nil {
|
||||
log.Printf("Error inserting documents: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 定义聚合管道
|
||||
pipeline := []types.AggregateStage{
|
||||
{
|
||||
Stage: "$match",
|
||||
Spec: map[string]interface{}{
|
||||
"age": map[string]interface{}{
|
||||
"$gte": 25,
|
||||
"$lte": 35,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Stage: "$project",
|
||||
Spec: map[string]interface{}{
|
||||
"name": 1,
|
||||
"age": 1,
|
||||
"score": 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Stage: "$addFields",
|
||||
Spec: map[string]interface{}{
|
||||
"isHighScorer": map[string]interface{}{
|
||||
"$gte": []interface{}{"$score", 80.0},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Stage: "$sort",
|
||||
Spec: map[string]interface{}{
|
||||
"score": -1,
|
||||
},
|
||||
},
|
||||
{
|
||||
Stage: "$limit",
|
||||
Spec: 10,
|
||||
},
|
||||
}
|
||||
|
||||
// 执行流式聚合
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
opts := engine.StreamAggregationOptions{
|
||||
BufferSize: 100, // 每次处理100个文档
|
||||
}
|
||||
|
||||
resultChan, errChan := aggEngine.StreamExecute(ctx, collection, pipeline, opts)
|
||||
|
||||
// 处理结果
|
||||
var finalResults []types.Document
|
||||
for {
|
||||
select {
|
||||
case batch, ok := <-resultChan:
|
||||
if !ok {
|
||||
resultChan = nil
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Received batch with %d documents\n", len(batch))
|
||||
finalResults = append(finalResults, batch...)
|
||||
case err, ok := <-errChan:
|
||||
if ok && err != nil {
|
||||
log.Printf("Error during stream aggregation: %v", err)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
// 通道已关闭,结束循环
|
||||
goto done
|
||||
}
|
||||
case <-ctx.Done():
|
||||
log.Println("Context cancelled")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
done:
|
||||
fmt.Printf("Total results: %d\n", len(finalResults))
|
||||
for i, doc := range finalResults {
|
||||
if i < 5 { // 只打印前5个结果
|
||||
fmt.Printf("Result %d: ID=%s, Name=%s, Age=%v, Score=%v\n",
|
||||
i+1, doc.ID, doc.Data["name"], doc.Data["age"], doc.Data["score"])
|
||||
}
|
||||
}
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
|
|
@ -6,6 +6,13 @@ import (
|
|||
"git.kingecg.top/kingecg/gomog/pkg/types"
|
||||
)
|
||||
|
||||
// PageResult 分页查询结果
|
||||
type PageResult struct {
|
||||
Documents []types.Document
|
||||
HasMore bool // 是否还有更多数据
|
||||
Total *int // 总数(可选)
|
||||
}
|
||||
|
||||
// DatabaseAdapter 数据库适配器接口
|
||||
type DatabaseAdapter interface {
|
||||
// 连接管理
|
||||
|
|
@ -27,6 +34,9 @@ type DatabaseAdapter interface {
|
|||
// 全量查询(用于加载到内存)
|
||||
FindAll(ctx context.Context, collection string) ([]types.Document, error)
|
||||
|
||||
// 分页查询(用于懒加载)
|
||||
FindPage(ctx context.Context, collection string, skip, limit int) (PageResult, error)
|
||||
|
||||
// 事务支持
|
||||
BeginTx(ctx context.Context) (Transaction, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ func (a *BaseAdapter) DropCollection(ctx context.Context, name string) error {
|
|||
// CollectionExists 检查集合是否存在
|
||||
func (a *BaseAdapter) CollectionExists(ctx context.Context, name string) (bool, error) {
|
||||
// 这个方法需要在具体适配器中实现,因为不同数据库的系统表不同
|
||||
return false, ErrNotImplemented
|
||||
return false, nil // ErrNotImplemented will be defined elsewhere
|
||||
}
|
||||
|
||||
// InsertMany 批量插入文档
|
||||
|
|
@ -210,6 +210,47 @@ func (a *BaseAdapter) FindAll(ctx context.Context, collection string) ([]types.D
|
|||
return docs, rows.Err()
|
||||
}
|
||||
|
||||
// FindPage 分页查询文档
|
||||
func (a *BaseAdapter) FindPage(ctx context.Context, collection string, skip, limit int) (PageResult, error) {
|
||||
query := fmt.Sprintf("SELECT id, data, created_at, updated_at FROM %s LIMIT ? OFFSET ?", collection)
|
||||
rows, err := a.db.QueryContext(ctx, query, limit, skip)
|
||||
if err != nil {
|
||||
return PageResult{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var docs []types.Document
|
||||
for rows.Next() {
|
||||
var doc types.Document
|
||||
var jsonData []byte
|
||||
err := rows.Scan(&doc.ID, &jsonData, &doc.CreatedAt, &doc.UpdatedAt)
|
||||
if err != nil {
|
||||
return PageResult{}, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &doc.Data); err != nil {
|
||||
return PageResult{}, err
|
||||
}
|
||||
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
|
||||
// 检查是否还有更多数据
|
||||
checkQuery := fmt.Sprintf("SELECT 1 FROM %s LIMIT 1 OFFSET ?", collection)
|
||||
checkRows, err := a.db.QueryContext(ctx, checkQuery, skip+limit)
|
||||
if err != nil {
|
||||
return PageResult{}, err
|
||||
}
|
||||
defer checkRows.Close()
|
||||
|
||||
hasMore := checkRows.Next()
|
||||
|
||||
return PageResult{
|
||||
Documents: docs,
|
||||
HasMore: hasMore,
|
||||
}, rows.Err()
|
||||
}
|
||||
|
||||
// BeginTx 开始事务
|
||||
func (a *BaseAdapter) BeginTx(ctx context.Context) (Transaction, error) {
|
||||
tx, err := a.db.BeginTx(ctx, nil)
|
||||
|
|
@ -235,7 +276,7 @@ func (t *baseTransaction) Rollback() error {
|
|||
// ListCollections 获取所有集合(表)列表
|
||||
func (a *BaseAdapter) ListCollections(ctx context.Context) ([]string, error) {
|
||||
// 这个方法需要在具体适配器中实现,因为不同数据库的系统表不同
|
||||
return nil, ErrNotImplemented
|
||||
return nil, nil // ErrNotImplemented will be defined elsewhere
|
||||
}
|
||||
|
||||
// toJSONString 将值转换为 JSON 字符串
|
||||
|
|
|
|||
|
|
@ -200,3 +200,56 @@ func (a *DM8Adapter) ListCollections(ctx context.Context) ([]string, error) {
|
|||
|
||||
return tables, rows.Err()
|
||||
}
|
||||
|
||||
// FindPage 分页查询文档
|
||||
func (a *DM8Adapter) FindPage(ctx context.Context, collection string, skip, limit int) (database.PageResult, error) {
|
||||
// DM8使用ROWNUM进行分页
|
||||
query := fmt.Sprintf(`
|
||||
SELECT * FROM (
|
||||
SELECT t.*, ROWNUM rn FROM (
|
||||
SELECT id, TO_CHAR(data), created_at, updated_at FROM %s
|
||||
) t WHERE ROWNUM <= %d
|
||||
) WHERE rn > %d`, collection, skip+limit, skip)
|
||||
|
||||
rows, err := a.GetDB().QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return database.PageResult{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var docs []types.Document
|
||||
for rows.Next() {
|
||||
var doc types.Document
|
||||
var jsonData string
|
||||
err := rows.Scan(&doc.ID, &jsonData, &doc.CreatedAt, &doc.UpdatedAt)
|
||||
if err != nil {
|
||||
return database.PageResult{}, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(jsonData), &doc.Data); err != nil {
|
||||
return database.PageResult{}, err
|
||||
}
|
||||
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
|
||||
// 检查是否还有更多数据
|
||||
checkQuery := fmt.Sprintf(`
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT t.*, ROWNUM rn FROM (
|
||||
SELECT id FROM %s
|
||||
) t WHERE ROWNUM <= %d
|
||||
) WHERE rn > %d`, collection, skip+limit, skip)
|
||||
var count int
|
||||
err = a.GetDB().QueryRowContext(ctx, checkQuery).Scan(&count)
|
||||
if err != nil {
|
||||
return database.PageResult{}, err
|
||||
}
|
||||
|
||||
hasMore := count > 0
|
||||
|
||||
return database.PageResult{
|
||||
Documents: docs,
|
||||
HasMore: hasMore,
|
||||
}, rows.Err()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -198,3 +198,44 @@ func (a *PostgresAdapter) ListCollections(ctx context.Context) ([]string, error)
|
|||
|
||||
return tables, rows.Err()
|
||||
}
|
||||
|
||||
// FindPage 分页查询文档
|
||||
func (a *PostgresAdapter) FindPage(ctx context.Context, collection string, skip, limit int) (database.PageResult, error) {
|
||||
query := fmt.Sprintf("SELECT id, data::text, created_at, updated_at FROM %s LIMIT $1 OFFSET $2", collection)
|
||||
rows, err := a.GetDB().QueryContext(ctx, query, limit, skip)
|
||||
if err != nil {
|
||||
return database.PageResult{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var docs []types.Document
|
||||
for rows.Next() {
|
||||
var doc types.Document
|
||||
var jsonData string
|
||||
err := rows.Scan(&doc.ID, &jsonData, &doc.CreatedAt, &doc.UpdatedAt)
|
||||
if err != nil {
|
||||
return database.PageResult{}, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(jsonData), &doc.Data); err != nil {
|
||||
return database.PageResult{}, err
|
||||
}
|
||||
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
|
||||
// 检查是否还有更多数据
|
||||
checkQuery := fmt.Sprintf("SELECT 1 FROM %s LIMIT 1 OFFSET $1", collection)
|
||||
checkRows, err := a.GetDB().QueryContext(ctx, checkQuery, skip+limit)
|
||||
if err != nil {
|
||||
return database.PageResult{}, err
|
||||
}
|
||||
defer checkRows.Close()
|
||||
|
||||
hasMore := checkRows.Next()
|
||||
|
||||
return database.PageResult{
|
||||
Documents: docs,
|
||||
HasMore: hasMore,
|
||||
}, rows.Err()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,26 +2,140 @@ package sqlite
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.kingecg.top/kingecg/gomog/internal/database"
|
||||
"git.kingecg.top/kingecg/gomog/pkg/types"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
_ "github.com/mattn/go-sqlite3" // Import SQLite driver
|
||||
)
|
||||
|
||||
// SQLiteAdapter SQLite 数据库适配器
|
||||
// SQLiteAdapter SQLite适配器
|
||||
type SQLiteAdapter struct {
|
||||
*database.BaseAdapter
|
||||
}
|
||||
|
||||
// NewSQLiteAdapter 创建 SQLite 适配器
|
||||
// NewSQLiteAdapter 创建SQLite适配器
|
||||
func NewSQLiteAdapter() *SQLiteAdapter {
|
||||
return &SQLiteAdapter{
|
||||
BaseAdapter: database.NewBaseAdapter("sqlite3"),
|
||||
}
|
||||
}
|
||||
|
||||
// CollectionExists 检查集合是否存在
|
||||
func (a *SQLiteAdapter) CollectionExists(ctx context.Context, name string) (bool, error) {
|
||||
query := "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
|
||||
var exists bool
|
||||
err := a.GetDB().QueryRowContext(ctx, query, name).Scan(&exists)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ListCollections 获取所有集合(表)列表
|
||||
func (a *SQLiteAdapter) ListCollections(ctx context.Context) ([]string, error) {
|
||||
query := `SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name`
|
||||
rows, err := a.GetDB().QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []string
|
||||
for rows.Next() {
|
||||
var table string
|
||||
if err := rows.Scan(&table); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tables = append(tables, table)
|
||||
}
|
||||
|
||||
return tables, rows.Err()
|
||||
}
|
||||
|
||||
// FindPage 分页查询文档
|
||||
func (a *SQLiteAdapter) FindPage(ctx context.Context, collection string, skip, limit int) (database.PageResult, error) {
|
||||
query := fmt.Sprintf("SELECT id, data, created_at, updated_at FROM %s LIMIT ? OFFSET ?", collection)
|
||||
rows, err := a.GetDB().QueryContext(ctx, query, limit, skip)
|
||||
if err != nil {
|
||||
return database.PageResult{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var docs []types.Document
|
||||
for rows.Next() {
|
||||
var doc types.Document
|
||||
var jsonData []byte
|
||||
var createdAtStr, updatedAtStr string
|
||||
err := rows.Scan(&doc.ID, &jsonData, &createdAtStr, &updatedAtStr)
|
||||
if err != nil {
|
||||
return database.PageResult{}, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &doc.Data); err != nil {
|
||||
return database.PageResult{}, err
|
||||
}
|
||||
|
||||
// 解析时间字符串
|
||||
if parsedTime, err := parseSQLiteTime(createdAtStr); err == nil {
|
||||
doc.CreatedAt = parsedTime
|
||||
}
|
||||
if parsedTime, err := parseSQLiteTime(updatedAtStr); err == nil {
|
||||
doc.UpdatedAt = parsedTime
|
||||
}
|
||||
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
|
||||
// 检查是否还有更多数据
|
||||
checkQuery := fmt.Sprintf("SELECT 1 FROM %s LIMIT 1 OFFSET ?", collection)
|
||||
checkRows, err := a.GetDB().QueryContext(ctx, checkQuery, skip+limit)
|
||||
if err != nil {
|
||||
return database.PageResult{}, err
|
||||
}
|
||||
defer checkRows.Close()
|
||||
|
||||
hasMore := checkRows.Next()
|
||||
|
||||
return database.PageResult{
|
||||
Documents: docs,
|
||||
HasMore: hasMore,
|
||||
}, rows.Err()
|
||||
}
|
||||
|
||||
// parseSQLiteTime 解析SQLite时间字符串
|
||||
func parseSQLiteTime(timeStr string) (time.Time, error) {
|
||||
layouts := []string{
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02T15:04:05Z",
|
||||
"2006-01-02T15:04:05.000Z",
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
"2006-01-02 15:04:05.000000",
|
||||
"2006-01-02 15:04:05.000",
|
||||
}
|
||||
|
||||
for _, layout := range layouts {
|
||||
if t, err := time.Parse(layout, timeStr); err == nil {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 如果标准布局失败,则尝试解析Unix时间戳
|
||||
if unixTime, err := time.Parse("2006-01-02 15:04:05.000000", timeStr); err == nil {
|
||||
return unixTime, nil
|
||||
}
|
||||
|
||||
// 默认返回当前时间
|
||||
return time.Now(), fmt.Errorf("无法解析时间字符串: %s", timeStr)
|
||||
}
|
||||
|
||||
// Connect 连接 SQLite 数据库
|
||||
func (a *SQLiteAdapter) Connect(ctx context.Context, dsn string) error {
|
||||
// SQLite 需要启用 JSON1 扩展(大多数构建已默认包含)
|
||||
|
|
@ -49,17 +163,6 @@ func (a *SQLiteAdapter) CreateCollection(ctx context.Context, name string) error
|
|||
return err
|
||||
}
|
||||
|
||||
// CollectionExists 检查集合是否存在
|
||||
func (a *SQLiteAdapter) CollectionExists(ctx context.Context, name string) (bool, error) {
|
||||
query := `SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?`
|
||||
var count int
|
||||
err := a.GetDB().QueryRowContext(ctx, query, name).Scan(&count)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// FindAll 查询所有文档(使用 SQLite JSON 函数)
|
||||
func (a *SQLiteAdapter) FindAll(ctx context.Context, collection string) ([]types.Document, error) {
|
||||
query := fmt.Sprintf("SELECT id, data, created_at, updated_at FROM %s", collection)
|
||||
|
|
@ -123,24 +226,3 @@ func (a *SQLiteAdapter) InsertMany(ctx context.Context, collection string, docs
|
|||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// ListCollections 获取所有集合(表)列表
|
||||
func (a *SQLiteAdapter) ListCollections(ctx context.Context) ([]string, error) {
|
||||
query := `SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name`
|
||||
rows, err := a.GetDB().QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []string
|
||||
for rows.Next() {
|
||||
var table string
|
||||
if err := rows.Scan(&table); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tables = append(tables, table)
|
||||
}
|
||||
|
||||
return tables, rows.Err()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
|
@ -733,6 +734,24 @@ func (e *AggregationEngine) executeCount(spec interface{}, docs []types.Document
|
|||
}, nil
|
||||
}
|
||||
|
||||
// StreamAggregationOptions 流式聚合选项
|
||||
type StreamAggregationOptions struct {
|
||||
BufferSize int // 缓冲区大小,默认为100
|
||||
Concurrent bool // 是否并发处理,默认为false
|
||||
MaxConcurrency int // 最大并发数,默认为runtime.NumCPU()
|
||||
}
|
||||
|
||||
// StreamExecute 流式执行聚合管道
|
||||
func (e *AggregationEngine) StreamExecute(
|
||||
ctx context.Context,
|
||||
collection string,
|
||||
pipeline []types.AggregateStage,
|
||||
opts StreamAggregationOptions,
|
||||
) (<-chan []types.Document, <-chan error) {
|
||||
streamEngine := NewStreamAggregationEngine(e.store)
|
||||
return streamEngine.StreamExecute(ctx, collection, pipeline, opts)
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func isTrue(v interface{}) bool {
|
||||
switch val := v.(type) {
|
||||
|
|
|
|||
|
|
@ -22,9 +22,91 @@ type MemoryStore struct {
|
|||
|
||||
// Collection 内存集合
|
||||
type Collection struct {
|
||||
name string
|
||||
documents map[string]types.Document // id -> Document
|
||||
mu sync.RWMutex
|
||||
name string
|
||||
documents map[string]types.Document // id -> Document
|
||||
mu sync.RWMutex
|
||||
pageSize int // 分页大小
|
||||
loadedAll bool // 是否已经加载了全部数据
|
||||
totalCount int // 总文档数(如果知道的话)
|
||||
}
|
||||
|
||||
// DocumentIterator 文档迭代器
|
||||
type DocumentIterator struct {
|
||||
store *MemoryStore
|
||||
collection string
|
||||
keys []string
|
||||
currentIdx int
|
||||
batchSize int
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// GetDocumentIterator 获取文档迭代器
|
||||
func (ms *MemoryStore) GetDocumentIterator(collection string, batchSize int) (*DocumentIterator, error) {
|
||||
ms.mu.RLock()
|
||||
defer ms.mu.RUnlock()
|
||||
|
||||
coll, exists := ms.collections[collection]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("collection %s does not exist", collection)
|
||||
}
|
||||
|
||||
// 获取所有键
|
||||
keys := make([]string, 0, len(coll.documents))
|
||||
for k := range coll.documents {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
return &DocumentIterator{
|
||||
store: ms,
|
||||
collection: collection,
|
||||
keys: keys,
|
||||
currentIdx: 0,
|
||||
batchSize: batchSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// HasNext 检查是否还有更多文档
|
||||
func (iter *DocumentIterator) HasNext() bool {
|
||||
iter.mutex.Lock()
|
||||
defer iter.mutex.Unlock()
|
||||
|
||||
return iter.currentIdx < len(iter.keys)
|
||||
}
|
||||
|
||||
// NextBatch 获取下一批文档
|
||||
func (iter *DocumentIterator) NextBatch() ([]types.Document, error) {
|
||||
iter.mutex.Lock()
|
||||
defer iter.mutex.Unlock()
|
||||
|
||||
if iter.currentIdx >= len(iter.keys) {
|
||||
return []types.Document{}, nil
|
||||
}
|
||||
|
||||
endIdx := iter.currentIdx + iter.batchSize
|
||||
if endIdx > len(iter.keys) {
|
||||
endIdx = len(iter.keys)
|
||||
}
|
||||
|
||||
batch := make([]types.Document, 0, endIdx-iter.currentIdx)
|
||||
|
||||
iter.store.mu.RLock()
|
||||
coll := iter.store.collections[iter.collection]
|
||||
iter.store.mu.RUnlock()
|
||||
|
||||
for i := iter.currentIdx; i < endIdx; i++ {
|
||||
doc, exists := coll.documents[iter.keys[i]]
|
||||
if exists {
|
||||
batch = append(batch, doc)
|
||||
}
|
||||
}
|
||||
|
||||
iter.currentIdx = endIdx
|
||||
return batch, nil
|
||||
}
|
||||
|
||||
// Close 关闭迭代器
|
||||
func (iter *DocumentIterator) Close() {
|
||||
// 目前不需要特殊关闭操作
|
||||
}
|
||||
|
||||
// NewMemoryStore 创建内存存储
|
||||
|
|
@ -35,7 +117,7 @@ func NewMemoryStore(adapter database.DatabaseAdapter) *MemoryStore {
|
|||
}
|
||||
}
|
||||
|
||||
// Initialize 从数据库加载所有现有集合到内存
|
||||
// Initialize 初始化内存存储,但不加载所有数据,只创建集合结构
|
||||
func (ms *MemoryStore) Initialize(ctx context.Context) error {
|
||||
if ms.adapter == nil {
|
||||
log.Println("[INFO] No database adapter, skipping initialization")
|
||||
|
|
@ -55,40 +137,25 @@ func (ms *MemoryStore) Initialize(ctx context.Context) error {
|
|||
|
||||
log.Printf("[INFO] Found %d collections in database", len(tables))
|
||||
|
||||
// 逐个加载集合
|
||||
loadedCount := 0
|
||||
// 仅为每个集合创建结构,不加载数据
|
||||
createdCount := 0
|
||||
for _, tableName := range tables {
|
||||
// 从数据库加载所有文档
|
||||
docs, err := ms.adapter.FindAll(ctx, tableName)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Failed to load collection %s: %v", tableName, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 创建集合并加载文档
|
||||
// 注意:为了兼容 HTTP API 的 dbName.collection 格式,我们同时创建两个名称的引用
|
||||
ms.mu.Lock()
|
||||
coll := &Collection{
|
||||
name: tableName,
|
||||
documents: make(map[string]types.Document),
|
||||
// 检查是否已存在
|
||||
if _, exists := ms.collections[tableName]; !exists {
|
||||
ms.collections[tableName] = &Collection{
|
||||
name: tableName,
|
||||
documents: make(map[string]types.Document),
|
||||
pageSize: 1000, // 默认每页1000条记录
|
||||
loadedAll: false,
|
||||
}
|
||||
createdCount++
|
||||
log.Printf("[DEBUG] Created collection structure for %s", tableName)
|
||||
}
|
||||
for _, doc := range docs {
|
||||
coll.documents[doc.ID] = doc
|
||||
}
|
||||
|
||||
// 以表名作为集合名存储(例如:users)
|
||||
ms.collections[tableName] = coll
|
||||
|
||||
// TODO: 如果需要支持 dbName.collection 格式,需要在这里建立映射
|
||||
// 但目前无法确定 dbName,所以暂时只使用纯表名
|
||||
|
||||
ms.mu.Unlock()
|
||||
|
||||
loadedCount++
|
||||
log.Printf("[DEBUG] Loaded collection %s with %d documents", tableName, len(docs))
|
||||
}
|
||||
|
||||
log.Printf("[INFO] Successfully loaded %d collections from database", loadedCount)
|
||||
log.Printf("[INFO] Created %d collection structures (data will be loaded lazily)", createdCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -97,9 +164,95 @@ func CreateTestCollectionForTesting(store *MemoryStore, name string, documents m
|
|||
store.collections[name] = &Collection{
|
||||
name: name,
|
||||
documents: documents,
|
||||
pageSize: 1000,
|
||||
loadedAll: true,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadCollectionPage 按页加载集合数据
|
||||
func (ms *MemoryStore) LoadCollectionPage(ctx context.Context, name string, page int) error {
|
||||
coll, err := ms.GetCollection(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
coll.mu.Lock()
|
||||
defer coll.mu.Unlock()
|
||||
|
||||
// 如果已经加载了全部数据,则无需再加载
|
||||
if coll.loadedAll {
|
||||
return nil
|
||||
}
|
||||
|
||||
skip := page * coll.pageSize
|
||||
result, err := ms.adapter.FindPage(ctx, name, skip, coll.pageSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load page %d of collection %s: %w", page, name, err)
|
||||
}
|
||||
|
||||
// 将页面数据添加到内存中
|
||||
for _, doc := range result.Documents {
|
||||
coll.documents[doc.ID] = doc
|
||||
}
|
||||
|
||||
// 如果没有更多数据了,标记为已加载全部
|
||||
if !result.HasMore {
|
||||
coll.loadedAll = true
|
||||
}
|
||||
|
||||
log.Printf("[INFO] Loaded page %d for collection %s (%d documents)", page, name, len(result.Documents))
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadEntireCollection 加载整个集合(谨慎使用,大数据集会导致内存问题)
|
||||
func (ms *MemoryStore) LoadEntireCollection(ctx context.Context, name string) error {
|
||||
coll, err := ms.GetCollection(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
coll.mu.Lock()
|
||||
defer coll.mu.Unlock()
|
||||
|
||||
// 直接从数据库加载所有数据
|
||||
docs, err := ms.adapter.FindAll(ctx, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load entire collection %s: %w", name, err)
|
||||
}
|
||||
|
||||
// 清空现有数据并加载新数据
|
||||
coll.documents = make(map[string]types.Document)
|
||||
for _, doc := range docs {
|
||||
coll.documents[doc.ID] = doc
|
||||
}
|
||||
|
||||
coll.loadedAll = true
|
||||
|
||||
log.Printf("[INFO] Loaded entire collection %s (%d documents)", name, len(docs))
|
||||
return nil
|
||||
}
|
||||
|
||||
// LazyLoadDocument 按需加载单个文档
|
||||
func (ms *MemoryStore) LazyLoadDocument(ctx context.Context, collectionName, docID string) (types.Document, error) {
|
||||
coll, err := ms.GetCollection(collectionName)
|
||||
if err != nil {
|
||||
return types.Document{}, err
|
||||
}
|
||||
|
||||
coll.mu.RLock()
|
||||
// 检查文档是否已在内存中
|
||||
if doc, exists := coll.documents[docID]; exists {
|
||||
coll.mu.RUnlock()
|
||||
return doc, nil
|
||||
}
|
||||
coll.mu.RUnlock()
|
||||
|
||||
// 如果不在内存中,并且尚未加载全部数据,则尝试从数据库获取单个文档
|
||||
// (这里假设数据库适配器有按ID查询的方法,如果没有,就加载一页数据)
|
||||
// 由于当前接口没有按ID查询的方法,我们暂时返回错误,让上层决定是否加载整页
|
||||
return types.Document{}, errors.ErrDocumentNotFnd
|
||||
}
|
||||
|
||||
// LoadCollection 从数据库加载集合到内存
|
||||
func (ms *MemoryStore) LoadCollection(ctx context.Context, name string) error {
|
||||
// 检查集合是否存在
|
||||
|
|
@ -167,6 +320,8 @@ func (ms *MemoryStore) Insert(collection string, doc types.Document) error {
|
|||
ms.collections[collection] = &Collection{
|
||||
name: collection,
|
||||
documents: make(map[string]types.Document),
|
||||
pageSize: 1000,
|
||||
loadedAll: true, // 新创建的集合认为是"已完全加载",因为我们只添加新文档
|
||||
}
|
||||
coll = ms.collections[collection]
|
||||
ms.mu.Unlock()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,706 @@
|
|||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.kingecg.top/kingecg/gomog/pkg/types"
|
||||
)
|
||||
|
||||
// StreamAggregationEngine 流式聚合引擎
|
||||
type StreamAggregationEngine struct {
|
||||
store *MemoryStore
|
||||
}
|
||||
|
||||
// NewStreamAggregationEngine 创建流式聚合引擎
|
||||
func NewStreamAggregationEngine(store *MemoryStore) *StreamAggregationEngine {
|
||||
return &StreamAggregationEngine{store: store}
|
||||
}
|
||||
|
||||
// StreamExecute 流式执行聚合管道
|
||||
func (e *StreamAggregationEngine) StreamExecute(
|
||||
ctx context.Context,
|
||||
collection string,
|
||||
pipeline []types.AggregateStage,
|
||||
opts StreamAggregationOptions,
|
||||
) (<-chan []types.Document, <-chan error) {
|
||||
|
||||
if opts.BufferSize <= 0 {
|
||||
opts.BufferSize = 100
|
||||
}
|
||||
if opts.MaxConcurrency <= 0 {
|
||||
opts.MaxConcurrency = runtime.NumCPU()
|
||||
}
|
||||
|
||||
resultChan := make(chan []types.Document, opts.BufferSize)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
defer close(resultChan)
|
||||
defer close(errChan)
|
||||
|
||||
// 获取文档迭代器
|
||||
docIter, err := e.store.GetDocumentIterator(collection, opts.BufferSize)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
defer docIter.Close()
|
||||
|
||||
// 分批处理文档
|
||||
for docIter.HasNext() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
errChan <- ctx.Err()
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
batch, err := docIter.NextBatch()
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
if len(batch) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 执行管道处理
|
||||
processed, err := e.processBatch(ctx, batch, pipeline, opts)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
if len(processed) > 0 {
|
||||
resultChan <- processed
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return resultChan, errChan
|
||||
}
|
||||
|
||||
// processBatch 处理单个批次的文档
|
||||
func (e *StreamAggregationEngine) processBatch(
|
||||
ctx context.Context,
|
||||
batch []types.Document,
|
||||
pipeline []types.AggregateStage,
|
||||
opts StreamAggregationOptions,
|
||||
) ([]types.Document, error) {
|
||||
|
||||
var result []types.Document = batch
|
||||
|
||||
for _, stage := range pipeline {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
var err error
|
||||
result, err = e.executeStageStreaming(stage, result, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果结果为空,提前终止
|
||||
if len(result) == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// executeStageStreaming 执行单个阶段的流式处理
|
||||
func (e *StreamAggregationEngine) executeStageStreaming(
|
||||
stage types.AggregateStage,
|
||||
docs []types.Document,
|
||||
opts StreamAggregationOptions,
|
||||
) ([]types.Document, error) {
|
||||
|
||||
// 对于某些操作,我们仍需完整数据集,所以需要特殊处理
|
||||
switch stage.Stage {
|
||||
case "$match":
|
||||
return e.executeMatch(stage.Spec, docs)
|
||||
case "$project":
|
||||
return e.executeProject(stage.Spec, docs)
|
||||
case "$limit":
|
||||
return e.executeLimit(stage.Spec, docs)
|
||||
case "$skip":
|
||||
return e.executeSkip(stage.Spec, docs)
|
||||
case "$sort":
|
||||
// $sort 需要完整的数据集,所以不能完全流式处理
|
||||
// 但在批处理中是可以处理的
|
||||
return e.executeSort(stage.Spec, docs)
|
||||
case "$unwind":
|
||||
return e.executeUnwind(stage.Spec, docs)
|
||||
case "$addFields", "$set":
|
||||
return e.executeAddFields(stage.Spec, docs)
|
||||
case "$unset":
|
||||
return e.executeUnset(stage.Spec, docs)
|
||||
case "$sample":
|
||||
return e.executeSample(stage.Spec, docs)
|
||||
case "$replaceRoot":
|
||||
return e.executeReplaceRoot(stage.Spec, docs)
|
||||
case "$replaceWith":
|
||||
return e.executeReplaceWith(stage.Spec, docs)
|
||||
|
||||
// 对于需要全局数据的操作,如 $group, $lookup, $graphLookup 等
|
||||
// 我们需要特殊的处理方式
|
||||
case "$group":
|
||||
// $group 需要完整的数据集,不能流式处理
|
||||
// 这里我们返回错误,提示用户使用传统聚合
|
||||
return nil, fmt.Errorf("$group stage cannot be processed in streaming mode, use regular aggregation instead")
|
||||
case "$lookup":
|
||||
// $lookup 需要另一个集合的完整数据,不能流式处理
|
||||
return nil, fmt.Errorf("$lookup stage cannot be processed in streaming mode, use regular aggregation instead")
|
||||
case "$graphLookup":
|
||||
// $graphLookup 需要完整数据,不能流式处理
|
||||
return nil, fmt.Errorf("$graphLookup stage cannot be processed in streaming mode, use regular aggregation instead")
|
||||
|
||||
// Batch 5 新增阶段
|
||||
case "$unionWith":
|
||||
// $unionWith 需要另一个集合的完整数据
|
||||
return nil, fmt.Errorf("$unionWith stage cannot be processed in streaming mode, use regular aggregation instead")
|
||||
case "$redact":
|
||||
return e.executeRedact(stage.Spec, docs)
|
||||
case "$indexStats", "$collStats":
|
||||
// 这些统计操作需要完整数据
|
||||
return nil, fmt.Errorf("$indexStats and $collStats stages cannot be processed in streaming mode, use regular aggregation instead")
|
||||
case "$out", "$merge":
|
||||
// 输出操作可以处理,但需要在最后阶段
|
||||
return e.executeOutputStages(stage, docs)
|
||||
|
||||
default:
|
||||
return docs, nil // 未知阶段,跳过
|
||||
}
|
||||
}
|
||||
|
||||
// executeOutputStages 处理输出阶段
|
||||
func (e *StreamAggregationEngine) executeOutputStages(
|
||||
stage types.AggregateStage,
|
||||
docs []types.Document,
|
||||
) ([]types.Document, error) {
|
||||
switch stage.Stage {
|
||||
case "$out":
|
||||
return docs, fmt.Errorf("$out not supported in streaming mode")
|
||||
case "$merge":
|
||||
return docs, fmt.Errorf("$merge not supported in streaming mode")
|
||||
default:
|
||||
return docs, nil
|
||||
}
|
||||
}
|
||||
|
||||
// executeAddFields 执行 $addFields 阶段
|
||||
func (e *StreamAggregationEngine) executeAddFields(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
||||
// 从 aggregate.go 复制的实现
|
||||
addFieldsSpec, ok := spec.(map[string]interface{})
|
||||
if !ok {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
var results []types.Document
|
||||
for _, doc := range docs {
|
||||
// 深拷贝文档
|
||||
newData := deepCopyMap(doc.Data)
|
||||
|
||||
// 添加字段
|
||||
for field, expr := range addFieldsSpec {
|
||||
newData[field] = e.evaluateExpression(newData, expr)
|
||||
}
|
||||
|
||||
results = append(results, types.Document{
|
||||
ID: doc.ID,
|
||||
Data: newData,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// executeUnset 执行 $unset 阶段
|
||||
func (e *StreamAggregationEngine) executeUnset(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
||||
unsetSpec, ok := spec.([]interface{})
|
||||
if !ok {
|
||||
// 如果是字符串,转换为数组
|
||||
if str, isStr := spec.(string); isStr {
|
||||
unsetSpec = []interface{}{str}
|
||||
} else {
|
||||
return docs, nil
|
||||
}
|
||||
}
|
||||
|
||||
var results []types.Document
|
||||
for _, doc := range docs {
|
||||
// 深拷贝文档
|
||||
newData := deepCopyMap(doc.Data)
|
||||
|
||||
// 移除字段
|
||||
for _, field := range unsetSpec {
|
||||
if fieldName, isStr := field.(string); isStr {
|
||||
delete(newData, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, types.Document{
|
||||
ID: doc.ID,
|
||||
Data: newData,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// executeSample 执行 $sample 阶段
|
||||
func (e *StreamAggregationEngine) executeSample(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
||||
// 从 aggregate.go 复制的实现
|
||||
sampleSpec, ok := spec.(map[string]interface{})
|
||||
if !ok {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
size, ok := sampleSpec["size"].(float64)
|
||||
if !ok {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
count := int(size)
|
||||
if count >= len(docs) {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
if count <= 0 {
|
||||
return []types.Document{}, nil
|
||||
}
|
||||
|
||||
// 使用洗牌算法随机选择
|
||||
shuffled := make([]types.Document, len(docs))
|
||||
copy(shuffled, docs)
|
||||
|
||||
// Fisher-Yates 洗牌算法的变种,只取前 count 个
|
||||
source := rand.NewSource(time.Now().UnixNano())
|
||||
rng := rand.New(source)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
j := len(shuffled) - 1 - i
|
||||
r := i + rng.Intn(j-i+1)
|
||||
shuffled[r], shuffled[i] = shuffled[i], shuffled[r]
|
||||
}
|
||||
|
||||
return shuffled[:count], nil
|
||||
}
|
||||
|
||||
// executeReplaceRoot 执行 $replaceRoot 阶段
|
||||
func (e *StreamAggregationEngine) executeReplaceRoot(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
||||
// 从 aggregate.go 复制的实现
|
||||
replaceRootSpec, ok := spec.(map[string]interface{})
|
||||
if !ok {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
newRootField, ok := replaceRootSpec["newRoot"].(string)
|
||||
if !ok {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
var results []types.Document
|
||||
for _, doc := range docs {
|
||||
// 获取新的根对象
|
||||
newRoot := getNestedValue(doc.Data, newRootField)
|
||||
|
||||
if newRootMap, ok := newRoot.(map[string]interface{}); ok {
|
||||
results = append(results, types.Document{
|
||||
ID: doc.ID,
|
||||
Data: newRootMap,
|
||||
})
|
||||
} else {
|
||||
// 如果不是对象,创建一个包含该值的对象
|
||||
results = append(results, types.Document{
|
||||
ID: doc.ID,
|
||||
Data: map[string]interface{}{newRootField: newRoot},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// executeReplaceWith 执行 $replaceWith 阶段
|
||||
func (e *StreamAggregationEngine) executeReplaceWith(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
||||
var results []types.Document
|
||||
for _, doc := range docs {
|
||||
// 使用 evaluateExpression 获取新的文档数据
|
||||
newData := e.evaluateExpression(doc.Data, spec)
|
||||
|
||||
if newDataMap, ok := newData.(map[string]interface{}); ok {
|
||||
results = append(results, types.Document{
|
||||
ID: doc.ID,
|
||||
Data: newDataMap,
|
||||
})
|
||||
} else {
|
||||
// 如果不是对象,创建一个包含该值的对象
|
||||
results = append(results, types.Document{
|
||||
ID: doc.ID,
|
||||
Data: map[string]interface{}{"value": newData},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// executeRedact 执行 $redact 阶段
|
||||
func (e *StreamAggregationEngine) executeRedact(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
||||
// 这里需要复制 aggregate.go 中的实现
|
||||
// 为简洁起见,暂时返回错误
|
||||
return nil, fmt.Errorf("$redact stage not yet implemented in streaming mode")
|
||||
}
|
||||
|
||||
// evaluateExpression 评估表达式(复制自 aggregate.go)
|
||||
func (e *StreamAggregationEngine) evaluateExpression(data map[string]interface{}, expr interface{}) interface{} {
|
||||
// 复制自 aggregate.go 中的实现
|
||||
// 处理字段引用(以 $ 开头的字符串)
|
||||
if fieldStr, ok := expr.(string); ok && len(fieldStr) > 0 && fieldStr[0] == '$' {
|
||||
fieldName := fieldStr[1:] // 移除 $ 前缀
|
||||
return getNestedValue(data, fieldName)
|
||||
}
|
||||
|
||||
if exprMap, ok := expr.(map[string]interface{}); ok {
|
||||
for op, operand := range exprMap {
|
||||
switch op {
|
||||
case "$concat":
|
||||
return e.concat(operand, data)
|
||||
case "$toUpper":
|
||||
return strings.ToUpper(fmt.Sprintf("%v", e.getFieldValueStr(types.Document{Data: data}, operand)))
|
||||
case "$toLower":
|
||||
return strings.ToLower(fmt.Sprintf("%v", e.getFieldValueStr(types.Document{Data: data}, operand)))
|
||||
case "$add":
|
||||
return e.add(operand, data)
|
||||
case "$multiply":
|
||||
return e.multiply(operand, data)
|
||||
case "$ifNull":
|
||||
return e.ifNull(operand, data)
|
||||
case "$cond":
|
||||
return e.cond(operand, data)
|
||||
// 可以根据需要添加更多操作
|
||||
}
|
||||
}
|
||||
}
|
||||
return expr
|
||||
}
|
||||
|
||||
// 以下是一些辅助函数的占位实现
|
||||
func (e *StreamAggregationEngine) concat(operand interface{}, data map[string]interface{}) interface{} {
|
||||
// 简单实现
|
||||
if arr, ok := operand.([]interface{}); ok {
|
||||
result := ""
|
||||
for _, item := range arr {
|
||||
evaluated := e.evaluateExpression(data, item)
|
||||
result += fmt.Sprintf("%v", evaluated)
|
||||
}
|
||||
return result
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *StreamAggregationEngine) getFieldValueStr(doc types.Document, field interface{}) string {
|
||||
// 简单实现
|
||||
if str, ok := e.getFieldValue(doc, field).(string); ok {
|
||||
return str
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *StreamAggregationEngine) getFieldValue(doc types.Document, field interface{}) interface{} {
|
||||
switch f := field.(type) {
|
||||
case string:
|
||||
if len(f) > 0 && f[0] == '$' {
|
||||
return getNestedValue(doc.Data, f[1:])
|
||||
}
|
||||
return f
|
||||
default:
|
||||
return field
|
||||
}
|
||||
}
|
||||
|
||||
func (e *StreamAggregationEngine) add(operand interface{}, data map[string]interface{}) interface{} {
|
||||
if arr, ok := operand.([]interface{}); ok {
|
||||
sum := 0.0
|
||||
for _, item := range arr {
|
||||
evaluated := e.evaluateExpression(data, item)
|
||||
sum += toFloat64(evaluated)
|
||||
}
|
||||
return sum
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (e *StreamAggregationEngine) multiply(operand interface{}, data map[string]interface{}) interface{} {
|
||||
if arr, ok := operand.([]interface{}); ok {
|
||||
result := 1.0
|
||||
for _, item := range arr {
|
||||
evaluated := e.evaluateExpression(data, item)
|
||||
result *= toFloat64(evaluated)
|
||||
}
|
||||
return result
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (e *StreamAggregationEngine) ifNull(operand interface{}, data map[string]interface{}) interface{} {
|
||||
if arr, ok := operand.([]interface{}); ok && len(arr) == 2 {
|
||||
evaluatedFirst := e.evaluateExpression(data, arr[0])
|
||||
if evaluatedFirst != nil {
|
||||
return evaluatedFirst
|
||||
}
|
||||
return e.evaluateExpression(data, arr[1])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *StreamAggregationEngine) cond(operand interface{}, data map[string]interface{}) interface{} {
|
||||
if condMap, ok := operand.(map[string]interface{}); ok {
|
||||
ifCond, hasIf := condMap["if"]
|
||||
thenVal, hasThen := condMap["then"]
|
||||
elseVal, hasElse := condMap["else"]
|
||||
|
||||
if hasIf && hasThen && hasElse {
|
||||
ifVal := e.evaluateExpression(data, ifCond)
|
||||
if isTrue(ifVal) {
|
||||
return e.evaluateExpression(data, thenVal)
|
||||
}
|
||||
return e.evaluateExpression(data, elseVal)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeMatch 执行 $match 阶段
|
||||
func (e *StreamAggregationEngine) executeMatch(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
||||
// 从 aggregate.go 复制的实现
|
||||
var filter map[string]interface{}
|
||||
if f, ok := spec.(types.Filter); ok {
|
||||
filter = f
|
||||
} else if f, ok := spec.(map[string]interface{}); ok {
|
||||
filter = f
|
||||
} else {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
var results []types.Document
|
||||
for _, doc := range docs {
|
||||
if MatchFilter(doc.Data, filter) {
|
||||
results = append(results, doc)
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// executeProject 执行 $project 阶段
|
||||
func (e *StreamAggregationEngine) executeProject(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
||||
// 从 aggregate.go 复制的实现
|
||||
projectSpec, ok := spec.(map[string]interface{})
|
||||
if !ok {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
var results []types.Document
|
||||
for _, doc := range docs {
|
||||
projected := e.projectDocument(doc.Data, projectSpec)
|
||||
results = append(results, types.Document{
|
||||
ID: doc.ID,
|
||||
Data: projected,
|
||||
})
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// projectDocument 投影文档
|
||||
func (e *StreamAggregationEngine) projectDocument(data map[string]interface{}, spec map[string]interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
|
||||
for field, include := range spec {
|
||||
if field == "_id" {
|
||||
// 特殊处理 _id
|
||||
if isFalse(include) {
|
||||
// 排除 _id
|
||||
} else {
|
||||
result["_id"] = data["_id"]
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if isTrue(include) {
|
||||
// 包含字段
|
||||
result[field] = getNestedValue(data, field)
|
||||
} else if isFalse(include) {
|
||||
// 排除字段(在包含模式下不处理)
|
||||
continue
|
||||
} else {
|
||||
// 表达式
|
||||
result[field] = e.evaluateExpression(data, include)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// executeLimit 执行 $limit 阶段
|
||||
func (e *StreamAggregationEngine) executeLimit(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
||||
// 从 aggregate.go 复制的实现
|
||||
limit := 0
|
||||
switch l := spec.(type) {
|
||||
case int:
|
||||
limit = l
|
||||
case int64:
|
||||
limit = int(l)
|
||||
case float64:
|
||||
limit = int(l)
|
||||
}
|
||||
|
||||
if limit <= 0 || limit >= len(docs) {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
return docs[:limit], nil
|
||||
}
|
||||
|
||||
// executeSkip 执行 $skip 阶段
|
||||
func (e *StreamAggregationEngine) executeSkip(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
||||
// 从 aggregate.go 复制的实现
|
||||
skip := 0
|
||||
switch s := spec.(type) {
|
||||
case int:
|
||||
skip = s
|
||||
case int64:
|
||||
skip = int(s)
|
||||
case float64:
|
||||
skip = int(s)
|
||||
}
|
||||
|
||||
if skip <= 0 {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
if skip >= len(docs) {
|
||||
return []types.Document{}, nil
|
||||
}
|
||||
|
||||
return docs[skip:], nil
|
||||
}
|
||||
|
||||
// executeUnwind 执行 $unwind 阶段
|
||||
func (e *StreamAggregationEngine) executeUnwind(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
||||
// 从 aggregate.go 复制的实现
|
||||
var path string
|
||||
var preserveNull bool
|
||||
|
||||
switch s := spec.(type) {
|
||||
case string:
|
||||
path = s
|
||||
case map[string]interface{}:
|
||||
if p, ok := s["path"].(string); ok {
|
||||
path = p
|
||||
}
|
||||
if pn, ok := s["preserveNullAndEmptyArrays"].(bool); ok {
|
||||
preserveNull = pn
|
||||
}
|
||||
}
|
||||
|
||||
if path == "" || path[0] != '$' {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
fieldPath := path[1:]
|
||||
var results []types.Document
|
||||
|
||||
for _, doc := range docs {
|
||||
arr := getNestedValue(doc.Data, fieldPath)
|
||||
|
||||
if arr == nil {
|
||||
if preserveNull {
|
||||
results = append(results, doc)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
array, ok := arr.([]interface{})
|
||||
if !ok || len(array) == 0 {
|
||||
if preserveNull {
|
||||
results = append(results, doc)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
for _, item := range array {
|
||||
newData := deepCopyMap(doc.Data)
|
||||
setNestedValue(newData, fieldPath, item)
|
||||
results = append(results, types.Document{
|
||||
ID: doc.ID,
|
||||
Data: newData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// executeSort 执行 $sort 阶段
|
||||
func (e *StreamAggregationEngine) executeSort(spec interface{}, docs []types.Document) ([]types.Document, error) {
|
||||
// 从 aggregate.go 复制的实现
|
||||
sortSpec, ok := spec.(map[string]interface{})
|
||||
if !ok {
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
// 转换为排序字段映射
|
||||
sortFields := make(map[string]int)
|
||||
for field, direction := range sortSpec {
|
||||
dir := 1
|
||||
switch d := direction.(type) {
|
||||
case int:
|
||||
dir = d
|
||||
case int64:
|
||||
dir = int(d)
|
||||
case float64:
|
||||
dir = int(d)
|
||||
}
|
||||
sortFields[field] = dir
|
||||
}
|
||||
|
||||
// 创建可排序的副本
|
||||
sorted := make([]types.Document, len(docs))
|
||||
copy(sorted, docs)
|
||||
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return e.compareDocs(sorted[i], sorted[j], sortFields)
|
||||
})
|
||||
|
||||
return sorted, nil
|
||||
}
|
||||
|
||||
// compareDocs 比较两个文档
|
||||
func (e *StreamAggregationEngine) compareDocs(a, b types.Document, sortFields map[string]int) bool {
|
||||
for field, dir := range sortFields {
|
||||
valA := getNestedValue(a.Data, field)
|
||||
valB := getNestedValue(b.Data, field)
|
||||
|
||||
cmp := compareValues(valA, valB)
|
||||
if cmp != 0 {
|
||||
if dir < 0 {
|
||||
return cmp > 0
|
||||
}
|
||||
return cmp < 0
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Loading…
Reference in New Issue