3
0

Refactor messages repo, fix channels

This commit is contained in:
Denis Arh 2019-10-21 11:25:45 +02:00
parent f422867469
commit 4cd4077f2d
9 changed files with 209 additions and 248 deletions

View File

@ -42,6 +42,6 @@ func provisionConfig(ctx context.Context, cmd *cobra.Command, c *cli.Config) err
// Provision ONLY when there are no channels (even if we find delete channels we abort provisioning
func isProvisioned(ctx context.Context) (bool, error) {
_, f, err := service.DefaultChannel.With(ctx).Find(types.ChannelFilter{IncludeDeleted: true})
return f.Count > 0, err
cc, _, err := service.DefaultChannel.With(ctx).Find(types.ChannelFilter{IncludeDeleted: true})
return len(cc) > 0, err
}

View File

@ -123,6 +123,10 @@ func (r channel) Find(filter types.ChannelFilter) (set types.ChannelSet, f types
query = query.Where(squirrel.Eq{"c.deleted_at": nil})
}
if len(f.ChannelID) > 0 {
query = query.Where(squirrel.Eq{"c.id": f.ChannelID})
}
if f.Query != "" {
q := "%" + strings.ToLower(f.Query) + "%"
query = query.Where(squirrel.Like{"LOWER(name)": q})
@ -143,11 +147,7 @@ func (r channel) Find(filter types.ChannelFilter) (set types.ChannelSet, f types
query = query.OrderBy(orderBy...)
}
if f.Count, err = rh.Count(r.db(), query); err != nil || f.Count == 0 {
return
}
return set, f, rh.FetchPaged(r.db(), query, f.Page, f.PerPage, &set)
return set, f, rh.FetchAll(r.db(), query, &set)
}
func (r channel) Create(mod *types.Channel) (*types.Channel, error) {

View File

@ -7,7 +7,6 @@ import (
"time"
"github.com/Masterminds/squirrel"
"github.com/jmoiron/sqlx"
"github.com/titpetric/factory"
"github.com/cortezaproject/corteza-server/messaging/types"
@ -19,8 +18,8 @@ type (
With(ctx context.Context, db *factory.DB) MessageRepository
FindByID(id uint64) (*types.Message, error)
Find(filter *types.MessageFilter) (types.MessageSet, error)
FindThreads(filter *types.MessageFilter) (types.MessageSet, error)
Find(types.MessageFilter) (types.MessageSet, types.MessageFilter, error)
FindThreads(types.MessageFilter) (types.MessageSet, types.MessageFilter, error)
CountFromMessageID(channelID, threadID, messageID uint64) (uint32, error)
LastMessageID(channelID, threadID uint64) (uint64, error)
PrefillThreadParticipants(mm types.MessageSet) error
@ -43,54 +42,6 @@ type (
const (
MESSAGES_MAX_LIMIT = 100
// subquery that filters out all channels that current user has access to as a member
// or via channel type (public channels)
sqlChannelAccess = ` (
SELECT id
FROM messaging_channel c
LEFT OUTER JOIN messaging_channel_member AS m ON (c.id = m.rel_channel)
WHERE rel_user = ?
UNION
SELECT id
FROM messaging_channel c
WHERE c.type = ?
)`
sqlMessageColumns = "id, " +
"COALESCE(type,'') AS type, " +
"message, " +
"rel_user, " +
"rel_channel, " +
"reply_to, " +
"replies, " +
"created_at, " +
"updated_at, " +
"deleted_at"
sqlMessageScope = "deleted_at IS NULL"
sqlMessagesSelect = `SELECT ` + sqlMessageColumns + `
FROM messaging_message
WHERE ` + sqlMessageScope
sqlMessagesThreads = "WITH originals AS (" +
" SELECT id AS original_id " +
" FROM messaging_message " +
" WHERE " + sqlMessageScope +
" AND rel_channel IN " + sqlChannelAccess +
" AND reply_to = 0 " +
" AND replies > 0 " +
// for finding only threads we've created or replied to
" AND (rel_user = ? OR id IN (SELECT DISTINCT reply_to FROM messaging_message WHERE rel_user = ?))" +
" ORDER BY id DESC " +
" LIMIT ? " +
")" +
" SELECT " + sqlMessageColumns +
" FROM messaging_message, originals " +
" WHERE " + sqlMessageScope +
" AND original_id IN (id, reply_to)"
sqlThreadParticipantsByMessageID = "SELECT DISTINCT reply_to, rel_user FROM messaging_message WHERE reply_to IN (?)"
sqlCountFromMessageID = "SELECT COUNT(*) AS count " +
"FROM messaging_message " +
"WHERE rel_channel = ? " +
@ -125,135 +76,154 @@ func (r message) table() string {
return "messaging_message"
}
func (r *message) FindByID(id uint64) (*types.Message, error) {
mod := &types.Message{}
sql := sqlMessagesSelect + " AND id = ?"
return mod, rh.IsFound(r.db().Get(mod, sql, id), mod.ID > 0, ErrMessageNotFound)
func (r message) columns() []string {
return []string{
"m.id",
"COALESCE(m.type,'') AS type",
"m.message",
"m.rel_user",
"m.rel_channel",
"m.reply_to",
"m.replies",
"m.created_at",
"m.updated_at",
"m.deleted_at",
}
}
func (r *message) Find(filter *types.MessageFilter) (types.MessageSet, error) {
r.sanitizeFilter(filter)
func (r message) query() squirrel.SelectBuilder {
return squirrel.
Select(r.columns()...).
From(r.table() + " AS m").
Where(squirrel.Eq{"m.deleted_at": nil})
}
params := make([]interface{}, 0)
rval := make(types.MessageSet, 0)
func (r message) FindByID(id uint64) (*types.Message, error) {
return r.findOneBy(squirrel.Eq{"m.id": id})
}
sql := sqlMessagesSelect
func (r message) findOneBy(cnd squirrel.Sqlizer) (*types.Message, error) {
var (
ch = &types.Message{}
if filter.Query != "" {
sql += " AND message LIKE ?"
params = append(params, "%"+filter.Query+"%")
q = r.query().
Where(cnd)
err = rh.FetchOne(r.db(), q, ch)
)
if err != nil {
return nil, err
} else if ch.ID == 0 {
return nil, ErrMessageNotFound
}
if len(filter.ChannelID) > 0 {
sql += " AND rel_channel IN (" + strings.Repeat(",?", len(filter.ChannelID))[1:] + ")"
for _, id := range filter.ChannelID {
params = append(params, id)
}
return ch, nil
}
func (r message) Find(filter types.MessageFilter) (set types.MessageSet, f types.MessageFilter, err error) {
f = r.sanitizeFilter(filter)
query := r.query()
if f.Query != "" {
q := "%" + strings.ToLower(f.Query) + "%"
query = query.Where(squirrel.Like{"LOWER(m.message)": q})
}
if len(filter.UserID) > 0 {
sql += " AND rel_user IN (" + strings.Repeat(",?", len(filter.UserID))[1:] + ")"
for _, id := range filter.UserID {
params = append(params, id)
}
if len(f.ChannelID) > 0 {
query = query.Where(squirrel.Eq{"m.rel_channel": f.ChannelID})
}
if len(filter.ThreadID) > 0 {
sql += " AND reply_to IN (" + strings.Repeat(",?", len(filter.ThreadID))[1:] + ")"
for _, id := range filter.ThreadID {
params = append(params, id)
}
if len(f.UserID) > 0 {
query = query.Where(squirrel.Eq{"m.rel_user": f.UserID})
}
if len(f.ThreadID) > 0 {
query = query.Where(squirrel.Eq{"m.reply_to": f.ThreadID})
} else {
sql += " AND reply_to = 0 "
query = query.Where(squirrel.Eq{"m.reply_to": 0})
}
if len(filter.Type) > 0 {
sql += " AND type IN (" + strings.Repeat(",?", len(filter.Type))[1:] + ")"
for _, id := range filter.Type {
params = append(params, id)
if f.AttachmentsOnly {
// Override Type filter
f.Type = []string{
types.MessageTypeAttachment.String(),
types.MessageTypeInlineImage.String(),
}
}
if len(f.Type) > 0 {
query = query.Where(squirrel.Eq{"m.type": f.Type})
}
// first, exclusive
if filter.AfterID > 0 {
sql += " AND id > ? "
params = append(params, filter.AfterID)
if f.AfterID > 0 {
query = query.Where(squirrel.Gt{"m.id": f.AfterID})
}
// from, inclusive
if filter.FromID > 0 {
sql += " AND id >= ? "
params = append(params, filter.FromID)
if f.FromID > 0 {
query = query.Where(squirrel.GtOrEq{"m.id": f.FromID})
}
// last, exclusive
if filter.BeforeID > 0 {
sql += " AND id < ? "
params = append(params, filter.BeforeID)
if f.BeforeID > 0 {
query = query.Where(squirrel.Lt{"m.id": f.BeforeID})
}
// to, inclusive
if filter.ToID > 0 {
sql += " AND id <= ? "
params = append(params, filter.ToID)
if f.ToID > 0 {
query = query.Where(squirrel.LtOrEq{"m.id": f.ToID})
}
if filter.BookmarkedOnly || filter.PinnedOnly {
sql += " AND id IN (SELECT rel_message FROM messaging_message_flag WHERE flag = ?) "
if filter.PinnedOnly {
params = append(params, types.MessageFlagPinnedToChannel)
} else {
params = append(params, types.MessageFlagBookmarkedMessage)
if f.BookmarkedOnly || f.PinnedOnly {
flag := types.MessageFlagBookmarkedMessage
if f.PinnedOnly {
flag = types.MessageFlagPinnedToChannel
}
query = query.
Where(squirrel.ConcatExpr("m.id IN(", (messageFlag{}).queryMessagesWithFlags(flag), ")"))
}
if filter.AttachmentsOnly {
sql += " AND type IN (?, ?) "
params = append(
params,
types.MessageTypeAttachment,
types.MessageTypeInlineImage,
)
}
query = query.
OrderBy("id DESC").
Limit(uint64(f.Limit))
sql += " AND rel_channel IN " + sqlChannelAccess
params = append(params, filter.CurrentUserID, types.ChannelTypePublic)
sql += " ORDER BY id DESC"
sql += " LIMIT ? "
params = append(params, filter.Limit)
return rval, r.db().Select(&rval, sql, params...)
return set, f, rh.FetchAll(r.db(), query, &set)
}
func (r *message) FindThreads(filter *types.MessageFilter) (types.MessageSet, error) {
r.sanitizeFilter(filter)
func (r *message) FindThreads(filter types.MessageFilter) (set types.MessageSet, f types.MessageFilter, err error) {
f = r.sanitizeFilter(filter)
params := make([]interface{}, 0)
rval := make(types.MessageSet, 0)
// Selecting first valid (deleted_at IS NULL) messages in threads (replies > 0 && reply_to = 0)
// that belong to filtered channels and we've contributed to (or stated it)
originals := squirrel.
Select("id AS original_id").
From(r.table()).
Where(squirrel.And{
squirrel.Eq{
"deleted_at": nil,
"rel_channel": f.ChannelID,
"reply_to": 0,
},
squirrel.Gt{"replies": 0},
squirrel.Or{
squirrel.Eq{"rel_user": filter.CurrentUserID},
squirrel.Expr(
"id IN (SELECT DISTINCT reply_to FROM messaging_message WHERE rel_user = ?)",
filter.CurrentUserID),
},
})
// for sqlChannelAccess
params = append(params, filter.CurrentUserID, types.ChannelTypePublic)
// Prepare the actual message selector
query := r.query().Join("originals ON (original_id IN (id, reply_to))")
// for finding only threads we've created or replied to
params = append(params, filter.CurrentUserID, filter.CurrentUserID)
// And create CTE
cte := squirrel.ConcatExpr("WITH originals AS (", originals, ") ", query)
// for sqlMessagesThreads
params = append(params, filter.Limit)
sql := sqlMessagesThreads
if len(filter.ChannelID) > 0 {
sql += " AND rel_channel IN (" + strings.Repeat(",?", len(filter.ChannelID))[1:] + ")"
for _, id := range filter.ChannelID {
params = append(params, id)
}
}
return rval, r.db().Select(&rval, sql, params...)
return set, f, rh.FetchAll(r.db(), cte, &set)
}
func (r *message) CountFromMessageID(channelID, threadID, lastReadMessageID uint64) (uint32, error) {
@ -281,11 +251,11 @@ func (r *message) LastMessageID(channelID, threadID uint64) (uint64, error) {
)
}
func (r *message) PrefillThreadParticipants(mm types.MessageSet) error {
var rval = []struct {
func (r *message) PrefillThreadParticipants(mm types.MessageSet) (err error) {
var rval []struct {
ReplyTo uint64 `db:"reply_to"`
UserID uint64 `db:"rel_user"`
}{}
}
// Filter out only relevant messages -- ones with replies
mm, _ = mm.Filter(func(m *types.Message) (b bool, e error) {
@ -296,27 +266,29 @@ func (r *message) PrefillThreadParticipants(mm types.MessageSet) error {
return nil
}
if sql, args, err := sqlx.In(sqlThreadParticipantsByMessageID, mm.IDs()); err != nil {
return err
} else if err = r.db().Select(&rval, sql, args...); err != nil {
return err
} else {
for _, p := range rval {
mm.FindByID(p.ReplyTo).RepliesFrom = append(mm.FindByID(p.ReplyTo).RepliesFrom, p.UserID)
}
query := squirrel.
Select("reply_to", "rel_user").
From(r.table()).
Where(squirrel.Eq{"reply_to": mm.IDs()})
err = rh.FetchAll(r.db(), query, &rval)
if err != nil {
return
}
for _, p := range rval {
mm.FindByID(p.ReplyTo).RepliesFrom = append(mm.FindByID(p.ReplyTo).RepliesFrom, p.UserID)
}
return nil
}
func (r *message) sanitizeFilter(filter *types.MessageFilter) {
if filter == nil {
filter = &types.MessageFilter{}
func (r *message) sanitizeFilter(f types.MessageFilter) types.MessageFilter {
if f.Limit == 0 || f.Limit > MESSAGES_MAX_LIMIT {
f.Limit = MESSAGES_MAX_LIMIT
}
if filter.Limit == 0 || filter.Limit > MESSAGES_MAX_LIMIT {
filter.Limit = MESSAGES_MAX_LIMIT
}
return f
}
func (r *message) Create(mod *types.Message) (*types.Message, error) {
@ -332,7 +304,7 @@ func (r *message) Update(mod *types.Message) (*types.Message, error) {
return mod, r.db().Replace("messaging_message", mod)
}
func (svc *message) BindAvatar(in *types.Message, avatar io.Reader) (*types.Message, error) {
func (r *message) BindAvatar(in *types.Message, avatar io.Reader) (*types.Message, error) {
// @todo: implement setting avatar on a message
in.Meta.Avatar = ""
return in, nil

View File

@ -56,6 +56,13 @@ func (r messageFlag) query() squirrel.SelectBuilder {
From(r.table() + " AS mf")
}
func (r messageFlag) queryMessagesWithFlags(flags ...string) squirrel.SelectBuilder {
return squirrel.
Select("mf.rel_message").
From(r.table() + " AS mf").
Where(squirrel.Eq{"flag": flags})
}
func (r messageFlag) With(ctx context.Context, db *factory.DB) MessageFlagRepository {
return &messageFlag{
repository: r.repository.With(ctx, db),

View File

@ -28,7 +28,7 @@ func (Search) New() *Search {
}
func (ctrl *Search) Messages(ctx context.Context, r *request.SearchMessages) (interface{}, error) {
return ctrl.wrapSet(ctx)(ctrl.svc.msg.With(ctx).Find(&types.MessageFilter{
mm, _, err := ctrl.svc.msg.With(ctx).Find(types.MessageFilter{
ChannelID: payload.ParseUInt64s(r.ChannelID),
AfterID: r.AfterMessageID,
BeforeID: r.BeforeMessageID,
@ -42,24 +42,27 @@ func (ctrl *Search) Messages(ctx context.Context, r *request.SearchMessages) (in
Limit: r.Limit,
Query: r.Query,
}))
})
return ctrl.wrapSet(ctx, mm, err)
}
func (ctrl *Search) Threads(ctx context.Context, r *request.SearchThreads) (interface{}, error) {
return ctrl.wrapSet(ctx)(ctrl.svc.msg.With(ctx).FindThreads(&types.MessageFilter{
mm, _, err := ctrl.svc.msg.With(ctx).FindThreads(types.MessageFilter{
ChannelID: payload.ParseUInt64s(r.ChannelID),
Limit: r.Limit,
Query: r.Query,
}))
})
return ctrl.wrapSet(ctx, mm, err)
}
func (ctrl *Search) wrapSet(ctx context.Context) func(mm types.MessageSet, err error) (*outgoing.MessageSet, error) {
return func(mm types.MessageSet, err error) (*outgoing.MessageSet, error) {
if err != nil {
return nil, err
} else {
return payload.Messages(ctx, mm), nil
}
func (ctrl *Search) wrapSet(ctx context.Context, mm types.MessageSet, err error) (*outgoing.MessageSet, error) {
if err != nil {
return nil, err
} else {
return payload.Messages(ctx, mm), nil
}
}

View File

@ -133,6 +133,10 @@ func (svc *channel) Find(filter types.ChannelFilter) (set types.ChannelSet, f ty
err = svc.preloadExtras(set)
}
set, err = set.Filter(func(c *types.Channel) (b bool, e error) {
return svc.ac.CanReadChannel(svc.ctx, c), nil
})
return
}

View File

@ -24,8 +24,9 @@ type (
logger *zap.Logger
ac messageAccessController
channel ChannelService
attachment repository.AttachmentRepository
channel repository.ChannelRepository
cmember repository.ChannelMemberRepository
unread repository.UnreadRepository
message repository.MessageRepository
@ -47,8 +48,8 @@ type (
MessageService interface {
With(ctx context.Context) MessageService
Find(filter *types.MessageFilter) (types.MessageSet, error)
FindThreads(filter *types.MessageFilter) (types.MessageSet, error)
Find(types.MessageFilter) (types.MessageSet, types.MessageFilter, error)
FindThreads(types.MessageFilter) (types.MessageSet, types.MessageFilter, error)
Create(messages *types.Message) (*types.Message, error)
Update(messages *types.Message) (*types.Message, error)
@ -82,6 +83,9 @@ var (
func Message(ctx context.Context) MessageService {
return (&message{
logger: DefaultLogger.Named("message"),
ac: DefaultAccessControl,
channel: DefaultChannel,
}).With(ctx)
}
@ -91,12 +95,13 @@ func (svc message) With(ctx context.Context) MessageService {
db: db,
ctx: ctx,
logger: svc.logger,
ac: DefaultAccessControl,
ac: svc.ac,
channel: svc.channel,
event: Event(ctx),
attachment: repository.Attachment(ctx, db),
channel: repository.Channel(ctx, db),
cmember: repository.ChannelMember(ctx, db),
unread: repository.Unread(ctx, db),
message: repository.Message(ctx, db),
@ -110,46 +115,34 @@ func (svc message) log(ctx context.Context, fields ...zapcore.Field) *zap.Logger
return logger.AddRequestID(ctx, svc.logger).With(fields...)
}
func (svc message) Find(filter *types.MessageFilter) (mm types.MessageSet, err error) {
filter.CurrentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
if err = svc.channelAccessCheck(filter.ChannelID...); err != nil {
func (svc message) Find(filter types.MessageFilter) (mm types.MessageSet, f types.MessageFilter, err error) {
f = filter
f.CurrentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
if f.ChannelID, err = svc.readableChannels(f); err != nil {
return
}
mm, err = svc.message.Find(filter)
mm, f, err = svc.message.Find(f)
if err != nil {
return nil, err
return
}
if len(filter.ChannelID) == 0 {
// If no channel check was done prior message loading,
// we do it now, by inspecting the actual payload we got
mm = svc.filterMessagesByAccessibleChannels(mm)
}
return mm, svc.preload(mm)
return mm, f, svc.preload(mm)
}
func (svc message) FindThreads(filter *types.MessageFilter) (mm types.MessageSet, err error) {
filter.CurrentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
if err = svc.channelAccessCheck(filter.ChannelID...); err != nil {
func (svc message) FindThreads(filter types.MessageFilter) (mm types.MessageSet, f types.MessageFilter, err error) {
f = filter
f.CurrentUserID = auth.GetIdentityFromContext(svc.ctx).Identity()
if f.ChannelID, err = svc.readableChannels(f); err != nil {
return
}
mm, err = svc.message.FindThreads(filter)
mm, f, err = svc.message.FindThreads(f)
if err != nil {
return nil, err
return
}
if len(filter.ChannelID) == 0 {
// If no channel check was done prior message loading,
// we do it now, by inspecting the actual payload we got
mm = svc.filterMessagesByAccessibleChannels(mm)
}
return mm, svc.preload(mm)
return mm, f, svc.preload(mm)
}
func (svc message) CreateWithAvatar(in *types.Message, avatar io.Reader) (*types.Message, error) {
@ -157,38 +150,26 @@ func (svc message) CreateWithAvatar(in *types.Message, avatar io.Reader) (*types
return svc.Create(in)
}
func (svc message) channelAccessCheck(IDs ...uint64) error {
for _, ID := range IDs {
if ID > 0 {
if ch, err := svc.findChannelByID(ID); err != nil {
return err
} else if !svc.ac.CanReadChannel(svc.ctx, ch) {
return ErrNoPermissions.withStack()
}
}
}
return nil
}
// Filter message set by accessible channels
func (svc message) filterMessagesByAccessibleChannels(mm types.MessageSet) types.MessageSet {
// Remember channels that were already checked.
chk := map[uint64]bool{}
mm, _ = mm.Filter(func(m *types.Message) (b bool, e error) {
if !chk[m.ChannelID] {
chk[m.ChannelID] = true
if ch, err := svc.findChannelByID(m.ChannelID); err != nil || !svc.ac.CanReadChannel(svc.ctx, ch) {
return false, nil
}
}
return true, nil
// Returns list of readable channels
//
// Either all (when len(f.ChannelID) == 0) or subset of channel IDs (from f.ChannelID)
func (svc message) readableChannels(f types.MessageFilter) ([]uint64, error) {
cc, _, err := svc.channel.With(svc.ctx).Find(types.ChannelFilter{
CurrentUserID: f.CurrentUserID,
ChannelID: f.ChannelID,
IncludeDeleted: true,
})
return mm
if err != nil {
return nil, err
}
if len(cc) == 0 {
// None of the channels requested were returned as accessible
return nil, ErrNoPermissions.withStack()
}
return cc.IDs(), nil
}
func (svc message) Create(in *types.Message) (m *types.Message, err error) {

View File

@ -6,7 +6,6 @@ import (
"github.com/jmoiron/sqlx/types"
"github.com/cortezaproject/corteza-server/pkg/permissions"
"github.com/cortezaproject/corteza-server/pkg/rh"
)
type (
@ -54,6 +53,8 @@ type (
ChannelFilter struct {
Query string
ChannelID []uint64
// Only return channels accessible by this user
CurrentUserID uint64
@ -61,9 +62,6 @@ type (
IncludeDeleted bool
Sort string `json:"sort"`
// Standard paging fields & helpers
rh.PageFilter
}
ChannelMembershipPolicy string

View File

@ -59,15 +59,11 @@ func FetchPaged(db *factory.DB, q squirrel.SelectBuilder, page, perPage uint, se
q = q.Offset(offset)
}
if sqlSelect, argsSelect, err := q.ToSql(); err != nil {
return err
} else {
return db.Select(set, sqlSelect, argsSelect...)
}
return FetchAll(db, q, set)
}
// FetchPaged fetches paged rows
func FetchAll(db *factory.DB, q squirrel.SelectBuilder, set interface{}) error {
func FetchAll(db *factory.DB, q squirrel.Sqlizer, set interface{}) error {
if sqlSelect, argsSelect, err := q.ToSql(); err != nil {
return err
} else {