3
0

Channel membership improved

This commit is contained in:
Denis Arh
2018-10-07 18:42:54 +02:00
parent c3457807e7
commit 910b5589dd
14 changed files with 322 additions and 113 deletions

View File

@@ -28,7 +28,7 @@ type (
With(ctx context.Context) UserService
FindByID(id uint64) (*types.User, error)
Find(filter *types.UserFilter) ([]*types.User, error)
Find(filter *types.UserFilter) (types.UserSet, error)
Create(input *types.User) (*types.User, error)
Update(mod *types.User) (*types.User, error)
@@ -73,7 +73,7 @@ func (svc *user) FindByID(id uint64) (*types.User, error) {
return svc.user.FindUserByID(id)
}
func (svc *user) Find(filter *types.UserFilter) ([]*types.User, error) {
func (svc *user) Find(filter *types.UserFilter) (types.UserSet, error) {
return svc.user.FindUsers(filter)
}

View File

@@ -62,3 +62,13 @@ func (uu UserSet) Walk(w func(*User) error) (err error) {
return
}
func (uu UserSet) FindById(ID uint64) *User {
for i := range uu {
if uu[i].ID == ID {
return uu[i]
}
}
return nil
}

View File

@@ -65,6 +65,24 @@ func Channels(channels sam.ChannelSet) *outgoing.ChannelSet {
return &retval
}
func ChannelMember(m *sam.ChannelMember) *outgoing.ChannelMember {
return &outgoing.ChannelMember{
User: User(m.User),
Type: string(m.Type),
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func ChannelMembers(members sam.ChannelMemberSet) *outgoing.ChannelMemberSet {
mm := make([]*outgoing.ChannelMember, len(members))
for k, c := range members {
mm[k] = ChannelMember(c)
}
retval := outgoing.ChannelMemberSet(mm)
return &retval
}
func User(user *auth.User) *outgoing.User {
if user == nil {
return nil

View File

@@ -0,0 +1,26 @@
package outgoing
import (
"encoding/json"
"time"
)
type (
ChannelMember struct {
// Channel to part (nil) for ALL channels
User *User `json:"user"`
Type string `json:"type""`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt *time.Time `json:"updatedAt,omitempty"`
}
ChannelMemberSet []*ChannelMember
)
func (p *ChannelMember) EncodeMessage() ([]byte, error) {
return json.Marshal(Payload{ChannelMember: p})
}
func (p *ChannelMemberSet) EncodeMessage() ([]byte, error) {
return json.Marshal(Payload{ChannelMemberSet: p})
}

View File

@@ -16,6 +16,9 @@ type (
*Channel `json:"channel,omitempty"`
*ChannelSet `json:"channels,omitempty"`
*ChannelMember `json:"channelMember,omitempty"`
*ChannelMemberSet `json:"channelMembers,omitempty"`
*User `json:"user,omitempty"`
*UserSet `json:"users,omitempty"`

View File

@@ -16,13 +16,9 @@ type (
FindChannelByID(id uint64) (*types.Channel, error)
FindDirectChannelByUserID(fromUserID, toUserID uint64) (*types.Channel, error)
FindChannels(filter *types.ChannelFilter) ([]*types.Channel, error)
FindMembers(userID uint64) (types.ChannelMemberSet, error)
CreateChannel(mod *types.Channel) (*types.Channel, error)
UpdateChannel(mod *types.Channel) (*types.Channel, error)
FindChannelsMembershipsByMemberId(memberId uint64) ([]*types.ChannelMember, error)
AddChannelMember(mod *types.ChannelMember) (*types.ChannelMember, error)
RemoveChannelMember(channelID, userID uint64) error
ArchiveChannelByID(id uint64) error
UnarchiveChannelByID(id uint64) error
DeleteChannelByID(id uint64) error
@@ -42,11 +38,7 @@ const (
FROM channels AS c
WHERE ` + sqlChannelValidOnly
sqlChannelMemberSelect = `SELECT m.*
FROM channel_members AS m
INNER JOIN channels AS c ON (m.rel_channel = c.id)
WHERE ` + sqlChannelValidOnly
// Returns channel (group) with exactly 2 members
sqlChannelDirect = `SELECT *
FROM channels AS c
WHERE c.type = ?
@@ -57,10 +49,6 @@ const (
AND MIN(rel_user) = ?
AND MAX(rel_user) = ?)`
sqlChannelMemberships = `SELECT *
FROM channel_members AS cm
WHERE true`
// subquery that filters out all channels that current user has access to as a member
// or via channel type (public chans)
sqlChannelAccess = ` (
@@ -109,7 +97,7 @@ func (r *channel) FindDirectChannelByUserID(fromUserID, toUserID uint64) (*types
toUserID, fromUserID = fromUserID, toUserID
}
return mod, isFound(r.db().Get(mod, sqlChannelDirect, types.ChannelTypeDirect, fromUserID, toUserID), mod.ID > 0, ErrChannelNotFound)
return mod, isFound(r.db().Get(mod, sqlChannelDirect, types.ChannelTypeGroup, fromUserID, toUserID), mod.ID > 0, ErrChannelNotFound)
}
func (r *channel) FindChannels(filter *types.ChannelFilter) ([]*types.Channel, error) {
@@ -138,22 +126,6 @@ func (r *channel) FindChannels(filter *types.ChannelFilter) ([]*types.Channel, e
return rval, r.db().Select(&rval, sql, params...)
}
// Returns member ids of all channels that user has access to
func (r *channel) FindMembers(userID uint64) (types.ChannelMemberSet, error) {
params := make([]interface{}, 0)
rval := types.ChannelMemberSet{}
sql := sqlChannelMemberSelect
if userID > 0 {
// scope: only channels we have access to
sql += " AND m.rel_channel IN " + sqlChannelAccess
params = append(params, userID, types.ChannelTypePublic)
}
return rval, r.db().Select(&rval, sql, params...)
}
func (r *channel) CreateChannel(mod *types.Channel) (*types.Channel, error) {
mod.ID = factory.Sonyflake.NextID()
mod.CreatedAt = time.Now()
@@ -180,24 +152,6 @@ func (r *channel) UpdateChannel(mod *types.Channel) (*types.Channel, error) {
UpdatePartial("channels", mod, whitelist, "id")
}
func (r *channel) FindChannelsMembershipsByMemberId(memberId uint64) ([]*types.ChannelMember, error) {
var rval = make([]*types.ChannelMember, 0)
return rval, r.db().Select(&rval, sqlChannelMemberships+" AND cm.rel_user = ? ", memberId)
}
func (r *channel) AddChannelMember(mod *types.ChannelMember) (*types.ChannelMember, error) {
sql := `INSERT INTO channel_members (rel_channel, rel_user) VALUES (?, ?)`
mod.CreatedAt = time.Now()
return mod, exec(r.db().Exec(sql, mod.ChannelID, mod.UserID))
}
func (r *channel) RemoveChannelMember(channelID, userID uint64) error {
sql := `DELETE FROM channel_members WHERE rel_channel = ? AND rel_user = ?`
return exec(r.db().Exec(sql, channelID, userID))
}
func (r *channel) ArchiveChannelByID(id uint64) error {
return r.updateColumnByID("channels", "archived_at", time.Now(), id)
}

View File

@@ -0,0 +1,114 @@
package repository
import (
"context"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/titpetric/factory"
"github.com/crusttech/crust/sam/types"
)
type (
// ChannelMemberRepository interface to channel member repository
ChannelMemberRepository interface {
With(ctx context.Context, db *factory.DB) ChannelMemberRepository
Find(filter *types.ChannelMemberFilter) (types.ChannelMemberSet, error)
Create(mod *types.ChannelMember) (*types.ChannelMember, error)
Delete(channelMemberID, userID uint64) error
}
channelMember struct {
*repository
}
)
const (
// Copy definitions to make it more obvious that we're reusing channel-scope sql
sqlChannelMemberChannelValidOnly = sqlChannelValidOnly
// Copy definitions to make it more obvious that we're reusing channel-scope sql
sqlChannelMemberChannelAccess = sqlChannelAccess
// Fetching channel members of all channels a specific user has access to
sqlChannelMemberSelect = `SELECT m.*
FROM channel_members AS m
INNER JOIN channels AS c ON (m.rel_channel = c.id)
WHERE ` + sqlChannelMemberChannelValidOnly
// Selects all user's memberships
sqlChannelMemberships = `SELECT *
FROM channel_members AS cm
WHERE true`
)
// ChannelMember creates new instance of channel member repository
func ChannelMember(ctx context.Context, db *factory.DB) ChannelMemberRepository {
return (&channelMember{}).With(ctx, db)
}
// With context...
func (r *channelMember) With(ctx context.Context, db *factory.DB) ChannelMemberRepository {
return &channelMember{
repository: r.repository.With(ctx, db),
}
}
// FindMembers fetches membership info
//
// If channelID > 0 it returns members of a specific channel
// If userID > 0 it returns members of all channels this user is member of
func (r *channelMember) Find(filter *types.ChannelMemberFilter) (types.ChannelMemberSet, error) {
params := make([]interface{}, 0)
mm := types.ChannelMemberSet{}
sql := sqlChannelMemberSelect
if filter != nil {
if filter.ComembersOf > 0 {
// scope: only channel we have access to
sql += " AND m.rel_channel IN " + sqlChannelMemberChannelAccess
params = append(params, filter.ComembersOf, types.ChannelTypePublic)
}
if filter.MemberID > 0 {
sql += " AND m.rel_user = ?"
params = append(params, filter.MemberID)
}
if filter.ChannelID > 0 {
sql += " AND m.rel_channel = ?"
params = append(params, filter.ChannelID)
}
}
spew.Dump(filter, sql, params)
return mm, r.db().Select(&mm, sql, params...)
}
// Create adds channel membership record
func (r *channelMember) Create(mod *types.ChannelMember) (*types.ChannelMember, error) {
mod.CreatedAt = time.Now()
mod.UpdatedAt = nil
return mod, r.db().Insert("channel_members", mod)
}
// Update modifies existing channel membership record
func (r *channelMember) Update(mod *types.ChannelMember) (*types.ChannelMember, error) {
mod.UpdatedAt = timeNowPtr()
whitelist := []string{"type", "updated_at"}
return mod, r.db().UpdatePartial("channel_members", mod, whitelist, "rel_channel", "rel_user")
}
// Delete removes existing channel membership record
func (r *channelMember) Delete(channelMemberID, userID uint64) error {
sql := `DELETE FROM channel_members WHERE rel_channelMember = ? AND rel_user = ?`
return exec(r.db().Exec(sql, channelMemberID, userID))
}

View File

@@ -63,7 +63,7 @@ func (ctrl *Channel) List(ctx context.Context, r *request.ChannelList) (interfac
}
func (ctrl *Channel) Members(ctx context.Context, r *request.ChannelMembers) (interface{}, error) {
return nil, nil
return ctrl.wrapMemberSet(ctrl.svc.ch.With(ctx).FindMembers(r.ChannelID))
}
func (ctrl *Channel) Join(ctx context.Context, r *request.ChannelJoin) (interface{}, error) {
@@ -116,3 +116,11 @@ func (ctrl *Channel) wrapSet(cc types.ChannelSet, err error) (*outgoing.ChannelS
return payload.Channels(cc), nil
}
}
func (ctrl *Channel) wrapMemberSet(mm types.ChannelMemberSet, err error) (*outgoing.ChannelMemberSet, error) {
if err != nil {
return nil, err
} else {
return payload.ChannelMembers(mm), nil
}
}

View File

@@ -73,18 +73,6 @@ func (mr *MockAttachmentServiceMockRecorder) Create(channelId, name, size, fh in
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockAttachmentService)(nil).Create), channelId, name, size, fh)
}
// LoadFromMessages mocks base method
func (m *MockAttachmentService) LoadFromMessages(mm types.MessageSet) error {
ret := m.ctrl.Call(m, "LoadFromMessages", mm)
ret0, _ := ret[0].(error)
return ret0
}
// LoadFromMessages indicates an expected call of LoadFromMessages
func (mr *MockAttachmentServiceMockRecorder) LoadFromMessages(mm interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadFromMessages", reflect.TypeOf((*MockAttachmentService)(nil).LoadFromMessages), mm)
}
// OpenOriginal mocks base method
func (m *MockAttachmentService) OpenOriginal(att *types.Attachment) (io.ReadSeeker, error) {
ret := m.ctrl.Call(m, "OpenOriginal", att)

View File

@@ -4,10 +4,11 @@ import (
"context"
"fmt"
"github.com/crusttech/crust/internal/auth"
"github.com/pkg/errors"
"github.com/titpetric/factory"
"github.com/crusttech/crust/internal/auth"
authService "github.com/crusttech/crust/auth/service"
"github.com/crusttech/crust/sam/repository"
"github.com/crusttech/crust/sam/types"
@@ -22,6 +23,7 @@ type (
evl EventService
channel repository.ChannelRepository
cmember repository.ChannelMemberRepository
message repository.MessageRepository
}
@@ -31,6 +33,7 @@ type (
FindByID(channelID uint64) (*types.Channel, error)
Find(filter *types.ChannelFilter) (types.ChannelSet, error)
FindByMembership() (rval []*types.Channel, err error)
FindMembers(channelID uint64) (types.ChannelMemberSet, error)
Create(channel *types.Channel) (*types.Channel, error)
Update(channel *types.Channel) (*types.Channel, error)
@@ -62,6 +65,7 @@ func (svc *channel) With(ctx context.Context) ChannelService {
evl: svc.evl.With(ctx),
channel: repository.Channel(ctx, db),
cmember: repository.ChannelMember(ctx, db),
message: repository.Message(ctx, db),
}
}
@@ -92,7 +96,7 @@ func (svc *channel) Find(filter *types.ChannelFilter) (types.ChannelSet, error)
func (svc *channel) preloadMembers(cc types.ChannelSet) error {
var userID = auth.GetIdentityFromContext(svc.ctx).Identity()
if mm, err := svc.channel.FindMembers(userID); err != nil {
if mm, err := svc.cmember.Find(&types.ChannelMemberFilter{ComembersOf: userID}); err != nil {
return err
} else {
cc.Walk(func(ch *types.Channel) error {
@@ -104,13 +108,39 @@ func (svc *channel) preloadMembers(cc types.ChannelSet) error {
return nil
}
// FindMembers loads all members (and full users) for a specific channel
func (svc *channel) FindMembers(channelID uint64) (out types.ChannelMemberSet, err error) {
var userID = auth.GetIdentityFromContext(svc.ctx).Identity()
// @todo [SECURITY] check if we can return members on this channel
_ = channelID
_ = userID
return out, svc.db.Transaction(func() (err error) {
out, err = svc.cmember.Find(&types.ChannelMemberFilter{ChannelID: channelID})
if err != nil {
return err
}
if uu, err := svc.usr.Find(nil); err != nil {
return err
} else {
return out.Walk(func(member *types.ChannelMember) error {
member.User = uu.FindById(member.UserID)
return nil
})
}
})
}
// Returns all channels with membership info
func (svc *channel) FindByMembership() (rval []*types.Channel, err error) {
return rval, svc.db.Transaction(func() error {
var chMemberId = repository.Identity(svc.ctx)
var mm []*types.ChannelMember
if mm, err = svc.channel.FindChannelsMembershipsByMemberId(chMemberId); err != nil {
if mm, err = svc.cmember.Find(&types.ChannelMemberFilter{MemberID: chMemberId}); err != nil {
return err
}
@@ -176,7 +206,7 @@ func (svc *channel) Create(in *types.Channel) (out *types.Channel, err error) {
}
// Join current user as an member & owner
_, err = svc.channel.AddChannelMember(&types.ChannelMember{
_, err = svc.cmember.Create(&types.ChannelMember{
ChannelID: out.ID,
UserID: chCreatorID,
Type: types.ChannelMembershipTypeOwner,
@@ -412,7 +442,29 @@ func (svc *channel) Unarchive(id uint64) error {
return svc.evl.Channel(ch)
})
}
func (svc *channel) AddMember(m *types.ChannelMember) (out *types.ChannelMember, err error) {
return out, svc.db.Transaction(func() (err error) {
var userID = repository.Identity(svc.ctx)
var ch *types.Channel
// @todo [SECURITY] can user access this channel?
if ch, err = svc.channel.FindChannelByID(m.ChannelID); err != nil {
return
}
// @todo [SECURITY] can user add members to this channel?
msg, err := svc.message.CreateMessage(svc.makeSystemMessage(ch, "@%d added a new member to this channel: @%d", userID, m.UserID))
if err != nil {
return
}
svc.sendMessageEvent(msg)
return err
})
}
func (svc *channel) makeSystemMessage(ch *types.Channel, format string, a ...interface{}) *types.Message {

View File

@@ -60,9 +60,9 @@ func (mr *MockChannelServiceMockRecorder) FindByID(channelID interface{}) *gomoc
}
// Find mocks base method
func (m *MockChannelService) Find(filter *types.ChannelFilter) ([]*types.Channel, error) {
func (m *MockChannelService) Find(filter *types.ChannelFilter) (types.ChannelSet, error) {
ret := m.ctrl.Call(m, "Find", filter)
ret0, _ := ret[0].([]*types.Channel)
ret0, _ := ret[0].(types.ChannelSet)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -85,6 +85,19 @@ func (mr *MockChannelServiceMockRecorder) FindByMembership() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByMembership", reflect.TypeOf((*MockChannelService)(nil).FindByMembership))
}
// FindMembers mocks base method
func (m *MockChannelService) FindMembers(channelID uint64) (types.ChannelMemberSet, error) {
ret := m.ctrl.Call(m, "FindMembers", channelID)
ret0, _ := ret[0].(types.ChannelMemberSet)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FindMembers indicates an expected call of FindMembers
func (mr *MockChannelServiceMockRecorder) FindMembers(channelID interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindMembers", reflect.TypeOf((*MockChannelService)(nil).FindMembers), channelID)
}
// Create mocks base method
func (m *MockChannelService) Create(channel *types.Channel) (*types.Channel, error) {
ret := m.ctrl.Call(m, "Create", channel)

View File

@@ -18,6 +18,7 @@ type (
attachment repository.AttachmentRepository
channel repository.ChannelRepository
cmember repository.ChannelMemberRepository
message repository.MessageRepository
reaction repository.ReactionRepository
@@ -64,8 +65,9 @@ func (svc *message) With(ctx context.Context) MessageService {
usr: svc.usr.With(ctx),
evl: svc.evl.With(ctx),
channel: repository.Channel(ctx, db),
attachment: repository.Attachment(ctx, db),
channel: repository.Channel(ctx, db),
cmember: repository.ChannelMember(ctx, db),
message: repository.Message(ctx, db),
reaction: repository.Reaction(ctx, db),
}
@@ -133,7 +135,7 @@ func (svc *message) Direct(recipientID uint64, in *types.Message) (out *types.Me
dch, err := svc.channel.FindDirectChannelByUserID(currentUserID, recipientID)
if err == repository.ErrChannelNotFound {
dch, err = svc.channel.CreateChannel(&types.Channel{
Type: types.ChannelTypeDirect,
Type: types.ChannelTypeGroup,
})
if err != nil {
@@ -143,12 +145,12 @@ func (svc *message) Direct(recipientID uint64, in *types.Message) (out *types.Me
membership := &types.ChannelMember{ChannelID: dch.ID, Type: types.ChannelMembershipTypeOwner}
membership.UserID = currentUserID
if _, err = svc.channel.AddChannelMember(membership); err != nil {
if _, err = svc.cmember.Create(membership); err != nil {
return
}
membership.UserID = recipientID
if _, err = svc.channel.AddChannelMember(membership); err != nil {
if _, err = svc.cmember.Create(membership); err != nil {
return
}

View File

@@ -28,16 +28,6 @@ type (
Members []uint64 `json:"-" db:"-"`
}
ChannelMember struct {
ChannelID uint64 `db:"rel_channel"`
UserID uint64 `db:"rel_user"`
Type ChannelMembershipType `db:"type"`
CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"`
UpdatedAt *time.Time `json:"updatedAt,omitempty" db:"updated_at"`
}
ChannelFilter struct {
Query string
@@ -47,11 +37,9 @@ type (
IncludeMembers bool
}
ChannelMembershipType string
ChannelType string
ChannelType string
ChannelSet []*Channel
ChannelMemberSet []*ChannelMember
ChannelSet []*Channel
)
// Scope returns permissions group that for this type
@@ -79,34 +67,8 @@ func (cc ChannelSet) Walk(w func(*Channel) error) (err error) {
return
}
func (mm ChannelMemberSet) Walk(w func(*ChannelMember) error) (err error) {
for i := range mm {
if err = w(mm[i]); err != nil {
return
}
}
return
}
func (mm ChannelMemberSet) MembersOf(channelID uint64) []uint64 {
var mmof = make([]uint64, 0)
for i := range mm {
if mm[i].ChannelID == channelID {
mmof = append(mmof, mm[i].UserID)
}
}
return mmof
}
const (
ChannelMembershipTypeOwner ChannelMembershipType = "owner"
ChannelMembershipTypeMember = "member"
ChannelTypePublic ChannelType = "public"
ChannelTypePrivate = "private"
ChannelTypeGroup = "group"
ChannelTypeDirect = "direct"
)

View File

@@ -0,0 +1,59 @@
package types
import (
"time"
authTypes "github.com/crusttech/crust/auth/types"
)
type (
ChannelMember struct {
ChannelID uint64 `db:"rel_channel"`
UserID uint64 `db:"rel_user"`
User *authTypes.User `db:"-"`
Type ChannelMembershipType `db:"type"`
CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"`
UpdatedAt *time.Time `json:"updatedAt,omitempty" db:"updated_at"`
}
ChannelMemberFilter struct {
ComembersOf uint64
ChannelID uint64
MemberID uint64
}
ChannelMembershipType string
ChannelMemberSet []*ChannelMember
)
func (mm ChannelMemberSet) Walk(w func(*ChannelMember) error) (err error) {
for i := range mm {
if err = w(mm[i]); err != nil {
return
}
}
return
}
func (mm ChannelMemberSet) MembersOf(channelID uint64) []uint64 {
var mmof = make([]uint64, 0)
for i := range mm {
if mm[i].ChannelID == channelID {
mmof = append(mmof, mm[i].UserID)
}
}
return mmof
}
const (
ChannelMembershipTypeOwner ChannelMembershipType = "owner"
ChannelMembershipTypeMember = "member"
ChannelMembershipTypeInvitee = "invitee"
)