Refactor messages repo, fix channels
This commit is contained in:
parent
f422867469
commit
4cd4077f2d
@ -42,6 +42,6 @@ func provisionConfig(ctx context.Context, cmd *cobra.Command, c *cli.Config) err
|
||||
|
||||
// Provision ONLY when there are no channels (even if we find delete channels we abort provisioning
|
||||
func isProvisioned(ctx context.Context) (bool, error) {
|
||||
_, f, err := service.DefaultChannel.With(ctx).Find(types.ChannelFilter{IncludeDeleted: true})
|
||||
return f.Count > 0, err
|
||||
cc, _, err := service.DefaultChannel.With(ctx).Find(types.ChannelFilter{IncludeDeleted: true})
|
||||
return len(cc) > 0, err
|
||||
}
|
||||
|
||||
@ -123,6 +123,10 @@ func (r channel) Find(filter types.ChannelFilter) (set types.ChannelSet, f types
|
||||
query = query.Where(squirrel.Eq{"c.deleted_at": nil})
|
||||
}
|
||||
|
||||
if len(f.ChannelID) > 0 {
|
||||
query = query.Where(squirrel.Eq{"c.id": f.ChannelID})
|
||||
}
|
||||
|
||||
if f.Query != "" {
|
||||
q := "%" + strings.ToLower(f.Query) + "%"
|
||||
query = query.Where(squirrel.Like{"LOWER(name)": q})
|
||||
@ -143,11 +147,7 @@ func (r channel) Find(filter types.ChannelFilter) (set types.ChannelSet, f types
|
||||
query = query.OrderBy(orderBy...)
|
||||
}
|
||||
|
||||
if f.Count, err = rh.Count(r.db(), query); err != nil || f.Count == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
return set, f, rh.FetchPaged(r.db(), query, f.Page, f.PerPage, &set)
|
||||
return set, f, rh.FetchAll(r.db(), query, &set)
|
||||
}
|
||||
|
||||
func (r channel) Create(mod *types.Channel) (*types.Channel, error) {
|
||||
|
||||
@ -7,7 +7,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/titpetric/factory"
|
||||
|
||||
"github.com/cortezaproject/corteza-server/messaging/types"
|
||||
@ -19,8 +18,8 @@ type (
|
||||
With(ctx context.Context, db *factory.DB) MessageRepository
|
||||
|
||||
FindByID(id uint64) (*types.Message, error)
|
||||
Find(filter *types.MessageFilter) (types.MessageSet, error)
|
||||
FindThreads(filter *types.MessageFilter) (types.MessageSet, error)
|
||||
Find(types.MessageFilter) (types.MessageSet, types.MessageFilter, error)
|
||||
FindThreads(types.MessageFilter) (types.MessageSet, types.MessageFilter, error)
|
||||
CountFromMessageID(channelID, threadID, messageID uint64) (uint32, error)
|
||||
LastMessageID(channelID, threadID uint64) (uint64, error)
|
||||
PrefillThreadParticipants(mm types.MessageSet) error
|
||||
@ -43,54 +42,6 @@ type (
|
||||
const (
|
||||
MESSAGES_MAX_LIMIT = 100
|
||||
|
||||
// subquery that filters out all channels that current user has access to as a member
|
||||
// or via channel type (public channels)
|
||||
sqlChannelAccess = ` (
|
||||
SELECT id
|
||||
FROM messaging_channel c
|
||||
LEFT OUTER JOIN messaging_channel_member AS m ON (c.id = m.rel_channel)
|
||||
WHERE rel_user = ?
|
||||
UNION
|
||||
SELECT id
|
||||
FROM messaging_channel c
|
||||
WHERE c.type = ?
|
||||
)`
|
||||
|
||||
sqlMessageColumns = "id, " +
|
||||
"COALESCE(type,'') AS type, " +
|
||||
"message, " +
|
||||
"rel_user, " +
|
||||
"rel_channel, " +
|
||||
"reply_to, " +
|
||||
"replies, " +
|
||||
"created_at, " +
|
||||
"updated_at, " +
|
||||
"deleted_at"
|
||||
sqlMessageScope = "deleted_at IS NULL"
|
||||
|
||||
sqlMessagesSelect = `SELECT ` + sqlMessageColumns + `
|
||||
FROM messaging_message
|
||||
WHERE ` + sqlMessageScope
|
||||
|
||||
sqlMessagesThreads = "WITH originals AS (" +
|
||||
" SELECT id AS original_id " +
|
||||
" FROM messaging_message " +
|
||||
" WHERE " + sqlMessageScope +
|
||||
" AND rel_channel IN " + sqlChannelAccess +
|
||||
" AND reply_to = 0 " +
|
||||
" AND replies > 0 " +
|
||||
// for finding only threads we've created or replied to
|
||||
" AND (rel_user = ? OR id IN (SELECT DISTINCT reply_to FROM messaging_message WHERE rel_user = ?))" +
|
||||
" ORDER BY id DESC " +
|
||||
" LIMIT ? " +
|
||||
")" +
|
||||
" SELECT " + sqlMessageColumns +
|
||||
" FROM messaging_message, originals " +
|
||||
" WHERE " + sqlMessageScope +
|
||||
" AND original_id IN (id, reply_to)"
|
||||
|
||||
sqlThreadParticipantsByMessageID = "SELECT DISTINCT reply_to, rel_user FROM messaging_message WHERE reply_to IN (?)"
|
||||
|
||||
sqlCountFromMessageID = "SELECT COUNT(*) AS count " +
|
||||
"FROM messaging_message " +
|
||||
"WHERE rel_channel = ? " +
|
||||
@ -125,135 +76,154 @@ func (r message) table() string {
|
||||
return "messaging_message"
|
||||
}
|
||||
|
||||
func (r *message) FindByID(id uint64) (*types.Message, error) {
|
||||
mod := &types.Message{}
|
||||
sql := sqlMessagesSelect + " AND id = ?"
|
||||
|
||||
return mod, rh.IsFound(r.db().Get(mod, sql, id), mod.ID > 0, ErrMessageNotFound)
|
||||
func (r message) columns() []string {
|
||||
return []string{
|
||||
"m.id",
|
||||
"COALESCE(m.type,'') AS type",
|
||||
"m.message",
|
||||
"m.rel_user",
|
||||
"m.rel_channel",
|
||||
"m.reply_to",
|
||||
"m.replies",
|
||||
"m.created_at",
|
||||
"m.updated_at",
|
||||
"m.deleted_at",
|
||||
}
|
||||
}
|
||||
|
||||
func (r *message) Find(filter *types.MessageFilter) (types.MessageSet, error) {
|
||||
r.sanitizeFilter(filter)
|
||||
func (r message) query() squirrel.SelectBuilder {
|
||||
return squirrel.
|
||||
Select(r.columns()...).
|
||||
From(r.table() + " AS m").
|
||||
Where(squirrel.Eq{"m.deleted_at": nil})
|
||||
}
|
||||
|
||||
params := make([]interface{}, 0)
|
||||
rval := make(types.MessageSet, 0)
|
||||
func (r message) FindByID(id uint64) (*types.Message, error) {
|
||||
return r.findOneBy(squirrel.Eq{"m.id": id})
|
||||
}
|
||||
|
||||
sql := sqlMessagesSelect
|
||||
func (r message) findOneBy(cnd squirrel.Sqlizer) (*types.Message, error) {
|
||||
var (
|
||||
ch = &types.Message{}
|
||||
|
||||
if filter.Query != "" {
|
||||
sql += " AND message LIKE ?"
|
||||
params = append(params, "%"+filter.Query+"%")
|
||||
q = r.query().
|
||||
Where(cnd)
|
||||
|
||||
err = rh.FetchOne(r.db(), q, ch)
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if ch.ID == 0 {
|
||||
return nil, ErrMessageNotFound
|
||||
}
|
||||
|
||||
if len(filter.ChannelID) > 0 {
|
||||
sql += " AND rel_channel IN (" + strings.Repeat(",?", len(filter.ChannelID))[1:] + ")"
|
||||
for _, id := range filter.ChannelID {
|
||||
params = append(params, id)
|
||||
}
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (r message) Find(filter types.MessageFilter) (set types.MessageSet, f types.MessageFilter, err error) {
|
||||
f = r.sanitizeFilter(filter)
|
||||
|
||||
query := r.query()
|
||||
|
||||
if f.Query != "" {
|
||||
q := "%" + strings.ToLower(f.Query) + "%"
|
||||
query = query.Where(squirrel.Like{"LOWER(m.message)": q})
|
||||
}
|
||||
|
||||
if len(filter.UserID) > 0 {
|
||||
sql += " AND rel_user IN (" + strings.Repeat(",?", len(filter.UserID))[1:] + ")"
|
||||
for _, id := range filter.UserID {
|
||||
params = append(params, id)
|
||||
}
|
||||
if len(f.ChannelID) > 0 {
|
||||
query = query.Where(squirrel.Eq{"m.rel_channel": f.ChannelID})
|
||||
}
|
||||
|
||||
if len(filter.ThreadID) > 0 {
|
||||
sql += " AND reply_to IN (" + strings.Repeat(",?", len(filter.ThreadID))[1:] + ")"
|
||||
for _, id := range filter.ThreadID {
|
||||
params = append(params, id)
|
||||
}
|
||||
if len(f.UserID) > 0 {
|
||||
query = query.Where(squirrel.Eq{"m.rel_user": f.UserID})
|
||||
}
|
||||
|
||||
if len(f.ThreadID) > 0 {
|
||||
query = query.Where(squirrel.Eq{"m.reply_to": f.ThreadID})
|
||||
} else {
|
||||
sql += " AND reply_to = 0 "
|
||||
query = query.Where(squirrel.Eq{"m.reply_to": 0})
|
||||
}
|
||||
|
||||
if len(filter.Type) > 0 {
|
||||
sql += " AND type IN (" + strings.Repeat(",?", len(filter.Type))[1:] + ")"
|
||||
for _, id := range filter.Type {
|
||||
params = append(params, id)
|
||||
if f.AttachmentsOnly {
|
||||
// Override Type filter
|
||||
f.Type = []string{
|
||||
types.MessageTypeAttachment.String(),
|
||||
types.MessageTypeInlineImage.String(),
|
||||
}
|
||||
}
|
||||
|
||||
if len(f.Type) > 0 {
|
||||
query = query.Where(squirrel.Eq{"m.type": f.Type})
|
||||
}
|
||||
|
||||
// first, exclusive
|
||||
if filter.AfterID > 0 {
|
||||
sql += " AND id > ? "
|
||||
params = append(params, filter.AfterID)
|
||||
if f.AfterID > 0 {
|
||||
query = query.Where(squirrel.Gt{"m.id": f.AfterID})
|
||||
}
|
||||
|
||||
// from, inclusive
|
||||
if filter.FromID > 0 {
|
||||
sql += " AND id >= ? "
|
||||
params = append(params, filter.FromID)
|
||||
if f.FromID > 0 {
|
||||
query = query.Where(squirrel.GtOrEq{"m.id": f.FromID})
|
||||
}
|
||||
|
||||
// last, exclusive
|
||||
if filter.BeforeID > 0 {
|
||||
sql += " AND id < ? "
|
||||
params = append(params, filter.BeforeID)
|
||||
if f.BeforeID > 0 {
|
||||
query = query.Where(squirrel.Lt{"m.id": f.BeforeID})
|
||||
}
|
||||
|
||||
// to, inclusive
|
||||
if filter.ToID > 0 {
|
||||
sql += " AND id <= ? "
|
||||
params = append(params, filter.ToID)
|
||||
if f.ToID > 0 {
|
||||
query = query.Where(squirrel.LtOrEq{"m.id": f.ToID})
|
||||
}
|
||||
|
||||
if filter.BookmarkedOnly || filter.PinnedOnly {
|
||||
sql += " AND id IN (SELECT rel_message FROM messaging_message_flag WHERE flag = ?) "
|
||||
|
||||
if filter.PinnedOnly {
|
||||
params = append(params, types.MessageFlagPinnedToChannel)
|
||||
} else {
|
||||
params = append(params, types.MessageFlagBookmarkedMessage)
|
||||
if f.BookmarkedOnly || f.PinnedOnly {
|
||||
flag := types.MessageFlagBookmarkedMessage
|
||||
if f.PinnedOnly {
|
||||
flag = types.MessageFlagPinnedToChannel
|
||||
}
|
||||
|
||||
query = query.
|
||||
Where(squirrel.ConcatExpr("m.id IN(", (messageFlag{}).queryMessagesWithFlags(flag), ")"))
|
||||
}
|
||||
|
||||
if filter.AttachmentsOnly {
|
||||
sql += " AND type IN (?, ?) "
|
||||
params = append(
|
||||
params,
|
||||
types.MessageTypeAttachment,
|
||||
types.MessageTypeInlineImage,
|
||||
)
|
||||
}
|
||||
query = query.
|
||||
OrderBy("id DESC").
|
||||
Limit(uint64(f.Limit))
|
||||
|
||||
sql += " AND rel_channel IN " + sqlChannelAccess
|
||||
params = append(params, filter.CurrentUserID, types.ChannelTypePublic)
|
||||
|
||||
sql += " ORDER BY id DESC"
|
||||
|
||||
sql += " LIMIT ? "
|
||||
params = append(params, filter.Limit)
|
||||
|
||||
return rval, r.db().Select(&rval, sql, params...)
|
||||
return set, f, rh.FetchAll(r.db(), query, &set)
|
||||
}
|
||||
|
||||
func (r *message) FindThreads(filter *types.MessageFilter) (types.MessageSet, error) {
|
||||
r.sanitizeFilter(filter)
|
||||
func (r *message) FindThreads(filter types.MessageFilter) (set types.MessageSet, f types.MessageFilter, err error) {
|
||||
f = r.sanitizeFilter(filter)
|
||||
|
||||
params := make([]interface{}, 0)
|
||||
rval := make(types.MessageSet, 0)
|
||||
// Selecting first valid (deleted_at IS NULL) messages in threads (replies > 0 && reply_to = 0)
|
||||
// that belong to filtered channels and we've contributed to (or stated it)
|
||||
originals := squirrel.
|
||||
Select("id AS original_id").
|
||||
From(r.table()).
|
||||
Where(squirrel.And{
|
||||
squirrel.Eq{
|
||||
"deleted_at": nil,
|
||||
"rel_channel": f.ChannelID,
|
||||
"reply_to": 0,
|
||||
},
|
||||
squirrel.Gt{"replies": 0},
|
||||
squirrel.Or{
|
||||
squirrel.Eq{"rel_user": filter.CurrentUserID},
|
||||
squirrel.Expr(
|
||||
"id IN (SELECT DISTINCT reply_to FROM messaging_message WHERE rel_user = ?)",
|
||||
filter.CurrentUserID),
|
||||
},
|
||||
})
|
||||
|
||||
// for sqlChannelAccess
|
||||
params = append(params, filter.CurrentUserID, types.ChannelTypePublic)
|
||||
// Prepare the actual message selector
|
||||
query := r.query().Join("originals ON (original_id IN (id, reply_to))")
|
||||
|
||||
// for finding only threads we've created or replied to
|
||||
params = append(params, filter.CurrentUserID, filter.CurrentUserID)
|
||||
// And create CTE
|
||||
cte := squirrel.ConcatExpr("WITH originals AS (", originals, ") ", query)
|
||||
|
||||
// for sqlMessagesThreads
|
||||
params = append(params, filter.Limit)
|
||||
|
||||
sql := sqlMessagesThreads
|
||||
|
||||
if len(filter.ChannelID) > 0 {
|
||||
sql += " AND rel_channel IN (" + strings.Repeat(",?", len(filter.ChannelID))[1:] + ")"
|
||||
for _, id := range filter.ChannelID {
|
||||
params = append(params, id)
|
||||
}
|
||||
}
|
||||
|
||||
return rval, r.db().Select(&rval, sql, params...)
|
||||
return set, f, rh.FetchAll(r.db(), cte, &set)
|
||||
}
|
||||
|
||||
func (r *message) CountFromMessageID(channelID, threadID, lastReadMessageID uint64) (uint32, error) {
|
||||
@ -281,11 +251,11 @@ func (r *message) LastMessageID(channelID, threadID uint64) (uint64, error) {
|
||||
)
|
||||
}
|
||||
|
||||
func (r *message) PrefillThreadParticipants(mm types.MessageSet) error {
|
||||
var rval = []struct {
|
||||
func (r *message) PrefillThreadParticipants(mm types.MessageSet) (err error) {
|
||||
var rval []struct {
|
||||
ReplyTo uint64 `db:"reply_to"`
|
||||
UserID uint64 `db:"rel_user"`
|
||||
}{}
|
||||
}
|
||||
|
||||
// Filter out only relevant messages -- ones with replies
|
||||
mm, _ = mm.Filter(func(m *types.Message) (b bool, e error) {
|
||||
@ -296,27 +266,29 @@ func (r *message) PrefillThreadParticipants(mm types.MessageSet) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if sql, args, err := sqlx.In(sqlThreadParticipantsByMessageID, mm.IDs()); err != nil {
|
||||
return err
|
||||
} else if err = r.db().Select(&rval, sql, args...); err != nil {
|
||||
return err
|
||||
} else {
|
||||
for _, p := range rval {
|
||||
mm.FindByID(p.ReplyTo).RepliesFrom = append(mm.FindByID(p.ReplyTo).RepliesFrom, p.UserID)
|
||||
}
|
||||
query := squirrel.
|
||||
Select("reply_to", "rel_user").
|
||||
From(r.table()).
|
||||
Where(squirrel.Eq{"reply_to": mm.IDs()})
|
||||
|
||||
err = rh.FetchAll(r.db(), query, &rval)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, p := range rval {
|
||||
mm.FindByID(p.ReplyTo).RepliesFrom = append(mm.FindByID(p.ReplyTo).RepliesFrom, p.UserID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *message) sanitizeFilter(filter *types.MessageFilter) {
|
||||
if filter == nil {
|
||||
filter = &types.MessageFilter{}
|
||||
func (r *message) sanitizeFilter(f types.MessageFilter) types.MessageFilter {
|
||||
if f.Limit == 0 || f.Limit > MESSAGES_MAX_LIMIT {
|
||||
f.Limit = MESSAGES_MAX_LIMIT
|
||||
}
|
||||
|
||||
if filter.Limit == 0 || filter.Limit > MESSAGES_MAX_LIMIT {
|
||||
filter.Limit = MESSAGES_MAX_LIMIT
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (r *message) Create(mod *types.Message) (*types.Message, error) {
|
||||
@ -332,7 +304,7 @@ func (r *message) Update(mod *types.Message) (*types.Message, error) {
|
||||
return mod, r.db().Replace("messaging_message", mod)
|
||||
}
|
||||
|
||||
func (svc *message) BindAvatar(in *types.Message, avatar io.Reader) (*types.Message, error) {
|
||||
func (r *message) BindAvatar(in *types.Message, avatar io.Reader) (*types.Message, error) {
|
||||
// @todo: implement setting avatar on a message
|
||||
in.Meta.Avatar = ""
|
||||
return in, nil
|
||||
|
||||
@ -56,6 +56,13 @@ func (r messageFlag) query() squirrel.SelectBuilder {
|
||||
From(r.table() + " AS mf")
|
||||
}
|
||||
|
||||
func (r messageFlag) queryMessagesWithFlags(flags ...string) squirrel.SelectBuilder {
|
||||
return squirrel.
|
||||
Select("mf.rel_message").
|
||||
From(r.table() + " AS mf").
|
||||
Where(squirrel.Eq{"flag": flags})
|
||||
}
|
||||
|
||||
func (r messageFlag) With(ctx context.Context, db *factory.DB) MessageFlagRepository {
|
||||
return &messageFlag{
|
||||
repository: r.repository.With(ctx, db),
|
||||
|
||||
@ -28,7 +28,7 @@ func (Search) New() *Search {
|
||||
}
|
||||
|
||||
func (ctrl *Search) Messages(ctx context.Context, r *request.SearchMessages) (interface{}, error) {
|
||||
return ctrl.wrapSet(ctx)(ctrl.svc.msg.With(ctx).Find(&types.MessageFilter{
|
||||
mm, _, err := ctrl.svc.msg.With(ctx).Find(types.MessageFilter{
|
||||
ChannelID: payload.ParseUInt64s(r.ChannelID),
|
||||
AfterID: r.AfterMessageID,
|
||||
BeforeID: r.BeforeMessageID,
|
||||
@ -42,24 +42,27 @@ func (ctrl *Search) Messages(ctx context.Context, r *request.SearchMessages) (in
|
||||
Limit: r.Limit,
|
||||
|
||||
Query: r.Query,
|
||||
}))
|
||||
})
|
||||
|
||||
return ctrl.wrapSet(ctx, mm, err)
|
||||
}
|
||||
|
||||
func (ctrl *Search) Threads(ctx context.Context, r *request.SearchThreads) (interface{}, error) {
|
||||
return ctrl.wrapSet(ctx)(ctrl.svc.msg.With(ctx).FindThreads(&types.MessageFilter{
|
||||
mm, _, err := ctrl.svc.msg.With(ctx).FindThreads(types.MessageFilter{
|
||||
ChannelID: payload.ParseUInt64s(r.ChannelID),
|
||||
Limit: r.Limit,
|
||||
|
||||
Query: r.Query,
|
||||
}))
|
||||
})
|
||||
|
||||
return ctrl.wrapSet(ctx, mm, err)
|
||||
|
||||
}
|
||||
|
||||
func (ctrl *Search) wrapSet(ctx context.Context) func(mm types.MessageSet, err error) (*outgoing.MessageSet, error) {
|
||||
return func(mm types.MessageSet, err error) (*outgoing.MessageSet, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return payload.Messages(ctx, mm), nil
|
||||
}
|
||||
func (ctrl *Search) wrapSet(ctx context.Context, mm types.MessageSet, err error) (*outgoing.MessageSet, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return payload.Messages(ctx, mm), nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -133,6 +133,10 @@ func (svc *channel) Find(filter types.ChannelFilter) (set types.ChannelSet, f ty
|
||||
err = svc.preloadExtras(set)
|
||||
}
|
||||
|
||||
set, err = set.Filter(func(c *types.Channel) (b bool, e error) {
|
||||
return svc.ac.CanReadChannel(svc.ctx, c), nil
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -24,8 +24,9 @@ type (
|
||||
logger *zap.Logger
|
||||
ac messageAccessController
|
||||
|
||||
channel ChannelService
|
||||
|
||||
attachment repository.AttachmentRepository
|
||||
channel repository.ChannelRepository
|
||||
cmember repository.ChannelMemberRepository
|
||||
unread repository.UnreadRepository
|
||||
message repository.MessageRepository
|
||||
@ -47,8 +48,8 @@ type (
|
||||
MessageService interface {
|
||||
With(ctx context.Context) MessageService
|
||||
|
||||
Find(filter *types.MessageFilter) (types.MessageSet, error)
|
||||
FindThreads(filter *types.MessageFilter) (types.MessageSet, error)
|
||||
Find(types.MessageFilter) (types.MessageSet, types.MessageFilter, error)
|
||||
FindThreads(types.MessageFilter) (types.MessageSet, types.MessageFilter, error)
|
||||
|
||||
Create(messages *types.Message) (*types.Message, error)
|
||||
Update(messages *types.Message) (*types.Message, error)
|
||||
@ -82,6 +83,9 @@ var (
|
||||
func Message(ctx context.Context) MessageService {
|
||||
return (&message{
|
||||
logger: DefaultLogger.Named("message"),
|
||||
|
||||
ac: DefaultAccessControl,
|
||||
channel: DefaultChannel,
|
||||
}).With(ctx)
|
||||
}
|
||||
|
||||
@ -91,12 +95,13 @@ func (svc message) With(ctx context.Context) MessageService {
|
||||
db: db,
|
||||
ctx: ctx,
|
||||
logger: svc.logger,
|
||||
ac: DefaultAccessControl,
|
||||
|
||||
ac: svc.ac,
|
||||
channel: svc.channel,
|
||||
|
||||
event: Event(ctx),
|
||||
|
||||
attachment: repository.Attachment(ctx, db),
|
||||
channel: repository.Channel(ctx, db),
|
||||
cmember: repository.ChannelMember(ctx, db),
|
||||
unread: repository.Unread(ctx, db),
|
||||
message: repository.Message(ctx, db),
|
||||
@ -110,46 +115,34 @@ func (svc message) log(ctx context.Context, fields ...zapcore.Field) *zap.Logger
|
||||
return logger.AddRequestID(ctx, svc.logger).With(fields...)
|
||||
}
|
||||
|
||||
func (svc message) Find(filter *types.MessageFilter) (mm types.MessageSet, err error) {
|
||||
filter.CurrentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
|
||||
|
||||
if err = svc.channelAccessCheck(filter.ChannelID...); err != nil {
|
||||
func (svc message) Find(filter types.MessageFilter) (mm types.MessageSet, f types.MessageFilter, err error) {
|
||||
f = filter
|
||||
f.CurrentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
|
||||
if f.ChannelID, err = svc.readableChannels(f); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
mm, err = svc.message.Find(filter)
|
||||
mm, f, err = svc.message.Find(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return
|
||||
}
|
||||
|
||||
if len(filter.ChannelID) == 0 {
|
||||
// If no channel check was done prior message loading,
|
||||
// we do it now, by inspecting the actual payload we got
|
||||
mm = svc.filterMessagesByAccessibleChannels(mm)
|
||||
}
|
||||
|
||||
return mm, svc.preload(mm)
|
||||
return mm, f, svc.preload(mm)
|
||||
}
|
||||
|
||||
func (svc message) FindThreads(filter *types.MessageFilter) (mm types.MessageSet, err error) {
|
||||
filter.CurrentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
|
||||
|
||||
if err = svc.channelAccessCheck(filter.ChannelID...); err != nil {
|
||||
func (svc message) FindThreads(filter types.MessageFilter) (mm types.MessageSet, f types.MessageFilter, err error) {
|
||||
f = filter
|
||||
f.CurrentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
|
||||
if f.ChannelID, err = svc.readableChannels(f); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
mm, err = svc.message.FindThreads(filter)
|
||||
mm, f, err = svc.message.FindThreads(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return
|
||||
}
|
||||
|
||||
if len(filter.ChannelID) == 0 {
|
||||
// If no channel check was done prior message loading,
|
||||
// we do it now, by inspecting the actual payload we got
|
||||
mm = svc.filterMessagesByAccessibleChannels(mm)
|
||||
}
|
||||
|
||||
return mm, svc.preload(mm)
|
||||
return mm, f, svc.preload(mm)
|
||||
}
|
||||
|
||||
func (svc message) CreateWithAvatar(in *types.Message, avatar io.Reader) (*types.Message, error) {
|
||||
@ -157,38 +150,26 @@ func (svc message) CreateWithAvatar(in *types.Message, avatar io.Reader) (*types
|
||||
return svc.Create(in)
|
||||
}
|
||||
|
||||
func (svc message) channelAccessCheck(IDs ...uint64) error {
|
||||
for _, ID := range IDs {
|
||||
if ID > 0 {
|
||||
if ch, err := svc.findChannelByID(ID); err != nil {
|
||||
return err
|
||||
} else if !svc.ac.CanReadChannel(svc.ctx, ch) {
|
||||
return ErrNoPermissions.withStack()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Filter message set by accessible channels
|
||||
func (svc message) filterMessagesByAccessibleChannels(mm types.MessageSet) types.MessageSet {
|
||||
// Remember channels that were already checked.
|
||||
chk := map[uint64]bool{}
|
||||
|
||||
mm, _ = mm.Filter(func(m *types.Message) (b bool, e error) {
|
||||
if !chk[m.ChannelID] {
|
||||
chk[m.ChannelID] = true
|
||||
|
||||
if ch, err := svc.findChannelByID(m.ChannelID); err != nil || !svc.ac.CanReadChannel(svc.ctx, ch) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
// Returns list of readable channels
|
||||
//
|
||||
// Either all (when len(f.ChannelID) == 0) or subset of channel IDs (from f.ChannelID)
|
||||
func (svc message) readableChannels(f types.MessageFilter) ([]uint64, error) {
|
||||
cc, _, err := svc.channel.With(svc.ctx).Find(types.ChannelFilter{
|
||||
CurrentUserID: f.CurrentUserID,
|
||||
ChannelID: f.ChannelID,
|
||||
IncludeDeleted: true,
|
||||
})
|
||||
|
||||
return mm
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(cc) == 0 {
|
||||
// None of the channels requested were returned as accessible
|
||||
return nil, ErrNoPermissions.withStack()
|
||||
}
|
||||
|
||||
return cc.IDs(), nil
|
||||
}
|
||||
|
||||
func (svc message) Create(in *types.Message) (m *types.Message, err error) {
|
||||
|
||||
@ -6,7 +6,6 @@ import (
|
||||
"github.com/jmoiron/sqlx/types"
|
||||
|
||||
"github.com/cortezaproject/corteza-server/pkg/permissions"
|
||||
"github.com/cortezaproject/corteza-server/pkg/rh"
|
||||
)
|
||||
|
||||
type (
|
||||
@ -54,6 +53,8 @@ type (
|
||||
ChannelFilter struct {
|
||||
Query string
|
||||
|
||||
ChannelID []uint64
|
||||
|
||||
// Only return channels accessible by this user
|
||||
CurrentUserID uint64
|
||||
|
||||
@ -61,9 +62,6 @@ type (
|
||||
IncludeDeleted bool
|
||||
|
||||
Sort string `json:"sort"`
|
||||
|
||||
// Standard paging fields & helpers
|
||||
rh.PageFilter
|
||||
}
|
||||
|
||||
ChannelMembershipPolicy string
|
||||
|
||||
@ -59,15 +59,11 @@ func FetchPaged(db *factory.DB, q squirrel.SelectBuilder, page, perPage uint, se
|
||||
q = q.Offset(offset)
|
||||
}
|
||||
|
||||
if sqlSelect, argsSelect, err := q.ToSql(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return db.Select(set, sqlSelect, argsSelect...)
|
||||
}
|
||||
return FetchAll(db, q, set)
|
||||
}
|
||||
|
||||
// FetchPaged fetches paged rows
|
||||
func FetchAll(db *factory.DB, q squirrel.SelectBuilder, set interface{}) error {
|
||||
func FetchAll(db *factory.DB, q squirrel.Sqlizer, set interface{}) error {
|
||||
if sqlSelect, argsSelect, err := q.ToSql(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user