Refactor & improve record filtering, proper handling of bool fields
This commit is contained in:
parent
9f3a00c0e2
commit
c2f4765fb0
@ -155,40 +155,23 @@ func (r record) Export(module *types.Module, filter types.RecordFilter) (set typ
|
||||
}
|
||||
|
||||
func (r record) buildQuery(module *types.Module, f types.RecordFilter) (query squirrel.SelectBuilder, err error) {
|
||||
// Create query for fetching and counting records.
|
||||
query = r.query().
|
||||
Where("r.module_id = ?", module.ID).
|
||||
Where("r.rel_namespace = ?", module.NamespaceID)
|
||||
|
||||
var joinedFields = []string{}
|
||||
var alreadyJoined = func(f string) bool {
|
||||
for _, a := range joinedFields {
|
||||
if a == f {
|
||||
return true
|
||||
var (
|
||||
joinedFields = []string{}
|
||||
alreadyJoined = func(f string) bool {
|
||||
for _, a := range joinedFields {
|
||||
if a == f {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
joinedFields = append(joinedFields, f)
|
||||
return false
|
||||
}
|
||||
|
||||
joinedFields = append(joinedFields, f)
|
||||
return false
|
||||
}
|
||||
|
||||
// Inc/exclude deleted records according to filter settings
|
||||
query = rh.FilterNullByState(query, "r.deleted_at", f.Deleted)
|
||||
|
||||
// Parse filters.
|
||||
if f.Query != "" {
|
||||
var (
|
||||
// Filter parser
|
||||
fp = ql.NewParser()
|
||||
|
||||
// Filter node
|
||||
fn ql.ASTNode
|
||||
)
|
||||
|
||||
// Make a nice wrapper that will translate module fields to subqueries
|
||||
fp.OnIdent = func(i ql.Ident) (ql.Ident, error) {
|
||||
identResolver = func(i ql.Ident) (ql.Ident, error) {
|
||||
var is bool
|
||||
if i.Value, is = isRealRecordCol(i.Value); is {
|
||||
i.Value += " "
|
||||
return i, nil
|
||||
}
|
||||
|
||||
@ -203,11 +186,46 @@ func (r record) buildQuery(module *types.Module, f types.RecordFilter) (query sq
|
||||
), i.Value)
|
||||
}
|
||||
|
||||
// @todo switch value for ref when doing Record/Owner lookup
|
||||
i.Value = fmt.Sprintf("rv_%s.value", i.Value)
|
||||
field := module.Fields.FindByName(i.Value)
|
||||
|
||||
switch true {
|
||||
case field.IsBoolean():
|
||||
i.Value = fmt.Sprintf("(rv_%s.value NOT IN ('', '0', 'false', 'f', 'FALSE', 'F', false))", i.Value)
|
||||
case field.IsNumeric():
|
||||
i.Value = fmt.Sprintf("CAST(rv_%s.value AS SIGNED)", i.Value)
|
||||
case field.IsDateTime():
|
||||
i.Value = fmt.Sprintf("CAST(rv_%s.value AS DATETIME)", i.Value)
|
||||
case field.IsRef():
|
||||
i.Value = fmt.Sprintf("rv_%s.ref ", i.Value)
|
||||
default:
|
||||
i.Value = fmt.Sprintf("rv_%s.value ", i.Value)
|
||||
}
|
||||
|
||||
return i, nil
|
||||
}
|
||||
)
|
||||
|
||||
// Create query for fetching and counting records.
|
||||
query = r.query().
|
||||
Where("r.module_id = ?", module.ID).
|
||||
Where("r.rel_namespace = ?", module.NamespaceID)
|
||||
|
||||
// Inc/exclude deleted records according to filter settings
|
||||
query = rh.FilterNullByState(query, "r.deleted_at", f.Deleted)
|
||||
|
||||
// Parse filters.
|
||||
if f.Query != "" {
|
||||
var (
|
||||
// Filter parser
|
||||
fp = ql.NewParser()
|
||||
|
||||
// Filter node
|
||||
fn ql.ASTNode
|
||||
)
|
||||
|
||||
// Resolve all identifiers found in the query
|
||||
// into their table/column counterparts
|
||||
fp.OnIdent = identResolver
|
||||
|
||||
if fn, err = fp.ParseExpression(f.Query); err != nil {
|
||||
return
|
||||
@ -227,43 +245,12 @@ func (r record) buildQuery(module *types.Module, f types.RecordFilter) (query sq
|
||||
sc ql.Columns
|
||||
)
|
||||
|
||||
sp.OnIdent = func(i ql.Ident) (ql.Ident, error) {
|
||||
var is bool
|
||||
if i.Value, is = isRealRecordCol(i.Value); is {
|
||||
i.Value += " "
|
||||
return i, nil
|
||||
}
|
||||
|
||||
if !module.Fields.HasName(i.Value) {
|
||||
return i, errors.Errorf("unknown field %q", i.Value)
|
||||
}
|
||||
|
||||
field := module.Fields.FindByName(i.Value)
|
||||
|
||||
if !alreadyJoined(i.Value) {
|
||||
query = query.LeftJoin(fmt.Sprintf(
|
||||
"compose_record_value AS rv_%s ON (rv_%s.record_id = r.id AND rv_%s.name = ? AND rv_%s.deleted_at IS NULL)",
|
||||
i.Value, i.Value, i.Value, i.Value,
|
||||
), i.Value)
|
||||
}
|
||||
|
||||
switch true {
|
||||
case field.IsRef():
|
||||
i.Value = fmt.Sprintf("rv_%s.ref ", i.Value)
|
||||
case field.IsNumeric():
|
||||
i.Value = fmt.Sprintf("CAST(rv_%s.value AS SIGNED)", i.Value)
|
||||
case field.IsDateTime():
|
||||
i.Value = fmt.Sprintf("CAST(rv_%s.value AS DATETIME)", i.Value)
|
||||
default:
|
||||
i.Value = fmt.Sprintf("rv_%s.value ", i.Value)
|
||||
}
|
||||
|
||||
return i, nil
|
||||
}
|
||||
// Resolve all identifiers found in sort
|
||||
// into their table/column counterparts
|
||||
sp.OnIdent = identResolver
|
||||
|
||||
if sc, err = sp.ParseColumns(f.Sort); err != nil {
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
query = query.OrderBy(sc.Strings()...)
|
||||
|
||||
@ -19,10 +19,12 @@ func TestRecordFinder(t *testing.T) {
|
||||
Fields: types.ModuleFieldSet{
|
||||
&types.ModuleField{Name: "foo"},
|
||||
&types.ModuleField{Name: "bar"},
|
||||
&types.ModuleField{Name: "booly", Kind: "Bool"},
|
||||
},
|
||||
}
|
||||
|
||||
ttc := []struct {
|
||||
name string
|
||||
f types.RecordFilter
|
||||
match []string
|
||||
noMatch []string
|
||||
@ -30,6 +32,7 @@ func TestRecordFinder(t *testing.T) {
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "default filter",
|
||||
match: []string{
|
||||
"SELECT r.id, r.module_id, r.rel_namespace, r.owned_by, r.created_at, " +
|
||||
"r.created_by, r.updated_at, r.updated_by, r.deleted_at, r.deleted_by " +
|
||||
@ -38,14 +41,16 @@ func TestRecordFinder(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
f: types.RecordFilter{Query: "id = 5 AND foo = 7"},
|
||||
name: "simple query",
|
||||
f: types.RecordFilter{Query: "id = 5 AND foo = 7"},
|
||||
match: []string{
|
||||
"r.id = 5",
|
||||
"rv_foo.value = 7"},
|
||||
"r.id = 5",
|
||||
"rv_foo.value = 7"},
|
||||
args: []interface{}{"foo"},
|
||||
},
|
||||
{
|
||||
f: types.RecordFilter{Sort: "id ASC, bar DESC"},
|
||||
name: "sorting",
|
||||
f: types.RecordFilter{Sort: "id ASC, bar DESC"},
|
||||
match: []string{
|
||||
" r.id ASC",
|
||||
" rv_bar.value DESC",
|
||||
@ -53,45 +58,56 @@ func TestRecordFinder(t *testing.T) {
|
||||
args: []interface{}{"bar"},
|
||||
},
|
||||
{
|
||||
name: "exclude deleted records (def. behaviour)",
|
||||
f: types.RecordFilter{Deleted: rh.FilterStateExcluded},
|
||||
match: []string{" r.deleted_at IS "},
|
||||
},
|
||||
{
|
||||
name: "include deleted records",
|
||||
f: types.RecordFilter{Deleted: rh.FilterStateInclusive},
|
||||
noMatch: []string{" r.deleted_at IS "},
|
||||
noMatch: []string{" r.deleted_at IS NULL "},
|
||||
},
|
||||
{
|
||||
name: "only deleted record",
|
||||
f: types.RecordFilter{Deleted: rh.FilterStateExclusive},
|
||||
match: []string{" r.deleted_at IS NOT NULL"},
|
||||
},
|
||||
{
|
||||
name: "boolean",
|
||||
f: types.RecordFilter{Query: "booly"},
|
||||
match: []string{"(rv_booly.value NOT IN ("},
|
||||
args: []interface{}{"booly"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range ttc {
|
||||
sb, err := r.buildQuery(m, tc.f)
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
sb, err := r.buildQuery(m, tc.f)
|
||||
|
||||
if tc.err != nil {
|
||||
require.True(t, tc.err.Error() == fmt.Sprintf("%v", err), "buildQuery(%+v) did not return an expected error %q but %q", tc.f, tc.err, err)
|
||||
} else {
|
||||
require.True(t, err == nil, "buildQuery(%+v) returned an unexpected error: %v", tc.f, err)
|
||||
}
|
||||
if tc.err != nil {
|
||||
require.True(t, tc.err.Error() == fmt.Sprintf("%v", err), "buildQuery(%+v) did not return an expected error %q but %q", tc.f, tc.err, err)
|
||||
} else {
|
||||
require.True(t, err == nil, "buildQuery(%+v) returned an unexpected error: %v", tc.f, err)
|
||||
}
|
||||
|
||||
sql, args, err := sb.ToSql()
|
||||
sql, args, err := sb.ToSql()
|
||||
|
||||
for _, m := range tc.match {
|
||||
require.True(t, strings.Contains(sql, m),
|
||||
"assertion failed; query %q \n "+
|
||||
" did not contain %q", sql, m)
|
||||
}
|
||||
for _, m := range tc.match {
|
||||
require.True(t, strings.Contains(sql, m),
|
||||
"assertion failed; query %q \n "+
|
||||
" did not contain %q", sql, m)
|
||||
}
|
||||
|
||||
for _, m := range tc.noMatch {
|
||||
require.False(t, strings.Contains(sql, m),
|
||||
"assertion failed; query %q \n "+
|
||||
" must not contain %q", sql, m)
|
||||
}
|
||||
for _, m := range tc.noMatch {
|
||||
require.False(t, strings.Contains(sql, m),
|
||||
"assertion failed; query %q \n "+
|
||||
" must not contain %q", sql, m)
|
||||
}
|
||||
|
||||
tc.args = append(tc.args, m.ID, m.NamespaceID)
|
||||
require.True(t, fmt.Sprintf("%+v", args) == fmt.Sprintf("%+v", tc.args),
|
||||
"assertion failed; args %+v \n "+
|
||||
" do not match expected %+v", args, tc.args)
|
||||
tc.args = append(tc.args, m.ID, m.NamespaceID)
|
||||
require.True(t, fmt.Sprintf("%+v", args) == fmt.Sprintf("%+v", tc.args),
|
||||
"assertion failed; args %+v \n "+
|
||||
" do not match expected %+v", args, tc.args)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,6 +106,10 @@ func (set ModuleFieldSet) Swap(i, j int) {
|
||||
set[i], set[j] = set[j], set[i]
|
||||
}
|
||||
|
||||
func (f ModuleField) IsBoolean() bool {
|
||||
return f.Kind == "Bool"
|
||||
}
|
||||
|
||||
func (f ModuleField) IsNumeric() bool {
|
||||
return f.Kind == "Number"
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user