Channel membership improved
This commit is contained in:
@@ -28,7 +28,7 @@ type (
|
||||
With(ctx context.Context) UserService
|
||||
|
||||
FindByID(id uint64) (*types.User, error)
|
||||
Find(filter *types.UserFilter) ([]*types.User, error)
|
||||
Find(filter *types.UserFilter) (types.UserSet, error)
|
||||
|
||||
Create(input *types.User) (*types.User, error)
|
||||
Update(mod *types.User) (*types.User, error)
|
||||
@@ -73,7 +73,7 @@ func (svc *user) FindByID(id uint64) (*types.User, error) {
|
||||
return svc.user.FindUserByID(id)
|
||||
}
|
||||
|
||||
func (svc *user) Find(filter *types.UserFilter) ([]*types.User, error) {
|
||||
func (svc *user) Find(filter *types.UserFilter) (types.UserSet, error) {
|
||||
return svc.user.FindUsers(filter)
|
||||
}
|
||||
|
||||
|
||||
@@ -62,3 +62,13 @@ func (uu UserSet) Walk(w func(*User) error) (err error) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (uu UserSet) FindById(ID uint64) *User {
|
||||
for i := range uu {
|
||||
if uu[i].ID == ID {
|
||||
return uu[i]
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -65,6 +65,24 @@ func Channels(channels sam.ChannelSet) *outgoing.ChannelSet {
|
||||
return &retval
|
||||
}
|
||||
|
||||
func ChannelMember(m *sam.ChannelMember) *outgoing.ChannelMember {
|
||||
return &outgoing.ChannelMember{
|
||||
User: User(m.User),
|
||||
Type: string(m.Type),
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func ChannelMembers(members sam.ChannelMemberSet) *outgoing.ChannelMemberSet {
|
||||
mm := make([]*outgoing.ChannelMember, len(members))
|
||||
for k, c := range members {
|
||||
mm[k] = ChannelMember(c)
|
||||
}
|
||||
retval := outgoing.ChannelMemberSet(mm)
|
||||
return &retval
|
||||
}
|
||||
|
||||
func User(user *auth.User) *outgoing.User {
|
||||
if user == nil {
|
||||
return nil
|
||||
|
||||
26
internal/payload/outgoing/channel_member.go
Normal file
26
internal/payload/outgoing/channel_member.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package outgoing
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
type (
|
||||
ChannelMember struct {
|
||||
// Channel to part (nil) for ALL channels
|
||||
User *User `json:"user"`
|
||||
Type string `json:"type""`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt *time.Time `json:"updatedAt,omitempty"`
|
||||
}
|
||||
|
||||
ChannelMemberSet []*ChannelMember
|
||||
)
|
||||
|
||||
func (p *ChannelMember) EncodeMessage() ([]byte, error) {
|
||||
return json.Marshal(Payload{ChannelMember: p})
|
||||
}
|
||||
|
||||
func (p *ChannelMemberSet) EncodeMessage() ([]byte, error) {
|
||||
return json.Marshal(Payload{ChannelMemberSet: p})
|
||||
}
|
||||
@@ -16,6 +16,9 @@ type (
|
||||
*Channel `json:"channel,omitempty"`
|
||||
*ChannelSet `json:"channels,omitempty"`
|
||||
|
||||
*ChannelMember `json:"channelMember,omitempty"`
|
||||
*ChannelMemberSet `json:"channelMembers,omitempty"`
|
||||
|
||||
*User `json:"user,omitempty"`
|
||||
*UserSet `json:"users,omitempty"`
|
||||
|
||||
|
||||
@@ -16,13 +16,9 @@ type (
|
||||
FindChannelByID(id uint64) (*types.Channel, error)
|
||||
FindDirectChannelByUserID(fromUserID, toUserID uint64) (*types.Channel, error)
|
||||
FindChannels(filter *types.ChannelFilter) ([]*types.Channel, error)
|
||||
FindMembers(userID uint64) (types.ChannelMemberSet, error)
|
||||
CreateChannel(mod *types.Channel) (*types.Channel, error)
|
||||
UpdateChannel(mod *types.Channel) (*types.Channel, error)
|
||||
|
||||
FindChannelsMembershipsByMemberId(memberId uint64) ([]*types.ChannelMember, error)
|
||||
AddChannelMember(mod *types.ChannelMember) (*types.ChannelMember, error)
|
||||
RemoveChannelMember(channelID, userID uint64) error
|
||||
ArchiveChannelByID(id uint64) error
|
||||
UnarchiveChannelByID(id uint64) error
|
||||
DeleteChannelByID(id uint64) error
|
||||
@@ -42,11 +38,7 @@ const (
|
||||
FROM channels AS c
|
||||
WHERE ` + sqlChannelValidOnly
|
||||
|
||||
sqlChannelMemberSelect = `SELECT m.*
|
||||
FROM channel_members AS m
|
||||
INNER JOIN channels AS c ON (m.rel_channel = c.id)
|
||||
WHERE ` + sqlChannelValidOnly
|
||||
|
||||
// Returns channel (group) with exactly 2 members
|
||||
sqlChannelDirect = `SELECT *
|
||||
FROM channels AS c
|
||||
WHERE c.type = ?
|
||||
@@ -57,10 +49,6 @@ const (
|
||||
AND MIN(rel_user) = ?
|
||||
AND MAX(rel_user) = ?)`
|
||||
|
||||
sqlChannelMemberships = `SELECT *
|
||||
FROM channel_members AS cm
|
||||
WHERE true`
|
||||
|
||||
// subquery that filters out all channels that current user has access to as a member
|
||||
// or via channel type (public chans)
|
||||
sqlChannelAccess = ` (
|
||||
@@ -109,7 +97,7 @@ func (r *channel) FindDirectChannelByUserID(fromUserID, toUserID uint64) (*types
|
||||
toUserID, fromUserID = fromUserID, toUserID
|
||||
}
|
||||
|
||||
return mod, isFound(r.db().Get(mod, sqlChannelDirect, types.ChannelTypeDirect, fromUserID, toUserID), mod.ID > 0, ErrChannelNotFound)
|
||||
return mod, isFound(r.db().Get(mod, sqlChannelDirect, types.ChannelTypeGroup, fromUserID, toUserID), mod.ID > 0, ErrChannelNotFound)
|
||||
}
|
||||
|
||||
func (r *channel) FindChannels(filter *types.ChannelFilter) ([]*types.Channel, error) {
|
||||
@@ -138,22 +126,6 @@ func (r *channel) FindChannels(filter *types.ChannelFilter) ([]*types.Channel, e
|
||||
return rval, r.db().Select(&rval, sql, params...)
|
||||
}
|
||||
|
||||
// Returns member ids of all channels that user has access to
|
||||
func (r *channel) FindMembers(userID uint64) (types.ChannelMemberSet, error) {
|
||||
params := make([]interface{}, 0)
|
||||
rval := types.ChannelMemberSet{}
|
||||
|
||||
sql := sqlChannelMemberSelect
|
||||
|
||||
if userID > 0 {
|
||||
// scope: only channels we have access to
|
||||
sql += " AND m.rel_channel IN " + sqlChannelAccess
|
||||
params = append(params, userID, types.ChannelTypePublic)
|
||||
}
|
||||
|
||||
return rval, r.db().Select(&rval, sql, params...)
|
||||
}
|
||||
|
||||
func (r *channel) CreateChannel(mod *types.Channel) (*types.Channel, error) {
|
||||
mod.ID = factory.Sonyflake.NextID()
|
||||
mod.CreatedAt = time.Now()
|
||||
@@ -180,24 +152,6 @@ func (r *channel) UpdateChannel(mod *types.Channel) (*types.Channel, error) {
|
||||
UpdatePartial("channels", mod, whitelist, "id")
|
||||
}
|
||||
|
||||
func (r *channel) FindChannelsMembershipsByMemberId(memberId uint64) ([]*types.ChannelMember, error) {
|
||||
var rval = make([]*types.ChannelMember, 0)
|
||||
|
||||
return rval, r.db().Select(&rval, sqlChannelMemberships+" AND cm.rel_user = ? ", memberId)
|
||||
}
|
||||
|
||||
func (r *channel) AddChannelMember(mod *types.ChannelMember) (*types.ChannelMember, error) {
|
||||
sql := `INSERT INTO channel_members (rel_channel, rel_user) VALUES (?, ?)`
|
||||
mod.CreatedAt = time.Now()
|
||||
|
||||
return mod, exec(r.db().Exec(sql, mod.ChannelID, mod.UserID))
|
||||
}
|
||||
|
||||
func (r *channel) RemoveChannelMember(channelID, userID uint64) error {
|
||||
sql := `DELETE FROM channel_members WHERE rel_channel = ? AND rel_user = ?`
|
||||
return exec(r.db().Exec(sql, channelID, userID))
|
||||
}
|
||||
|
||||
func (r *channel) ArchiveChannelByID(id uint64) error {
|
||||
return r.updateColumnByID("channels", "archived_at", time.Now(), id)
|
||||
}
|
||||
|
||||
114
sam/repository/channel_member.go
Normal file
114
sam/repository/channel_member.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/sam/types"
|
||||
)
|
||||
|
||||
type (
|
||||
// ChannelMemberRepository interface to channel member repository
|
||||
ChannelMemberRepository interface {
|
||||
With(ctx context.Context, db *factory.DB) ChannelMemberRepository
|
||||
|
||||
Find(filter *types.ChannelMemberFilter) (types.ChannelMemberSet, error)
|
||||
|
||||
Create(mod *types.ChannelMember) (*types.ChannelMember, error)
|
||||
Delete(channelMemberID, userID uint64) error
|
||||
}
|
||||
|
||||
channelMember struct {
|
||||
*repository
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
// Copy definitions to make it more obvious that we're reusing channel-scope sql
|
||||
sqlChannelMemberChannelValidOnly = sqlChannelValidOnly
|
||||
|
||||
// Copy definitions to make it more obvious that we're reusing channel-scope sql
|
||||
sqlChannelMemberChannelAccess = sqlChannelAccess
|
||||
|
||||
// Fetching channel members of all channels a specific user has access to
|
||||
sqlChannelMemberSelect = `SELECT m.*
|
||||
FROM channel_members AS m
|
||||
INNER JOIN channels AS c ON (m.rel_channel = c.id)
|
||||
WHERE ` + sqlChannelMemberChannelValidOnly
|
||||
|
||||
// Selects all user's memberships
|
||||
sqlChannelMemberships = `SELECT *
|
||||
FROM channel_members AS cm
|
||||
WHERE true`
|
||||
)
|
||||
|
||||
// ChannelMember creates new instance of channel member repository
|
||||
func ChannelMember(ctx context.Context, db *factory.DB) ChannelMemberRepository {
|
||||
return (&channelMember{}).With(ctx, db)
|
||||
}
|
||||
|
||||
// With context...
|
||||
func (r *channelMember) With(ctx context.Context, db *factory.DB) ChannelMemberRepository {
|
||||
return &channelMember{
|
||||
repository: r.repository.With(ctx, db),
|
||||
}
|
||||
}
|
||||
|
||||
// FindMembers fetches membership info
|
||||
//
|
||||
// If channelID > 0 it returns members of a specific channel
|
||||
// If userID > 0 it returns members of all channels this user is member of
|
||||
func (r *channelMember) Find(filter *types.ChannelMemberFilter) (types.ChannelMemberSet, error) {
|
||||
params := make([]interface{}, 0)
|
||||
mm := types.ChannelMemberSet{}
|
||||
|
||||
sql := sqlChannelMemberSelect
|
||||
|
||||
if filter != nil {
|
||||
if filter.ComembersOf > 0 {
|
||||
// scope: only channel we have access to
|
||||
sql += " AND m.rel_channel IN " + sqlChannelMemberChannelAccess
|
||||
params = append(params, filter.ComembersOf, types.ChannelTypePublic)
|
||||
}
|
||||
|
||||
if filter.MemberID > 0 {
|
||||
sql += " AND m.rel_user = ?"
|
||||
params = append(params, filter.MemberID)
|
||||
}
|
||||
|
||||
if filter.ChannelID > 0 {
|
||||
sql += " AND m.rel_channel = ?"
|
||||
params = append(params, filter.ChannelID)
|
||||
}
|
||||
}
|
||||
|
||||
spew.Dump(filter, sql, params)
|
||||
|
||||
return mm, r.db().Select(&mm, sql, params...)
|
||||
}
|
||||
|
||||
// Create adds channel membership record
|
||||
func (r *channelMember) Create(mod *types.ChannelMember) (*types.ChannelMember, error) {
|
||||
mod.CreatedAt = time.Now()
|
||||
mod.UpdatedAt = nil
|
||||
|
||||
return mod, r.db().Insert("channel_members", mod)
|
||||
}
|
||||
|
||||
// Update modifies existing channel membership record
|
||||
func (r *channelMember) Update(mod *types.ChannelMember) (*types.ChannelMember, error) {
|
||||
mod.UpdatedAt = timeNowPtr()
|
||||
|
||||
whitelist := []string{"type", "updated_at"}
|
||||
|
||||
return mod, r.db().UpdatePartial("channel_members", mod, whitelist, "rel_channel", "rel_user")
|
||||
}
|
||||
|
||||
// Delete removes existing channel membership record
|
||||
func (r *channelMember) Delete(channelMemberID, userID uint64) error {
|
||||
sql := `DELETE FROM channel_members WHERE rel_channelMember = ? AND rel_user = ?`
|
||||
return exec(r.db().Exec(sql, channelMemberID, userID))
|
||||
}
|
||||
@@ -63,7 +63,7 @@ func (ctrl *Channel) List(ctx context.Context, r *request.ChannelList) (interfac
|
||||
}
|
||||
|
||||
func (ctrl *Channel) Members(ctx context.Context, r *request.ChannelMembers) (interface{}, error) {
|
||||
return nil, nil
|
||||
return ctrl.wrapMemberSet(ctrl.svc.ch.With(ctx).FindMembers(r.ChannelID))
|
||||
}
|
||||
|
||||
func (ctrl *Channel) Join(ctx context.Context, r *request.ChannelJoin) (interface{}, error) {
|
||||
@@ -116,3 +116,11 @@ func (ctrl *Channel) wrapSet(cc types.ChannelSet, err error) (*outgoing.ChannelS
|
||||
return payload.Channels(cc), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (ctrl *Channel) wrapMemberSet(mm types.ChannelMemberSet, err error) (*outgoing.ChannelMemberSet, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return payload.ChannelMembers(mm), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,18 +73,6 @@ func (mr *MockAttachmentServiceMockRecorder) Create(channelId, name, size, fh in
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockAttachmentService)(nil).Create), channelId, name, size, fh)
|
||||
}
|
||||
|
||||
// LoadFromMessages mocks base method
|
||||
func (m *MockAttachmentService) LoadFromMessages(mm types.MessageSet) error {
|
||||
ret := m.ctrl.Call(m, "LoadFromMessages", mm)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// LoadFromMessages indicates an expected call of LoadFromMessages
|
||||
func (mr *MockAttachmentServiceMockRecorder) LoadFromMessages(mm interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadFromMessages", reflect.TypeOf((*MockAttachmentService)(nil).LoadFromMessages), mm)
|
||||
}
|
||||
|
||||
// OpenOriginal mocks base method
|
||||
func (m *MockAttachmentService) OpenOriginal(att *types.Attachment) (io.ReadSeeker, error) {
|
||||
ret := m.ctrl.Call(m, "OpenOriginal", att)
|
||||
|
||||
@@ -4,10 +4,11 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/crusttech/crust/internal/auth"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/crusttech/crust/internal/auth"
|
||||
|
||||
authService "github.com/crusttech/crust/auth/service"
|
||||
"github.com/crusttech/crust/sam/repository"
|
||||
"github.com/crusttech/crust/sam/types"
|
||||
@@ -22,6 +23,7 @@ type (
|
||||
evl EventService
|
||||
|
||||
channel repository.ChannelRepository
|
||||
cmember repository.ChannelMemberRepository
|
||||
message repository.MessageRepository
|
||||
}
|
||||
|
||||
@@ -31,6 +33,7 @@ type (
|
||||
FindByID(channelID uint64) (*types.Channel, error)
|
||||
Find(filter *types.ChannelFilter) (types.ChannelSet, error)
|
||||
FindByMembership() (rval []*types.Channel, err error)
|
||||
FindMembers(channelID uint64) (types.ChannelMemberSet, error)
|
||||
|
||||
Create(channel *types.Channel) (*types.Channel, error)
|
||||
Update(channel *types.Channel) (*types.Channel, error)
|
||||
@@ -62,6 +65,7 @@ func (svc *channel) With(ctx context.Context) ChannelService {
|
||||
evl: svc.evl.With(ctx),
|
||||
|
||||
channel: repository.Channel(ctx, db),
|
||||
cmember: repository.ChannelMember(ctx, db),
|
||||
message: repository.Message(ctx, db),
|
||||
}
|
||||
}
|
||||
@@ -92,7 +96,7 @@ func (svc *channel) Find(filter *types.ChannelFilter) (types.ChannelSet, error)
|
||||
func (svc *channel) preloadMembers(cc types.ChannelSet) error {
|
||||
var userID = auth.GetIdentityFromContext(svc.ctx).Identity()
|
||||
|
||||
if mm, err := svc.channel.FindMembers(userID); err != nil {
|
||||
if mm, err := svc.cmember.Find(&types.ChannelMemberFilter{ComembersOf: userID}); err != nil {
|
||||
return err
|
||||
} else {
|
||||
cc.Walk(func(ch *types.Channel) error {
|
||||
@@ -104,13 +108,39 @@ func (svc *channel) preloadMembers(cc types.ChannelSet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindMembers loads all members (and full users) for a specific channel
|
||||
func (svc *channel) FindMembers(channelID uint64) (out types.ChannelMemberSet, err error) {
|
||||
var userID = auth.GetIdentityFromContext(svc.ctx).Identity()
|
||||
|
||||
// @todo [SECURITY] check if we can return members on this channel
|
||||
_ = channelID
|
||||
_ = userID
|
||||
|
||||
return out, svc.db.Transaction(func() (err error) {
|
||||
out, err = svc.cmember.Find(&types.ChannelMemberFilter{ChannelID: channelID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if uu, err := svc.usr.Find(nil); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return out.Walk(func(member *types.ChannelMember) error {
|
||||
member.User = uu.FindById(member.UserID)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Returns all channels with membership info
|
||||
func (svc *channel) FindByMembership() (rval []*types.Channel, err error) {
|
||||
return rval, svc.db.Transaction(func() error {
|
||||
var chMemberId = repository.Identity(svc.ctx)
|
||||
|
||||
var mm []*types.ChannelMember
|
||||
|
||||
if mm, err = svc.channel.FindChannelsMembershipsByMemberId(chMemberId); err != nil {
|
||||
if mm, err = svc.cmember.Find(&types.ChannelMemberFilter{MemberID: chMemberId}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -176,7 +206,7 @@ func (svc *channel) Create(in *types.Channel) (out *types.Channel, err error) {
|
||||
}
|
||||
|
||||
// Join current user as an member & owner
|
||||
_, err = svc.channel.AddChannelMember(&types.ChannelMember{
|
||||
_, err = svc.cmember.Create(&types.ChannelMember{
|
||||
ChannelID: out.ID,
|
||||
UserID: chCreatorID,
|
||||
Type: types.ChannelMembershipTypeOwner,
|
||||
@@ -412,7 +442,29 @@ func (svc *channel) Unarchive(id uint64) error {
|
||||
|
||||
return svc.evl.Channel(ch)
|
||||
})
|
||||
}
|
||||
|
||||
func (svc *channel) AddMember(m *types.ChannelMember) (out *types.ChannelMember, err error) {
|
||||
return out, svc.db.Transaction(func() (err error) {
|
||||
var userID = repository.Identity(svc.ctx)
|
||||
|
||||
var ch *types.Channel
|
||||
|
||||
// @todo [SECURITY] can user access this channel?
|
||||
if ch, err = svc.channel.FindChannelByID(m.ChannelID); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// @todo [SECURITY] can user add members to this channel?
|
||||
|
||||
msg, err := svc.message.CreateMessage(svc.makeSystemMessage(ch, "@%d added a new member to this channel: @%d", userID, m.UserID))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
svc.sendMessageEvent(msg)
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (svc *channel) makeSystemMessage(ch *types.Channel, format string, a ...interface{}) *types.Message {
|
||||
|
||||
@@ -60,9 +60,9 @@ func (mr *MockChannelServiceMockRecorder) FindByID(channelID interface{}) *gomoc
|
||||
}
|
||||
|
||||
// Find mocks base method
|
||||
func (m *MockChannelService) Find(filter *types.ChannelFilter) ([]*types.Channel, error) {
|
||||
func (m *MockChannelService) Find(filter *types.ChannelFilter) (types.ChannelSet, error) {
|
||||
ret := m.ctrl.Call(m, "Find", filter)
|
||||
ret0, _ := ret[0].([]*types.Channel)
|
||||
ret0, _ := ret[0].(types.ChannelSet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -85,6 +85,19 @@ func (mr *MockChannelServiceMockRecorder) FindByMembership() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindByMembership", reflect.TypeOf((*MockChannelService)(nil).FindByMembership))
|
||||
}
|
||||
|
||||
// FindMembers mocks base method
|
||||
func (m *MockChannelService) FindMembers(channelID uint64) (types.ChannelMemberSet, error) {
|
||||
ret := m.ctrl.Call(m, "FindMembers", channelID)
|
||||
ret0, _ := ret[0].(types.ChannelMemberSet)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindMembers indicates an expected call of FindMembers
|
||||
func (mr *MockChannelServiceMockRecorder) FindMembers(channelID interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindMembers", reflect.TypeOf((*MockChannelService)(nil).FindMembers), channelID)
|
||||
}
|
||||
|
||||
// Create mocks base method
|
||||
func (m *MockChannelService) Create(channel *types.Channel) (*types.Channel, error) {
|
||||
ret := m.ctrl.Call(m, "Create", channel)
|
||||
|
||||
@@ -18,6 +18,7 @@ type (
|
||||
|
||||
attachment repository.AttachmentRepository
|
||||
channel repository.ChannelRepository
|
||||
cmember repository.ChannelMemberRepository
|
||||
message repository.MessageRepository
|
||||
reaction repository.ReactionRepository
|
||||
|
||||
@@ -64,8 +65,9 @@ func (svc *message) With(ctx context.Context) MessageService {
|
||||
usr: svc.usr.With(ctx),
|
||||
evl: svc.evl.With(ctx),
|
||||
|
||||
channel: repository.Channel(ctx, db),
|
||||
attachment: repository.Attachment(ctx, db),
|
||||
channel: repository.Channel(ctx, db),
|
||||
cmember: repository.ChannelMember(ctx, db),
|
||||
message: repository.Message(ctx, db),
|
||||
reaction: repository.Reaction(ctx, db),
|
||||
}
|
||||
@@ -133,7 +135,7 @@ func (svc *message) Direct(recipientID uint64, in *types.Message) (out *types.Me
|
||||
dch, err := svc.channel.FindDirectChannelByUserID(currentUserID, recipientID)
|
||||
if err == repository.ErrChannelNotFound {
|
||||
dch, err = svc.channel.CreateChannel(&types.Channel{
|
||||
Type: types.ChannelTypeDirect,
|
||||
Type: types.ChannelTypeGroup,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -143,12 +145,12 @@ func (svc *message) Direct(recipientID uint64, in *types.Message) (out *types.Me
|
||||
membership := &types.ChannelMember{ChannelID: dch.ID, Type: types.ChannelMembershipTypeOwner}
|
||||
|
||||
membership.UserID = currentUserID
|
||||
if _, err = svc.channel.AddChannelMember(membership); err != nil {
|
||||
if _, err = svc.cmember.Create(membership); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
membership.UserID = recipientID
|
||||
if _, err = svc.channel.AddChannelMember(membership); err != nil {
|
||||
if _, err = svc.cmember.Create(membership); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -28,16 +28,6 @@ type (
|
||||
Members []uint64 `json:"-" db:"-"`
|
||||
}
|
||||
|
||||
ChannelMember struct {
|
||||
ChannelID uint64 `db:"rel_channel"`
|
||||
UserID uint64 `db:"rel_user"`
|
||||
|
||||
Type ChannelMembershipType `db:"type"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"`
|
||||
UpdatedAt *time.Time `json:"updatedAt,omitempty" db:"updated_at"`
|
||||
}
|
||||
|
||||
ChannelFilter struct {
|
||||
Query string
|
||||
|
||||
@@ -47,11 +37,9 @@ type (
|
||||
IncludeMembers bool
|
||||
}
|
||||
|
||||
ChannelMembershipType string
|
||||
ChannelType string
|
||||
ChannelType string
|
||||
|
||||
ChannelSet []*Channel
|
||||
ChannelMemberSet []*ChannelMember
|
||||
ChannelSet []*Channel
|
||||
)
|
||||
|
||||
// Scope returns permissions group that for this type
|
||||
@@ -79,34 +67,8 @@ func (cc ChannelSet) Walk(w func(*Channel) error) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (mm ChannelMemberSet) Walk(w func(*ChannelMember) error) (err error) {
|
||||
for i := range mm {
|
||||
if err = w(mm[i]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (mm ChannelMemberSet) MembersOf(channelID uint64) []uint64 {
|
||||
var mmof = make([]uint64, 0)
|
||||
|
||||
for i := range mm {
|
||||
if mm[i].ChannelID == channelID {
|
||||
mmof = append(mmof, mm[i].UserID)
|
||||
}
|
||||
}
|
||||
|
||||
return mmof
|
||||
}
|
||||
|
||||
const (
|
||||
ChannelMembershipTypeOwner ChannelMembershipType = "owner"
|
||||
ChannelMembershipTypeMember = "member"
|
||||
|
||||
ChannelTypePublic ChannelType = "public"
|
||||
ChannelTypePrivate = "private"
|
||||
ChannelTypeGroup = "group"
|
||||
ChannelTypeDirect = "direct"
|
||||
)
|
||||
|
||||
59
sam/types/channel_member.go
Normal file
59
sam/types/channel_member.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
authTypes "github.com/crusttech/crust/auth/types"
|
||||
)
|
||||
|
||||
type (
|
||||
ChannelMember struct {
|
||||
ChannelID uint64 `db:"rel_channel"`
|
||||
|
||||
UserID uint64 `db:"rel_user"`
|
||||
User *authTypes.User `db:"-"`
|
||||
|
||||
Type ChannelMembershipType `db:"type"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"`
|
||||
UpdatedAt *time.Time `json:"updatedAt,omitempty" db:"updated_at"`
|
||||
}
|
||||
|
||||
ChannelMemberFilter struct {
|
||||
ComembersOf uint64
|
||||
ChannelID uint64
|
||||
MemberID uint64
|
||||
}
|
||||
|
||||
ChannelMembershipType string
|
||||
|
||||
ChannelMemberSet []*ChannelMember
|
||||
)
|
||||
|
||||
func (mm ChannelMemberSet) Walk(w func(*ChannelMember) error) (err error) {
|
||||
for i := range mm {
|
||||
if err = w(mm[i]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (mm ChannelMemberSet) MembersOf(channelID uint64) []uint64 {
|
||||
var mmof = make([]uint64, 0)
|
||||
|
||||
for i := range mm {
|
||||
if mm[i].ChannelID == channelID {
|
||||
mmof = append(mmof, mm[i].UserID)
|
||||
}
|
||||
}
|
||||
|
||||
return mmof
|
||||
}
|
||||
|
||||
const (
|
||||
ChannelMembershipTypeOwner ChannelMembershipType = "owner"
|
||||
ChannelMembershipTypeMember = "member"
|
||||
ChannelMembershipTypeInvitee = "invitee"
|
||||
)
|
||||
Reference in New Issue
Block a user