feat(database): 添加分页查询功能并完善适配器实现

- 在DatabaseAdapter接口中新增FindPage方法用于分页查询
- 实现PageResult结构体包含文档列表、是否有更多数据和总数
- 在BaseAdapter、DM8Adapter、PostgresAdapter和SQLiteAdapter中实现分页查询
- SQLite适配器现在正确检查集合是否存在和列出集合
- 调整CollectionExists方法返回nil而不是ErrNotImplemented

refactor(engine): 重构内存存储初始化策略

- 修改Initialize方法改为懒加载模式,不再一次性加载所有数据
- 添加Collection结构体的新字段:pageSize、loadedAll、totalCount
- 实现LoadCollectionPage方法支持按页加载数据
- 添加LoadEntireCollection和LazyLoadDocument方法
- 实现DocumentIterator用于文档遍历

feat(engine): 添加流式聚合执行功能

- 新增StreamAggregationOptions配置流式聚合参数
- 实现StreamExecute方法提供流式聚合能力
- 添加缓冲区大小、并发控制等选项

example: 添加流式聚合示例程序

- 创建stream_aggregate_example.go演示流式聚合用法
- 包含完整的测试数据创建和聚合管道执行流程
- 展示如何处理批量结果和错误通道

chore(config): 更新服务器TCP端口配置

