diff --git a/pkg/dal/def_aggregate.go b/pkg/dal/def_aggregate.go index 3a3532cac..6241d846d 100644 --- a/pkg/dal/def_aggregate.go +++ b/pkg/dal/def_aggregate.go @@ -52,6 +52,7 @@ type ( Label string Expression *ql.ASTNode Type Type + Store Codec } ) diff --git a/store/adapters/rdbms/dal/iterator.go b/store/adapters/rdbms/dal/iterator.go index 8090f6a63..eafb7d825 100644 --- a/store/adapters/rdbms/dal/iterator.go +++ b/store/adapters/rdbms/dal/iterator.go @@ -16,20 +16,28 @@ import ( type ( iterator struct { - ms *model + // source model; how data we are reading from is shaped + src *model + + // destination model; how data we are reading into is shaped + // this is used to create scan buffer + // when not doing plain selection from one table final results might + // require a different list of columns and scanning needs to be adjusted + dst *model + + // buffer for scanned rows + scanBuf []any + + // results from the last read rows *sql.Rows + // last error err error query *goqu.SelectDataset sorting filter.SortExprSet cursor *filter.PagingCursor limit uint - - // @todo should filter also be here? - - // buffer for scanned rows - scanBuf []any } ) @@ -85,7 +93,7 @@ func (i *iterator) fetch(ctx context.Context) (rows *sql.Rows, err error) { // we're going to init scan buffer only once // and rely on the query.Rows.Scan function to // fill it up with fresh values! - i.scanBuf = i.ms.table.MakeScanBuffer() + i.scanBuf = i.dst.table.MakeScanBuffer() } var ( @@ -109,14 +117,14 @@ func (i *iterator) fetch(ctx context.Context) (rows *sql.Rows, err error) { // @todo this needs to work with embedded attributes (non physical columns) as well! tmp, err = rdbms.CursorExpression( cur, - func(ident string) (exp.LiteralExpression, error) { return i.ms.table.AttributeExpression(ident) }, + func(ident string) (exp.LiteralExpression, error) { return i.src.table.AttributeExpression(ident) }, func(ident string, val any) (exp.LiteralExpression, error) { - attr := i.ms.model.Attributes.FindByIdent(ident) + attr := i.dst.model.Attributes.FindByIdent(ident) if attr == nil { panic("unknown attribute " + ident + " used in cursor expression cast callback") } - enc, err := i.ms.dialect.TypeWrap(attr.Type).Encode(val) + enc, err := i.dst.dialect.TypeWrap(attr.Type).Encode(val) if err != nil { return nil, err } @@ -141,7 +149,7 @@ func (i *iterator) fetch(ctx context.Context) (rows *sql.Rows, err error) { innerSort.Reverse() // Wrap the fil & ordered sub-query with cursor-conditions - sqlExpr = i.ms.dialect.GOQU().From(sqlExpr.Order(i.orderByExp(innerSort)...).As(i.ms.model.Ident)) + sqlExpr = i.src.dialect.GOQU().From(sqlExpr.Order(i.orderByExp(innerSort)...).As(i.src.model.Ident)) // make sure we don't reverse it again } else { @@ -167,7 +175,7 @@ func (i *iterator) fetch(ctx context.Context) (rows *sql.Rows, err error) { return nil, err } - rows, err = i.ms.conn.QueryContext(ctx, query, args...) + rows, err = i.src.conn.QueryContext(ctx, query, args...) if errors.Is(err, sql.ErrNoRows) { // no rows, no error return nil, nil @@ -180,7 +188,7 @@ func (i *iterator) fetch(ctx context.Context) (rows *sql.Rows, err error) { func (i *iterator) orderByExp(sort filter.SortExprSet) (oe []exp.OrderedExpression) { for _, s := range sort { // assuming all columns were pre-validated! - tmp, _ := i.ms.table.AttributeExpression(s.Column) + tmp, _ := i.src.table.AttributeExpression(s.Column) if s.Descending { oe = append(oe, exp.NewOrderedExpression(tmp, exp.DescSortDir, exp.NoNullsSortType)) @@ -201,7 +209,7 @@ func (i *iterator) Scan(r dal.ValueSetter) (err error) { return err } - if err = i.ms.table.Decode(i.scanBuf, r); err != nil { + if err = i.dst.table.Decode(i.scanBuf, r); err != nil { return } @@ -250,7 +258,7 @@ func (i *iterator) collectCursorValues(r dal.ValueGetter) (_ *filter.PagingCurso pKeys = make(map[string]bool) ) - for _, c := range i.ms.table.Columns() { + for _, c := range i.dst.table.Columns() { if c.IsPrimaryKey() { attrIdent := c.Attribute().Ident pKeys[attrIdent] = true diff --git a/store/adapters/rdbms/dal/main_test.go b/store/adapters/rdbms/dal/main_test.go index d30d56ed6..fbca60fff 100644 --- a/store/adapters/rdbms/dal/main_test.go +++ b/store/adapters/rdbms/dal/main_test.go @@ -82,8 +82,12 @@ func (r kv) String() string { // build string by iterating over sorted keys and appending values var out string - for _, k := range keys { - out += fmt.Sprintf("%s=%v ", k, r[k]) + for i, k := range keys { + if i > 0 { + out += " " + } + + out += fmt.Sprintf("%s=%v", k, r[k]) } return out diff --git a/store/adapters/rdbms/dal/model.go b/store/adapters/rdbms/dal/model.go index dd4edf390..35d1d9b6f 100644 --- a/store/adapters/rdbms/dal/model.go +++ b/store/adapters/rdbms/dal/model.go @@ -224,22 +224,22 @@ func (d *model) Search(f filter.Filter) (i *iterator, err error) { orderBy = append(orderBy, &filter.SortExpr{Column: attrIdent, Descending: orderBy.LastDescending()}) } - var ( - q *goqu.SelectDataset - ) + i = &iterator{ + // source and destination is the same + src: d, + dst: d, - q = d.searchSql(f) - if err = q.Error(); err != nil { - return - } - - return &iterator{ - ms: d, - query: q, sorting: orderBy, cursor: f.Cursor(), limit: f.Limit(), - }, nil + } + + i.query = d.searchSql(f) + if err = i.query.Error(); err != nil { + return + } + + return } // Aggregate constructs SELECT sql with group-by and an optional having CLAUSE @@ -262,25 +262,35 @@ func (d *model) Aggregate(f filter.Filter, groupBy []*dal.AggregateAttr, aggrExp } i = &iterator{ - ms: &model{}, sorting: f.OrderBy(), limit: f.Limit(), } var ( - dalModel = &dal.Model{} - attr *dal.Attribute + // source model; how data we are reading from is shaped + srcModel = &dal.Model{} + + // destination model; how data we are reading into is shaped + dstModel = &dal.Model{} + + srcAttr, dstAttr *dal.Attribute ) // prepare a bit modified module that // describes aggregated columns (prepending attributes used for group-by) for _, c := range append(groupBy, aggrExpr...) { - attr = &dal.Attribute{ + srcAttr = &dal.Attribute{ + Ident: c.Identifier, + Type: c.Type, + Store: c.Store, + } + srcModel.Attributes = append(srcModel.Attributes, srcAttr) + + dstAttr = &dal.Attribute{ Ident: c.Identifier, Type: c.Type, } - - dalModel.Attributes = append(dalModel.Attributes, attr) + dstModel.Attributes = append(dstModel.Attributes, dstAttr) } i.query = d.aggregateSql(f, groupBy, aggrExpr, having) @@ -288,7 +298,8 @@ func (d *model) Aggregate(f filter.Filter, groupBy []*dal.AggregateAttr, aggrExp return } - i.ms = Model(dalModel, d.conn, d.dialect) + i.src = Model(srcModel, d.conn, d.dialect) + i.dst = Model(dstModel, d.conn, d.dialect) return } @@ -494,47 +505,43 @@ func (d *model) aggregateSql(f filter.Filter, groupBy []*dal.AggregateAttr, out expr exp.Expression selected []any + + field = func(c *dal.AggregateAttr) (expr exp.Expression, err error) { + switch { + case len(c.RawExpr) > 0: + // @todo could probably be removed since RawExpr is only a temporary solution? + return d.parseQuery(c.RawExpr) + case c.Expression != nil: + return d.convertQuery(c.Expression) + } + + return d.table.AttributeExpression(c.Identifier) + } ) - for _, c := range groupBy { - if len(c.RawExpr) > 0 { - // @todo could probably be removed since RawExpr is only a temporary solution? - if expr, err = d.parseQuery(c.RawExpr); err != nil { - return q.SetError(err) - } - } else if c.Expression != nil { - if expr, err = d.convertQuery(c.Expression); err != nil { - return q.SetError(err) - } - } else { - expr = d.table.Ident().Col(c.Identifier) + for i, c := range groupBy { + if expr, err = field(c); err != nil { + return q.SetError(err) } // Add all group-by columns at the start - q = q.GroupByAppend(expr) + alias := fmt.Sprintf("group_by_%d_%s", i, c.Identifier) + expr = exp.NewAliasExpression(expr, alias) selected = append(selected, expr) + + // grouping by selected + q = q.GroupByAppend(alias) } q = q.Select(selected...) - for _, c := range out { - if len(c.RawExpr) > 0 { - // @todo could probably be removed since RawExpr is only a temporary solution? - if expr, err = d.parseQuery(c.RawExpr); err != nil { - return q.SetError(err) - } - } else { - if c.Expression == nil { - // expecting expression - return q.SetError(fmt.Errorf("expecting expression for aggregation")) - } - - if expr, err = d.convertQuery(c.Expression); err != nil { - return q.SetError(err) - } + for i, c := range out { + if expr, err = field(c); err != nil { + return q.SetError(err) } - // Add all group-by columns at the start + alias := fmt.Sprintf("aggr_%d_%s", i, c.Identifier) + expr = exp.NewAliasExpression(expr, alias) q = q.SelectAppend(expr) } diff --git a/store/adapters/rdbms/dal/model_test.go b/store/adapters/rdbms/dal/model_test.go index a1dafc49e..ad130095a 100644 --- a/store/adapters/rdbms/dal/model_test.go +++ b/store/adapters/rdbms/dal/model_test.go @@ -13,21 +13,25 @@ import ( "time" ) -func TestModel_Aggregate(t *testing.T) { +func TestModel_Search(t *testing.T) { _ = logger.Default() + const ( + items = 1000 + ) + var ( req = require.New(t) ctx = context.Background() baseModel = &dal.Model{ - Ident: "test_dal_Aggregation", + Ident: "test_dal_select", Attributes: []*dal.Attribute{ {Ident: "item", Type: &dal.TypeText{}}, - {Ident: "group", Type: &dal.TypeText{}}, + {Ident: "group", Type: &dal.TypeText{}, Store: &dal.CodecRecordValueSetJSON{Ident: "values"}, Sortable: true}, {Ident: "price", Type: &dal.TypeNumber{}, Filterable: true}, - {Ident: "published", Type: &dal.TypeBoolean{}}, + {Ident: "published", Type: &dal.TypeBoolean{}, Filterable: true}, }, } @@ -41,47 +45,33 @@ func TestModel_Aggregate(t *testing.T) { //ctx = logger.ContextWithValue(context.Background(), logger.MakeDebugLogger()) + t.Logf("Creating temporary table %q", table.Ident) table.Temporary = true req.NoError(s.DataDefiner.TableCreate(ctx, table)) - t.Log("Inserting test data") bm := time.Now() ctx = context.Background() // no need to log inserts count := 0 - for i := 1; i <= 1000; i++ { - for g := 1; g <= 5; g++ { - req.NoError(m.Create(ctx, &kv{ - "item": fmt.Sprintf("i%d", i), - "group": fmt.Sprintf("g%d", g), - "price": (100000 * g) + i, - "published": i%2 == 0, - })) - count++ - } + for i := 1; i <= items; i++ { + req.NoError(m.Create(ctx, &kv{ + "item": fmt.Sprintf("i%d", i), + "group": fmt.Sprintf("g%d", i%1000), + "price": i, + "published": i%2 == 0, + })) + count++ } - t.Logf("inserted %d entries in %v", count, time.Now().Sub(bm)) - t.Log("Aggregating all records, calculating min & max price per group") - i, err = m.Aggregate( - filter.Generic( - filter.WithOrderBy(filter.SortExprSet{ - &filter.SortExpr{Column: "group", Descending: true}, - }), - ), - // group-by - []*dal.AggregateAttr{ - {Identifier: "group", Type: &dal.TypeText{}}, - }, - // aggregation expressions - []*dal.AggregateAttr{ - {Identifier: "count", RawExpr: "COUNT(*)", Type: &dal.TypeNumber{}}, - {Identifier: "max", RawExpr: "MAX(price)", Type: &dal.TypeNumber{}}, - {Identifier: "min", RawExpr: "MIN(price)", Type: &dal.TypeNumber{}}, - {Identifier: "avg", RawExpr: "AVG(price)", Type: &dal.TypeNumber{}}, - {Identifier: "sum", RawExpr: "SUM(price)", Type: &dal.TypeNumber{}}, - }, - "", // <== here be having condition - ) + t.Logf("Inserted %d entries in %v", count, time.Now().Sub(bm)) + + t.Log("Search through records, filter out published") + i, err = m.Search(filter.Generic( + filter.WithExpression("published"), + filter.WithOrderBy(filter.SortExprSet{ + &filter.SortExpr{Column: "group"}, + }), + )) + req.NoError(err) req.NotNil(i) @@ -91,6 +81,92 @@ func TestModel_Aggregate(t *testing.T) { t.Log("Iterating over results") rows := make([]kv, 0, 5) + for i.Next(ctx) { + row = kv{} + req.NoError(i.Scan(row)) + rows = append(rows, row) + } + + req.NoError(i.Err()) + req.Len(rows, items/2) + req.Equal("group=g0 item=i1000 price=1000 published=1", rows[0].String()) +} + +func TestModel_Aggregate(t *testing.T) { + _ = logger.Default() + + var ( + req = require.New(t) + + ctx = context.Background() + + baseModel = &dal.Model{ + Ident: "test_dal_aggregation", + Attributes: []*dal.Attribute{ + {Ident: "item", Type: &dal.TypeText{}}, + {Ident: "date", Type: &dal.TypeDate{}, Store: &dal.CodecRecordValueSetJSON{Ident: "values"}, Sortable: true}, + {Ident: "group", Type: &dal.TypeText{}, Store: &dal.CodecRecordValueSetJSON{Ident: "values"}, Sortable: true}, + {Ident: "quantity", Type: &dal.TypeNumber{}, Store: &dal.CodecRecordValueSetJSON{Ident: "values"}, Filterable: true}, + {Ident: "price", Type: &dal.TypeNumber{}, Filterable: true}, + {Ident: "published", Type: &dal.TypeBoolean{}, Filterable: true}, + }, + } + + m = Model(baseModel, s.DB, s.Dialect) + + i dal.Iterator + + table, err = s.DataDefiner.ConvertModel(baseModel) + row kv + ) + + ctx = logger.ContextWithValue(context.Background(), logger.MakeDebugLogger()) + + t.Logf("Creating temporary table %q", table.Ident) + table.Temporary = true + req.NoError(s.DataDefiner.TableCreate(ctx, table)) + + req.NoError(m.Create(ctx, &kv{"item": "i1", "date": "2022-10-06", "group": "g1", "price": "1000", "quantity": "10", "published": true})) + req.NoError(m.Create(ctx, &kv{"item": "i2", "date": "2022-10-06", "group": "g1", "price": "3000", "quantity": "30", "published": true})) + req.NoError(m.Create(ctx, &kv{"item": "i3", "date": "2022-10-06", "group": "g2", "price": "4000", "quantity": "40", "published": false})) + req.NoError(m.Create(ctx, &kv{"item": "i4", "date": "2022-10-06", "group": "g2", "price": "1000", "quantity": "10", "published": true})) + req.NoError(m.Create(ctx, &kv{"item": "i5", "date": "2022-10-07", "group": "g2", "price": "1000", "quantity": "10", "published": false})) + req.NoError(m.Create(ctx, &kv{"item": "i6", "date": "2022-10-07", "group": "g2", "price": "5000", "quantity": "50", "published": true})) + + t.Log("Aggregating all records, calculating min & max price per group") + i, err = m.Aggregate( + filter.Generic( + filter.WithExpression("published"), + filter.WithOrderBy(filter.SortExprSet{ + &filter.SortExpr{Column: "group", Descending: true}, + &filter.SortExpr{Column: "date", Descending: false}, + }), + ), + // group-by + []*dal.AggregateAttr{ + {Identifier: "date", Type: &dal.TypeDate{}, Store: &dal.CodecRecordValueSetJSON{Ident: "values"}}, + {Identifier: "group", Type: &dal.TypeText{}, Store: &dal.CodecRecordValueSetJSON{Ident: "values"}}, + }, + // aggregation expressions + []*dal.AggregateAttr{ + {Identifier: "count", RawExpr: "COUNT(*)", Type: &dal.TypeNumber{}}, + {Identifier: "max", RawExpr: "MAX(price)", Type: &dal.TypeNumber{}}, + {Identifier: "min", RawExpr: "MIN(price)", Type: &dal.TypeNumber{}}, + {Identifier: "avg", RawExpr: "AVG(price)", Type: &dal.TypeNumber{}}, + {Identifier: "stock", RawExpr: "SUM(quantity)", Type: &dal.TypeNumber{}}, + }, + "", // <== here be having condition + ) + req.NoError(err) + req.NotNil(i) + + defer req.NoError(i.Close()) + + // uncomment to se generated query + ctx = logger.ContextWithValue(context.Background(), logger.MakeDebugLogger()) + + t.Log("Iterating over results") + rows := make([]kv, 0, 3) for i.Next(ctx) { row = kv{} req.NoError(i.Scan(row)) @@ -98,14 +174,14 @@ func TestModel_Aggregate(t *testing.T) { // due to difference of number of decimal digits in different DBs, we need to do this // to make sure we get the same result row["avg"] = fmt.Sprintf("%.2f", cast.ToFloat64(row["avg"])) + row["stock"] = fmt.Sprintf("%.2f", cast.ToFloat64(row["stock"])) rows = append(rows, row) } - req.Len(rows, 5) - req.Equal("avg=500500.50 count=1000 group=g5 max=501000 min=500001 sum=500500500 ", rows[0].String()) - req.Equal("avg=400500.50 count=1000 group=g4 max=401000 min=400001 sum=400500500 ", rows[1].String()) - req.Equal("avg=300500.50 count=1000 group=g3 max=301000 min=300001 sum=300500500 ", rows[2].String()) - req.Equal("avg=200500.50 count=1000 group=g2 max=201000 min=200001 sum=200500500 ", rows[3].String()) - req.Equal("avg=100500.50 count=1000 group=g1 max=101000 min=100001 sum=100500500 ", rows[4].String()) + req.NoError(i.Err()) + req.Len(rows, 3) + req.Equal("avg=1000.00 count=1 date=2022-10-06 group=g2 max=1000 min=1000 stock=10.00", rows[0].String()) + req.Equal("avg=5000.00 count=1 date=2022-10-07 group=g2 max=5000 min=5000 stock=50.00", rows[1].String()) + req.Equal("avg=2000.00 count=2 date=2022-10-06 group=g1 max=3000 min=1000 stock=40.00", rows[2].String()) }