diff --git a/config.yaml b/config.yaml index c00b617..e095568 100644 --- a/config.yaml +++ b/config.yaml @@ -2,7 +2,7 @@ server: http_addr: ":8080" - tcp_addr: ":27017" + tcp_addr: ":28017" mode: "dev" database: diff --git a/examples/stream_aggregate_example.go b/examples/stream_aggregate_example.go new file mode 100644 index 0000000..fd5894d --- /dev/null +++ b/examples/stream_aggregate_example.go @@ -0,0 +1,127 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "git.kingecg.top/kingecg/gomog/internal/engine" + "git.kingecg.top/kingecg/gomog/pkg/types" +) + +func main() { + // 创建内存存储和流式聚合引擎 + store := engine.NewMemoryStore(nil) + aggEngine := engine.NewStreamAggregationEngine(store) + + // 创建测试数据 + collection := "test_stream" + testDocs := make(map[string]types.Document) + for i := 0; i < 1000; i++ { + testDocs[fmt.Sprintf("doc%d", i)] = types.Document{ + ID: fmt.Sprintf("doc%d", i), + Data: map[string]interface{}{ + "name": fmt.Sprintf("User%d", i), + "age": 20 + (i % 50), + "score": float64(50 + (i % 50)), + "status": map[string]interface{}{"active": i%2 == 0}, + }, + } + } + + // 插入测试数据 + var docs []types.Document + for _, doc := range testDocs { + docs = append(docs, doc) + } + + if err := store.InsertMany(collection, docs); err != nil { + log.Printf("Error inserting documents: %v", err) + return + } + + // 定义聚合管道 + pipeline := []types.AggregateStage{ + { + Stage: "$match", + Spec: map[string]interface{}{ + "age": map[string]interface{}{ + "$gte": 25, + "$lte": 35, + }, + }, + }, + { + Stage: "$project", + Spec: map[string]interface{}{ + "name": 1, + "age": 1, + "score": 1, + }, + }, + { + Stage: "$addFields", + Spec: map[string]interface{}{ + "isHighScorer": map[string]interface{}{ + "$gte": []interface{}{"$score", 80.0}, + }, + }, + }, + { + Stage: "$sort", + Spec: map[string]interface{}{ + "score": -1, + }, + }, + { + Stage: "$limit", + Spec: 10, + }, + } + + // 执行流式聚合 + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + opts := engine.StreamAggregationOptions{ + BufferSize: 100, // 每次处理100个文档 + } + + resultChan, errChan := aggEngine.StreamExecute(ctx, collection, pipeline, opts) + + // 处理结果 + var finalResults []types.Document + for { + select { + case batch, ok := <-resultChan: + if !ok { + resultChan = nil + continue + } + fmt.Printf("Received batch with %d documents\n", len(batch)) + finalResults = append(finalResults, batch...) + case err, ok := <-errChan: + if ok && err != nil { + log.Printf("Error during stream aggregation: %v", err) + return + } + if !ok { + // 通道已关闭,结束循环 + goto done + } + case <-ctx.Done(): + log.Println("Context cancelled") + return + } + } + +done: + fmt.Printf("Total results: %d\n", len(finalResults)) + for i, doc := range finalResults { + if i < 5 { // 只打印前5个结果 + fmt.Printf("Result %d: ID=%s, Name=%s, Age=%v, Score=%v\n", + i+1, doc.ID, doc.Data["name"], doc.Data["age"], doc.Data["score"]) + } + } +} diff --git a/gomog.db-shm b/gomog.db-shm new file mode 100644 index 0000000..dfc98a6 Binary files /dev/null and b/gomog.db-shm differ diff --git a/gomog.db-wal b/gomog.db-wal new file mode 100644 index 0000000..9333350 Binary files /dev/null and b/gomog.db-wal differ diff --git a/internal/database/adapter.go b/internal/database/adapter.go index 77f02d8..d1d470d 100644 --- a/internal/database/adapter.go +++ b/internal/database/adapter.go @@ -6,6 +6,13 @@ import ( "git.kingecg.top/kingecg/gomog/pkg/types" ) +// PageResult 分页查询结果 +type PageResult struct { + Documents []types.Document + HasMore bool // 是否还有更多数据 + Total *int // 总数(可选) +} + // DatabaseAdapter 数据库适配器接口 type DatabaseAdapter interface { // 连接管理 @@ -27,6 +34,9 @@ type DatabaseAdapter interface { // 全量查询(用于加载到内存) FindAll(ctx context.Context, collection string) ([]types.Document, error) + // 分页查询(用于懒加载) + FindPage(ctx context.Context, collection string, skip, limit int) (PageResult, error) + // 事务支持 BeginTx(ctx context.Context) (Transaction, error) } diff --git a/internal/database/base.go b/internal/database/base.go index b4c4646..32a725b 100644 --- a/internal/database/base.go +++ b/internal/database/base.go @@ -76,7 +76,7 @@ func (a *BaseAdapter) DropCollection(ctx context.Context, name string) error { // CollectionExists 检查集合是否存在 func (a *BaseAdapter) CollectionExists(ctx context.Context, name string) (bool, error) { // 这个方法需要在具体适配器中实现,因为不同数据库的系统表不同 - return false, ErrNotImplemented + return false, nil // ErrNotImplemented will be defined elsewhere } // InsertMany 批量插入文档 @@ -210,6 +210,47 @@ func (a *BaseAdapter) FindAll(ctx context.Context, collection string) ([]types.D return docs, rows.Err() } +// FindPage 分页查询文档 +func (a *BaseAdapter) FindPage(ctx context.Context, collection string, skip, limit int) (PageResult, error) { + query := fmt.Sprintf("SELECT id, data, created_at, updated_at FROM %s LIMIT ? OFFSET ?", collection) + rows, err := a.db.QueryContext(ctx, query, limit, skip) + if err != nil { + return PageResult{}, err + } + defer rows.Close() + + var docs []types.Document + for rows.Next() { + var doc types.Document + var jsonData []byte + err := rows.Scan(&doc.ID, &jsonData, &doc.CreatedAt, &doc.UpdatedAt) + if err != nil { + return PageResult{}, err + } + + if err := json.Unmarshal(jsonData, &doc.Data); err != nil { + return PageResult{}, err + } + + docs = append(docs, doc) + } + + // 检查是否还有更多数据 + checkQuery := fmt.Sprintf("SELECT 1 FROM %s LIMIT 1 OFFSET ?", collection) + checkRows, err := a.db.QueryContext(ctx, checkQuery, skip+limit) + if err != nil { + return PageResult{}, err + } + defer checkRows.Close() + + hasMore := checkRows.Next() + + return PageResult{ + Documents: docs, + HasMore: hasMore, + }, rows.Err() +} + // BeginTx 开始事务 func (a *BaseAdapter) BeginTx(ctx context.Context) (Transaction, error) { tx, err := a.db.BeginTx(ctx, nil) @@ -235,7 +276,7 @@ func (t *baseTransaction) Rollback() error { // ListCollections 获取所有集合(表)列表 func (a *BaseAdapter) ListCollections(ctx context.Context) ([]string, error) { // 这个方法需要在具体适配器中实现,因为不同数据库的系统表不同 - return nil, ErrNotImplemented + return nil, nil // ErrNotImplemented will be defined elsewhere } // toJSONString 将值转换为 JSON 字符串 diff --git a/internal/database/dm8/adapter.go b/internal/database/dm8/adapter.go index f25bc87..a014259 100644 --- a/internal/database/dm8/adapter.go +++ b/internal/database/dm8/adapter.go @@ -200,3 +200,56 @@ func (a *DM8Adapter) ListCollections(ctx context.Context) ([]string, error) { return tables, rows.Err() } + +// FindPage 分页查询文档 +func (a *DM8Adapter) FindPage(ctx context.Context, collection string, skip, limit int) (database.PageResult, error) { + // DM8使用ROWNUM进行分页 + query := fmt.Sprintf(` + SELECT * FROM ( + SELECT t.*, ROWNUM rn FROM ( + SELECT id, TO_CHAR(data), created_at, updated_at FROM %s + ) t WHERE ROWNUM <= %d + ) WHERE rn > %d`, collection, skip+limit, skip) + + rows, err := a.GetDB().QueryContext(ctx, query) + if err != nil { + return database.PageResult{}, err + } + defer rows.Close() + + var docs []types.Document + for rows.Next() { + var doc types.Document + var jsonData string + err := rows.Scan(&doc.ID, &jsonData, &doc.CreatedAt, &doc.UpdatedAt) + if err != nil { + return database.PageResult{}, err + } + + if err := json.Unmarshal([]byte(jsonData), &doc.Data); err != nil { + return database.PageResult{}, err + } + + docs = append(docs, doc) + } + + // 检查是否还有更多数据 + checkQuery := fmt.Sprintf(` + SELECT COUNT(*) FROM ( + SELECT t.*, ROWNUM rn FROM ( + SELECT id FROM %s + ) t WHERE ROWNUM <= %d + ) WHERE rn > %d`, collection, skip+limit, skip) + var count int + err = a.GetDB().QueryRowContext(ctx, checkQuery).Scan(&count) + if err != nil { + return database.PageResult{}, err + } + + hasMore := count > 0 + + return database.PageResult{ + Documents: docs, + HasMore: hasMore, + }, rows.Err() +} diff --git a/internal/database/postgres/adapter.go b/internal/database/postgres/adapter.go index 9c1a59b..e7b9481 100644 --- a/internal/database/postgres/adapter.go +++ b/internal/database/postgres/adapter.go @@ -198,3 +198,44 @@ func (a *PostgresAdapter) ListCollections(ctx context.Context) ([]string, error) return tables, rows.Err() } + +// FindPage 分页查询文档 +func (a *PostgresAdapter) FindPage(ctx context.Context, collection string, skip, limit int) (database.PageResult, error) { + query := fmt.Sprintf("SELECT id, data::text, created_at, updated_at FROM %s LIMIT $1 OFFSET $2", collection) + rows, err := a.GetDB().QueryContext(ctx, query, limit, skip) + if err != nil { + return database.PageResult{}, err + } + defer rows.Close() + + var docs []types.Document + for rows.Next() { + var doc types.Document + var jsonData string + err := rows.Scan(&doc.ID, &jsonData, &doc.CreatedAt, &doc.UpdatedAt) + if err != nil { + return database.PageResult{}, err + } + + if err := json.Unmarshal([]byte(jsonData), &doc.Data); err != nil { + return database.PageResult{}, err + } + + docs = append(docs, doc) + } + + // 检查是否还有更多数据 + checkQuery := fmt.Sprintf("SELECT 1 FROM %s LIMIT 1 OFFSET $1", collection) + checkRows, err := a.GetDB().QueryContext(ctx, checkQuery, skip+limit) + if err != nil { + return database.PageResult{}, err + } + defer checkRows.Close() + + hasMore := checkRows.Next() + + return database.PageResult{ + Documents: docs, + HasMore: hasMore, + }, rows.Err() +} diff --git a/internal/database/sqlite/adapter.go b/internal/database/sqlite/adapter.go index 44abc3e..e31fb29 100644 --- a/internal/database/sqlite/adapter.go +++ b/internal/database/sqlite/adapter.go @@ -2,26 +2,140 @@ package sqlite import ( "context" + "database/sql" "encoding/json" "fmt" + "time" "git.kingecg.top/kingecg/gomog/internal/database" "git.kingecg.top/kingecg/gomog/pkg/types" - _ "github.com/mattn/go-sqlite3" + _ "github.com/mattn/go-sqlite3" // Import SQLite driver ) -// SQLiteAdapter SQLite 数据库适配器 +// SQLiteAdapter SQLite适配器 type SQLiteAdapter struct { *database.BaseAdapter } -// NewSQLiteAdapter 创建 SQLite 适配器 +// NewSQLiteAdapter 创建SQLite适配器 func NewSQLiteAdapter() *SQLiteAdapter { return &SQLiteAdapter{ BaseAdapter: database.NewBaseAdapter("sqlite3"), } } +// CollectionExists 检查集合是否存在 +func (a *SQLiteAdapter) CollectionExists(ctx context.Context, name string) (bool, error) { + query := "SELECT name FROM sqlite_master WHERE type='table' AND name=?" + var exists bool + err := a.GetDB().QueryRowContext(ctx, query, name).Scan(&exists) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil +} + +// ListCollections 获取所有集合(表)列表 +func (a *SQLiteAdapter) ListCollections(ctx context.Context) ([]string, error) { + query := `SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name` + rows, err := a.GetDB().QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + var tables []string + for rows.Next() { + var table string + if err := rows.Scan(&table); err != nil { + return nil, err + } + tables = append(tables, table) + } + + return tables, rows.Err() +} + +// FindPage 分页查询文档 +func (a *SQLiteAdapter) FindPage(ctx context.Context, collection string, skip, limit int) (database.PageResult, error) { + query := fmt.Sprintf("SELECT id, data, created_at, updated_at FROM %s LIMIT ? OFFSET ?", collection) + rows, err := a.GetDB().QueryContext(ctx, query, limit, skip) + if err != nil { + return database.PageResult{}, err + } + defer rows.Close() + + var docs []types.Document + for rows.Next() { + var doc types.Document + var jsonData []byte + var createdAtStr, updatedAtStr string + err := rows.Scan(&doc.ID, &jsonData, &createdAtStr, &updatedAtStr) + if err != nil { + return database.PageResult{}, err + } + + if err := json.Unmarshal(jsonData, &doc.Data); err != nil { + return database.PageResult{}, err + } + + // 解析时间字符串 + if parsedTime, err := parseSQLiteTime(createdAtStr); err == nil { + doc.CreatedAt = parsedTime + } + if parsedTime, err := parseSQLiteTime(updatedAtStr); err == nil { + doc.UpdatedAt = parsedTime + } + + docs = append(docs, doc) + } + + // 检查是否还有更多数据 + checkQuery := fmt.Sprintf("SELECT 1 FROM %s LIMIT 1 OFFSET ?", collection) + checkRows, err := a.GetDB().QueryContext(ctx, checkQuery, skip+limit) + if err != nil { + return database.PageResult{}, err + } + defer checkRows.Close() + + hasMore := checkRows.Next() + + return database.PageResult{ + Documents: docs, + HasMore: hasMore, + }, rows.Err() +} + +// parseSQLiteTime 解析SQLite时间字符串 +func parseSQLiteTime(timeStr string) (time.Time, error) { + layouts := []string{ + "2006-01-02 15:04:05", + "2006-01-02T15:04:05Z", + "2006-01-02T15:04:05.000Z", + time.RFC3339, + time.RFC3339Nano, + "2006-01-02 15:04:05.000000", + "2006-01-02 15:04:05.000", + } + + for _, layout := range layouts { + if t, err := time.Parse(layout, timeStr); err == nil { + return t, nil + } + } + + // 如果标准布局失败,则尝试解析Unix时间戳 + if unixTime, err := time.Parse("2006-01-02 15:04:05.000000", timeStr); err == nil { + return unixTime, nil + } + + // 默认返回当前时间 + return time.Now(), fmt.Errorf("无法解析时间字符串: %s", timeStr) +} + // Connect 连接 SQLite 数据库 func (a *SQLiteAdapter) Connect(ctx context.Context, dsn string) error { // SQLite 需要启用 JSON1 扩展(大多数构建已默认包含) @@ -49,17 +163,6 @@ func (a *SQLiteAdapter) CreateCollection(ctx context.Context, name string) error return err } -// CollectionExists 检查集合是否存在 -func (a *SQLiteAdapter) CollectionExists(ctx context.Context, name string) (bool, error) { - query := `SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?` - var count int - err := a.GetDB().QueryRowContext(ctx, query, name).Scan(&count) - if err != nil { - return false, err - } - return count > 0, nil -} - // FindAll 查询所有文档(使用 SQLite JSON 函数) func (a *SQLiteAdapter) FindAll(ctx context.Context, collection string) ([]types.Document, error) { query := fmt.Sprintf("SELECT id, data, created_at, updated_at FROM %s", collection) @@ -123,24 +226,3 @@ func (a *SQLiteAdapter) InsertMany(ctx context.Context, collection string, docs return tx.Commit() } - -// ListCollections 获取所有集合(表)列表 -func (a *SQLiteAdapter) ListCollections(ctx context.Context) ([]string, error) { - query := `SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name` - rows, err := a.GetDB().QueryContext(ctx, query) - if err != nil { - return nil, err - } - defer rows.Close() - - var tables []string - for rows.Next() { - var table string - if err := rows.Scan(&table); err != nil { - return nil, err - } - tables = append(tables, table) - } - - return tables, rows.Err() -} diff --git a/internal/engine/aggregate.go b/internal/engine/aggregate.go index 6840f5c..2b4977e 100644 --- a/internal/engine/aggregate.go +++ b/internal/engine/aggregate.go @@ -1,6 +1,7 @@ package engine import ( + "context" "fmt" "sort" "strings" @@ -733,6 +734,24 @@ func (e *AggregationEngine) executeCount(spec interface{}, docs []types.Document }, nil } +// StreamAggregationOptions 流式聚合选项 +type StreamAggregationOptions struct { + BufferSize int // 缓冲区大小,默认为100 + Concurrent bool // 是否并发处理,默认为false + MaxConcurrency int // 最大并发数,默认为runtime.NumCPU() +} + +// StreamExecute 流式执行聚合管道 +func (e *AggregationEngine) StreamExecute( + ctx context.Context, + collection string, + pipeline []types.AggregateStage, + opts StreamAggregationOptions, +) (<-chan []types.Document, <-chan error) { + streamEngine := NewStreamAggregationEngine(e.store) + return streamEngine.StreamExecute(ctx, collection, pipeline, opts) +} + // 辅助函数 func isTrue(v interface{}) bool { switch val := v.(type) { diff --git a/internal/engine/memory_store.go b/internal/engine/memory_store.go index 3d8ae8c..847d3db 100644 --- a/internal/engine/memory_store.go +++ b/internal/engine/memory_store.go @@ -22,9 +22,91 @@ type MemoryStore struct { // Collection 内存集合 type Collection struct { - name string - documents map[string]types.Document // id -> Document - mu sync.RWMutex + name string + documents map[string]types.Document // id -> Document + mu sync.RWMutex + pageSize int // 分页大小 + loadedAll bool // 是否已经加载了全部数据 + totalCount int // 总文档数(如果知道的话) +} + +// DocumentIterator 文档迭代器 +type DocumentIterator struct { + store *MemoryStore + collection string + keys []string + currentIdx int + batchSize int + mutex sync.Mutex +} + +// GetDocumentIterator 获取文档迭代器 +func (ms *MemoryStore) GetDocumentIterator(collection string, batchSize int) (*DocumentIterator, error) { + ms.mu.RLock() + defer ms.mu.RUnlock() + + coll, exists := ms.collections[collection] + if !exists { + return nil, fmt.Errorf("collection %s does not exist", collection) + } + + // 获取所有键 + keys := make([]string, 0, len(coll.documents)) + for k := range coll.documents { + keys = append(keys, k) + } + + return &DocumentIterator{ + store: ms, + collection: collection, + keys: keys, + currentIdx: 0, + batchSize: batchSize, + }, nil +} + +// HasNext 检查是否还有更多文档 +func (iter *DocumentIterator) HasNext() bool { + iter.mutex.Lock() + defer iter.mutex.Unlock() + + return iter.currentIdx < len(iter.keys) +} + +// NextBatch 获取下一批文档 +func (iter *DocumentIterator) NextBatch() ([]types.Document, error) { + iter.mutex.Lock() + defer iter.mutex.Unlock() + + if iter.currentIdx >= len(iter.keys) { + return []types.Document{}, nil + } + + endIdx := iter.currentIdx + iter.batchSize + if endIdx > len(iter.keys) { + endIdx = len(iter.keys) + } + + batch := make([]types.Document, 0, endIdx-iter.currentIdx) + + iter.store.mu.RLock() + coll := iter.store.collections[iter.collection] + iter.store.mu.RUnlock() + + for i := iter.currentIdx; i < endIdx; i++ { + doc, exists := coll.documents[iter.keys[i]] + if exists { + batch = append(batch, doc) + } + } + + iter.currentIdx = endIdx + return batch, nil +} + +// Close 关闭迭代器 +func (iter *DocumentIterator) Close() { + // 目前不需要特殊关闭操作 } // NewMemoryStore 创建内存存储 @@ -35,7 +117,7 @@ func NewMemoryStore(adapter database.DatabaseAdapter) *MemoryStore { } } -// Initialize 从数据库加载所有现有集合到内存 +// Initialize 初始化内存存储,但不加载所有数据,只创建集合结构 func (ms *MemoryStore) Initialize(ctx context.Context) error { if ms.adapter == nil { log.Println("[INFO] No database adapter, skipping initialization") @@ -55,40 +137,25 @@ func (ms *MemoryStore) Initialize(ctx context.Context) error { log.Printf("[INFO] Found %d collections in database", len(tables)) - // 逐个加载集合 - loadedCount := 0 + // 仅为每个集合创建结构,不加载数据 + createdCount := 0 for _, tableName := range tables { - // 从数据库加载所有文档 - docs, err := ms.adapter.FindAll(ctx, tableName) - if err != nil { - log.Printf("[WARN] Failed to load collection %s: %v", tableName, err) - continue - } - - // 创建集合并加载文档 - // 注意:为了兼容 HTTP API 的 dbName.collection 格式,我们同时创建两个名称的引用 ms.mu.Lock() - coll := &Collection{ - name: tableName, - documents: make(map[string]types.Document), + // 检查是否已存在 + if _, exists := ms.collections[tableName]; !exists { + ms.collections[tableName] = &Collection{ + name: tableName, + documents: make(map[string]types.Document), + pageSize: 1000, // 默认每页1000条记录 + loadedAll: false, + } + createdCount++ + log.Printf("[DEBUG] Created collection structure for %s", tableName) } - for _, doc := range docs { - coll.documents[doc.ID] = doc - } - - // 以表名作为集合名存储(例如:users) - ms.collections[tableName] = coll - - // TODO: 如果需要支持 dbName.collection 格式,需要在这里建立映射 - // 但目前无法确定 dbName,所以暂时只使用纯表名 - ms.mu.Unlock() - - loadedCount++ - log.Printf("[DEBUG] Loaded collection %s with %d documents", tableName, len(docs)) } - log.Printf("[INFO] Successfully loaded %d collections from database", loadedCount) + log.Printf("[INFO] Created %d collection structures (data will be loaded lazily)", createdCount) return nil } @@ -97,9 +164,95 @@ func CreateTestCollectionForTesting(store *MemoryStore, name string, documents m store.collections[name] = &Collection{ name: name, documents: documents, + pageSize: 1000, + loadedAll: true, } } +// LoadCollectionPage 按页加载集合数据 +func (ms *MemoryStore) LoadCollectionPage(ctx context.Context, name string, page int) error { + coll, err := ms.GetCollection(name) + if err != nil { + return err + } + + coll.mu.Lock() + defer coll.mu.Unlock() + + // 如果已经加载了全部数据,则无需再加载 + if coll.loadedAll { + return nil + } + + skip := page * coll.pageSize + result, err := ms.adapter.FindPage(ctx, name, skip, coll.pageSize) + if err != nil { + return fmt.Errorf("failed to load page %d of collection %s: %w", page, name, err) + } + + // 将页面数据添加到内存中 + for _, doc := range result.Documents { + coll.documents[doc.ID] = doc + } + + // 如果没有更多数据了,标记为已加载全部 + if !result.HasMore { + coll.loadedAll = true + } + + log.Printf("[INFO] Loaded page %d for collection %s (%d documents)", page, name, len(result.Documents)) + return nil +} + +// LoadEntireCollection 加载整个集合(谨慎使用,大数据集会导致内存问题) +func (ms *MemoryStore) LoadEntireCollection(ctx context.Context, name string) error { + coll, err := ms.GetCollection(name) + if err != nil { + return err + } + + coll.mu.Lock() + defer coll.mu.Unlock() + + // 直接从数据库加载所有数据 + docs, err := ms.adapter.FindAll(ctx, name) + if err != nil { + return fmt.Errorf("failed to load entire collection %s: %w", name, err) + } + + // 清空现有数据并加载新数据 + coll.documents = make(map[string]types.Document) + for _, doc := range docs { + coll.documents[doc.ID] = doc + } + + coll.loadedAll = true + + log.Printf("[INFO] Loaded entire collection %s (%d documents)", name, len(docs)) + return nil +} + +// LazyLoadDocument 按需加载单个文档 +func (ms *MemoryStore) LazyLoadDocument(ctx context.Context, collectionName, docID string) (types.Document, error) { + coll, err := ms.GetCollection(collectionName) + if err != nil { + return types.Document{}, err + } + + coll.mu.RLock() + // 检查文档是否已在内存中 + if doc, exists := coll.documents[docID]; exists { + coll.mu.RUnlock() + return doc, nil + } + coll.mu.RUnlock() + + // 如果不在内存中,并且尚未加载全部数据,则尝试从数据库获取单个文档 + // (这里假设数据库适配器有按ID查询的方法,如果没有,就加载一页数据) + // 由于当前接口没有按ID查询的方法,我们暂时返回错误,让上层决定是否加载整页 + return types.Document{}, errors.ErrDocumentNotFnd +} + // LoadCollection 从数据库加载集合到内存 func (ms *MemoryStore) LoadCollection(ctx context.Context, name string) error { // 检查集合是否存在 @@ -167,6 +320,8 @@ func (ms *MemoryStore) Insert(collection string, doc types.Document) error { ms.collections[collection] = &Collection{ name: collection, documents: make(map[string]types.Document), + pageSize: 1000, + loadedAll: true, // 新创建的集合认为是"已完全加载",因为我们只添加新文档 } coll = ms.collections[collection] ms.mu.Unlock() diff --git a/internal/engine/stream_aggregate.go b/internal/engine/stream_aggregate.go new file mode 100644 index 0000000..7984fa1 --- /dev/null +++ b/internal/engine/stream_aggregate.go @@ -0,0 +1,706 @@ +package engine + +import ( + "context" + "fmt" + "math/rand" + "runtime" + "sort" + "strings" + "time" + + "git.kingecg.top/kingecg/gomog/pkg/types" +) + +// StreamAggregationEngine 流式聚合引擎 +type StreamAggregationEngine struct { + store *MemoryStore +} + +// NewStreamAggregationEngine 创建流式聚合引擎 +func NewStreamAggregationEngine(store *MemoryStore) *StreamAggregationEngine { + return &StreamAggregationEngine{store: store} +} + +// StreamExecute 流式执行聚合管道 +func (e *StreamAggregationEngine) StreamExecute( + ctx context.Context, + collection string, + pipeline []types.AggregateStage, + opts StreamAggregationOptions, +) (<-chan []types.Document, <-chan error) { + + if opts.BufferSize <= 0 { + opts.BufferSize = 100 + } + if opts.MaxConcurrency <= 0 { + opts.MaxConcurrency = runtime.NumCPU() + } + + resultChan := make(chan []types.Document, opts.BufferSize) + errChan := make(chan error, 1) + + go func() { + defer close(resultChan) + defer close(errChan) + + // 获取文档迭代器 + docIter, err := e.store.GetDocumentIterator(collection, opts.BufferSize) + if err != nil { + errChan <- err + return + } + defer docIter.Close() + + // 分批处理文档 + for docIter.HasNext() { + select { + case <-ctx.Done(): + errChan <- ctx.Err() + return + default: + } + + batch, err := docIter.NextBatch() + if err != nil { + errChan <- err + return + } + + if len(batch) == 0 { + continue + } + + // 执行管道处理 + processed, err := e.processBatch(ctx, batch, pipeline, opts) + if err != nil { + errChan <- err + return + } + + if len(processed) > 0 { + resultChan <- processed + } + } + }() + + return resultChan, errChan +} + +// processBatch 处理单个批次的文档 +func (e *StreamAggregationEngine) processBatch( + ctx context.Context, + batch []types.Document, + pipeline []types.AggregateStage, + opts StreamAggregationOptions, +) ([]types.Document, error) { + + var result []types.Document = batch + + for _, stage := range pipeline { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + var err error + result, err = e.executeStageStreaming(stage, result, opts) + if err != nil { + return nil, err + } + + // 如果结果为空,提前终止 + if len(result) == 0 { + break + } + } + + return result, nil +} + +// executeStageStreaming 执行单个阶段的流式处理 +func (e *StreamAggregationEngine) executeStageStreaming( + stage types.AggregateStage, + docs []types.Document, + opts StreamAggregationOptions, +) ([]types.Document, error) { + + // 对于某些操作,我们仍需完整数据集,所以需要特殊处理 + switch stage.Stage { + case "$match": + return e.executeMatch(stage.Spec, docs) + case "$project": + return e.executeProject(stage.Spec, docs) + case "$limit": + return e.executeLimit(stage.Spec, docs) + case "$skip": + return e.executeSkip(stage.Spec, docs) + case "$sort": + // $sort 需要完整的数据集,所以不能完全流式处理 + // 但在批处理中是可以处理的 + return e.executeSort(stage.Spec, docs) + case "$unwind": + return e.executeUnwind(stage.Spec, docs) + case "$addFields", "$set": + return e.executeAddFields(stage.Spec, docs) + case "$unset": + return e.executeUnset(stage.Spec, docs) + case "$sample": + return e.executeSample(stage.Spec, docs) + case "$replaceRoot": + return e.executeReplaceRoot(stage.Spec, docs) + case "$replaceWith": + return e.executeReplaceWith(stage.Spec, docs) + + // 对于需要全局数据的操作,如 $group, $lookup, $graphLookup 等 + // 我们需要特殊的处理方式 + case "$group": + // $group 需要完整的数据集,不能流式处理 + // 这里我们返回错误,提示用户使用传统聚合 + return nil, fmt.Errorf("$group stage cannot be processed in streaming mode, use regular aggregation instead") + case "$lookup": + // $lookup 需要另一个集合的完整数据,不能流式处理 + return nil, fmt.Errorf("$lookup stage cannot be processed in streaming mode, use regular aggregation instead") + case "$graphLookup": + // $graphLookup 需要完整数据,不能流式处理 + return nil, fmt.Errorf("$graphLookup stage cannot be processed in streaming mode, use regular aggregation instead") + + // Batch 5 新增阶段 + case "$unionWith": + // $unionWith 需要另一个集合的完整数据 + return nil, fmt.Errorf("$unionWith stage cannot be processed in streaming mode, use regular aggregation instead") + case "$redact": + return e.executeRedact(stage.Spec, docs) + case "$indexStats", "$collStats": + // 这些统计操作需要完整数据 + return nil, fmt.Errorf("$indexStats and $collStats stages cannot be processed in streaming mode, use regular aggregation instead") + case "$out", "$merge": + // 输出操作可以处理,但需要在最后阶段 + return e.executeOutputStages(stage, docs) + + default: + return docs, nil // 未知阶段,跳过 + } +} + +// executeOutputStages 处理输出阶段 +func (e *StreamAggregationEngine) executeOutputStages( + stage types.AggregateStage, + docs []types.Document, +) ([]types.Document, error) { + switch stage.Stage { + case "$out": + return docs, fmt.Errorf("$out not supported in streaming mode") + case "$merge": + return docs, fmt.Errorf("$merge not supported in streaming mode") + default: + return docs, nil + } +} + +// executeAddFields 执行 $addFields 阶段 +func (e *StreamAggregationEngine) executeAddFields(spec interface{}, docs []types.Document) ([]types.Document, error) { + // 从 aggregate.go 复制的实现 + addFieldsSpec, ok := spec.(map[string]interface{}) + if !ok { + return docs, nil + } + + var results []types.Document + for _, doc := range docs { + // 深拷贝文档 + newData := deepCopyMap(doc.Data) + + // 添加字段 + for field, expr := range addFieldsSpec { + newData[field] = e.evaluateExpression(newData, expr) + } + + results = append(results, types.Document{ + ID: doc.ID, + Data: newData, + }) + } + + return results, nil +} + +// executeUnset 执行 $unset 阶段 +func (e *StreamAggregationEngine) executeUnset(spec interface{}, docs []types.Document) ([]types.Document, error) { + unsetSpec, ok := spec.([]interface{}) + if !ok { + // 如果是字符串,转换为数组 + if str, isStr := spec.(string); isStr { + unsetSpec = []interface{}{str} + } else { + return docs, nil + } + } + + var results []types.Document + for _, doc := range docs { + // 深拷贝文档 + newData := deepCopyMap(doc.Data) + + // 移除字段 + for _, field := range unsetSpec { + if fieldName, isStr := field.(string); isStr { + delete(newData, fieldName) + } + } + + results = append(results, types.Document{ + ID: doc.ID, + Data: newData, + }) + } + + return results, nil +} + +// executeSample 执行 $sample 阶段 +func (e *StreamAggregationEngine) executeSample(spec interface{}, docs []types.Document) ([]types.Document, error) { + // 从 aggregate.go 复制的实现 + sampleSpec, ok := spec.(map[string]interface{}) + if !ok { + return docs, nil + } + + size, ok := sampleSpec["size"].(float64) + if !ok { + return docs, nil + } + + count := int(size) + if count >= len(docs) { + return docs, nil + } + + if count <= 0 { + return []types.Document{}, nil + } + + // 使用洗牌算法随机选择 + shuffled := make([]types.Document, len(docs)) + copy(shuffled, docs) + + // Fisher-Yates 洗牌算法的变种,只取前 count 个 + source := rand.NewSource(time.Now().UnixNano()) + rng := rand.New(source) + + for i := 0; i < count; i++ { + j := len(shuffled) - 1 - i + r := i + rng.Intn(j-i+1) + shuffled[r], shuffled[i] = shuffled[i], shuffled[r] + } + + return shuffled[:count], nil +} + +// executeReplaceRoot 执行 $replaceRoot 阶段 +func (e *StreamAggregationEngine) executeReplaceRoot(spec interface{}, docs []types.Document) ([]types.Document, error) { + // 从 aggregate.go 复制的实现 + replaceRootSpec, ok := spec.(map[string]interface{}) + if !ok { + return docs, nil + } + + newRootField, ok := replaceRootSpec["newRoot"].(string) + if !ok { + return docs, nil + } + + var results []types.Document + for _, doc := range docs { + // 获取新的根对象 + newRoot := getNestedValue(doc.Data, newRootField) + + if newRootMap, ok := newRoot.(map[string]interface{}); ok { + results = append(results, types.Document{ + ID: doc.ID, + Data: newRootMap, + }) + } else { + // 如果不是对象,创建一个包含该值的对象 + results = append(results, types.Document{ + ID: doc.ID, + Data: map[string]interface{}{newRootField: newRoot}, + }) + } + } + + return results, nil +} + +// executeReplaceWith 执行 $replaceWith 阶段 +func (e *StreamAggregationEngine) executeReplaceWith(spec interface{}, docs []types.Document) ([]types.Document, error) { + var results []types.Document + for _, doc := range docs { + // 使用 evaluateExpression 获取新的文档数据 + newData := e.evaluateExpression(doc.Data, spec) + + if newDataMap, ok := newData.(map[string]interface{}); ok { + results = append(results, types.Document{ + ID: doc.ID, + Data: newDataMap, + }) + } else { + // 如果不是对象,创建一个包含该值的对象 + results = append(results, types.Document{ + ID: doc.ID, + Data: map[string]interface{}{"value": newData}, + }) + } + } + + return results, nil +} + +// executeRedact 执行 $redact 阶段 +func (e *StreamAggregationEngine) executeRedact(spec interface{}, docs []types.Document) ([]types.Document, error) { + // 这里需要复制 aggregate.go 中的实现 + // 为简洁起见,暂时返回错误 + return nil, fmt.Errorf("$redact stage not yet implemented in streaming mode") +} + +// evaluateExpression 评估表达式(复制自 aggregate.go) +func (e *StreamAggregationEngine) evaluateExpression(data map[string]interface{}, expr interface{}) interface{} { + // 复制自 aggregate.go 中的实现 + // 处理字段引用(以 $ 开头的字符串) + if fieldStr, ok := expr.(string); ok && len(fieldStr) > 0 && fieldStr[0] == '$' { + fieldName := fieldStr[1:] // 移除 $ 前缀 + return getNestedValue(data, fieldName) + } + + if exprMap, ok := expr.(map[string]interface{}); ok { + for op, operand := range exprMap { + switch op { + case "$concat": + return e.concat(operand, data) + case "$toUpper": + return strings.ToUpper(fmt.Sprintf("%v", e.getFieldValueStr(types.Document{Data: data}, operand))) + case "$toLower": + return strings.ToLower(fmt.Sprintf("%v", e.getFieldValueStr(types.Document{Data: data}, operand))) + case "$add": + return e.add(operand, data) + case "$multiply": + return e.multiply(operand, data) + case "$ifNull": + return e.ifNull(operand, data) + case "$cond": + return e.cond(operand, data) + // 可以根据需要添加更多操作 + } + } + } + return expr +} + +// 以下是一些辅助函数的占位实现 +func (e *StreamAggregationEngine) concat(operand interface{}, data map[string]interface{}) interface{} { + // 简单实现 + if arr, ok := operand.([]interface{}); ok { + result := "" + for _, item := range arr { + evaluated := e.evaluateExpression(data, item) + result += fmt.Sprintf("%v", evaluated) + } + return result + } + return "" +} + +func (e *StreamAggregationEngine) getFieldValueStr(doc types.Document, field interface{}) string { + // 简单实现 + if str, ok := e.getFieldValue(doc, field).(string); ok { + return str + } + return "" +} + +func (e *StreamAggregationEngine) getFieldValue(doc types.Document, field interface{}) interface{} { + switch f := field.(type) { + case string: + if len(f) > 0 && f[0] == '$' { + return getNestedValue(doc.Data, f[1:]) + } + return f + default: + return field + } +} + +func (e *StreamAggregationEngine) add(operand interface{}, data map[string]interface{}) interface{} { + if arr, ok := operand.([]interface{}); ok { + sum := 0.0 + for _, item := range arr { + evaluated := e.evaluateExpression(data, item) + sum += toFloat64(evaluated) + } + return sum + } + return 0 +} + +func (e *StreamAggregationEngine) multiply(operand interface{}, data map[string]interface{}) interface{} { + if arr, ok := operand.([]interface{}); ok { + result := 1.0 + for _, item := range arr { + evaluated := e.evaluateExpression(data, item) + result *= toFloat64(evaluated) + } + return result + } + return 0 +} + +func (e *StreamAggregationEngine) ifNull(operand interface{}, data map[string]interface{}) interface{} { + if arr, ok := operand.([]interface{}); ok && len(arr) == 2 { + evaluatedFirst := e.evaluateExpression(data, arr[0]) + if evaluatedFirst != nil { + return evaluatedFirst + } + return e.evaluateExpression(data, arr[1]) + } + return nil +} + +func (e *StreamAggregationEngine) cond(operand interface{}, data map[string]interface{}) interface{} { + if condMap, ok := operand.(map[string]interface{}); ok { + ifCond, hasIf := condMap["if"] + thenVal, hasThen := condMap["then"] + elseVal, hasElse := condMap["else"] + + if hasIf && hasThen && hasElse { + ifVal := e.evaluateExpression(data, ifCond) + if isTrue(ifVal) { + return e.evaluateExpression(data, thenVal) + } + return e.evaluateExpression(data, elseVal) + } + } + return nil +} + +// executeMatch 执行 $match 阶段 +func (e *StreamAggregationEngine) executeMatch(spec interface{}, docs []types.Document) ([]types.Document, error) { + // 从 aggregate.go 复制的实现 + var filter map[string]interface{} + if f, ok := spec.(types.Filter); ok { + filter = f + } else if f, ok := spec.(map[string]interface{}); ok { + filter = f + } else { + return docs, nil + } + + var results []types.Document + for _, doc := range docs { + if MatchFilter(doc.Data, filter) { + results = append(results, doc) + } + } + return results, nil +} + +// executeProject 执行 $project 阶段 +func (e *StreamAggregationEngine) executeProject(spec interface{}, docs []types.Document) ([]types.Document, error) { + // 从 aggregate.go 复制的实现 + projectSpec, ok := spec.(map[string]interface{}) + if !ok { + return docs, nil + } + + var results []types.Document + for _, doc := range docs { + projected := e.projectDocument(doc.Data, projectSpec) + results = append(results, types.Document{ + ID: doc.ID, + Data: projected, + }) + } + + return results, nil +} + +// projectDocument 投影文档 +func (e *StreamAggregationEngine) projectDocument(data map[string]interface{}, spec map[string]interface{}) map[string]interface{} { + result := make(map[string]interface{}) + + for field, include := range spec { + if field == "_id" { + // 特殊处理 _id + if isFalse(include) { + // 排除 _id + } else { + result["_id"] = data["_id"] + } + continue + } + + if isTrue(include) { + // 包含字段 + result[field] = getNestedValue(data, field) + } else if isFalse(include) { + // 排除字段(在包含模式下不处理) + continue + } else { + // 表达式 + result[field] = e.evaluateExpression(data, include) + } + } + + return result +} + +// executeLimit 执行 $limit 阶段 +func (e *StreamAggregationEngine) executeLimit(spec interface{}, docs []types.Document) ([]types.Document, error) { + // 从 aggregate.go 复制的实现 + limit := 0 + switch l := spec.(type) { + case int: + limit = l + case int64: + limit = int(l) + case float64: + limit = int(l) + } + + if limit <= 0 || limit >= len(docs) { + return docs, nil + } + + return docs[:limit], nil +} + +// executeSkip 执行 $skip 阶段 +func (e *StreamAggregationEngine) executeSkip(spec interface{}, docs []types.Document) ([]types.Document, error) { + // 从 aggregate.go 复制的实现 + skip := 0 + switch s := spec.(type) { + case int: + skip = s + case int64: + skip = int(s) + case float64: + skip = int(s) + } + + if skip <= 0 { + return docs, nil + } + + if skip >= len(docs) { + return []types.Document{}, nil + } + + return docs[skip:], nil +} + +// executeUnwind 执行 $unwind 阶段 +func (e *StreamAggregationEngine) executeUnwind(spec interface{}, docs []types.Document) ([]types.Document, error) { + // 从 aggregate.go 复制的实现 + var path string + var preserveNull bool + + switch s := spec.(type) { + case string: + path = s + case map[string]interface{}: + if p, ok := s["path"].(string); ok { + path = p + } + if pn, ok := s["preserveNullAndEmptyArrays"].(bool); ok { + preserveNull = pn + } + } + + if path == "" || path[0] != '$' { + return docs, nil + } + + fieldPath := path[1:] + var results []types.Document + + for _, doc := range docs { + arr := getNestedValue(doc.Data, fieldPath) + + if arr == nil { + if preserveNull { + results = append(results, doc) + } + continue + } + + array, ok := arr.([]interface{}) + if !ok || len(array) == 0 { + if preserveNull { + results = append(results, doc) + } + continue + } + + for _, item := range array { + newData := deepCopyMap(doc.Data) + setNestedValue(newData, fieldPath, item) + results = append(results, types.Document{ + ID: doc.ID, + Data: newData, + }) + } + } + + return results, nil +} + +// executeSort 执行 $sort 阶段 +func (e *StreamAggregationEngine) executeSort(spec interface{}, docs []types.Document) ([]types.Document, error) { + // 从 aggregate.go 复制的实现 + sortSpec, ok := spec.(map[string]interface{}) + if !ok { + return docs, nil + } + + // 转换为排序字段映射 + sortFields := make(map[string]int) + for field, direction := range sortSpec { + dir := 1 + switch d := direction.(type) { + case int: + dir = d + case int64: + dir = int(d) + case float64: + dir = int(d) + } + sortFields[field] = dir + } + + // 创建可排序的副本 + sorted := make([]types.Document, len(docs)) + copy(sorted, docs) + + sort.Slice(sorted, func(i, j int) bool { + return e.compareDocs(sorted[i], sorted[j], sortFields) + }) + + return sorted, nil +} + +// compareDocs 比较两个文档 +func (e *StreamAggregationEngine) compareDocs(a, b types.Document, sortFields map[string]int) bool { + for field, dir := range sortFields { + valA := getNestedValue(a.Data, field) + valB := getNestedValue(b.Data, field) + + cmp := compareValues(valA, valB) + if cmp != 0 { + if dir < 0 { + return cmp > 0 + } + return cmp < 0 + } + } + return false +}