diff --git a/compose/repository/record.go b/compose/repository/record.go index 70ae66996..67648eac1 100644 --- a/compose/repository/record.go +++ b/compose/repository/record.go @@ -155,40 +155,23 @@ func (r record) Export(module *types.Module, filter types.RecordFilter) (set typ } func (r record) buildQuery(module *types.Module, f types.RecordFilter) (query squirrel.SelectBuilder, err error) { - // Create query for fetching and counting records. - query = r.query(). - Where("r.module_id = ?", module.ID). - Where("r.rel_namespace = ?", module.NamespaceID) - - var joinedFields = []string{} - var alreadyJoined = func(f string) bool { - for _, a := range joinedFields { - if a == f { - return true + var ( + joinedFields = []string{} + alreadyJoined = func(f string) bool { + for _, a := range joinedFields { + if a == f { + return true + } } + + joinedFields = append(joinedFields, f) + return false } - joinedFields = append(joinedFields, f) - return false - } - - // Inc/exclude deleted records according to filter settings - query = rh.FilterNullByState(query, "r.deleted_at", f.Deleted) - - // Parse filters. - if f.Query != "" { - var ( - // Filter parser - fp = ql.NewParser() - - // Filter node - fn ql.ASTNode - ) - - // Make a nice wrapper that will translate module fields to subqueries - fp.OnIdent = func(i ql.Ident) (ql.Ident, error) { + identResolver = func(i ql.Ident) (ql.Ident, error) { var is bool if i.Value, is = isRealRecordCol(i.Value); is { + i.Value += " " return i, nil } @@ -203,11 +186,46 @@ func (r record) buildQuery(module *types.Module, f types.RecordFilter) (query sq ), i.Value) } - // @todo switch value for ref when doing Record/Owner lookup - i.Value = fmt.Sprintf("rv_%s.value", i.Value) + field := module.Fields.FindByName(i.Value) + + switch true { + case field.IsBoolean(): + i.Value = fmt.Sprintf("(rv_%s.value NOT IN ('', '0', 'false', 'f', 'FALSE', 'F', false))", i.Value) + case field.IsNumeric(): + i.Value = fmt.Sprintf("CAST(rv_%s.value AS SIGNED)", i.Value) + case field.IsDateTime(): + i.Value = fmt.Sprintf("CAST(rv_%s.value AS DATETIME)", i.Value) + case field.IsRef(): + i.Value = fmt.Sprintf("rv_%s.ref ", i.Value) + default: + i.Value = fmt.Sprintf("rv_%s.value ", i.Value) + } return i, nil } + ) + + // Create query for fetching and counting records. + query = r.query(). + Where("r.module_id = ?", module.ID). + Where("r.rel_namespace = ?", module.NamespaceID) + + // Inc/exclude deleted records according to filter settings + query = rh.FilterNullByState(query, "r.deleted_at", f.Deleted) + + // Parse filters. + if f.Query != "" { + var ( + // Filter parser + fp = ql.NewParser() + + // Filter node + fn ql.ASTNode + ) + + // Resolve all identifiers found in the query + // into their table/column counterparts + fp.OnIdent = identResolver if fn, err = fp.ParseExpression(f.Query); err != nil { return @@ -227,43 +245,12 @@ func (r record) buildQuery(module *types.Module, f types.RecordFilter) (query sq sc ql.Columns ) - sp.OnIdent = func(i ql.Ident) (ql.Ident, error) { - var is bool - if i.Value, is = isRealRecordCol(i.Value); is { - i.Value += " " - return i, nil - } - - if !module.Fields.HasName(i.Value) { - return i, errors.Errorf("unknown field %q", i.Value) - } - - field := module.Fields.FindByName(i.Value) - - if !alreadyJoined(i.Value) { - query = query.LeftJoin(fmt.Sprintf( - "compose_record_value AS rv_%s ON (rv_%s.record_id = r.id AND rv_%s.name = ? AND rv_%s.deleted_at IS NULL)", - i.Value, i.Value, i.Value, i.Value, - ), i.Value) - } - - switch true { - case field.IsRef(): - i.Value = fmt.Sprintf("rv_%s.ref ", i.Value) - case field.IsNumeric(): - i.Value = fmt.Sprintf("CAST(rv_%s.value AS SIGNED)", i.Value) - case field.IsDateTime(): - i.Value = fmt.Sprintf("CAST(rv_%s.value AS DATETIME)", i.Value) - default: - i.Value = fmt.Sprintf("rv_%s.value ", i.Value) - } - - return i, nil - } + // Resolve all identifiers found in sort + // into their table/column counterparts + sp.OnIdent = identResolver if sc, err = sp.ParseColumns(f.Sort); err != nil { return - } query = query.OrderBy(sc.Strings()...) diff --git a/compose/repository/record_test.go b/compose/repository/record_test.go index a5214084b..332bd23ff 100644 --- a/compose/repository/record_test.go +++ b/compose/repository/record_test.go @@ -19,10 +19,12 @@ func TestRecordFinder(t *testing.T) { Fields: types.ModuleFieldSet{ &types.ModuleField{Name: "foo"}, &types.ModuleField{Name: "bar"}, + &types.ModuleField{Name: "booly", Kind: "Bool"}, }, } ttc := []struct { + name string f types.RecordFilter match []string noMatch []string @@ -30,6 +32,7 @@ func TestRecordFinder(t *testing.T) { err error }{ { + name: "default filter", match: []string{ "SELECT r.id, r.module_id, r.rel_namespace, r.owned_by, r.created_at, " + "r.created_by, r.updated_at, r.updated_by, r.deleted_at, r.deleted_by " + @@ -38,14 +41,16 @@ func TestRecordFinder(t *testing.T) { }, }, { - f: types.RecordFilter{Query: "id = 5 AND foo = 7"}, + name: "simple query", + f: types.RecordFilter{Query: "id = 5 AND foo = 7"}, match: []string{ - "r.id = 5", - "rv_foo.value = 7"}, + "r.id = 5", + "rv_foo.value = 7"}, args: []interface{}{"foo"}, }, { - f: types.RecordFilter{Sort: "id ASC, bar DESC"}, + name: "sorting", + f: types.RecordFilter{Sort: "id ASC, bar DESC"}, match: []string{ " r.id ASC", " rv_bar.value DESC", @@ -53,45 +58,56 @@ func TestRecordFinder(t *testing.T) { args: []interface{}{"bar"}, }, { + name: "exclude deleted records (def. behaviour)", f: types.RecordFilter{Deleted: rh.FilterStateExcluded}, match: []string{" r.deleted_at IS "}, }, { + name: "include deleted records", f: types.RecordFilter{Deleted: rh.FilterStateInclusive}, - noMatch: []string{" r.deleted_at IS "}, + noMatch: []string{" r.deleted_at IS NULL "}, }, { + name: "only deleted record", f: types.RecordFilter{Deleted: rh.FilterStateExclusive}, match: []string{" r.deleted_at IS NOT NULL"}, }, + { + name: "boolean", + f: types.RecordFilter{Query: "booly"}, + match: []string{"(rv_booly.value NOT IN ("}, + args: []interface{}{"booly"}, + }, } for _, tc := range ttc { - sb, err := r.buildQuery(m, tc.f) + t.Run(tc.name, func(t *testing.T) { + sb, err := r.buildQuery(m, tc.f) - if tc.err != nil { - require.True(t, tc.err.Error() == fmt.Sprintf("%v", err), "buildQuery(%+v) did not return an expected error %q but %q", tc.f, tc.err, err) - } else { - require.True(t, err == nil, "buildQuery(%+v) returned an unexpected error: %v", tc.f, err) - } + if tc.err != nil { + require.True(t, tc.err.Error() == fmt.Sprintf("%v", err), "buildQuery(%+v) did not return an expected error %q but %q", tc.f, tc.err, err) + } else { + require.True(t, err == nil, "buildQuery(%+v) returned an unexpected error: %v", tc.f, err) + } - sql, args, err := sb.ToSql() + sql, args, err := sb.ToSql() - for _, m := range tc.match { - require.True(t, strings.Contains(sql, m), - "assertion failed; query %q \n "+ - " did not contain %q", sql, m) - } + for _, m := range tc.match { + require.True(t, strings.Contains(sql, m), + "assertion failed; query %q \n "+ + " did not contain %q", sql, m) + } - for _, m := range tc.noMatch { - require.False(t, strings.Contains(sql, m), - "assertion failed; query %q \n "+ - " must not contain %q", sql, m) - } + for _, m := range tc.noMatch { + require.False(t, strings.Contains(sql, m), + "assertion failed; query %q \n "+ + " must not contain %q", sql, m) + } - tc.args = append(tc.args, m.ID, m.NamespaceID) - require.True(t, fmt.Sprintf("%+v", args) == fmt.Sprintf("%+v", tc.args), - "assertion failed; args %+v \n "+ - " do not match expected %+v", args, tc.args) + tc.args = append(tc.args, m.ID, m.NamespaceID) + require.True(t, fmt.Sprintf("%+v", args) == fmt.Sprintf("%+v", tc.args), + "assertion failed; args %+v \n "+ + " do not match expected %+v", args, tc.args) + }) } } diff --git a/compose/types/module_field.go b/compose/types/module_field.go index 3014efcc0..3c6aca91b 100644 --- a/compose/types/module_field.go +++ b/compose/types/module_field.go @@ -106,6 +106,10 @@ func (set ModuleFieldSet) Swap(i, j int) { set[i], set[j] = set[j], set[i] } +func (f ModuleField) IsBoolean() bool { + return f.Kind == "Bool" +} + func (f ModuleField) IsNumeric() bool { return f.Kind == "Number" }