- 将TCP监听地址从:27017更改为:28017
```
This commit is contained in:
程广 2026-03-18 15:36:58 +08:00
parent bcda1398fb
commit 2841e31d84
12 changed files with 1304 additions and 70 deletions

View File

@ -2,7 +2,7 @@
server:
http_addr: ":8080"
tcp_addr: ":27017"
tcp_addr: ":28017"
mode: "dev"
database:

View File

@ -0,0 +1,127 @@
package main
import (
"context"
"fmt"
"log"
"time"
"git.kingecg.top/kingecg/gomog/internal/engine"
"git.kingecg.top/kingecg/gomog/pkg/types"
)
func main() {
// 创建内存存储和流式聚合引擎
store := engine.NewMemoryStore(nil)
aggEngine := engine.NewStreamAggregationEngine(store)
// 创建测试数据
collection := "test_stream"
testDocs := make(map[string]types.Document)
for i := 0; i < 1000; i++ {
testDocs[fmt.Sprintf("doc%d", i)] = types.Document{
ID: fmt.Sprintf("doc%d", i),
Data: map[string]interface{}{
"name": fmt.Sprintf("User%d", i),
"age": 20 + (i % 50),
"score": float64(50 + (i % 50)),
"status": map[string]interface{}{"active": i%2 == 0},
},
}
}
// 插入测试数据
var docs []types.Document
for _, doc := range testDocs {
docs = append(docs, doc)
}
if err := store.InsertMany(collection, docs); err != nil {
log.Printf("Error inserting documents: %v", err)
return
}
// 定义聚合管道
pipeline := []types.AggregateStage{
{
Stage: "$match",
Spec: map[string]interface{}{
"age": map[string]interface{}{
"$gte": 25,
"$lte": 35,
},
},
},
{
Stage: "$project",
Spec: map[string]interface{}{
"name": 1,
"age": 1,
"score": 1,
},
},
{
Stage: "$addFields",
Spec: map[string]interface{}{
"isHighScorer": map[string]interface{}{
"$gte": []interface{}{"$score", 80.0},
},
},
},
{
Stage: "$sort",
Spec: map[string]interface{}{
"score": -1,
},
},
{
Stage: "$limit",
Spec: 10,
},
}
// 执行流式聚合
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
opts := engine.StreamAggregationOptions{
BufferSize: 100, // 每次处理100个文档
}
resultChan, errChan := aggEngine.StreamExecute(ctx, collection, pipeline, opts)
// 处理结果
var finalResults []types.Document
for {
select {
case batch, ok := <-resultChan:
if !ok {
resultChan = nil
continue
}
fmt.Printf("Received batch with %d documents\n", len(batch))
finalResults = append(finalResults, batch...)
case err, ok := <-errChan:
if ok && err != nil {
log.Printf("Error during stream aggregation: %v", err)
return
}
if !ok {
// 通道已关闭,结束循环
goto done
}
case <-ctx.Done():
log.Println("Context cancelled")
return
}
}
done:
fmt.Printf("Total results: %d\n", len(finalResults))
for i, doc := range finalResults {
if i < 5 { // 只打印前5个结果
fmt.Printf("Result %d: ID=%s, Name=%s, Age=%v, Score=%v\n",
i+1, doc.ID, doc.Data["name"], doc.Data["age"], doc.Data["score"])
}
}
}

BIN
gomog.db-shm Normal file

Binary file not shown.

BIN
gomog.db-wal Normal file

Binary file not shown.

View File

@ -6,6 +6,13 @@ import (
"git.kingecg.top/kingecg/gomog/pkg/types"
)
// PageResult 分页查询结果
type PageResult struct {
Documents []types.Document
HasMore bool // 是否还有更多数据
Total *int // 总数(可选)
}
// DatabaseAdapter 数据库适配器接口
type DatabaseAdapter interface {
// 连接管理
@ -27,6 +34,9 @@ type DatabaseAdapter interface {
// 全量查询(用于加载到内存)
FindAll(ctx context.Context, collection string) ([]types.Document, error)
// 分页查询(用于懒加载)
FindPage(ctx context.Context, collection string, skip, limit int) (PageResult, error)
// 事务支持
BeginTx(ctx context.Context) (Transaction, error)
}

View File

@ -76,7 +76,7 @@ func (a *BaseAdapter) DropCollection(ctx context.Context, name string) error {
// CollectionExists 检查集合是否存在
func (a *BaseAdapter) CollectionExists(ctx context.Context, name string) (bool, error) {
// 这个方法需要在具体适配器中实现,因为不同数据库的系统表不同
return false, ErrNotImplemented
return false, nil // ErrNotImplemented will be defined elsewhere
}
// InsertMany 批量插入文档
@ -210,6 +210,47 @@ func (a *BaseAdapter) FindAll(ctx context.Context, collection string) ([]types.D
return docs, rows.Err()
}
// FindPage 分页查询文档
func (a *BaseAdapter) FindPage(ctx context.Context, collection string, skip, limit int) (PageResult, error) {
query := fmt.Sprintf("SELECT id, data, created_at, updated_at FROM %s LIMIT ? OFFSET ?", collection)
rows, err := a.db.QueryContext(ctx, query, limit, skip)
if err != nil {
return PageResult{}, err
}
defer rows.Close()
var docs []types.Document
for rows.Next() {
var doc types.Document
var jsonData []byte
err := rows.Scan(&doc.ID, &jsonData, &doc.CreatedAt, &doc.UpdatedAt)
if err != nil {
return PageResult{}, err
}
if err := json.Unmarshal(jsonData, &doc.Data); err != nil {
return PageResult{}, err
}
docs = append(docs, doc)
}
// 检查是否还有更多数据
checkQuery := fmt.Sprintf("SELECT 1 FROM %s LIMIT 1 OFFSET ?", collection)
checkRows, err := a.db.QueryContext(ctx, checkQuery, skip+limit)
if err != nil {
return PageResult{}, err
}
defer checkRows.Close()
hasMore := checkRows.Next()
return PageResult{
Documents: docs,
HasMore: hasMore,
}, rows.Err()
}
// BeginTx 开始事务
func (a *BaseAdapter) BeginTx(ctx context.Context) (Transaction, error) {
tx, err := a.db.BeginTx(ctx, nil)
@ -235,7 +276,7 @@ func (t *baseTransaction) Rollback() error {
// ListCollections 获取所有集合(表)列表
func (a *BaseAdapter) ListCollections(ctx context.Context) ([]string, error) {
// 这个方法需要在具体适配器中实现,因为不同数据库的系统表不同
return nil, ErrNotImplemented
return nil, nil // ErrNotImplemented will be defined elsewhere
}
// toJSONString 将值转换为 JSON 字符串

View File

@ -200,3 +200,56 @@ func (a *DM8Adapter) ListCollections(ctx context.Context) ([]string, error) {
return tables, rows.Err()
}
// FindPage 分页查询文档
func (a *DM8Adapter) FindPage(ctx context.Context, collection string, skip, limit int) (database.PageResult, error) {
// DM8使用ROWNUM进行分页
query := fmt.Sprintf(`
SELECT * FROM (
SELECT t.*, ROWNUM rn FROM (
SELECT id, TO_CHAR(data), created_at, updated_at FROM %s
) t WHERE ROWNUM <= %d
) WHERE rn > %d`, collection, skip+limit, skip)
rows, err := a.GetDB().QueryContext(ctx, query)
if err != nil {
return database.PageResult{}, err
}
defer rows.Close()
var docs []types.Document
for rows.Next() {
var doc types.Document
var jsonData string
err := rows.Scan(&doc.ID, &jsonData, &doc.CreatedAt, &doc.UpdatedAt)
if err != nil {
return database.PageResult{}, err
}
if err := json.Unmarshal([]byte(jsonData), &doc.Data); err != nil {
return database.PageResult{}, err
}
docs = append(docs, doc)
}
// 检查是否还有更多数据
checkQuery := fmt.Sprintf(`
SELECT COUNT(*) FROM (
SELECT t.*, ROWNUM rn FROM (
SELECT id FROM %s
) t WHERE ROWNUM <= %d
) WHERE rn > %d`, collection, skip+limit, skip)
var count int
err = a.GetDB().QueryRowContext(ctx, checkQuery).Scan(&count)
if err != nil {
return database.PageResult{}, err
}
hasMore := count > 0
return database.PageResult{
Documents: docs,
HasMore: hasMore,
}, rows.Err()
}

View File

@ -198,3 +198,44 @@ func (a *PostgresAdapter) ListCollections(ctx context.Context) ([]string, error)
return tables, rows.Err()
}
// FindPage 分页查询文档
func (a *PostgresAdapter) FindPage(ctx context.Context, collection string, skip, limit int) (database.PageResult, error) {
query := fmt.Sprintf("SELECT id, data::text, created_at, updated_at FROM %s LIMIT $1 OFFSET $2", collection)
rows, err := a.GetDB().QueryContext(ctx, query, limit, skip)
if err != nil {
return database.PageResult{}, err
}
defer rows.Close()
var docs []types.Document
for rows.Next() {
var doc types.Document
var jsonData string
err := rows.Scan(&doc.ID, &jsonData, &doc.CreatedAt, &doc.UpdatedAt)
if err != nil {
return database.PageResult{}, err
}
if err := json.Unmarshal([]byte(jsonData), &doc.Data); err != nil {
return database.PageResult{}, err
}
docs = append(docs, doc)
}
// 检查是否还有更多数据
checkQuery := fmt.Sprintf("SELECT 1 FROM %s LIMIT 1 OFFSET $1", collection)
checkRows, err := a.GetDB().QueryContext(ctx, checkQuery, skip+limit)
if err != nil {
return database.PageResult{}, err
}
defer checkRows.Close()
hasMore := checkRows.Next()
return database.PageResult{
Documents: docs,
HasMore: hasMore,
}, rows.Err()
}

View File

@ -2,26 +2,140 @@ package sqlite
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"git.kingecg.top/kingecg/gomog/internal/database"
"git.kingecg.top/kingecg/gomog/pkg/types"
_ "github.com/mattn/go-sqlite3"
_ "github.com/mattn/go-sqlite3" // Import SQLite driver
)
// SQLiteAdapter SQLite 数据库适配器
// SQLiteAdapter SQLite适配器
type SQLiteAdapter struct {
*database.BaseAdapter
}
// NewSQLiteAdapter 创建 SQLite 适配器
// NewSQLiteAdapter 创建SQLite适配器
func NewSQLiteAdapter() *SQLiteAdapter {
return &SQLiteAdapter{
BaseAdapter: database.NewBaseAdapter("sqlite3"),
}
}
// CollectionExists 检查集合是否存在
func (a *SQLiteAdapter) CollectionExists(ctx context.Context, name string) (bool, error) {
query := "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
var exists bool
err := a.GetDB().QueryRowContext(ctx, query, name).Scan(&exists)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
// ListCollections 获取所有集合(表)列表
func (a *SQLiteAdapter) ListCollections(ctx context.Context) ([]string, error) {
query := `SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name`
rows, err := a.GetDB().QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var tables []string
for rows.Next() {
var table string
if err := rows.Scan(&table); err != nil {
return nil, err
}
tables = append(tables, table)
}
return tables, rows.Err()
}
// FindPage 分页查询文档
func (a *SQLiteAdapter) FindPage(ctx context.Context, collection string, skip, limit int) (database.PageResult, error) {
query := fmt.Sprintf("SELECT id, data, created_at, updated_at FROM %s LIMIT ? OFFSET ?", collection)
rows, err := a.GetDB().QueryContext(ctx, query, limit, skip)
if err != nil {
return database.PageResult{}, err
}
defer rows.Close()
var docs []types.Document
for rows.Next() {
var doc types.Document
var jsonData []byte
var createdAtStr, updatedAtStr string
err := rows.Scan(&doc.ID, &jsonData, &createdAtStr, &updatedAtStr)
if err != nil {
return database.PageResult{}, err
}
if err := json.Unmarshal(jsonData, &doc.Data); err != nil {
return database.PageResult{}, err
}
// 解析时间字符串
if parsedTime, err := parseSQLiteTime(createdAtStr); err == nil {
doc.CreatedAt = parsedTime
}
if parsedTime, err := parseSQLiteTime(updatedAtStr); err == nil {
doc.UpdatedAt = parsedTime
}
docs = append(docs, doc)
}
// 检查是否还有更多数据
checkQuery := fmt.Sprintf("SELECT 1 FROM %s LIMIT 1 OFFSET ?", collection)
checkRows, err := a.GetDB().QueryContext(ctx, checkQuery, skip+limit)
if err != nil {
return database.PageResult{}, err
}
defer checkRows.Close()
hasMore := checkRows.Next()
return database.PageResult{
Documents: docs,
HasMore: hasMore,
}, rows.Err()
}
// parseSQLiteTime 解析SQLite时间字符串
func parseSQLiteTime(timeStr string) (time.Time, error) {
layouts := []string{
"2006-01-02 15:04:05",
"2006-01-02T15:04:05Z",
"2006-01-02T15:04:05.000Z",
time.RFC3339,
time.RFC3339Nano,
"2006-01-02 15:04:05.000000",
"2006-01-02 15:04:05.000",
}
for _, layout := range layouts {
if t, err := time.Parse(layout, timeStr); err == nil {
return t, nil
}
}
// 如果标准布局失败则尝试解析Unix时间戳
if unixTime, err := time.Parse("2006-01-02 15:04:05.000000", timeStr); err == nil {
return unixTime, nil
}
// 默认返回当前时间
return time.Now(), fmt.Errorf("无法解析时间字符串: %s", timeStr)
}
// Connect 连接 SQLite 数据库
func (a *SQLiteAdapter) Connect(ctx context.Context, dsn string) error {
// SQLite 需要启用 JSON1 扩展(大多数构建已默认包含)
@ -49,17 +163,6 @@ func (a *SQLiteAdapter) CreateCollection(ctx context.Context, name string) error
return err
}
// CollectionExists 检查集合是否存在
func (a *SQLiteAdapter) CollectionExists(ctx context.Context, name string) (bool, error) {
query := `SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?`
var count int
err := a.GetDB().QueryRowContext(ctx, query, name).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
// FindAll 查询所有文档(使用 SQLite JSON 函数)
func (a *SQLiteAdapter) FindAll(ctx context.Context, collection string) ([]types.Document, error) {
query := fmt.Sprintf("SELECT id, data, created_at, updated_at FROM %s", collection)
@ -123,24 +226,3 @@ func (a *SQLiteAdapter) InsertMany(ctx context.Context, collection string, docs
return tx.Commit()
}
// ListCollections 获取所有集合(表)列表
func (a *SQLiteAdapter) ListCollections(ctx context.Context) ([]string, error) {
query := `SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name`
rows, err := a.GetDB().QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var tables []string
for rows.Next() {
var table string
if err := rows.Scan(&table); err != nil {
return nil, err
}
tables = append(tables, table)
}
return tables, rows.Err()
}

View File

@ -1,6 +1,7 @@
package engine
import (
"context"
"fmt"
"sort"
"strings"
@ -733,6 +734,24 @@ func (e *AggregationEngine) executeCount(spec interface{}, docs []types.Document
}, nil
}
// StreamAggregationOptions 流式聚合选项
type StreamAggregationOptions struct {
BufferSize int // 缓冲区大小默认为100
Concurrent bool // 是否并发处理默认为false
MaxConcurrency int // 最大并发数默认为runtime.NumCPU()
}
// StreamExecute 流式执行聚合管道
func (e *AggregationEngine) StreamExecute(
ctx context.Context,
collection string,
pipeline []types.AggregateStage,
opts StreamAggregationOptions,
) (<-chan []types.Document, <-chan error) {
streamEngine := NewStreamAggregationEngine(e.store)
return streamEngine.StreamExecute(ctx, collection, pipeline, opts)
}
// 辅助函数
func isTrue(v interface{}) bool {
switch val := v.(type) {

View File

@ -22,9 +22,91 @@ type MemoryStore struct {
// Collection 内存集合
type Collection struct {
name string
documents map[string]types.Document // id -> Document
mu sync.RWMutex
name string
documents map[string]types.Document // id -> Document
mu sync.RWMutex
pageSize int // 分页大小
loadedAll bool // 是否已经加载了全部数据
totalCount int // 总文档数(如果知道的话)
}
// DocumentIterator 文档迭代器
type DocumentIterator struct {
store *MemoryStore
collection string
keys []string
currentIdx int
batchSize int
mutex sync.Mutex
}
// GetDocumentIterator 获取文档迭代器
func (ms *MemoryStore) GetDocumentIterator(collection string, batchSize int) (*DocumentIterator, error) {
ms.mu.RLock()
defer ms.mu.RUnlock()
coll, exists := ms.collections[collection]
if !exists {
return nil, fmt.Errorf("collection %s does not exist", collection)
}
// 获取所有键
keys := make([]string, 0, len(coll.documents))
for k := range coll.documents {
keys = append(keys, k)
}
return &DocumentIterator{
store: ms,
collection: collection,
keys: keys,
currentIdx: 0,
batchSize: batchSize,
}, nil
}
// HasNext 检查是否还有更多文档
func (iter *DocumentIterator) HasNext() bool {
iter.mutex.Lock()
defer iter.mutex.Unlock()
return iter.currentIdx < len(iter.keys)
}
// NextBatch 获取下一批文档
func (iter *DocumentIterator) NextBatch() ([]types.Document, error) {
iter.mutex.Lock()
defer iter.mutex.Unlock()
if iter.currentIdx >= len(iter.keys) {
return []types.Document{}, nil
}
endIdx := iter.currentIdx + iter.batchSize
if endIdx > len(iter.keys) {
endIdx = len(iter.keys)
}
batch := make([]types.Document, 0, endIdx-iter.currentIdx)
iter.store.mu.RLock()
coll := iter.store.collections[iter.collection]
iter.store.mu.RUnlock()
for i := iter.currentIdx; i < endIdx; i++ {
doc, exists := coll.documents[iter.keys[i]]
if exists {
batch = append(batch, doc)
}
}
iter.currentIdx = endIdx
return batch, nil
}
// Close 关闭迭代器
func (iter *DocumentIterator) Close() {
// 目前不需要特殊关闭操作
}
// NewMemoryStore 创建内存存储
@ -35,7 +117,7 @@ func NewMemoryStore(adapter database.DatabaseAdapter) *MemoryStore {
}
}
// Initialize 从数据库加载所有现有集合到内存
// Initialize 初始化内存存储,但不加载所有数据,只创建集合结构
func (ms *MemoryStore) Initialize(ctx context.Context) error {
if ms.adapter == nil {
log.Println("[INFO] No database adapter, skipping initialization")
@ -55,40 +137,25 @@ func (ms *MemoryStore) Initialize(ctx context.Context) error {
log.Printf("[INFO] Found %d collections in database", len(tables))
// 逐个加载集合
loadedCount := 0
// 仅为每个集合创建结构,不加载数据
createdCount := 0
for _, tableName := range tables {
// 从数据库加载所有文档
docs, err := ms.adapter.FindAll(ctx, tableName)
if err != nil {
log.Printf("[WARN] Failed to load collection %s: %v", tableName, err)
continue
}
// 创建集合并加载文档
// 注意:为了兼容 HTTP API 的 dbName.collection 格式,我们同时创建两个名称的引用
ms.mu.Lock()
coll := &Collection{
name: tableName,
documents: make(map[string]types.Document),
// 检查是否已存在
if _, exists := ms.collections[tableName]; !exists {
ms.collections[tableName] = &Collection{
name: tableName,
documents: make(map[string]types.Document),
pageSize: 1000, // 默认每页1000条记录
loadedAll: false,
}
createdCount++
log.Printf("[DEBUG] Created collection structure for %s", tableName)
}
for _, doc := range docs {
coll.documents[doc.ID] = doc
}
// 以表名作为集合名存储例如users
ms.collections[tableName] = coll
// TODO: 如果需要支持 dbName.collection 格式,需要在这里建立映射
// 但目前无法确定 dbName所以暂时只使用纯表名
ms.mu.Unlock()
loadedCount++
log.Printf("[DEBUG] Loaded collection %s with %d documents", tableName, len(docs))
}
log.Printf("[INFO] Successfully loaded %d collections from database", loadedCount)
log.Printf("[INFO] Created %d collection structures (data will be loaded lazily)", createdCount)
return nil
}
@ -97,9 +164,95 @@ func CreateTestCollectionForTesting(store *MemoryStore, name string, documents m
store.collections[name] = &Collection{
name: name,
documents: documents,
pageSize: 1000,
loadedAll: true,
}
}
// LoadCollectionPage 按页加载集合数据
func (ms *MemoryStore) LoadCollectionPage(ctx context.Context, name string, page int) error {
coll, err := ms.GetCollection(name)
if err != nil {
return err
}
coll.mu.Lock()
defer coll.mu.Unlock()
// 如果已经加载了全部数据,则无需再加载
if coll.loadedAll {
return nil
}
skip := page * coll.pageSize
result, err := ms.adapter.FindPage(ctx, name, skip, coll.pageSize)
if err != nil {
return fmt.Errorf("failed to load page %d of collection %s: %w", page, name, err)
}
// 将页面数据添加到内存中
for _, doc := range result.Documents {
coll.documents[doc.ID] = doc
}
// 如果没有更多数据了,标记为已加载全部
if !result.HasMore {
coll.loadedAll = true
}
log.Printf("[INFO] Loaded page %d for collection %s (%d documents)", page, name, len(result.Documents))
return nil
}
// LoadEntireCollection 加载整个集合(谨慎使用,大数据集会导致内存问题)
func (ms *MemoryStore) LoadEntireCollection(ctx context.Context, name string) error {
coll, err := ms.GetCollection(name)
if err != nil {
return err
}
coll.mu.Lock()
defer coll.mu.Unlock()
// 直接从数据库加载所有数据
docs, err := ms.adapter.FindAll(ctx, name)
if err != nil {
return fmt.Errorf("failed to load entire collection %s: %w", name, err)
}
// 清空现有数据并加载新数据
coll.documents = make(map[string]types.Document)
for _, doc := range docs {
coll.documents[doc.ID] = doc
}
coll.loadedAll = true
log.Printf("[INFO] Loaded entire collection %s (%d documents)", name, len(docs))
return nil
}
// LazyLoadDocument 按需加载单个文档
func (ms *MemoryStore) LazyLoadDocument(ctx context.Context, collectionName, docID string) (types.Document, error) {
coll, err := ms.GetCollection(collectionName)
if err != nil {
return types.Document{}, err
}
coll.mu.RLock()
// 检查文档是否已在内存中
if doc, exists := coll.documents[docID]; exists {
coll.mu.RUnlock()
return doc, nil
}
coll.mu.RUnlock()
// 如果不在内存中,并且尚未加载全部数据,则尝试从数据库获取单个文档
// 这里假设数据库适配器有按ID查询的方法如果没有就加载一页数据
// 由于当前接口没有按ID查询的方法我们暂时返回错误让上层决定是否加载整页
return types.Document{}, errors.ErrDocumentNotFnd
}
// LoadCollection 从数据库加载集合到内存
func (ms *MemoryStore) LoadCollection(ctx context.Context, name string) error {
// 检查集合是否存在
@ -167,6 +320,8 @@ func (ms *MemoryStore) Insert(collection string, doc types.Document) error {
ms.collections[collection] = &Collection{
name: collection,
documents: make(map[string]types.Document),
pageSize: 1000,
loadedAll: true, // 新创建的集合认为是"已完全加载",因为我们只添加新文档
}
coll = ms.collections[collection]
ms.mu.Unlock()

View File

@ -0,0 +1,706 @@
package engine
import (
"context"
"fmt"
"math/rand"
"runtime"
"sort"
"strings"
"time"
"git.kingecg.top/kingecg/gomog/pkg/types"
)
// StreamAggregationEngine 流式聚合引擎
type StreamAggregationEngine struct {
store *MemoryStore
}
// NewStreamAggregationEngine 创建流式聚合引擎
func NewStreamAggregationEngine(store *MemoryStore) *StreamAggregationEngine {
return &StreamAggregationEngine{store: store}
}
// StreamExecute 流式执行聚合管道
func (e *StreamAggregationEngine) StreamExecute(
ctx context.Context,
collection string,
pipeline []types.AggregateStage,
opts StreamAggregationOptions,
) (<-chan []types.Document, <-chan error) {
if opts.BufferSize <= 0 {
opts.BufferSize = 100
}
if opts.MaxConcurrency <= 0 {
opts.MaxConcurrency = runtime.NumCPU()
}
resultChan := make(chan []types.Document, opts.BufferSize)
errChan := make(chan error, 1)
go func() {
defer close(resultChan)
defer close(errChan)
// 获取文档迭代器
docIter, err := e.store.GetDocumentIterator(collection, opts.BufferSize)
if err != nil {
errChan <- err
return
}
defer docIter.Close()
// 分批处理文档
for docIter.HasNext() {
select {
case <-ctx.Done():
errChan <- ctx.Err()
return
default:
}
batch, err := docIter.NextBatch()
if err != nil {
errChan <- err
return
}
if len(batch) == 0 {
continue
}
// 执行管道处理
processed, err := e.processBatch(ctx, batch, pipeline, opts)
if err != nil {
errChan <- err
return
}
if len(processed) > 0 {
resultChan <- processed
}
}
}()
return resultChan, errChan
}
// processBatch 处理单个批次的文档
func (e *StreamAggregationEngine) processBatch(
ctx context.Context,
batch []types.Document,
pipeline []types.AggregateStage,
opts StreamAggregationOptions,
) ([]types.Document, error) {
var result []types.Document = batch
for _, stage := range pipeline {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
var err error
result, err = e.executeStageStreaming(stage, result, opts)
if err != nil {
return nil, err
}
// 如果结果为空,提前终止
if len(result) == 0 {
break
}
}
return result, nil
}
// executeStageStreaming 执行单个阶段的流式处理
func (e *StreamAggregationEngine) executeStageStreaming(
stage types.AggregateStage,
docs []types.Document,
opts StreamAggregationOptions,
) ([]types.Document, error) {
// 对于某些操作,我们仍需完整数据集,所以需要特殊处理
switch stage.Stage {
case "$match":
return e.executeMatch(stage.Spec, docs)
case "$project":
return e.executeProject(stage.Spec, docs)
case "$limit":
return e.executeLimit(stage.Spec, docs)
case "$skip":
return e.executeSkip(stage.Spec, docs)
case "$sort":
// $sort 需要完整的数据集,所以不能完全流式处理
// 但在批处理中是可以处理的
return e.executeSort(stage.Spec, docs)
case "$unwind":
return e.executeUnwind(stage.Spec, docs)
case "$addFields", "$set":
return e.executeAddFields(stage.Spec, docs)
case "$unset":
return e.executeUnset(stage.Spec, docs)
case "$sample":
return e.executeSample(stage.Spec, docs)
case "$replaceRoot":
return e.executeReplaceRoot(stage.Spec, docs)
case "$replaceWith":
return e.executeReplaceWith(stage.Spec, docs)
// 对于需要全局数据的操作,如 $group, $lookup, $graphLookup 等
// 我们需要特殊的处理方式
case "$group":
// $group 需要完整的数据集,不能流式处理
// 这里我们返回错误,提示用户使用传统聚合
return nil, fmt.Errorf("$group stage cannot be processed in streaming mode, use regular aggregation instead")
case "$lookup":
// $lookup 需要另一个集合的完整数据,不能流式处理
return nil, fmt.Errorf("$lookup stage cannot be processed in streaming mode, use regular aggregation instead")
case "$graphLookup":
// $graphLookup 需要完整数据,不能流式处理
return nil, fmt.Errorf("$graphLookup stage cannot be processed in streaming mode, use regular aggregation instead")
// Batch 5 新增阶段
case "$unionWith":
// $unionWith 需要另一个集合的完整数据
return nil, fmt.Errorf("$unionWith stage cannot be processed in streaming mode, use regular aggregation instead")
case "$redact":
return e.executeRedact(stage.Spec, docs)
case "$indexStats", "$collStats":
// 这些统计操作需要完整数据
return nil, fmt.Errorf("$indexStats and $collStats stages cannot be processed in streaming mode, use regular aggregation instead")
case "$out", "$merge":
// 输出操作可以处理,但需要在最后阶段
return e.executeOutputStages(stage, docs)
default:
return docs, nil // 未知阶段,跳过
}
}
// executeOutputStages 处理输出阶段
func (e *StreamAggregationEngine) executeOutputStages(
stage types.AggregateStage,
docs []types.Document,
) ([]types.Document, error) {
switch stage.Stage {
case "$out":
return docs, fmt.Errorf("$out not supported in streaming mode")
case "$merge":
return docs, fmt.Errorf("$merge not supported in streaming mode")
default:
return docs, nil
}
}
// executeAddFields 执行 $addFields 阶段
func (e *StreamAggregationEngine) executeAddFields(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
addFieldsSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
var results []types.Document
for _, doc := range docs {
// 深拷贝文档
newData := deepCopyMap(doc.Data)
// 添加字段
for field, expr := range addFieldsSpec {
newData[field] = e.evaluateExpression(newData, expr)
}
results = append(results, types.Document{
ID: doc.ID,
Data: newData,
})
}
return results, nil
}
// executeUnset 执行 $unset 阶段
func (e *StreamAggregationEngine) executeUnset(spec interface{}, docs []types.Document) ([]types.Document, error) {
unsetSpec, ok := spec.([]interface{})
if !ok {
// 如果是字符串,转换为数组
if str, isStr := spec.(string); isStr {
unsetSpec = []interface{}{str}
} else {
return docs, nil
}
}
var results []types.Document
for _, doc := range docs {
// 深拷贝文档
newData := deepCopyMap(doc.Data)
// 移除字段
for _, field := range unsetSpec {
if fieldName, isStr := field.(string); isStr {
delete(newData, fieldName)
}
}
results = append(results, types.Document{
ID: doc.ID,
Data: newData,
})
}
return results, nil
}
// executeSample 执行 $sample 阶段
func (e *StreamAggregationEngine) executeSample(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
sampleSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
size, ok := sampleSpec["size"].(float64)
if !ok {
return docs, nil
}
count := int(size)
if count >= len(docs) {
return docs, nil
}
if count <= 0 {
return []types.Document{}, nil
}
// 使用洗牌算法随机选择
shuffled := make([]types.Document, len(docs))
copy(shuffled, docs)
// Fisher-Yates 洗牌算法的变种,只取前 count 个
source := rand.NewSource(time.Now().UnixNano())
rng := rand.New(source)
for i := 0; i < count; i++ {
j := len(shuffled) - 1 - i
r := i + rng.Intn(j-i+1)
shuffled[r], shuffled[i] = shuffled[i], shuffled[r]
}
return shuffled[:count], nil
}
// executeReplaceRoot 执行 $replaceRoot 阶段
func (e *StreamAggregationEngine) executeReplaceRoot(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
replaceRootSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
newRootField, ok := replaceRootSpec["newRoot"].(string)
if !ok {
return docs, nil
}
var results []types.Document
for _, doc := range docs {
// 获取新的根对象
newRoot := getNestedValue(doc.Data, newRootField)
if newRootMap, ok := newRoot.(map[string]interface{}); ok {
results = append(results, types.Document{
ID: doc.ID,
Data: newRootMap,
})
} else {
// 如果不是对象,创建一个包含该值的对象
results = append(results, types.Document{
ID: doc.ID,
Data: map[string]interface{}{newRootField: newRoot},
})
}
}
return results, nil
}
// executeReplaceWith 执行 $replaceWith 阶段
func (e *StreamAggregationEngine) executeReplaceWith(spec interface{}, docs []types.Document) ([]types.Document, error) {
var results []types.Document
for _, doc := range docs {
// 使用 evaluateExpression 获取新的文档数据
newData := e.evaluateExpression(doc.Data, spec)
if newDataMap, ok := newData.(map[string]interface{}); ok {
results = append(results, types.Document{
ID: doc.ID,
Data: newDataMap,
})
} else {
// 如果不是对象,创建一个包含该值的对象
results = append(results, types.Document{
ID: doc.ID,
Data: map[string]interface{}{"value": newData},
})
}
}
return results, nil
}
// executeRedact 执行 $redact 阶段
func (e *StreamAggregationEngine) executeRedact(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 这里需要复制 aggregate.go 中的实现
// 为简洁起见,暂时返回错误
return nil, fmt.Errorf("$redact stage not yet implemented in streaming mode")
}
// evaluateExpression 评估表达式(复制自 aggregate.go
func (e *StreamAggregationEngine) evaluateExpression(data map[string]interface{}, expr interface{}) interface{} {
// 复制自 aggregate.go 中的实现
// 处理字段引用(以 $ 开头的字符串)
if fieldStr, ok := expr.(string); ok && len(fieldStr) > 0 && fieldStr[0] == '$' {
fieldName := fieldStr[1:] // 移除 $ 前缀
return getNestedValue(data, fieldName)
}
if exprMap, ok := expr.(map[string]interface{}); ok {
for op, operand := range exprMap {
switch op {
case "$concat":
return e.concat(operand, data)
case "$toUpper":
return strings.ToUpper(fmt.Sprintf("%v", e.getFieldValueStr(types.Document{Data: data}, operand)))
case "$toLower":
return strings.ToLower(fmt.Sprintf("%v", e.getFieldValueStr(types.Document{Data: data}, operand)))
case "$add":
return e.add(operand, data)
case "$multiply":
return e.multiply(operand, data)
case "$ifNull":
return e.ifNull(operand, data)
case "$cond":
return e.cond(operand, data)
// 可以根据需要添加更多操作
}
}
}
return expr
}
// 以下是一些辅助函数的占位实现
func (e *StreamAggregationEngine) concat(operand interface{}, data map[string]interface{}) interface{} {
// 简单实现
if arr, ok := operand.([]interface{}); ok {
result := ""
for _, item := range arr {
evaluated := e.evaluateExpression(data, item)
result += fmt.Sprintf("%v", evaluated)
}
return result
}
return ""
}
func (e *StreamAggregationEngine) getFieldValueStr(doc types.Document, field interface{}) string {
// 简单实现
if str, ok := e.getFieldValue(doc, field).(string); ok {
return str
}
return ""
}
func (e *StreamAggregationEngine) getFieldValue(doc types.Document, field interface{}) interface{} {
switch f := field.(type) {
case string:
if len(f) > 0 && f[0] == '$' {
return getNestedValue(doc.Data, f[1:])
}
return f
default:
return field
}
}
func (e *StreamAggregationEngine) add(operand interface{}, data map[string]interface{}) interface{} {
if arr, ok := operand.([]interface{}); ok {
sum := 0.0
for _, item := range arr {
evaluated := e.evaluateExpression(data, item)
sum += toFloat64(evaluated)
}
return sum
}
return 0
}
func (e *StreamAggregationEngine) multiply(operand interface{}, data map[string]interface{}) interface{} {
if arr, ok := operand.([]interface{}); ok {
result := 1.0
for _, item := range arr {
evaluated := e.evaluateExpression(data, item)
result *= toFloat64(evaluated)
}
return result
}
return 0
}
func (e *StreamAggregationEngine) ifNull(operand interface{}, data map[string]interface{}) interface{} {
if arr, ok := operand.([]interface{}); ok && len(arr) == 2 {
evaluatedFirst := e.evaluateExpression(data, arr[0])
if evaluatedFirst != nil {
return evaluatedFirst
}
return e.evaluateExpression(data, arr[1])
}
return nil
}
func (e *StreamAggregationEngine) cond(operand interface{}, data map[string]interface{}) interface{} {
if condMap, ok := operand.(map[string]interface{}); ok {
ifCond, hasIf := condMap["if"]
thenVal, hasThen := condMap["then"]
elseVal, hasElse := condMap["else"]
if hasIf && hasThen && hasElse {
ifVal := e.evaluateExpression(data, ifCond)
if isTrue(ifVal) {
return e.evaluateExpression(data, thenVal)
}
return e.evaluateExpression(data, elseVal)
}
}
return nil
}
// executeMatch 执行 $match 阶段
func (e *StreamAggregationEngine) executeMatch(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
var filter map[string]interface{}
if f, ok := spec.(types.Filter); ok {
filter = f
} else if f, ok := spec.(map[string]interface{}); ok {
filter = f
} else {
return docs, nil
}
var results []types.Document
for _, doc := range docs {
if MatchFilter(doc.Data, filter) {
results = append(results, doc)
}
}
return results, nil
}
// executeProject 执行 $project 阶段
func (e *StreamAggregationEngine) executeProject(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
projectSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
var results []types.Document
for _, doc := range docs {
projected := e.projectDocument(doc.Data, projectSpec)
results = append(results, types.Document{
ID: doc.ID,
Data: projected,
})
}
return results, nil
}
// projectDocument 投影文档
func (e *StreamAggregationEngine) projectDocument(data map[string]interface{}, spec map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for field, include := range spec {
if field == "_id" {
// 特殊处理 _id
if isFalse(include) {
// 排除 _id
} else {
result["_id"] = data["_id"]
}
continue
}
if isTrue(include) {
// 包含字段
result[field] = getNestedValue(data, field)
} else if isFalse(include) {
// 排除字段(在包含模式下不处理)
continue
} else {
// 表达式
result[field] = e.evaluateExpression(data, include)
}
}
return result
}
// executeLimit 执行 $limit 阶段
func (e *StreamAggregationEngine) executeLimit(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
limit := 0
switch l := spec.(type) {
case int:
limit = l
case int64:
limit = int(l)
case float64:
limit = int(l)
}
if limit <= 0 || limit >= len(docs) {
return docs, nil
}
return docs[:limit], nil
}
// executeSkip 执行 $skip 阶段
func (e *StreamAggregationEngine) executeSkip(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
skip := 0
switch s := spec.(type) {
case int:
skip = s
case int64:
skip = int(s)
case float64:
skip = int(s)
}
if skip <= 0 {
return docs, nil
}
if skip >= len(docs) {
return []types.Document{}, nil
}
return docs[skip:], nil
}
// executeUnwind 执行 $unwind 阶段
func (e *StreamAggregationEngine) executeUnwind(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
var path string
var preserveNull bool
switch s := spec.(type) {
case string:
path = s
case map[string]interface{}:
if p, ok := s["path"].(string); ok {
path = p
}
if pn, ok := s["preserveNullAndEmptyArrays"].(bool); ok {
preserveNull = pn
}
}
if path == "" || path[0] != '$' {
return docs, nil
}
fieldPath := path[1:]
var results []types.Document
for _, doc := range docs {
arr := getNestedValue(doc.Data, fieldPath)
if arr == nil {
if preserveNull {
results = append(results, doc)
}
continue
}
array, ok := arr.([]interface{})
if !ok || len(array) == 0 {
if preserveNull {
results = append(results, doc)
}
continue
}
for _, item := range array {
newData := deepCopyMap(doc.Data)
setNestedValue(newData, fieldPath, item)
results = append(results, types.Document{
ID: doc.ID,
Data: newData,
})
}
}
return results, nil
}
// executeSort 执行 $sort 阶段
func (e *StreamAggregationEngine) executeSort(spec interface{}, docs []types.Document) ([]types.Document, error) {
// 从 aggregate.go 复制的实现
sortSpec, ok := spec.(map[string]interface{})
if !ok {
return docs, nil
}
// 转换为排序字段映射
sortFields := make(map[string]int)
for field, direction := range sortSpec {
dir := 1
switch d := direction.(type) {
case int:
dir = d
case int64:
dir = int(d)
case float64:
dir = int(d)
}
sortFields[field] = dir
}
// 创建可排序的副本
sorted := make([]types.Document, len(docs))
copy(sorted, docs)
sort.Slice(sorted, func(i, j int) bool {
return e.compareDocs(sorted[i], sorted[j], sortFields)
})
return sorted, nil
}
// compareDocs 比较两个文档
func (e *StreamAggregationEngine) compareDocs(a, b types.Document, sortFields map[string]int) bool {
for field, dir := range sortFields {
valA := getNestedValue(a.Data, field)
valB := getNestedValue(b.Data, field)
cmp := compareValues(valA, valB)
if cmp != 0 {
if dir < 0 {
return cmp > 0
}
return cmp < 0
}
}
return false
}