diff --git a/store/rdbms/messaging_channel_members.go b/store/rdbms/messaging_channel_members.go index 1d0dc8a9b..edc4fc827 100644 --- a/store/rdbms/messaging_channel_members.go +++ b/store/rdbms/messaging_channel_members.go @@ -5,6 +5,7 @@ import ( "github.com/cortezaproject/corteza-server/messaging/types" "sort" "strconv" + "strings" ) func (s Store) convertMessagingChannelMemberFilter(f types.ChannelMemberFilter) (query squirrel.SelectBuilder, err error) { @@ -21,30 +22,52 @@ func (s Store) convertMessagingChannelMemberFilter(f types.ChannelMemberFilter) return query, nil } -func (s Store) getMessagingChannelMembersQuery(memberIDs ...uint64) squirrel.SelectBuilder { +// perfectly, we would be able to solve this per-rdbms implementation with query directly +func (s Store) getMessagingChannelMembersQuery(cnd squirrel.Sqlizer, memberIDs ...uint64) squirrel.Sqlizer { if len(memberIDs) == 0 { return squirrel. Select("null") } - // Make sure members are sorted - sort.Slice(memberIDs, func(i, j int) bool { - return memberIDs[i] < memberIDs[j] - }) + if strings.HasPrefix(s.config.DriverName, "mysql") { + // Make sure members are sorted + sort.Slice(memberIDs, func(i, j int) bool { + return memberIDs[i] < memberIDs[j] + }) - // Concatenating members fore - list := "" - for i := range memberIDs { - if i > 0 { - list += "," + // Concatenating members fore + list := "" + for i := range memberIDs { + if i > 0 { + list += "," + } + list += strconv.FormatUint(memberIDs[i], 10) } - list += strconv.FormatUint(memberIDs[i], 10) + + return s.SelectBuilder(s.messagingChannelMemberTable("mcm"), "mcm.rel_channel"). + GroupBy("mcm.rel_channel"). + Having(squirrel.Eq{ + "COUNT(*)": len(memberIDs), + "GROUP_CONCAT(mcm.rel_user ORDER BY 1 ASC SEPARATOR ',')": list, + }) } - return s.SelectBuilder(s.messagingChannelMemberTable("mcm"), "mcm.rel_channel"). - GroupBy("mcm.rel_channel"). - Having(squirrel.Eq{ - "COUNT(*)": len(memberIDs), - "GROUP_CONCAT(mcm.rel_user ORDER BY 1 ASC SEPARATOR ',')": list, - }) + var ( + base = s.SelectBuilder(s.messagingChannelMemberTable("mcm"), "mcm.rel_channel"). + PlaceholderFormat(squirrel.Question) + + // construct SQLs with fitting number of members + counter = base. + Where(cnd). + GroupBy("mcm.rel_channel"). + Having(squirrel.Eq{"COUNT(*)": len(memberIDs)}) + ) + + for _, memberID := range memberIDs { + sql, args, _ := base.Where(squirrel.Eq{"mcm.rel_user": memberID}).ToSql() + counter = counter.Suffix(" INTERSECT "+sql, args...) + } + + return counter + } diff --git a/store/rdbms/messaging_channels.go b/store/rdbms/messaging_channels.go index 609a77beb..3ecb42d22 100644 --- a/store/rdbms/messaging_channels.go +++ b/store/rdbms/messaging_channels.go @@ -11,8 +11,7 @@ import ( // CountReplies counts unread thread info func (s Store) LookupMessagingChannelByMemberSet(ctx context.Context, memberIDs ...uint64) (ch *types.Channel, err error) { // prepare subquery that merges - mcmq := s.getMessagingChannelMembersQuery(memberIDs...). - Where("mch.id = mcm.rel_channel") + mcmq := s.getMessagingChannelMembersQuery(squirrel.Expr("mch.id = mcm.rel_channel"), memberIDs...) if sql, args, err := mcmq.ToSql(); err != nil { return nil, err diff --git a/store/tests/messaging_channel_members_test.go b/store/tests/messaging_channel_members_test.go index 8325a1c95..07bf7c41a 100644 --- a/store/tests/messaging_channel_members_test.go +++ b/store/tests/messaging_channel_members_test.go @@ -50,14 +50,14 @@ func testMessagingChannelMembers(t *testing.T, s store.MessagingChannelMembers) t.Run("update", func(t *testing.T) { req, messagingChannelMember := truncAndCreate(t) messagingChannelMember.Type = types.ChannelMembershipType("member") - + req.NoError(s.UpdateMessagingChannelMember(ctx, messagingChannelMember)) - + set, _, err := s.SearchMessagingChannelMembers(ctx, types.ChannelMemberFilter{ChannelID: []uint64{messagingChannelMember.ChannelID}, MemberID: []uint64{messagingChannelMember.UserID}}) req.NoError(err) req.Equal(types.ChannelMembershipType("member"), set[0].Type) }) - + t.Run("upsert", func(t *testing.T) { t.Run("existing", func(t *testing.T) { req, messagingChannelMember := truncAndCreate(t) diff --git a/store/tests/messaging_channels_test.go b/store/tests/messaging_channels_test.go index 6463ebfd3..e5556bbe4 100644 --- a/store/tests/messaging_channels_test.go +++ b/store/tests/messaging_channels_test.go @@ -11,7 +11,7 @@ import ( "time" ) -func testMessagingChannels(t *testing.T, s store.MessagingChannels) { +func testMessagingChannels(t *testing.T, s store.Storer) { var ( ctx = context.Background() req = require.New(t) @@ -124,4 +124,44 @@ func testMessagingChannels(t *testing.T, s store.MessagingChannels) { _ = f // dummy }) + + t.Run("lookup by member set", func(t *testing.T) { + var ( + req = require.New(t) + + ch1, ch2, ch3, ch4 = makeNew("one"), makeNew("two"), makeNew("three"), makeNew("four") + ) + + ch1.Type = types.ChannelTypeGroup + ch2.Type = types.ChannelTypeGroup + ch3.Type = types.ChannelTypeGroup + ch4.Type = types.ChannelTypeGroup + + req.NoError(store.TruncateMessagingChannels(ctx, s)) + req.NoError(store.TruncateMessagingChannelMembers(ctx, s)) + req.NoError(store.CreateMessagingChannel(ctx, s, ch1, ch2, ch3)) + req.NoError(store.CreateMessagingChannelMember(ctx, s, + // fits + &types.ChannelMember{CreatedAt: time.Time{}, ChannelID: ch1.ID, UserID: 1000}, + &types.ChannelMember{CreatedAt: time.Time{}, ChannelID: ch1.ID, UserID: 2000}, + + // one to many + &types.ChannelMember{CreatedAt: time.Time{}, ChannelID: ch2.ID, UserID: 1000}, + &types.ChannelMember{CreatedAt: time.Time{}, ChannelID: ch2.ID, UserID: 2000}, + &types.ChannelMember{CreatedAt: time.Time{}, ChannelID: ch2.ID, UserID: 3000}, + + // no diff + &types.ChannelMember{CreatedAt: time.Time{}, ChannelID: ch3.ID, UserID: 1000}, + &types.ChannelMember{CreatedAt: time.Time{}, ChannelID: ch3.ID, UserID: 5000}, + + // one only + &types.ChannelMember{CreatedAt: time.Time{}, ChannelID: ch4.ID, UserID: 1000}, + &types.ChannelMember{CreatedAt: time.Time{}, ChannelID: ch4.ID, UserID: 4000}, + &types.ChannelMember{CreatedAt: time.Time{}, ChannelID: ch4.ID, UserID: 5000}, + )) + + ch, err := s.LookupMessagingChannelByMemberSet(ctx, 1000, 2000) + req.NoError(err) + req.Equal(ch.ID, ch1.ID) + }) }