diff --git a/pkg/dal/exec_join_left_test.go b/pkg/dal/exec_join_left_test.go index 235ef92f4..e944f441e 100644 --- a/pkg/dal/exec_join_left_test.go +++ b/pkg/dal/exec_join_left_test.go @@ -1392,3 +1392,123 @@ func TestStepJoin_paging(t *testing.T) { }) } } + +func TestStepJoin_multiValueFields(t *testing.T) { + basicLocalAttrs := []simpleAttribute{ + {ident: "l_pk", t: TypeID{}}, + {ident: "l_ref", t: TypeID{}}, + {ident: "l_val", t: TypeText{}}, + } + basicForeignAttrs := []simpleAttribute{ + {ident: "f_pk", t: TypeID{}}, + {ident: "f_fk", t: TypeRef{}}, + {ident: "f_val", t: TypeText{}}, + } + + basicAttrs := append(basicLocalAttrs, basicForeignAttrs...) + + tcc := []struct { + name string + + outAttributes []simpleAttribute + leftAttributes []simpleAttribute + rightAttributes []simpleAttribute + joinPred JoinPredicate + + lIn []*Row + fIn []*Row + out []*Row + }{ + { + name: "multiple left keys", + + outAttributes: basicAttrs, + leftAttributes: basicLocalAttrs, + rightAttributes: basicForeignAttrs, + joinPred: JoinPredicate{Left: "l_ref", Right: "f_fk"}, + + lIn: []*Row{ + (&Row{}). + WithValue("l_pk", 0, 1). + WithValue("l_ref", 0, 1). + WithValue("l_ref", 1, 2). + WithValue("l_val", 0, "l1 v1"), + }, + fIn: []*Row{ + (&Row{}). + WithValue("f_pk", 0, 1). + WithValue("f_fk", 0, 1). + WithValue("f_val", 0, "f1 v1"), + + (&Row{}). + WithValue("f_pk", 0, 2). + WithValue("f_fk", 0, 2). + WithValue("f_val", 0, "f1 v2"), + }, + + out: []*Row{ + (&Row{}). + WithValue("l_pk", 0, 1). + WithValue("l_ref", 0, 1). + WithValue("l_ref", 1, 2). + WithValue("l_val", 0, "l1 v1"). + // ... + WithValue("f_pk", 0, 1). + WithValue("f_fk", 0, 1). + WithValue("f_val", 0, "f1 v1"), + + (&Row{}). + WithValue("l_pk", 0, 1). + WithValue("l_ref", 0, 1). + WithValue("l_ref", 1, 2). + WithValue("l_val", 0, "l1 v1"). + // ... + WithValue("f_pk", 0, 2). + WithValue("f_fk", 0, 2). + WithValue("f_val", 0, "f1 v2"), + }, + }, + } + + ctx := context.Background() + for _, tc := range tcc { + t.Run(tc.name, func(t *testing.T) { + l := InMemoryBuffer() + for _, r := range tc.lIn { + require.NoError(t, l.Add(ctx, r)) + } + + f := InMemoryBuffer() + for _, r := range tc.fIn { + require.NoError(t, f.Add(ctx, r)) + } + + def := Join{ + Ident: "foo", + On: tc.joinPred, + OutAttributes: saToMapping(tc.outAttributes...), + LeftAttributes: saToMapping(tc.leftAttributes...), + RightAttributes: saToMapping(tc.rightAttributes...), + Filter: filter.Generic(filter.WithOrderBy(filter.SortExprSet{{Column: "l_pk"}, {Column: "f_pk"}})), + + plan: joinPlan{}, + } + + xs, err := def.iterator(ctx, l, f) + require.NoError(t, err) + + i := 0 + for xs.Next(ctx) { + require.NoError(t, xs.Err()) + out := &Row{} + require.NoError(t, xs.Err()) + require.NoError(t, xs.Scan(out)) + + require.Equal(t, tc.out[i], out) + + i++ + } + require.Equal(t, len(tc.out), i) + }) + } +}